In [1]:
from transformers import AutoTokenizer, BitsAndBytesConfig
from llava.model import LlavaLlamaForCausalLM
import torch

[2023-12-06 19:52:28,865] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
%%time
model_path = "4bit/llava-v1.5-13b-3GB"
kwargs = {"device_map": "auto"}
kwargs['load_in_4bit'] = True
kwargs['quantization_config'] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4'
)
model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)

vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
    vision_tower.load_model()
vision_tower.to(device='cuda')
image_processor = vision_tower.image_processor

Loading checkpoint shards:   0%|          | 0/9 [00:00<?, ?it/s]

CPU times: user 18.6 s, sys: 29.9 s, total: 48.5 s
Wall time: 3min 55s


In [3]:
import os
import requests
from PIL import Image
from io import BytesIO
from llava.conversation import conv_templates, SeparatorStyle
from llava.utils import disable_torch_init
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from transformers import TextStreamer

def caption_image(image_file, prompt):
    if image_file.startswith('http') or image_file.startswith('https'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    disable_torch_init()
    conv_mode = "llava_v0"
    conv = conv_templates[conv_mode].copy()
    roles = conv.roles
    image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()
    inp = f"{roles[0]}: {prompt}"
    inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
    conv.append_message(conv.roles[0], inp)
    conv.append_message(conv.roles[1], None)
    raw_prompt = conv.get_prompt()
    input_ids = tokenizer_image_token(raw_prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    keywords = [stop_str]
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
    with torch.inference_mode():
      output_ids = model.generate(input_ids, images=image_tensor, do_sample=True, temperature=0.2, 
                                  max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria])
    outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
    conv.messages[-1][-1] = outputs
    output = outputs.rsplit('</s>', 1)[0]
    return image, output

In [4]:
#image, output = caption_image(f'https://llava-vl.github.io/static/images/view.jpg', 'Describe the image and color details.', image_processor)
#print(output)
#image

In [5]:
%%time
#image, output = caption_image('https://llava-vl.github.io/static/images/view.jpg', 'Would it be a good idea to swim in this lake? Answer with 1 word.', image_processor, tokenizer)
#print(output)
#image

CPU times: user 2 µs, sys: 2 µs, total: 4 µs
Wall time: 8.58 µs


In [6]:
import json

In [7]:
%%time
#all_questions_filepath = "/scratch/chaijy_root/chaijy0/heyandy/gqa_eval/dev_val_questions.json"
all_questions_filepath = "/scratch/chaijy_root/chaijy0/heyandy/gqa_eval/val_balanced_questions.json"
with open(all_questions_filepath, 'r') as f:
  all_questions = json.load(f)

CPU times: user 3.3 s, sys: 311 ms, total: 3.61 s
Wall time: 3.75 s


In [8]:
predictions_file = '/scratch/chaijy_root/chaijy0/heyandy/gqa_eval/val_predictions.json'
if os.path.exists(predictions_file): 
    with open(predictions_file, 'r') as f:
        predictions_array = json.load(f)
else:
    predictions_array = []  # create a new dictionary for your data

In [9]:
predictions = {}
if predictions_array:
    for entry in predictions_array:
        question_id = entry['questionId']
        prediction = entry['prediction']
        predictions[question_id] = prediction

In [15]:
%%time
from tqdm import tqdm
for k,v in tqdm(all_questions.items()):
    if k not in predictions:
        question_id = k
        image_id = v['imageId']
        question = v['question']
        #print(question)
        prompt = f'{question} Answer with one word.'
        _, output = caption_image(f'/scratch/chaijy_root/chaijy0/heyandy/gqa_eval/images/{image_id}.jpg', prompt)
        predictions[question_id] = output.lower()
        #print(output)

  0%|          | 446/132062 [03:07<15:23:28,  2.38it/s]


KeyboardInterrupt: 

In [16]:
for question_id, prediction in predictions.items():
    predictions_array.append({'questionId' : question_id, "prediction" : prediction})

In [17]:
with open(predictions_file, 'w') as f:
    json.dump(predictions_array, f)