In [1]:
from PIL import Image
import requests
import torch
from transformers import Blip2Processor, DefaultDataCollator, TrainingArguments, Trainer, Blip2ForConditionalGeneration

# model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base", cache_dir="/home/congnguyen/drive/.cache/")
# processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base", cache_dir="/home/congnguyen/drive/.cache/")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b", cache_dir="/home/congnguyen/drive/.cache/")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", cache_dir="/home/congnguyen/drive/.cache/", torch_dtype=torch.float64)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [2]:
model.to(device)

Blip2ForConditionalGeneration(
  (vision_model): Blip2VisionModel(
    (embeddings): Blip2VisionEmbeddings(
      (patch_embedding): Conv2d(3, 1408, kernel_size=(14, 14), stride=(14, 14))
    )
    (encoder): Blip2Encoder(
      (layers): ModuleList(
        (0): Blip2EncoderLayer(
          (self_attn): Blip2Attention(
            (dropout): Dropout(p=0.0, inplace=False)
            (qkv): Linear(in_features=1408, out_features=4224, bias=True)
            (projection): Linear(in_features=1408, out_features=1408, bias=True)
          )
          (layer_norm1): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
          (mlp): Blip2MLP(
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=1408, out_features=6144, bias=True)
            (fc2): Linear(in_features=6144, out_features=1408, bias=True)
          )
          (layer_norm2): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
        )
        (1): Blip2EncoderLayer(
          (self_attn): 

In [3]:
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.data import Dataset
import torch 

class VQADataset(torch.utils.data.Dataset):
    def __init__(self, images, questions, answers, processor):
        self.images = images
        self.questions = questions
        self.answers = answers
        self.processor = processor
        self.max_length = 8
        # self.image_height = 128
        # self.image_width = 128

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        # get image + text
        answers = self.answers[idx]
        questions = self.questions[idx]
        image = Image.open(self.images[idx]).convert("RGB")
        text = self.questions[idx]

        image_encoding = self.processor(image,
                                  do_resize=True,
                                  # size=(self.image_height,self.image_width),
                                  return_tensors="pt")

        encoding = self.processor(
                                  None,
                                  text,
                                  padding="max_length",
                                  truncation=True,
                                  max_length = self.max_length,
                                  return_tensors="pt"
                                  )
        # # remove batch dimension
        for k,v in encoding.items():
            encoding[k] = v.squeeze()
        encoding["pixel_values"] = image_encoding["pixel_values"][0]
        # # add labels
        labels = self.processor.tokenizer.encode(
            answers,
            max_length= self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors='pt'
        )[0]
        encoding["labels"] = labels

        return encoding

def collate_fn(batch):
    input_ids = [item['input_ids'] for item in batch]
    pixel_values = [item['pixel_values'] for item in batch]
    attention_mask = [item['attention_mask'] for item in batch]
    labels = [item['labels'] for item in batch]
    # create new batch
    batch = {}
    batch['input_ids'] = torch.stack(input_ids)
    batch['attention_mask'] = torch.stack(attention_mask)
    batch['pixel_values'] = torch.stack(pixel_values)
    batch['labels'] = torch.stack(labels)

    return batch

In [8]:
from datasets import load_dataset
dataset = load_dataset("json", data_files="data/train.jsonl", split="train")

questions = [item for item in dataset["question"]]
images = [
        f"data/train_fill_in_blank/train_fill_in_blank/{pid}/image.png" for pid in dataset["pid"]
    ] 
answers = [item for item in dataset["answer"]]

dataset = VQADataset(questions = questions,
                          answers = answers,
                          images = images,
                          processor=processor)
# train_set, val_set = torch.utils.data.random_split(dataset, [13549, 1000])

# test_dataset = VQADataset(questions = questions,
#                           answers = answers,
#                           image_paths = images,
#                           processor=processor)

batch_size = 1
# train_dataloader = DataLoader(train_set, collate_fn=collate_fn, batch_size=batch_size, shuffle=False)
train_dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False)
# test_dataloader = DataLoader(val_set, collate_fn=collate_fn, batch_size=batch_size, shuffle=False, num_workers=0)

In [9]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
    print(k, v.shape)

input_ids torch.Size([1, 8])
attention_mask torch.Size([1, 8])
pixel_values torch.Size([1, 3, 224, 224])
labels torch.Size([1, 8])


In [10]:
from tqdm.notebook import tqdm
import torch
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.train()

Blip2ForConditionalGeneration(
  (vision_model): Blip2VisionModel(
    (embeddings): Blip2VisionEmbeddings(
      (patch_embedding): Conv2d(3, 1408, kernel_size=(14, 14), stride=(14, 14))
    )
    (encoder): Blip2Encoder(
      (layers): ModuleList(
        (0): Blip2EncoderLayer(
          (self_attn): Blip2Attention(
            (dropout): Dropout(p=0.0, inplace=False)
            (qkv): Linear(in_features=1408, out_features=4224, bias=True)
            (projection): Linear(in_features=1408, out_features=1408, bias=True)
          )
          (layer_norm1): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
          (mlp): Blip2MLP(
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=1408, out_features=6144, bias=True)
            (fc2): Linear(in_features=6144, out_features=1408, bias=True)
          )
          (layer_norm2): LayerNorm((1408,), eps=1e-06, elementwise_affine=True)
        )
        (1): Blip2EncoderLayer(
          (self_attn): 

In [11]:
for epoch in range(100):
    print(f"Epoch: {epoch}")
    total_loss = []
    for batch in tqdm(train_dataloader):
        # get the inputs;
        batch = {k:v.to(device) for k,v in batch.items()}

        # zero the parameter gradients
        optimizer.zero_grad()
        # forward + backward + optimize
        with torch.autocast(device_type='cuda', dtype=torch.float64):
            outputs = model(**batch)
        loss = outputs.loss
        total_loss.append(loss.item())
        loss.backward()
        optimizer.step()
    print("Loss:", sum(total_loss))
    if epoch % 5 == 0 or epoch > 90:
        torch.save(model, "./model/BLIP2/checkpoint_"+str(epoch))

Epoch: 0


  0%|          | 0/14549 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 47.46 GiB total capacity; 46.10 GiB already allocated; 3.56 MiB free; 46.41 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF