[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/trainer/blob/main/captioner_llava.ipynb)

In [None]:
%cd /content
!git clone -b v1.0 https://github.com/camenduru/LLaVA
%cd /content/LLaVA

!pip install -e .

In [None]:
from transformers import AutoTokenizer, BitsAndBytesConfig
from llava.model import LlavaLlamaForCausalLM
import torch
model_path = "4bit/llava-v1.5-13b-3GB"
kwargs = {"device_map": "auto"}
kwargs['load_in_4bit'] = True
kwargs['quantization_config'] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4'
)
model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)

vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
    vision_tower.load_model()
vision_tower.to(device='cuda')
image_processor = vision_tower.image_processor

In [None]:
!huggingface-cli login --token
!mkdir /content/images
!wget https://huggingface.co/camenduru/polaroid/resolve/main/style_name_fix.zip
!unzip style_name_fix.zip -d /content/images

In [None]:
import os
import requests
from PIL import Image
from io import BytesIO
from llava.conversation import conv_templates, SeparatorStyle
from llava.utils import disable_torch_init
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from transformers import TextStreamer
import torch

def load_image(image_file):
    if image_file.startswith('http') or image_file.startswith('https'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    return image

file_names = os.listdir('/content/images')
sorted_file_names = sorted(file_names)
for file_name in sorted_file_names:
  image = None
  input_ids = None
  image_tensor = None
  stopping_criteria = None
  stop_str = None
  keywords = None
  conv = None
  try:
      print(f'/content/images/{file_name}')
      disable_torch_init()
      conv_mode = "llava_v0"
      conv = conv_templates[conv_mode].copy()
      roles = conv.roles
      # image = load_image("https://llava-vl.github.io/static/images/view.jpg")
      image = load_image(f'/content/images/{file_name}')
      image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()
      inp = f"{roles[0]}: Tag the image with words."
      inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
      conv.append_message(conv.roles[0], inp)
      conv.append_message(conv.roles[1], None)
      prompt = conv.get_prompt()
      input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
      stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
      keywords = [stop_str]
      stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
      # streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
      with torch.inference_mode():
        output_ids = model.generate(input_ids, images=image_tensor, do_sample=True, temperature=0.2, 
                                    max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria])
      outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
      conv.messages[-1][-1] = outputs
      print(outputs.rsplit('</s>', 1)[0])
  except Exception as e:
    print(f"Error processing {file_name}: {str(e)}")
    continue