In [1]:
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForVision2Seq, AutoProcessor
from PIL import Image

from tqdm import tqdm
import csv

In [2]:
import os 
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [3]:
dataset = load_dataset("csv", data_files={"train" : "./ambiguous_questions_train.csv", "test" : "./ambiguous_questions_test.csv"})

Using custom data configuration default-b4c98d92954de12c
Found cached dataset csv (/root/.cache/huggingface/datasets/csv/default-b4c98d92954de12c/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)


  0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
model_name_or_path = 'Salesforce/blip2-flan-t5-xxl'
cache_dir = "./" + model_name_or_path.split('/')[-1]

# We load our model and processor using `transformers`
processor = AutoProcessor.from_pretrained(model_name_or_path,cache_dir=cache_dir)
model = AutoModelForVision2Seq.from_pretrained(model_name_or_path,cache_dir=cache_dir, device_map="sequential",torch_dtype=torch.float16)

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model.to(device)
model.eval()


Blip2ForConditionalGeneration(
  (vision_model): Blip2VisionModel(
    (embeddings): Blip2VisionEmbeddings(
      (patch_embedding): Conv2d(3, 1408, kernel_size=(14, 14), stride=(14, 14))
    )
    (encoder): Blip2Encoder(
      (layers): ModuleList(
        (0-38): 39 x Blip2EncoderLayer(
          (self_attn): Blip2Attention(
            (dropout): Dropout(p=0.0, inplace=False)
            (qkv): Linear(in_features=1408, out_features=4224, bias=True)
            (projection): Linear(in_features=1408, out_features=1408, bias=True)
          )
          (layer_norm1): LayerNorm((1408,), eps=1e-05, elementwise_affine=True)
          (mlp): Blip2MLP(
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=1408, out_features=6144, bias=True)
            (fc2): Linear(in_features=6144, out_features=1408, bias=True)
          )
          (layer_norm2): LayerNorm((1408,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
    (post_layernorm): LayerNorm((

In [10]:
with open("./ambiguous_questions_train.csv", 'r') as f:
    reader = csv.reader(f)
    train_lines = [line for line in reader]


with open("./ambiguous_questions_test.csv", 'r') as f:
    reader = csv.reader(f)
    test_lines = [line for line in reader]
    
    
  

promt = f"Instructions : Classify the following into ambigous and definite question. An ambiguous question is unanswerable question considering image. "

few_shot_image_ids = [train_lines[idx][5] for idx in range(1,6)]
images = [Image.open("./images/" + str(id) + '.jpg').convert('RGB') for id in few_shot_image_ids]

predictions = []
for idx, line in tqdm(enumerate(test_lines)):
    if idx == 0:
        continue
    question = line[1]
    image_id = line[5]
    image = Image.open("./images/" + str(image_id) + '.jpg').convert('RGB')
    input_promt = promt + "qustion : {} label : ".format(question)
   

    inputs = processor(image, input_promt, return_tensors="pt").to(device, torch.float16)
    out = model.generate(**inputs)
    prediction = processor.decode(out[0], skip_special_tokens=True)
    print(prediction)
    predictions.append(prediction)
    
with open ("./test_fewshot.csv", 'w') as f:
        
    writer = csv.writer(f)
    for idx, line in enumerate(test_lines):
        if idx == 0:
            writer.writerow(line)
        else:
            line.append(predictions[idx-1])
            writer.writerow(line)
                
            
                

3it [00:00,  7.26it/s]

definite
definite


5it [00:00,  5.84it/s]

definite
definite


7it [00:01,  5.59it/s]

definite
definite


9it [00:01,  5.35it/s]

definite
definite


11it [00:01,  5.41it/s]

definite
definite


13it [00:02,  5.29it/s]

definite
definite


15it [00:02,  5.42it/s]

definite
definite


17it [00:03,  5.46it/s]

definite
definite


19it [00:03,  5.47it/s]

definite
definite


20it [00:03,  5.29it/s]

definite


22it [00:03,  5.22it/s]

a sandwich
definite


24it [00:04,  5.26it/s]

definite
definite


26it [00:04,  5.25it/s]

definite
definite


28it [00:05,  5.29it/s]

definite
definite


30it [00:05,  5.40it/s]

definite
definite


31it [00:05,  5.18it/s]

definite


33it [00:06,  5.06it/s]

definite
definite


35it [00:06,  5.27it/s]

definite
definite


37it [00:06,  5.15it/s]

definite
definite


39it [00:07,  5.28it/s]

definite
definite


41it [00:07,  5.20it/s]

definite
definite


43it [00:07,  5.34it/s]

definite
definite


45it [00:08,  5.31it/s]

definite
definite


47it [00:08,  5.41it/s]

definite
definite


49it [00:09,  5.44it/s]

definite
definite


51it [00:09,  5.46it/s]

definite
definite


53it [00:09,  5.49it/s]

definite
definite


55it [00:10,  5.43it/s]

definite
definite


57it [00:10,  5.43it/s]

definite
definite


59it [00:10,  5.33it/s]

definite
definite


60it [00:11,  5.30it/s]

definite


62it [00:11,  5.21it/s]

definite
definite


64it [00:11,  5.32it/s]

definite
definite


66it [00:12,  5.41it/s]

definite
definite


68it [00:12,  5.35it/s]

definite
definite


70it [00:13,  5.41it/s]

definite
definite


72it [00:13,  5.36it/s]

definite
definite


74it [00:13,  5.44it/s]

definite
definite


75it [00:13,  5.45it/s]

definite


77it [00:14,  4.54it/s]

a sundae
definite


79it [00:14,  4.99it/s]

definite
definite


81it [00:15,  5.08it/s]

definite
definite


83it [00:15,  5.10it/s]

definite
definite


85it [00:15,  5.31it/s]

definite
definite


87it [00:16,  5.41it/s]

definite
definite


89it [00:16,  5.45it/s]

definite
definite


91it [00:17,  5.47it/s]

definite
definite


93it [00:17,  5.46it/s]

definite
definite


95it [00:17,  5.37it/s]

definite
definite


97it [00:18,  5.42it/s]

definite
definite


99it [00:18,  5.45it/s]

definite
definite


101it [00:18,  5.34it/s]

definite
definite



