In [None]:
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForVision2Seq, AutoProcessor
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
from PIL import Image

from tqdm import tqdm
import csv

In [None]:
# import os 
# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"]="0,1"

In [None]:
model_name_or_path = 'Salesforce/instructblip-flan-t5-xxl'
cache_dir = "./" + model_name_or_path.split('/')[-1]

# We load our model and processor using `transformers`
processor = AutoProcessor.from_pretrained(model_name_or_path,cache_dir=cache_dir)
model = AutoModelForVision2Seq.from_pretrained(model_name_or_path,cache_dir=cache_dir,torch_dtype=torch.float16)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model.to(device)
model.eval()


In [None]:
import pandas as pd
import json
import numpy as np

few_shot_examples = {}

intermediate_questions =  pd.read_excel("intermediate_questions_samples_32.xlsx", dtype={'qid' : str})

with open("./ambiguous_questions_rebuilt_ptb.json") as f:
    ambiguous_questions = json.load(f)

with open("./ambiguous_questions.json") as f:
    original_questions = json.load(f)

print(len(ambiguous_questions))
for idx in intermediate_questions.index:
    intermediate_question_example = intermediate_questions.iloc[idx]
    qid = intermediate_question_example['qid']

    ambiguous_question_example = ambiguous_questions[qid]
    
    ambiguous_question = ambiguous_question_example['question']
    intermediate_question = intermediate_question_example['intermediate question']
    original_question_example = original_questions.pop(qid)
    
    few_shot_examples[qid] = original_question_example
    few_shot_examples[qid]['ambiguous_question'] = ambiguous_question
    few_shot_examples[qid]['intermediate_question'] = intermediate_question
    few_shot_examples[qid].pop("addtional_question")
    
    entities_list = set()
    for object_name in few_shot_examples[qid]['irrelated_object_names'].values():
        if object_name in few_shot_examples[qid]['ambiguous_question']:
             entities_list.add(object_name)
    few_shot_examples[qid]['question_entities'] = list(entities_list)
   
np.random.seed(42) 
test_keys = np.random.choice(list(original_questions.keys()), 32, replace=False)
test_examples = {}

for qid in test_keys:
    original_questions[qid].pop("addtional_question")
    test_examples[qid] =  original_questions[qid]
    
    ambiguous_question_example = ambiguous_questions[qid]
    ambiguous_question = ambiguous_question_example['question']
    
    test_examples[qid]["ambiguous_question"] = ambiguous_question
    
    entities_list = set()
    for object_name in test_examples[qid]['irrelated_object_names'].values():
        if object_name in test_examples[qid]['ambiguous_question']:
             entities_list.add(object_name)
    test_examples[qid]['question_entities'] = list(entities_list)

In [None]:

promt = f"Instructions : Classify the following into ambigous and definite question. An ambiguous question is unanswerable question considering image. "
text = ""

predictions = []
csv_lines = [["AQ", "IQ"]]
for idx,(qid, example) in tqdm(enumerate(test_examples.items())):
    if idx == 0:
        continue
    question = example["ambiguous_question"]
    image_id = example['imageId']
    image = Image.open("./images/" + str(image_id) + '.jpg').convert('RGB')
    
    input_prompt = text + \
    f"Main question is {example['ambiguous_question']} " +  \
    "The ambiguous entities of the main question: " +  " ".join(example['question_entities']) + ". " + \
    "Write a intermediate question to clarify the ambiguous entities of the main question. " + \
    f"Question: "

    # "the intermediate question is not same with the main question." + \
    
    
    print(input_prompt)
    inputs = processor(image, input_prompt, return_tensors="pt").to(device, torch.float16)
    out = model.generate(**inputs,
                         do_sample=False,
                        num_beams=5,
                        max_length=256,
                        min_length=3,
                        top_p=0.9,
                        repetition_penalty=1.5,
                        length_penalty=1.0,
                        temperature=1,)
    prediction = processor.decode(out[0], skip_special_tokens=True)
    example['intermediate_question'] = prediction
    print(prediction)
    predictions.append(prediction)
    csv_lines.append([question, prediction])
    
with open (f"./test_fewshot_{model_name_or_path.split('/')[-1]}.json" , 'w') as f:
     
    json.dump(test_examples, f)   
    # writer = csv.writer(f)
    # for idx, line in enumerate(test_lines):
    #     if idx == 0:
    #         writer.writerow(line)
    #     else:
    #         line.append(predictions[idx-1])
    #         writer.writerow(line)

with open (f"./test_fewshot_{model_name_or_path.split('/')[-1]}.csv" , 'w') as f:
    writer = csv.writer(f)
    writer.writerows(csv_lines)       
            
                