### Colab Setting
- torchscript.ipynb, torchscript_inference.ipynb 의 device 와 동일하게 설정 필수

In [None]:
from google.colab import drive
drive.mount('/content/drive')

visual7w_path = '/content/visual7w'
device = 'cuda'

In [None]:
!mkdir -p /content/visual7w
!wget -P /content/visual7w http://vision.stanford.edu/yukezhu/visual7w_images.zip
!wget -P /content/visual7w https://ai.stanford.edu/~yukez/papers/resources/dataset_v7w_telling.zip
!unzip -nq /content/visual7w/visual7w_images   -d /content/visual7w
!rm /content/visual7w/visual7w_images.zip
!unzip -nq /content/visual7w/dataset_v7w_telling -d /content/visual7w
!mv /content/visual7w/*telling*.json /content/visual7w/dataset_v7w_telling.json
!mv /content/visual7w/visual7w_images /content/visual7w/images
!rm /content/visual7w/dataset_v7w_telling.zip

%cd /content/drive/MyDrive/bridgeblip
!pip install -r requirements.txt

In [None]:
from transformers import InstructBlipForConditionalGeneration, InstructBlipProcessor

import torch
import torch.nn as nn

from peft import LoraConfig, get_peft_model

import random
import pandas as pd
import json
from PIL import Image

from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from scipy.ndimage import uniform_filter1d

### seed setting

In [None]:
random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

### Create Visual7W Dataset

In [None]:

dataframe = []

with open(visual7w_path + '/dataset_v7w_telling.json', 'r') as f:
    images = json.load(f)['images']

    for i in images:
        qa_pairs = i['qa_pairs']

        for q in qa_pairs:
            image    = q['image_id']
            question = q['question']
            choices  = [*q['multiple_choices'],q['answer']]

            random.shuffle(choices)

            answer_idx = choices.index(q['answer'])

            label = [0, 0, 0, 0]
            label[answer_idx] = 1

            dataframe.append({
                'image_id' : image,
                'question' : question,
                'choices'  : choices,
                'label'    : label
            })

dataframe = pd.DataFrame(dataframe)

### Create **LoRA-Bridge** InstructBlip

In [None]:
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print(f'total params      : {total_params:,}')
    print(f'trainable params  : {trainable_params:,}')
    print(f'trainable percent : {100. * trainable_params / total_params:.2f}%')

In [None]:
class ClassificationHead(nn.Module):
    def __init__(self, in_features, out_features, dropout):
        super().__init__()
        self.layer= nn.Sequential(
            nn.Linear(in_features, in_features),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(in_features, out_features)
        )
    def forward(self, x):
        return self.layer(x)

In [None]:
processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-flan-t5-xl")
base_blip = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-flan-t5-xl")

# freeze all params
for param in base_blip.parameters():
    param.requires_grad = False


# change language_model head
base_blip.language_model.lm_head = ClassificationHead(
    in_features=base_blip.language_model.config.d_model,
    out_features=4,
    dropout=0.0
)
for param in base_blip.language_model.lm_head.parameters():
    param.requires_grad = True


# slice decoder layer 24 to 8
base_blip.language_model.decoder.block = nn.ModuleList([
    base_blip.language_model.decoder.block[i] for i in [0, 3, 6, 9, 12, 15, 18, 21]
])

# re-connect with lora bridge
lora_config = LoraConfig(
    r=96,
    lora_alpha=192,
    target_modules=['q', 'k', 'v'],
    lora_dropout=0.1
)

base_blip.language_model.decoder = get_peft_model(base_blip.language_model.decoder, lora_config)

In [None]:
count_parameters(base_blip)

### Create Visual7W Dataset class

In [None]:
class Visual7W(Dataset):
   def __len__(self):
      return len(dataframe)
   
   def __getitem__(self, idx):
      row = dataframe.iloc[idx]

      image = Image.open(f'{visual7w_path}/images/v7w_{row["image_id"]}.jpg').convert("RGB")

      instructions = f'Question: {row["question"]} Options: {" ".join([f"({chr(i+97)}) {c}" for i, c in enumerate(row["choices"])])} Short answer:' # instructions from InstructBLIP paper
      inputs = processor(
         images=image,
         text=instructions,
         return_tensors="pt",
         padding='max_length',
         truncation=True,
         max_length=128,
      )
      return {
         'inputs' : inputs,
         'label'  : torch.tensor(row['label'], dtype=torch.float) 
      }

## **Train**

In [None]:
def train_epoch(model, loader, opt, criterion, device, start_token_id):
    model.train()

    loss_history = []
    bar = tqdm(loader)

    for batch in bar:

        inputs = { k:v.squeeze(1) for k, v in batch['inputs'].to(device).items() }
        labels = batch['label'].to(device)

        batch_size = labels.shape[0]
        decoder_input_ids = torch.full((batch_size, 1), start_token_id, dtype=torch.long).to(device)
        
        opt.zero_grad()

        logits = model(**inputs, decoder_input_ids=decoder_input_ids).logits.squeeze(1)
        target = torch.argmax(labels, dim=1)

        loss = criterion(logits, target)
        loss.backward()
        opt.step()


        bar.set_postfix(
            loss      = f'{loss.item():.4f}',
        )
        loss_history.append(loss.item())
    
    return loss_history

In [None]:
base_blip   = base_blip.to(device)
num_epochs  = 1
dataset     = Visual7W()
dataloader  = DataLoader(dataset, batch_size=12, shuffle=True)

optimizer   = torch.optim.AdamW([
    {'params': base_blip.language_model.decoder.parameters(), 'lr': 5e-5},
    {'params': base_blip.language_model.lm_head.parameters(), 'lr': 1e-4}
])

criterion   = nn.CrossEntropyLoss()

In [None]:
for ep in range(num_epochs):
    history = train_epoch(base_blip, dataloader, optimizer, criterion, device, base_blip.language_model.config.decoder_start_token_id)

    state = {
        "model": base_blip.state_dict(),
        "optimizer": optimizer.state_dict(),
        "epoch": ep,
    }
    
    torch.save(state, f"checkpoint/weights-{ep}.pt")

    plt.plot(uniform_filter1d(history, size=100))
    plt.ylim(0.3, 1.5)
    plt.legend()
    plt.show()
    