In [1]:
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForVision2Seq, AutoProcessor
from PIL import Image

from peft import LoraConfig, get_peft_model
from tqdm import tqdm
import csv


Welcome to bitsandbytes. For bug reports, please run

python -m bitsandbytes

 and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
bin /opt/conda/envs/torch2.0/lib/python3.9/site-packages/bitsandbytes/libbitsandbytes_cuda117.so
CUDA SETUP: CUDA runtime path found: /opt/conda/envs/torch2.0/lib/libcudart.so.11.0
CUDA SETUP: Highest compute capability among GPUs detected: 8.6
CUDA SETUP: Detected CUDA version 117
CUDA SETUP: Loading binary /opt/conda/envs/torch2.0/lib/python3.9/site-packages/bitsandbytes/libbitsandbytes_cuda117.so...


In [2]:
import os 
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [3]:
from torch import nn

class Model(nn.Module):
    
    def __init__(self, multi_modal_model):
        
        super(Model, self).__init__()
        self.multi_modal_model = multi_modal_model
        
        self.text_model_name = self.multi_modal_model.config._name_or_path
        
        self.hidden_size = self.multi_modal_model.config.text_config.hidden_size
        #self.classifier = nn.Linear(self.hidden_size, 2, dtype=torch.float16, bias=False)
        
        self.classifier = nn.Sequential(
            nn.Linear(self.hidden_size, 512  , dtype=torch.float16),
            nn.LayerNorm(512, dtype=torch.float16),
            nn.GELU(),
            nn.Linear(512, 2, dtype=torch.float16),
        )
    
    def forward(self, inputs):
        
       
        if "t5" in self.text_model_name:
            decoder_input_ids = inputs.pop("decoder_input_ids")
            outputs = self.multi_modal_model(pixel_values=pixel_values, input_ids=input_ids, decoder_input_ids=decoder_input_ids, output_hidden_states=True, return_dict=True)
        
            last_hidden_state = outputs.language_model_outputs.decoder_hidden_states[0]
            last_token = last_hidden_state[:,-1]
            
            logits = self.classifier(last_token)
            
            outputs = logits
        
        else:
            outputs = self.multi_modal_model(pixel_values=pixel_values, input_ids=input_ids, output_hidden_states=True, return_dict=True)
            last_hidden_state = outputs.language_model_outputs.hidden_states[0] 
            
            logits = self.classifier(last_hidden_state)
            
            batch_size, sequence_length = input_ids.shape[:2]
            sequence_lengths = (torch.ne(input_ids, self.multi_modal_model.config.text_config.pad_token_id).sum(-1) - 1).to(logits.device)
            
            pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
            
            outputs = pooled_logits
        
        
        return outputs


In [4]:
dataset = load_dataset("csv", data_files={"train" : "./ambiguous_questions_train.csv", "test" : "./ambiguous_questions_test.csv"})

Using custom data configuration default-b4c98d92954de12c
Found cached dataset csv (/root/.cache/huggingface/datasets/csv/default-b4c98d92954de12c/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)


  0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
# Let's define the LoraConfig
config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
)

model_name_or_path = 'Salesforce/blip2-flan-t5-xxl'
cache_dir = "./" + model_name_or_path.split('/')[-1]

# We load our model and processor using `transformers`
processor = AutoProcessor.from_pretrained(model_name_or_path,cache_dir=cache_dir)
model = AutoModelForVision2Seq.from_pretrained(model_name_or_path,cache_dir=cache_dir, device_map="sequential",torch_dtype=torch.float16)

# Get our peft model and print the number of trainable parameters
model = get_peft_model(model, config)
model.print_trainable_parameters()

#print(model.get_base_model)

model = Model(model)


Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

trainable params: 22388736 || all params: 12251985408 || trainable%: 0.1827355751287555


In [6]:
class ImageTextClassificationDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        
        image = Image.open("./images/"+str(item['image_id'])+".jpg")
        encoding = self.processor(images=image, padding="max_length", return_tensors="pt")
        # remove batch dimension
        encoding = {k: v.squeeze() for k, v in encoding.items()}
        encoding["text"] = item["question"]
        if item['is_ambiguous']:
            encoding['label'] = torch.tensor(1) if item['is_ambiguous'] == "O" else torch.tensor(0)
        
        if "t5" in self.processor.tokenizer.name_or_path:
            encoding['decoder_input_ids'] = torch.tensor([self.processor.tokenizer.pad_token_id])
        
        return encoding


def collator(batch):
    # pad the input_ids and attention_mask
    processed_batch = {}
    for key in batch[0].keys():
        if key == "text":
            text_inputs = processor.tokenizer(
                [example["text"] for example in batch], padding=True, return_tensors="pt"
            )
            processed_batch["input_ids"] = text_inputs["input_ids"]
            processed_batch["attention_mask"] = text_inputs["attention_mask"]
        else:
            processed_batch[key] = torch.stack([example[key] for example in batch])
     
    
    return processed_batch

In [7]:

train_dataset = ImageTextClassificationDataset(dataset['train'], processor)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=16, collate_fn=collator)

test_dataset = ImageTextClassificationDataset(dataset['test'], processor)
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=1, collate_fn=collator)

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

device = "cuda" if torch.cuda.is_available() else "cpu"

model.to(device)
model.train()


Model(
  (multi_modal_model): PeftModel(
    (base_model): LoraModel(
      (model): Blip2ForConditionalGeneration(
        (vision_model): Blip2VisionModel(
          (embeddings): Blip2VisionEmbeddings(
            (patch_embedding): Conv2d(3, 1408, kernel_size=(14, 14), stride=(14, 14))
          )
          (encoder): Blip2Encoder(
            (layers): ModuleList(
              (0-38): 39 x Blip2EncoderLayer(
                (self_attn): Blip2Attention(
                  (dropout): Dropout(p=0.0, inplace=False)
                  (qkv): Linear(
                    in_features=1408, out_features=4224, bias=True
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.05, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=1408, out_features=16, bias=False)
                    )
                    (lora_B): ModuleDict(
                      (default): Linear(in_features

In [8]:
import numpy as np
epoch_loss_list = []

criterion = nn.CrossEntropyLoss(reduction='mean')

with open("./ambiguous_questions_test.csv", 'r') as f:
    reader = csv.reader(f)
    lines = [line for line in reader]

for epoch in range(10):
    print("Epoch:", epoch)
    epoch_loss = []
    for idx, batch in enumerate(tqdm(train_dataloader)):
        input_ids = batch.pop("input_ids").to(device)
        pixel_values = batch.pop("pixel_values").to(device, dtype=torch.float16)
        labels = batch.pop("label").to(device)
        outputs = model(batch)
        
        #print(labels)
        #print(outputs)
        
        loss = criterion(outputs, labels)
        print(loss.item())
        #loss = torch.mean(outputs)
        
        epoch_loss.append(loss.item())

        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        #if idx % 10 == 0:
        #    generated_output = model.generate(pixel_values=pixel_values, input_ids=input_ids)
        #    print(processor.batch_decode(generated_output, skip_special_tokens=True))
    
    print(np.mean(epoch_loss))
    
    
    model.eval()
    with torch.no_grad():
        epoch_outputs = []
        for idx, batch in enumerate(tqdm(test_dataloader)):
            # input_ids = batch.pop("input_ids").to(device)
            # pixel_values = batch.pop("pixel_values").to(device, dtype=torch.float16)
            #generated_output = model.generate(pixel_values=pixel_values, input_ids=input_ids)
            
            logits = model(batch)
            predictions = torch.argmax(logits, dim=1).item()
            predictions = [predictions] if type(predictions) != list else predictions
            
            epoch_outputs += predictions #processor.batch_decode(generated_output, skip_special_tokens=True)
    
    with open ("./test_{}.csv".format(epoch), 'w') as f:
        
        writer = csv.writer(f)
        for idx, line in enumerate(lines):
            if idx == 0:
                writer.writerow(line)
            else:
                line.append(epoch_outputs[idx-1])
                writer.writerow(line)
                
    model.train()            
                

Epoch: 0


  0%|          | 0/7 [00:00<?, ?it/s]

5t5t5t5t5t5t5t


 14%|█▍        | 1/7 [00:01<00:09,  1.52s/it]

0.84423828125
5t5t5t5t5t5t5t


 29%|██▊       | 2/7 [00:02<00:04,  1.06it/s]

nan
5t5t5t5t5t5t5t


 43%|████▎     | 3/7 [00:02<00:03,  1.28it/s]

nan
5t5t5t5t5t5t5t


 57%|█████▋    | 4/7 [00:03<00:02,  1.45it/s]

nan
5t5t5t5t5t5t5t


 71%|███████▏  | 5/7 [00:03<00:01,  1.57it/s]

nan
5t5t5t5t5t5t5t


 86%|████████▌ | 6/7 [00:04<00:00,  1.66it/s]

nan
5t5t5t5t5t5t5t


100%|██████████| 7/7 [00:04<00:00,  1.55it/s]


nan
nan


  1%|          | 1/100 [00:00<00:13,  7.20it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


  3%|▎         | 3/100 [00:00<00:12,  7.55it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


  5%|▌         | 5/100 [00:00<00:12,  7.44it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


  7%|▋         | 7/100 [00:00<00:12,  7.52it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


  9%|▉         | 9/100 [00:01<00:12,  7.54it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 11%|█         | 11/100 [00:01<00:12,  7.28it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 13%|█▎        | 13/100 [00:01<00:11,  7.64it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 15%|█▌        | 15/100 [00:01<00:10,  7.81it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 17%|█▋        | 17/100 [00:02<00:11,  7.51it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 19%|█▉        | 19/100 [00:02<00:11,  7.20it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 21%|██        | 21/100 [00:02<00:10,  7.26it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 23%|██▎       | 23/100 [00:03<00:10,  7.33it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 25%|██▌       | 25/100 [00:03<00:09,  7.56it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 27%|██▋       | 27/100 [00:03<00:09,  7.69it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 29%|██▉       | 29/100 [00:03<00:09,  7.73it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 31%|███       | 31/100 [00:04<00:09,  7.08it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 33%|███▎      | 33/100 [00:04<00:09,  7.29it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 35%|███▌      | 35/100 [00:04<00:09,  7.17it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 37%|███▋      | 37/100 [00:04<00:08,  7.42it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 39%|███▉      | 39/100 [00:05<00:08,  7.42it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 41%|████      | 41/100 [00:05<00:08,  7.33it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 43%|████▎     | 43/100 [00:05<00:07,  7.37it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 45%|████▌     | 45/100 [00:06<00:07,  7.59it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 47%|████▋     | 47/100 [00:06<00:06,  7.68it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 49%|████▉     | 49/100 [00:06<00:06,  7.60it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 51%|█████     | 51/100 [00:06<00:06,  7.74it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 53%|█████▎    | 53/100 [00:07<00:06,  7.49it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 55%|█████▌    | 55/100 [00:07<00:06,  7.48it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 57%|█████▋    | 57/100 [00:07<00:05,  7.65it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 59%|█████▉    | 59/100 [00:07<00:05,  7.72it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 61%|██████    | 61/100 [00:08<00:05,  7.57it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 63%|██████▎   | 63/100 [00:08<00:04,  7.69it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 65%|██████▌   | 65/100 [00:08<00:04,  7.83it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 67%|██████▋   | 67/100 [00:08<00:04,  7.87it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 69%|██████▉   | 69/100 [00:09<00:03,  7.87it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 71%|███████   | 71/100 [00:09<00:03,  7.59it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 73%|███████▎  | 73/100 [00:09<00:03,  7.51it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 75%|███████▌  | 75/100 [00:09<00:03,  7.69it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 77%|███████▋  | 77/100 [00:10<00:03,  7.63it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 79%|███████▉  | 79/100 [00:10<00:02,  7.40it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 81%|████████  | 81/100 [00:10<00:02,  6.83it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 83%|████████▎ | 83/100 [00:11<00:02,  7.29it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 85%|████████▌ | 85/100 [00:11<00:01,  7.54it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 87%|████████▋ | 87/100 [00:11<00:01,  7.56it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 89%|████████▉ | 89/100 [00:11<00:01,  7.61it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 91%|█████████ | 91/100 [00:12<00:01,  7.62it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 93%|█████████▎| 93/100 [00:12<00:00,  7.59it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 95%|█████████▌| 95/100 [00:12<00:00,  7.68it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 97%|█████████▋| 97/100 [00:12<00:00,  7.76it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 99%|█████████▉| 99/100 [00:13<00:00,  7.81it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


100%|██████████| 100/100 [00:13<00:00,  7.52it/s]


Epoch: 1


  0%|          | 0/7 [00:00<?, ?it/s]

5t5t5t5t5t5t5t


 14%|█▍        | 1/7 [00:00<00:03,  1.76it/s]

nan
5t5t5t5t5t5t5t


 29%|██▊       | 2/7 [00:01<00:02,  1.78it/s]

nan
5t5t5t5t5t5t5t


 43%|████▎     | 3/7 [00:01<00:02,  1.80it/s]

nan
5t5t5t5t5t5t5t


 57%|█████▋    | 4/7 [00:02<00:01,  1.82it/s]

nan
5t5t5t5t5t5t5t


 71%|███████▏  | 5/7 [00:02<00:01,  1.83it/s]

nan
5t5t5t5t5t5t5t


 86%|████████▌ | 6/7 [00:03<00:00,  1.79it/s]

nan
5t5t5t5t5t5t5t


100%|██████████| 7/7 [00:03<00:00,  1.95it/s]


nan
nan


  1%|          | 1/100 [00:00<00:12,  7.82it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


  3%|▎         | 3/100 [00:00<00:12,  7.82it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


  5%|▌         | 5/100 [00:00<00:13,  6.96it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


  7%|▋         | 7/100 [00:00<00:13,  6.98it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


  9%|▉         | 9/100 [00:01<00:12,  7.47it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 11%|█         | 11/100 [00:01<00:12,  7.39it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 13%|█▎        | 13/100 [00:01<00:11,  7.34it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 15%|█▌        | 15/100 [00:02<00:11,  7.56it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 17%|█▋        | 17/100 [00:02<00:10,  7.72it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 19%|█▉        | 19/100 [00:02<00:10,  7.70it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 21%|██        | 21/100 [00:02<00:10,  7.82it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 23%|██▎       | 23/100 [00:03<00:09,  7.90it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 25%|██▌       | 25/100 [00:03<00:09,  7.66it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 27%|██▋       | 27/100 [00:03<00:09,  7.60it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 29%|██▉       | 29/100 [00:03<00:09,  7.70it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 31%|███       | 31/100 [00:04<00:09,  7.66it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 33%|███▎      | 33/100 [00:04<00:08,  7.62it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 35%|███▌      | 35/100 [00:04<00:08,  7.43it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 37%|███▋      | 37/100 [00:04<00:08,  7.68it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 39%|███▉      | 39/100 [00:05<00:08,  7.24it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 41%|████      | 41/100 [00:05<00:07,  7.50it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 43%|████▎     | 43/100 [00:05<00:07,  7.61it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 45%|████▌     | 45/100 [00:05<00:07,  7.79it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 47%|████▋     | 47/100 [00:06<00:06,  7.57it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 49%|████▉     | 49/100 [00:06<00:06,  7.61it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 51%|█████     | 51/100 [00:06<00:06,  7.68it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 53%|█████▎    | 53/100 [00:07<00:06,  7.64it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 55%|█████▌    | 55/100 [00:07<00:05,  7.81it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 57%|█████▋    | 57/100 [00:07<00:05,  7.91it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 59%|█████▉    | 59/100 [00:07<00:05,  7.93it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 61%|██████    | 61/100 [00:08<00:04,  7.95it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 63%|██████▎   | 63/100 [00:08<00:04,  7.66it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 65%|██████▌   | 65/100 [00:08<00:04,  7.40it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 67%|██████▋   | 67/100 [00:08<00:04,  7.63it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 69%|██████▉   | 69/100 [00:09<00:03,  7.81it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t


 71%|███████   | 71/100 [00:09<00:03,  7.53it/s]

5t5t5t5t5t5t5t
5t5t5t5t5t5t5t





KeyboardInterrupt: 