In [1]:
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForVision2Seq, AutoProcessor
from PIL import Image

from tqdm import tqdm
import csv

In [2]:
import os 
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [3]:
import pandas as pd
import json
import numpy as np

few_shot_examples = {}

intermediate_questions =  pd.read_excel("intermediate_questions_samples_32.xlsx", dtype={'qid' : str})

with open("./ambiguous_questions_rebuilt_ptb.json") as f:
    ambiguous_questions = json.load(f)

with open("./ambiguous_questions.json") as f:
    original_questions = json.load(f)

print(len(ambiguous_questions))
for idx in intermediate_questions.index:
    intermediate_question_example = intermediate_questions.iloc[idx]
    qid = intermediate_question_example['qid']

    ambiguous_question_example = ambiguous_questions[qid]
    
    ambiguous_question = ambiguous_question_example['question']
    intermediate_question = intermediate_question_example['intermediate question']
    original_question_example = original_questions.pop(qid)
    
    few_shot_examples[qid] = original_question_example
    few_shot_examples[qid]['ambiguous_question'] = ambiguous_question
    few_shot_examples[qid]['intermediate_question'] = intermediate_question
    few_shot_examples[qid].pop("addtional_question")
    
    entities_list = set()
    for object_name in few_shot_examples[qid]['irrelated_object_names'].values():
        if object_name in few_shot_examples[qid]['ambiguous_question']:
             entities_list.add(object_name)
    few_shot_examples[qid]['question_entities'] = list(entities_list)
   
np.random.seed(42) 
test_keys = np.random.choice(list(original_questions.keys()), 32, replace=False)
test_examples = {}

for qid in test_keys:
    original_questions[qid].pop("addtional_question")
    test_examples[qid] =  original_questions[qid]
    
    ambiguous_question_example = ambiguous_questions[qid]
    ambiguous_question = ambiguous_question_example['question']
    
    test_examples[qid]["ambiguous_question"] = ambiguous_question
    
    entities_list = set()
    for object_name in test_examples[qid]['irrelated_object_names'].values():
        if object_name in test_examples[qid]['ambiguous_question']:
             entities_list.add(object_name)
    test_examples[qid]['question_entities'] = list(entities_list)
    
# dataset = load_dataset("csv", data_files={"train" : "./ambiguous_questions_train.csv", "test" : "./ambiguous_questions_test.csv"})

125854


In [4]:
from open_flamingo import create_model_and_transforms

model, image_processor, tokenizer = create_model_and_transforms(
    clip_vision_encoder_path="ViT-L-14",
    clip_vision_encoder_pretrained="openai",
    lang_encoder_path="anas-awadalla/mpt-7b",
    tokenizer_path="anas-awadalla/mpt-7b",
    cross_attn_every_n_layers=4
)

Using pad_token, but it is not set yet.
Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.


You are using config.init_device='cpu', but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Flamingo model initialized with 1384781840 trainable parameters


In [5]:
# grab model checkpoint from huggingface hub
from huggingface_hub import hf_hub_download
import torch

checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-9B-vitl-mpt7b", "checkpoint.pt")
model.load_state_dict(torch.load(checkpoint_path), strict=False)

_IncompatibleKeys(missing_keys=['vision_encoder.class_embedding', 'vision_encoder.positional_embedding', 'vision_encoder.proj', 'vision_encoder.conv1.weight', 'vision_encoder.ln_pre.weight', 'vision_encoder.ln_pre.bias', 'vision_encoder.transformer.resblocks.0.ln_1.weight', 'vision_encoder.transformer.resblocks.0.ln_1.bias', 'vision_encoder.transformer.resblocks.0.attn.in_proj_weight', 'vision_encoder.transformer.resblocks.0.attn.in_proj_bias', 'vision_encoder.transformer.resblocks.0.attn.out_proj.weight', 'vision_encoder.transformer.resblocks.0.attn.out_proj.bias', 'vision_encoder.transformer.resblocks.0.ln_2.weight', 'vision_encoder.transformer.resblocks.0.ln_2.bias', 'vision_encoder.transformer.resblocks.0.mlp.c_fc.weight', 'vision_encoder.transformer.resblocks.0.mlp.c_fc.bias', 'vision_encoder.transformer.resblocks.0.mlp.c_proj.weight', 'vision_encoder.transformer.resblocks.0.mlp.c_proj.bias', 'vision_encoder.transformer.resblocks.1.ln_1.weight', 'vision_encoder.transformer.resbloc

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model.to(device)
model.eval()


Flamingo(
  (vision_encoder): VisionTransformer(
    (patchnorm_pre_ln): Identity()
    (conv1): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
    (patch_dropout): Identity()
    (ln_pre): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): ModuleList(
        (0-23): 24 x ResidualAttentionBlock(
          (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
          )
          (ls_1): Identity()
          (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
          )
          (ls_2): Identity()
        )
      )
    )
    (ln_post): Layer

In [14]:
from collections import Counter

num_few_shot_examples = len(few_shot_examples)

few_shot_image_ids = [example["imageId"] for example in few_shot_examples.values()]
few_shot_images = [Image.open("./images/" + str(image_id) + ".jpg").convert('RGB') for image_id in few_shot_image_ids]

vision_x = [image_processor(demo_image).unsqueeze(0) for demo_image in few_shot_images]
                
tokenizer.padding_side = "left" # For generation padding tokens should be on the left

text = [f"<image>Main question is {few_shot_example['ambiguous_question']}. " +   
        "Write a intermediate question to clarify entities of the main question. " + 
        f"Intermediate question: {few_shot_example['intermediate_question']} <|endofchunk|>"  for qid, few_shot_example in few_shot_examples.items()]

        # "The entitites of question: " +  " ".join(few_shot_example['question_entities']) + ". " + 
        

text = "".join(text)

print(text)

<image>main question is What is on the wall?. Write a intermediate question to clarify entities of the main question. Intermediate question: Which part of the wall? <|endofchunk|><image>main question is Is the woman to the left or to the right of the man?. Write a intermediate question to clarify entities of the main question. Intermediate question: What is the man wearing? <|endofchunk|><image>main question is Are the bananas?. Write a intermediate question to clarify entities of the main question. Intermediate question: Is the bananas is in the bottom of the image?  <|endofchunk|><image>main question is On which side of the image is the cup?. Write a intermediate question to clarify entities of the main question. Intermediate question: What is the color of the cup? <|endofchunk|><image>main question is Is the man to the right of the surfboard?. Write a intermediate question to clarify entities of the main question. Intermediate question: where is the surfboard in the picture? <|endof

In [15]:
predictions = []
for idx, (qid, example)in tqdm(enumerate(test_examples.items())):
    if idx == 0:
        continue
    question = example["ambiguous_question"]
    image_id = example['imageId']
    image = Image.open("./images/" + str(image_id) + '.jpg').convert('RGB')
    input_prompt = text + \
    f"<image>Main question is {example['ambiguous_question']}. " +  \
    "Write a intermediate question to clarify entities of the main question. " + \
    f"Intermediate question: "
    
    
    # "The entitites of question : " +  " ".join(example['question_entities']) + ". " + \
    
    # print(input_prompt)
    
    input_images = vision_x + [image]
    input_images = torch.cat(vision_x, dim=0)
    input_images = input_images.unsqueeze(1).unsqueeze(0)
    
    lang_x = tokenizer(
        [input_prompt] ,
        return_tensors="pt",
    )
    
    generated_text = model.generate(
    vision_x=input_images.to(device),
    lang_x=lang_x["input_ids"].to(device),
    attention_mask=lang_x["attention_mask"].to(device),
    max_new_tokens=20,
    num_beams=3,
    )
    
    #print(tokenizer.batch_decode(lang_x["input_ids"],skip_special_tokens=False))
    #print("Input prompt: ", input_prompt)
    
    decoded_text =  tokenizer.decode(generated_text[0], skip_special_tokens=True)
    if idx < 5:
        print("question: ", question)
        print("Generated text: ",decoded_text.split("Intermediate question:")[-1].strip())
        # print()
    intermediate_question = decoded_text.split("Intermediate question:")[-1].strip()

    #inputs = processor(image, input_promt, return_tensors="pt").to(device)
    
    #out = model.generate(**inputs)
    #prediction = processor.decode(out[0], skip_special_tokens=True)
    #print(prediction)
    predictions.append(intermediate_question)
    

print(predictions)
                
            



0it [00:00, ?it/s]Setting `pad_token_id` to `eos_token_id`:50277 for open-end generation.
2it [01:16, 38.23s/it]Setting `pad_token_id` to `eos_token_id`:50277 for open-end generation.


question:  Is the bus to the left or to the right of the pedestrians?
Generated text:  What is the color of the bus? main question is Is the man to the left or to


3it [02:35, 55.18s/it]Setting `pad_token_id` to `eos_token_id`:50277 for open-end generation.


question:  What is the food to the left of the doughnut the bucket is to the right of?
Generated text:  What is the color of the bucket?  main question is What is the color of the


4it [03:54, 63.95s/it]Setting `pad_token_id` to `eos_token_id`:50277 for open-end generation.


question:  How large is the bear?
Generated text:  Is the bear in the middle of the image? main question is What is the color of the


5it [05:12, 69.13s/it]Setting `pad_token_id` to `eos_token_id`:50277 for open-end generation.


question:  Is the bus to the right or to the left of the driver?
Generated text:  What is the color of the bus? main question is Is the man to the right or to


6it [06:31, 72.40s/it]Setting `pad_token_id` to `eos_token_id`:50277 for open-end generation.
7it [07:50, 74.33s/it]Setting `pad_token_id` to `eos_token_id`:50277 for open-end generation.
8it [09:08, 75.67s/it]Setting `pad_token_id` to `eos_token_id`:50277 for open-end generation.
9it [10:27, 76.70s/it]Setting `pad_token_id` to `eos_token_id`:50277 for open-end generation.
10it [11:46, 77.32s/it]Setting `pad_token_id` to `eos_token_id`:50277 for open-end generation.
11it [13:05, 77.83s/it]Setting `pad_token_id` to `eos_token_id`:50277 for open-end generation.
12it [14:24, 78.26s/it]Setting `pad_token_id` to `eos_token_id`:50277 for open-end generation.
13it [15:43, 78.45s/it]Setting `pad_token_id` to `eos_token_id`:50277 for open-end generation.
14it [17:02, 78.61s/it]Setting `pad_token_id` to `eos_token_id`:50277 for open-end generation.
15it [18:21, 78.58s/it]Setting `pad_token_id` to `eos_token_id`:50277 for open-end generation.
16it [19:39, 78.55s/it]Setting `pad_token_id` to `eos_

['What is the color of the bus? main question is Is the man to the left or to', 'What is the color of the bucket? \xa0main question is What is the color of the', 'Is the bear in the middle of the image? main question is What is the color of the', 'What is the color of the bus? main question is Is the man to the right or to', 'What is the color of the car? main question is Is the man to the right or to', 'Is the elephant is in the middle of the image? main question is Is the man to the', 'Is the picture in the middle of the image? main question is Which side of the image is', 'What is the color of the phone? main question is Is the man to the right or to', 'What is the color of the plate? main question is What is the man doing?. Write a', 'What is the color of the television? main question is Is the person to the right or to', 'Is the chair to the left of the remote control that is to the left of the girl?', 'What is the man wearing? main question is Is the man to the right or to the le




In [16]:
for idx, (qid, example)in tqdm(enumerate(test_examples.items())):
    print(example['ambiguous_question'])

32it [00:00, 74153.44it/s]

Is the man 's hair?
Is the bus to the left or to the right of the pedestrians?
What is the food to the left of the doughnut the bucket is to the right of?
How large is the bear?
Is the bus to the right or to the left of the driver?
Is the blue car to the right or to the left of the van?
Is the elephant?
Which side is the picture on?
What device is the girl holding , a phone or a Wii controller?
What is the food on the plate called?
Is the person to the left of the laptop looking at a television?
Is the chair to the left of the remote control that is to the left of the girl?
What is the person to the left of the man holding?
Is the house to the right or to the left of the cow?
What color is the zebra?
What is the elephant?
What is the vegetable to the left of the cucumber called?
Is she to the left of a picture?
What shape does the table have?
What is the item of furniture to the left of the speaker?
Are the glasses to the right or to the left of the person?
Which side of the picture is


