In [1]:
!pip install torchvision --quiet
!pip install bitsandbytes-cuda110 bitsandbytes --quiet
!pip install huggingface_hub --quiet
!pip install transformers --quiet
!pip install accelerate>=0.26.0
# !pip install pillow --quiet

In [2]:
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
import torch
import json
import random
from PIL import Image
import os
import argparse

In [9]:
# Run this first time
# from huggingface_hub import hf_hub_download
# hf_hub_download(repo_id="google/spiqa", filename="test-A/SPIQA_testA.json", repo_type="dataset", local_dir='datasets/test-A/')
# hf_hub_download(repo_id="google/spiqa", filename="test-A/SPIQA_testA_Images_224px.zip", repo_type="dataset", local_dir='datasets/test-A/')

SPIQA_testA.json:   0%|          | 0.00/778k [00:00<?, ?B/s]

SPIQA_testA_Images_224px.zip:   0%|          | 0.00/47.1M [00:00<?, ?B/s]

'datasets/test-A/test-A/SPIQA_testA_Images_224px.zip'

In [10]:
# Run this first time
# !unzip datasets/test-A/test-A/SPIQA_testA_Images_224px.zip -d datasets/test-A/

Archive:  datasets/test-A/test-A/SPIQA_testA_Images_224px.zip
   creating: datasets/test-A/SPIQA_testA_Images_224px/
   creating: datasets/test-A/SPIQA_testA_Images_224px/1809.03449v3/
  inflating: datasets/test-A/SPIQA_testA_Images_224px/1809.03449v3/1809.03449v3-Table3-1.png  
  inflating: datasets/test-A/SPIQA_testA_Images_224px/1809.03449v3/1809.03449v3-Table1-1.png  
  inflating: datasets/test-A/SPIQA_testA_Images_224px/1809.03449v3/1809.03449v3-Figure3-1.png  
  inflating: datasets/test-A/SPIQA_testA_Images_224px/1809.03449v3/1809.03449v3-Figure1-1.png  
  inflating: datasets/test-A/SPIQA_testA_Images_224px/1809.03449v3/1809.03449v3-Figure4-1.png  
  inflating: datasets/test-A/SPIQA_testA_Images_224px/1809.03449v3/1809.03449v3-Figure2-1.png  
  inflating: datasets/test-A/SPIQA_testA_Images_224px/1809.03449v3/1809.03449v3-Table2-1.png  
   creating: datasets/test-A/SPIQA_testA_Images_224px/1710.01507v4/
  inflating: datasets/test-A/SPIQA_testA_Images_224px/1710.01507v4/1710.01507v

In [3]:
# Don't forget to move JSON file outside 
testA_filtered_annotations_path = 'datasets/test-A/SPIQA_testA.json'
with open(testA_filtered_annotations_path, "r") as f:
  testA_data = json.load(f)

In [4]:
def prepare_inputs(paper, question_idx):
    all_figures = list(paper['all_figures'].keys())
    referred_figures = [paper['qa'][question_idx]['reference']]
    answer = paper['qa'][question_idx]['answer']

    referred_figures_captions = []
    for figure in referred_figures:
        referred_figures_captions.append(paper['all_figures'][figure]['caption'])

    return answer, all_figures, referred_figures, referred_figures_captions


_PROMPT_1 = "Caption: <caption>. Is the input image and caption helpful to answer the following question. Answer in one word - Yes or No. Question: <question>. "
_PROMPT_2 = "Caption: <caption>. Please provide a brief answer to the following question after looking into the input image and caption. Question: <question>."


def infer_instructblip(testA_data, args):

  if args['image_resolution'] == 224:
      _testA_IMAGE_ROOT = "datasets/test-A/SPIQA_testA_Images_224px"
  else:
      raise NotImplementedError

  processor = InstructBlipProcessor.from_pretrained(args['model_id'])
  model = InstructBlipForConditionalGeneration.from_pretrained(args['model_id'], load_in_4bit=True, torch_dtype=torch.float16)

  _RESPONSE_ROOT = args['response_root']
  os.makedirs(_RESPONSE_ROOT, exist_ok=True)

  for paper_id, paper in sorted(testA_data.items(), key=lambda x: random.random()):
    if os.path.exists(os.path.join(_RESPONSE_ROOT, str(paper_id) + '_response.json')):
      continue
    response_paper = {}

    try:
      for question_idx, qa in enumerate(paper['qa']):

        question = qa['question']

        answer, all_figures, referred_figures, referred_figures_captions = prepare_inputs(paper, question_idx)

        answer_dict = {}

        for _idx, figure in enumerate(referred_figures):
          
          caption = referred_figures_captions[_idx]
              
          instructblip_prompt_1 = _PROMPT_1.replace('<caption>', caption).replace('<question>', question)
          instructblip_prompt_2 = _PROMPT_2.replace('<caption>', caption).replace('<question>', question)
            
          image = Image.open(os.path.join(_testA_IMAGE_ROOT, paper['paper_id'], figure))
          image = image.resize((args['image_resolution'], args['image_resolution']))
            
          inputs_1 = processor(images=image, text=instructblip_prompt_1, return_tensors="pt").to(device="cuda", dtype=torch.float16)

          # autoregressively generate an answer
          outputs_1 = model.generate(
                  **inputs_1,
                  num_beams=1,
                  max_new_tokens=50,
                  min_length=1,
                  top_p=0.9,
                  repetition_penalty=1.5,
                  length_penalty=1.0,
                  temperature=1,
          )
          outputs_1[outputs_1 == 0] = 2 # this line can be removed once https://github.com/huggingface/transformers/pull/24492 is fixed
          generated_text_1 = processor.batch_decode(outputs_1, skip_special_tokens=True)[0].strip()

          inputs_2 = processor(images=image, text=instructblip_prompt_2, return_tensors="pt").to(device="cuda", dtype=torch.float16)

          # autoregressively generate an answer
          outputs_2 = model.generate(
                  **inputs_2,
                  num_beams=5,
                  max_new_tokens=256,
                  min_length=1,
                  top_p=0.9,
                  repetition_penalty=1.5,
                  length_penalty=1.0,
                  temperature=1,
          )
          outputs_2[outputs_2 == 0] = 2 # this line can be removed once https://github.com/huggingface/transformers/pull/24492 is fixed
          generated_text_2 = processor.batch_decode(outputs_2, skip_special_tokens=True)[0].strip()
          

          answer_dict.update({figure: [generated_text_1, generated_text_2]})
      
          print(answer_dict[figure])
          print('-----------------')

        response_paper.update({question_idx: {'question': question, 'response': answer_dict,
                                              'referred_figures_names': referred_figures, 'answer': answer}})   

    except Exception as e:
        print(f'Error in generating: {e}')
        processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
        model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", load_in_4bit=True, torch_dtype=torch.float16)
        continue

    with open(os.path.join(_RESPONSE_ROOT, str(paper_id) + '_response.json'), 'w') as f:
      json.dump(response_paper, f)

In [None]:
args = {"model_id":"Salesforce/instructblip-vicuna-7b",
        "response_root": "response",
        "image_resolution": 224}

infer_instructblip(testA_data, args)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
`low_cpu_mem_usage` was None, now default to True since model is quantized.


model.safetensors.index.json:   0%|          | 0.00/104k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/9.90G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/9.96G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/9.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.88G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]