# BLIP finetuning to target task

TODO: rewrite this to reflect the latest changes

Sample is formed from a single row of dataset:
$$\operatorname{batch} = ((E_t, E_{i_0}), (E_t, E_{i_1}), ..., (E_t, E_{i_9})); \operatorname{batch} : R^{10 \times (E_t + E_i)}$$
ITM predicts probas for $y = 0$, $y = 1$
$$\operatorname{ITM} : R^{10 \times (E_t + E_i)} \rightarrow R^{10 \times 2}$$
Model is defined as:
$$\operatorname{F} = \operatorname{softmax} \circ \operatorname{ITM} \circ \operatorname{batch}$$
$$\operatorname{F} : R^{10 \times (E_t + E_i)} \rightarrow R^{10}$$
So, this definition is for a single row

In [4]:
from pathlib import Path
import logging
import json
from typing import *
import time

import pandas as pd
import numpy as np
import torch
from PIL import Image, ImageFile
import torch.nn as nn
from lavis.models import load_model_and_preprocess, BlipBase
from lavis.processors import load_processor
import torch.nn.functional as F
from transformers import get_linear_schedule_with_warmup
from transformers import BatchEncoding
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from sklearn.metrics import top_k_accuracy_score

from src.data import CustomSplitLoader
from src.utils import evaluate, mrr

## Config

Versioning

In [5]:
HEAD = "itm" # "itm" | "itc" | "mean"
MODEL_VERSION = 7

Paths resolution:

In [6]:
DATASET_VERSION = "v1"
PART = "train"
PATH = Path("data").resolve() / f"{PART}_{DATASET_VERSION}"
DATA_PATH = PATH / f"{PART}.data.{DATASET_VERSION}.txt"
LABELS_PATH = PATH / f"{PART}.gold.{DATASET_VERSION}.txt"
IMAGES_PATH = PATH / f"{PART}_images_{DATASET_VERSION}"
TRAIN_SPLIT_PATH = PATH / "split_train.txt"
VALIDATION_SPLIT_PATH = PATH / "split_valid.txt"
TEST_SPLIT_PATH = PATH / "split_test.txt"
SAVE_CHECKPOINT_PATH = Path("checkpoints").resolve() / f"BLIP-{HEAD}-{MODEL_VERSION}"
SAVE_CHECKPOINT_PATH.mkdir(parents=True, exist_ok=True)

Environment settings:

In [7]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# some images from train might not load without the following settings or warnings would be thrown
Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True

writer = SummaryWriter()

In [8]:
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE)
# WARNING: this is specific to my setup
DEVICE = torch.device("cuda:1") # 3060
# DEVICE = torch.device("cuda:1")
# a more conventional way to do this is:
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_WORKERS = 32
print(f"Running on {DEVICE}")

Running on cuda:1


Model & training settings

In [9]:
BLIP_VARIANT = "base" # "base" | "large"
NUM_EPOCHS = 15
NUM_PICS = 10
WARMUP_STEPS_FRAC = 0.1
STEPS_BETWEEN_EVAL = 100
GRAD_ACCUM_STEPS = 15
SAVE_CHECKPOINT_STEPS = STEPS_BETWEEN_EVAL
LR = 1e-5
WEIGHT_DECAY = 0.001
TRAIN_BATCH_SIZE = 1
VALIDATION_BATCH_SIZE = 3
HEAD_SUM_BIAS_ENABLED = True
TRAIN_EFFECTIVE_BATCH_SIZE = GRAD_ACCUM_STEPS * TRAIN_BATCH_SIZE
TRAIN_EFFECTIVE_BATCH_SIZE

15

## Loading data

In [10]:
df = pd.read_csv(DATA_PATH, sep='\t', header=None)
df.columns = ["word", "context"] + [f"image{i}" for i in range(NUM_PICS)]
df["label"] = pd.read_csv(LABELS_PATH, sep='\t', header=None)

train_df = df.loc[pd.read_csv(TRAIN_SPLIT_PATH, sep='\t', header=None).T.values[0]]
validation_df = df.loc[pd.read_csv(VALIDATION_SPLIT_PATH, sep='\t', header=None).T.values[0]]
test_df = df.loc[pd.read_csv(TEST_SPLIT_PATH, sep='\t', header=None).T.values[0]]

## Preprocessing

In [11]:
def infinite_repeat(value):
    while True:
        yield value

def concat_iters(*iterables):
    for it in iterables:
        for value in it:
            yield value

In [12]:
class ItmDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        df: pd.DataFrame,
        images_path: Path,
        text_processor,
        vis_processor,
        use_context_as_text: bool = True,
        enable_cache: bool = False,
    ) -> None:
        self.df = df
        self.images_path = images_path
        self.text_processor = text_processor
        self.vis_processor = vis_processor
        self.tokens_cache = dict()
        self.image_tensor_cache = dict()
        self.enable_cache = enable_cache
        self.text_field = "context" if use_context_as_text else "word"
        self.labels_map = self._gen_labels()

    def _gen_labels(self) -> Dict[int, int]: # index to label
        labels = self.df["label"].values
        zips = []
        for i in range(NUM_PICS):
            images = self.df[f"image{i}"].values
            zips.append(zip(np.argwhere(labels == images).reshape(-1), infinite_repeat(i)))
        return dict(concat_iters(*tuple(zips)))
    
    def __len__(self) -> int:
        return len(self.df)
    
    def _make_image_tensor(self, name: str) -> torch.Tensor:
        return self.vis_processor(Image.open(self.images_path / name).convert("RGB"))

    def _get_image_tensor(self, name: str) -> Image:
        if not self.enable_cache:
            return self._make_image_tensor(name)
        if name in self.image_tensor_cache:
            return self.image_tensor_cache[name]
        t = self._make_image_tensor(name)
        self.image_tensor_cache[name] = t
        return t

    def _get_image_batch(self, idx: int) -> torch.Tensor:
        row = self.df.iloc[idx]
        return torch.stack([self._get_image_tensor(row[f"image{i}"]) for i in range(NUM_PICS)])

    def _make_tokens(self, idx: int) -> BatchEncoding:
        return self.text_processor(self.df.iloc[idx][self.text_field])
    
    def _get_tokens(self, idx: int) -> BatchEncoding:
        if not self.enable_cache:
            return self._make_tokens(idx)
        if idx in self.tokens_cache:
            return self.tokens_cache[idx]
        t = self._make_tokens(idx)
        self.tokens_cache[idx] = t
        return t
    
    def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, BatchEncoding, int]]:
        # makes a batch for each row!
        return {
            "text": self._get_tokens(idx),
            "images": self._get_image_batch(idx),
            "label": self.labels_map[idx],
        }

In [13]:
blip_model, vis_processors, text_processors = load_model_and_preprocess("blip_image_text_matching", BLIP_VARIANT, is_eval=True)

INFO:root:Missing keys []
INFO:root:load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth


In [14]:
train_ds = ItmDataset(
    df=train_df,
    images_path=IMAGES_PATH,
    text_processor=text_processors["eval"],
    vis_processor=vis_processors["eval"],
)
val_ds = ItmDataset(
    df=validation_df,
    images_path=IMAGES_PATH,
    text_processor=text_processors["eval"],
    vis_processor=vis_processors["eval"],
)
test_ds = ItmDataset(
    df=test_df,
    images_path=IMAGES_PATH,
    text_processor=text_processors["eval"],
    vis_processor=vis_processors["eval"],
)

In [15]:
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=TRAIN_BATCH_SIZE, num_workers=NUM_WORKERS, persistent_workers=True)
train_l = len(train_dl)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=VALIDATION_BATCH_SIZE, num_workers=NUM_WORKERS, persistent_workers=True)
val_l = len(val_dl)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=1, num_workers=NUM_WORKERS, persistent_workers=True, shuffle=False)
test_l = len(test_dl)

In [16]:
def to_device(object, device):
    if not isinstance(object, dict):
        raise NotImplementedError("Implement other types than dict if needed!")
    return {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in object.items()}

## Model setup

In [17]:
class Classifier(nn.Module):
    def __init__(
        self,
        blip_model: BlipBase,
        match_head: str = "itm",
        head_sum_bias_enabled: bool = True
    ) -> None:
        super().__init__()
        self.blip_model = blip_model
        self.match_head = match_head
        if self.match_head == "mean":
            self.head_combiner = nn.Linear(2, 1, bias=head_sum_bias_enabled)

    def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
        # TODO: move all this batch-dependent stuff to collate_fn?
        # TODO: optimize!
        # text: str
        # image: 
        images_shape = inputs["images"].shape # image: (B, NUM_PICS, C, H, W)
        batch_size = images_shape[0]
        text_input = []
        for t in inputs["text"]:
            for _ in range(NUM_PICS):
                text_input.append(t)
        # TODO: 10 -> X
        images_input = inputs["images"].reshape(batch_size * NUM_PICS, images_shape[2], images_shape[3], images_shape[4]) # image: (B * NUM_PICS, C, H, W)
        # (B * X, 2)
        if self.match_head == "itm":
            batch_outputs = self.blip_model({"text_input": text_input, "image": images_input}, match_head=self.match_head).reshape(batch_size, NUM_PICS, 2) # (B, NUM_PICS, 2)
            batch_probas = F.softmax(batch_outputs[:, :, 1], dim=1)
        elif self.match_head == "itc":
            batch_outputs = self.blip_model({"text_input": text_input, "image": images_input}, match_head=self.match_head).reshape(batch_size, NUM_PICS) # (B * NUM_PICS) -> (B, NUM_PICS)
            # hugginface VisionTextEncoder see cos * N before softmax
            # TODO: N as hyperparam const
            # TODO: or learnable param
            batch_probas = F.softmax(batch_outputs, dim=1) # softmax(cosine similarity) => =(
        elif self.match_head == "mean":
            raise NotImplementedError("Implement me!")
            # Warning: not tested
            # TODO: replace with mean(p_itm, p_itc)
            # itm_batch_outputs = self.blip_model({"text_input": text_input, "image": images_input}, match_head="itm").reshape(batch_size, 10, 2)
            # itc_batch_outputs = self.blip_model({"text_input": text_input, "image": images_input}, match_head="itc").reshape(batch_size, 10)
            # dual_probas = torch.stack([
            #     torch.stack([F.softmax(itm_batch_outputs[i, :, 1], dim=0) for i in range(batch_size)]).reshape(batch_size * 10),
            #     torch.stack([F.softmax(batch_outputs[i, :], dim=0) for i in range(batch_size)]).reshape(batch_size * 10),
            # ], dim=1)
            # batch_ouputs = F.softmax(self.head_combiner(dual_probas), dim=0).reshape(batch_size, 10)
        else:
            raise ValueError(f"Unexpected value for match_head parameter \"{self.match_head}\". Allowed values: \"itm\", \"itc\" or \"mean\".")
        return batch_probas 

In [18]:
model = Classifier(blip_model).to(DEVICE)

## Training

In [19]:
metric2name = {
    "acc1": "Accuracy@Top1",
    "acc3": "Accuracy@Top3",
    "mrr": "Mean Reciprocal Rank",
}

labels_range = np.arange(NUM_PICS)

def eval_batch(labels, preds):
    labels = labels.numpy(force=True)
    preds = preds.numpy(force=True)
    return {
        "acc1": top_k_accuracy_score(labels, preds, k=1, labels=labels_range), 
        "acc3": top_k_accuracy_score(labels, preds, k=3, labels=labels_range),
        "mrr": mrr(labels, preds),
    }

def sum_scores(scores, new_scores):
    return {k: scores[k] + new_scores[k] for k in scores}

def div_scores(scores, n):
    return {k: v / n for k, v in scores.items()}

In [20]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
num_training_steps = int(NUM_EPOCHS * (train_l / GRAD_ACCUM_STEPS))
num_warmup_steps = int(num_training_steps * WARMUP_STEPS_FRAC)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps,
)
print(f"{num_training_steps} training steps which include {num_warmup_steps} warmup ones")

6103 training steps which include 610 warmup ones


In [21]:
# TODO: implement validation of untuned model here

In [22]:
step_num = 0
steps_since_last_eval = 0
grad_accum_step_cnt = 0
save_checkpoint_step_cnt = 0
progress_bar = tqdm(range(num_training_steps))

for epoch_num in range(NUM_EPOCHS):
    model.train()
    train_loss = 0.0
    train_scores = {"acc1": 0, "acc3": 0, "mrr": 0}
    for batch in train_dl:
        outputs = model(to_device(batch, DEVICE))
        loss = loss_fn(outputs, F.one_hot(batch["label"], 10).float().to(DEVICE))
        train_loss += loss.item()
        new_scores = eval_batch(batch["label"], outputs)
        train_scores = sum_scores(train_scores, new_scores)
        loss.backward()
        grad_accum_step_cnt += 1

        if grad_accum_step_cnt == GRAD_ACCUM_STEPS: 
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            writer.add_scalar("Loss/Train", train_loss / TRAIN_EFFECTIVE_BATCH_SIZE, step_num)            
            for k, v in div_scores(train_scores, GRAD_ACCUM_STEPS).items():
                writer.add_scalar(metric2name[k] + "/Train", v, step_num)
            train_loss = 0.0
            train_scores = {"acc1": 0, "acc3": 0, "mrr": 0}
            grad_accum_step_cnt = 0
            step_num += 1
            steps_since_last_eval += 1
            save_checkpoint_step_cnt += 1
            progress_bar.update(1)

        if steps_since_last_eval == STEPS_BETWEEN_EVAL: # add 0-th step
            model.eval()
            val_loss = 0.0
            val_scores = {"acc1": 0, "acc3": 0, "mrr": 0}
            with torch.no_grad():
                for batch in val_dl:
                    outputs = model(to_device(batch, DEVICE))
                    loss = loss_fn(outputs, F.one_hot(batch["label"], 10).float().to(DEVICE))
                    val_loss += loss.item()
                    new_scores = eval_batch(batch["label"], outputs)
                    val_scores = sum_scores(val_scores, new_scores)
            writer.add_scalar("Loss/Validation", val_loss / val_l, step_num)            
            for k, v in div_scores(val_scores, val_l).items():
                writer.add_scalar(metric2name[k] + "/Validation", v, step_num)
            model.train()
            steps_since_last_eval = 0
        
        if save_checkpoint_step_cnt == SAVE_CHECKPOINT_STEPS:
            save_checkpoint_step_cnt = 0
            p = SAVE_CHECKPOINT_PATH / f"step-{step_num}.pt"
            logging.info(f"[{epoch_num}:{step_num}] Saving checkpoint to \"{str(p)}\"")
            torch.save(model.state_dict(), p)

INFO:root:[0:100] Saving checkpoint to "/home/s1m00n/research/vwsd/checkpoints/BLIP-itm-7/step-100.pt"
  3%|▎         | 200/6103 [50:24<17:43:07, 10.81s/it]INFO:root:[0:200] Saving checkpoint to "/home/s1m00n/research/vwsd/checkpoints/BLIP-itm-7/step-200.pt"
  5%|▍         | 300/6103 [1:22:41<17:26:29, 10.82s/it]INFO:root:[0:300] Saving checkpoint to "/home/s1m00n/research/vwsd/checkpoints/BLIP-itm-7/step-300.pt"
  7%|▋         | 400/6103 [1:54:56<17:07:06, 10.81s/it]INFO:root:[0:400] Saving checkpoint to "/home/s1m00n/research/vwsd/checkpoints/BLIP-itm-7/step-400.pt"
  8%|▊         | 500/6103 [2:27:13<16:49:08, 10.81s/it]  INFO:root:[1:500] Saving checkpoint to "/home/s1m00n/research/vwsd/checkpoints/BLIP-itm-7/step-500.pt"
 10%|▉         | 600/6103 [2:59:32<16:32:19, 10.82s/it]  INFO:root:[1:600] Saving checkpoint to "/home/s1m00n/research/vwsd/checkpoints/BLIP-itm-7/step-600.pt"
 11%|█▏        | 700/6103 [3:31:59<16:21:03, 10.89s/it]  INFO:root:[1:700] Saving checkpoint to "/home/s1

## Evaluation

Here, let's load the best checkpoint according to Tensorboard

In [23]:
CHECKPOINT_NUM = 6100

In [None]:
checkpoint = Classifier(blip_model).to(DEVICE)
checkpoint.load_state_dict(torch.load(SAVE_CHECKPOINT_PATH / f"step-{CHECKPOINT_NUM}.pt"))
checkpoint.eval()

In [25]:
predictions = []
with torch.no_grad():
    for (i, batch) in enumerate(tqdm(test_dl)):
        preds = checkpoint(to_device(batch, DEVICE))[0].numpy(force=True)
        row = test_df.iloc[i]
        predictions.append({row[f"image{j}"]: preds[j] for j in range(10)})

100%|██████████| 3356/3356 [13:34<00:00,  4.12it/s]


In [26]:
evaluate(
    test_df.iloc[:, 2:-1].values,
    test_df["label"].values.reshape(-1, 1),
    predictions,
)

{'acc1': 0.8367103694874851,
 'acc3': 0.9737783075089392,
 'mrr': 0.9047753845280662}

In [31]:
# creates a file in <project root>/data with submissions in target format
with open(PATH / f"blip-{HEAD}-{MODEL_VERSION}-{CHECKPOINT_NUM}_submission.json", 'w') as f:
    json.dump([{k: str(v) for k, v in p.items()} for p in predictions], f, indent=2)