In [None]:
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForVision2Seq, AutoProcessor
from PIL import Image

from tqdm import tqdm
import csv

In [None]:
import os 
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
import pandas as pd
import json
import numpy as np

few_shot_examples = {}

intermediate_questions =  pd.read_excel("intermediate_questions_samples_32.xlsx", dtype={'qid' : str})

with open("./ambiguous_questions_rebuilt_ptb.json") as f:
    ambiguous_questions = json.load(f)

with open("./ambiguous_questions.json") as f:
    original_questions = json.load(f)

print(len(ambiguous_questions))
for idx in intermediate_questions.index:
    intermediate_question_example = intermediate_questions.iloc[idx]
    qid = intermediate_question_example['qid']

    ambiguous_question_example = ambiguous_questions[qid]
    
    ambiguous_question = ambiguous_question_example['question']
    intermediate_question = intermediate_question_example['intermediate question']
    original_question_example = original_questions.pop(qid)
    
    few_shot_examples[qid] = original_question_example
    few_shot_examples[qid]['ambiguous_question'] = ambiguous_question
    few_shot_examples[qid]['intermediate_question'] = intermediate_question
    few_shot_examples[qid].pop("addtional_question")
    
    entities_list = set()
    for object_name in few_shot_examples[qid]['irrelated_object_names'].values():
        if object_name in few_shot_examples[qid]['ambiguous_question']:
             entities_list.add(object_name)
    few_shot_examples[qid]['question_entities'] = list(entities_list)
   
np.random.seed(42) 
test_keys = np.random.choice(list(original_questions.keys()), 32, replace=False)
test_examples = {}

for qid in test_keys:
    original_questions[qid].pop("addtional_question")
    test_examples[qid] =  original_questions[qid]
    
    ambiguous_question_example = ambiguous_questions[qid]
    ambiguous_question = ambiguous_question_example['question']
    
    test_examples[qid]["ambiguous_question"] = ambiguous_question
    
    entities_list = set()
    for object_name in test_examples[qid]['irrelated_object_names'].values():
        if object_name in test_examples[qid]['ambiguous_question']:
             entities_list.add(object_name)
    test_examples[qid]['question_entities'] = list(entities_list)
    
# dataset = load_dataset("csv", data_files={"train" : "./ambiguous_questions_train.csv", "test" : "./ambiguous_questions_test.csv"})

In [None]:
from open_flamingo import create_model_and_transforms

model, image_processor, tokenizer = create_model_and_transforms(
    clip_vision_encoder_path="ViT-L-14",
    clip_vision_encoder_pretrained="openai",
    lang_encoder_path="anas-awadalla/mpt-7b",
    tokenizer_path="anas-awadalla/mpt-7b",
    cross_attn_every_n_layers=4
)

In [None]:
# grab model checkpoint from huggingface hub
from huggingface_hub import hf_hub_download
import torch

checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-9B-vitl-mpt7b", "checkpoint.pt")
model.load_state_dict(torch.load(checkpoint_path), strict=False)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model.to(device)
model.eval()


In [None]:
from collections import Counter

num_few_shot_examples = len(few_shot_examples)

few_shot_image_ids = [example["imageId"] for example in few_shot_examples.values()]
few_shot_images = [Image.open("./images/" + str(image_id) + ".jpg").convert('RGB') for image_id in few_shot_image_ids]

vision_x = [image_processor(demo_image).unsqueeze(0) for demo_image in few_shot_images]
                
tokenizer.padding_side = "left" # For generation padding tokens should be on the left

text = [f"<image>Main question is {few_shot_example['ambiguous_question']}. " +   
        "Write a intermediate question to clarify entities of the main question. " + 
        f"Intermediate question: {few_shot_example['intermediate_question']} <|endofchunk|>"  for qid, few_shot_example in few_shot_examples.items()]

        # "The entitites of question: " +  " ".join(few_shot_example['question_entities']) + ". " + 
        

text = "".join(text)

print(text)

In [None]:
predictions = []
for idx, (qid, example)in tqdm(enumerate(test_examples.items())):
    if idx == 0:
        continue
    question = example["ambiguous_question"]
    image_id = example['imageId']
    image = Image.open("./images/" + str(image_id) + '.jpg').convert('RGB')
    input_prompt = text + \
    f"<image>Main question is {example['ambiguous_question']}. " +  \
    "Write a intermediate question to clarify entities of the main question. " + \
    f"Intermediate question: "
    
    input_images = vision_x + [image]
    input_images = torch.cat(vision_x, dim=0)
    input_images = input_images.unsqueeze(1).unsqueeze(0)
    
    lang_x = tokenizer(
        [input_prompt] ,
        return_tensors="pt",
    )
    
    generated_text = model.generate(
    vision_x=input_images.to(device),
    lang_x=lang_x["input_ids"].to(device),
    attention_mask=lang_x["attention_mask"].to(device),
    max_new_tokens=20,
    num_beams=3,
    )
    
    
    decoded_text =  tokenizer.decode(generated_text[0], skip_special_tokens=True)
    if idx < 5:
        print("question: ", question)
        print("Generated text: ",decoded_text.split("Intermediate question:")[-1].strip())
        # print()
    intermediate_question = decoded_text.split("Intermediate question:")[-1].strip()


    predictions.append(intermediate_question)
    

print(predictions)
                
            



In [None]:
for idx, (qid, example)in tqdm(enumerate(test_examples.items())):
    print(example['ambiguous_question'])