In [4]:
# For T5 based model
from MIC.model.instructblip import InstructBlipConfig, InstructBlipModel, InstructBlipPreTrainedModel,InstructBlipForConditionalGeneration,InstructBlipProcessor
import datasets
import json
import transformers
from PIL import Image
import torch
import os

os.environ['HF_HOME'] = '/scratch3/wenyan/cache'
DEVICE = "cuda:0"

In [12]:
## load mutli-image vqa questions
data_dir = "/scratch3/wenyan/data/foodie"
question_file = os.path.join(data_dir, "multi-image-vqa-ann.json")
# mivqa = datasets.load_dataset('json', data_files=question_file)['train']
with open(question_file, 'r') as f:
    mivqa = json.load(f)

{'question': '哪一道菜中含有土豆？',
 'choices': '',
 'user-answer': '3',
 'gt': '3',
 'rationale': 'D中含有土豆，其他均没有',
 'bad_q': 'No',
 'num_hops': 'single',
 'question_id': '316ee2c7ac04c1a3e636175a8d42c6ec_2',
 'session_id': 'val-a6b03c47b0cb48b1a7c03504b4a7bf4c',
 'ann_group': '湘',
 'images': ['14456664_all_202404292352223293/104_IMG_0277.jpeg',
  '14456664_all_202404292352223293/132_IMG_20220702_130156.jpg',
  '14456664_all_202404292352223293/147_IMG_20190225_184723.jpg',
  '14456664_all_202404292352223293/151_IMG_20240414_200337.jpg']}

In [11]:
model_type="instructblip"
# use large model
# model_ckpt="BleachNick/MMICL-Instructblip-T5-xxl"
# processor_ckpt = "Salesforce/instructblip-flan-t5-xxl"

# use small model
model_ckpt = "BleachNick/MMICL-Instructblip-T5-xl"
processor_ckpt = "Salesforce/instructblip-flan-t5-xl"

config = InstructBlipConfig.from_pretrained(model_ckpt)

if 'instructblip' in model_type:
    model = InstructBlipForConditionalGeneration.from_pretrained(
        model_ckpt,
        config=config, cache_dir=os.environ["HF_HOME"]).to(DEVICE,dtype=torch.bfloat16) 

processor = InstructBlipProcessor.from_pretrained(
    processor_ckpt,cache_dir=os.environ["HF_HOME"]
)

image_palceholder="图"
sp = [image_palceholder]+[f"<image{i}>" for i in range(20)]

sp = sp+processor.tokenizer.additional_special_tokens[len(sp):]
processor.tokenizer.add_special_tokens({'additional_special_tokens':sp})
if model.qformer.embeddings.word_embeddings.weight.shape[0] != len(processor.qformer_tokenizer):
    model.qformer.resize_token_embeddings(len(processor.qformer_tokenizer))
replace_token="".join(32*[image_palceholder])



In [14]:
def load_images(question):
    image_choices = question["images"]
    return [Image.open(os.path.join(data_dir, image_choice)) for image_choice in image_choices]

In [29]:
def format_question(sample, is_shot=False):
    question = sample["question"]
    if is_shot:
        answer = sample["gt"]
        prompt = [f'Use the image 0: <image0>{replace_token}, image 1: <image1>{replace_token}, image 2: <image2>{replace_token}, and image 3: <image3>{replace_token} as visual choices to answer the question. Question: {question} Answer: image {answer}\n']
    else:
        prompt = [f'Use the image 0: <image0>{replace_token}, image 1: <image1>{replace_token}, image 2: <image2>{replace_token}, and image 3: <image3>{replace_token} as visual choices to answer the question. Question: {question} Answer: image ']
    return prompt

def format_nshot_question(questions, cur_question, n=1):
    final_prompt = f''
    for i in range(n):
        question = questions[i]
        prompt = format_question(question, is_shot=True)
        final_prompt += prompt[0]
    
    final_prompt += format_question(cur_question)[0]
    return [final_prompt]


## 1 shot
n_shot = 1
n_shot_images = []
for i in range(n_shot):
    n_shot_images.extend(load_images(mivqa[i]))

prompt = format_nshot_question(mivqa, mivqa[n_shot], n_shot)
question_images = load_images(mivqa[n_shot])
print(prompt)

# images = load_images(mivqa[0])
# prompt = format_question(mivqa[0]["question"])

n_shot_images.extend(question_images)
print(len(n_shot_images))
# max_length = max([len(f) for f in images ])
# print(max_length)
images = n_shot_images
inputs = processor(images=images, text=prompt[0], return_tensors="pt")

inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
inputs['img_mask'] = torch.tensor([[1 for i in range(len(images))]])
inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)

inputs = inputs.to(DEVICE)
outputs = model.generate(
        pixel_values = inputs['pixel_values'],
        input_ids = inputs['input_ids'],
        attention_mask = inputs['attention_mask'],
        img_mask = inputs['img_mask'],
        do_sample=False,
        max_length=160,
        min_length=50,
        num_beams=8,
        set_min_padding_size =False,
)
generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
print("prompt: ", prompt)
print(generated_text)

['Use the image 0: <image0>图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图, image 1: <image1>图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图, image 2: <image2>图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图, and image 3: <image3>图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图 as visual choices to answer the question. Question: 哪一道菜中含有土豆？ Answer: image 3\nUse the image 0: <image0>图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图, image 1: <image1>图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图, image 2: <image2>图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图, and image 3: <image3>图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图 as visual choices to answer the question. Question: 下面哪道菜有家禽肉? Answer: image ']
8


prompt:  ['Use the image 0: <image0>图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图, image 1: <image1>图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图, image 2: <image2>图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图, and image 3: <image3>图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图 as visual choices to answer the question. Question: 哪一道菜中含有土豆？ Answer: image 3\nUse the image 0: <image0>图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图, image 1: <image1>图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图, image 2: <image2>图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图, and image 3: <image3>图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图图 as visual choices to answer the question. Question: 下面哪道菜有家禽肉? Answer: image ']
chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chinese chines


In [None]:
image = Image.open ("images/flamingo_photo.png")
image1 = Image.open ("images/flamingo_cartoon.png")
image2 = Image.open ("images/flamingo_3d.png")

images = [image,image1,image2]
prompt = [f'Use the image 0: <image0>{replace_token}, image 1: <image1>{replace_token} and image 2: <image2>{replace_token} as a visual aids to help you answer the question. Question: Give the reason why image 0, image 1 and image 2 are different? Answer:']

prompt = " ".join(prompt)

inputs = processor(images=images, text=prompt, return_tensors="pt")

inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
inputs['img_mask'] = torch.tensor([[1 for i in range(len(images))]])
inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)

inputs = inputs.to(DEVICE)
outputs = model.generate(
        pixel_values = inputs['pixel_values'],
        input_ids = inputs['input_ids'],
        attention_mask = inputs['attention_mask'],
        img_mask = inputs['img_mask'],
        do_sample=False,
        max_length=80,
        min_length=50,
        num_beams=8,
        set_min_padding_size =False,
)
generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
print(generated_text)


In [4]:
from torch.nn.functional import pad
def padd_images( image, max_length):
    image = torch.tensor(image)
    mask = torch.zeros(max_length).bool()
    pad_len = max_length - image.shape[0]
    mask[:image.shape[0]] = True
    image = pad(image,(0,0,0,0,0,0,0,pad_len)) # padding behind the first dim
    return image,mask

image = Image.open ("images/chinchilla.png")
image1 = Image.open ("images/shiba.png")
image2 = Image.open ("images/flamingo.png")

image4 = Image.open ("images/shiba.png")
image5 = Image.open ("images/flamingo.png")

images =[ 
[image,image1,image2], [image4 ,image5]
]

prompt = [f'image 0 is <image0>{replace_token},image 1 is <image1>{replace_token},image 2 is <image2>{replace_token}. Question: <image0> is a chinchilla. They are mainly found in Chile.\n Question: <image1> is a shiba. They are very popular in Japan.\nQuestion: image 2 is',
f'image 0 is <image0>{replace_token}, image 0 is a shiba. They are very popular in Japan.\n image 1 is <image1>{replace_token}, image 1 is a',
]

max_image_length = max([len(f) for f in images ])


inputs = processor( text=prompt, return_tensors="pt",padding=True) 

pixel_values= [ processor(images=img, return_tensors="pt")['pixel_values'] for img in images]

image_list=[]
mask_list= []
for img in pixel_values:
    image,img_mask = padd_images(img,max_image_length)
    image_list.append(image)
    mask_list.append(img_mask)
inputs['pixel_values'] = torch.stack(image_list).to(torch.bfloat16)
inputs['img_mask'] = torch.stack(mask_list)
inputs = inputs.to(DEVICE)
outputs = model.generate(
        pixel_values = inputs['pixel_values'],
        input_ids = inputs['input_ids'],
        attention_mask = inputs['attention_mask'],
        img_mask = inputs['img_mask'],
        do_sample=False,
        max_length=50,
        min_length=1,
        set_min_padding_size =False,
)
generated_text = processor.batch_decode(outputs, skip_special_tokens=True)
print(generated_text)



  image = torch.tensor(image)


['a flamingo.', 'flamingo standing in the water.']


In [5]:
inputs['pixel_values'].shape

torch.Size([2, 3, 3, 224, 224])

In [6]:
image = Image.open ("images/cal_num1.png")
image1 = Image.open ("images/cal_num2.png")
image2 = Image.open ("images/cal_num3.png")
images = [image,image1,image2]

prompt = [f'Use the image 0: <image0>{replace_token},image 1: <image1>{replace_token} and image 2: <image2>{replace_token} as a visual aid to help you calculate the equation accurately. image 0 is 2+1=3.\nimage 1 is 5+6=11.\nimage 2 is"']
prompt = " ".join(prompt)

inputs = processor(images=images, text=prompt, return_tensors="pt")

inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
inputs['img_mask'] = torch.tensor([[1 for i in range(len(images))]])
inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)

inputs = inputs.to(DEVICE)
outputs = model.generate(
        pixel_values = inputs['pixel_values'],
        input_ids = inputs['input_ids'],
        attention_mask = inputs['attention_mask'],
        img_mask = inputs['img_mask'],
        do_sample=False,
        max_length=50,
        min_length=1,
        set_min_padding_size =False,
)
generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
print(generated_text)


3x6


In [42]:

image = Image.open ("images/chinchilla.png")
image1 = Image.open ("images/shiba.png")
image2 = Image.open ("images/flamingo.png")
images = [image,image1,image2]
images = [image,image1,image2]
prompt = [f'image 0 is <image0>{replace_token},image 1 is <image1>{replace_token},image 2 is <image2>{replace_token}. Question: <image0> is a chinchilla. They are mainly found in Chile.\n Question: <image1> is a shiba. They are very popular in Japan.\nQuestion: image 2 is']

prompt = " ".join(prompt)

inputs = processor(images=images, text=prompt, return_tensors="pt")

inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
inputs['img_mask'] = torch.tensor([[1 for i in range(len(images))]])
inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)

inputs = inputs.to('cuda:0')
outputs = model.generate(
        pixel_values = inputs['pixel_values'],
        input_ids = inputs['input_ids'],
        attention_mask = inputs['attention_mask'],
        img_mask = inputs['img_mask'],
        do_sample=False,
        max_length=50,
        min_length=1,
        set_min_padding_size =False,
)
generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
print(generated_text)


a flamingo. They are native to South America and are known for their bright red plumage and distinctive call.


In [7]:
image = Image.open ("images/flamingo_photo.png")
image1 = Image.open ("images/flamingo_cartoon.png")
image2 = Image.open ("images/flamingo_3d.png")

images = [image,image1,image2]
prompt = [f'Use the image 0: <image0>{replace_token}, image 1: <image1>{replace_token} and image 2: <image2>{replace_token} as a visual aids to help you answer the question. Question: Give the reason why image 0, image 1 and image 2 are different? Answer:']

prompt = " ".join(prompt)

inputs = processor(images=images, text=prompt, return_tensors="pt")

inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
inputs['img_mask'] = torch.tensor([[1 for i in range(len(images))]])
inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)

inputs = inputs.to(DEVICE)
outputs = model.generate(
        pixel_values = inputs['pixel_values'],
        input_ids = inputs['input_ids'],
        attention_mask = inputs['attention_mask'],
        img_mask = inputs['img_mask'],
        do_sample=False,
        max_length=80,
        min_length=50,
        num_beams=8,
        set_min_padding_size =False,
)
generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
print(generated_text)


image 0 is a flamingo and image 1 is a cartoon of a flamingo and image 2 is a cartoon of a flamingo a flamingo a flamingo a flamingo a flamingo a flamingo a flamingo a flaming
