# BLIP ITM finetuning to target task

Batch is formed from a single row of dataset:
$$\operatorname{batch} = ((E_t, E_{i_0}), (E_t, E_{i_1}), ..., (E_t, E_{i_9})); \operatorname{batch} : R^{10 \times (E_t + E_i)}$$
ITM predicts probas for $y = 0$, $y = 1$
$$\operatorname{ITM} : R^{10 \times (E_t + E_i)} \rightarrow R^{10 \times 2}$$
Model is defined as:
$$\operatorname{F} = \operatorname{softmax} \circ \operatorname{ITM} \circ \operatorname{batch}$$
$$\operatorname{F} : R^{10 \times (E_t + E_i)} \rightarrow R^{10}$$
So, this definition is for a single row

In [21]:
from pathlib import Path
import logging
import json
from typing import *
import time

import pandas as pd
import numpy as np
import torch
from PIL import Image, ImageFile
import torch.nn as nn
from lavis.models import load_model_and_preprocess, BlipBase
from lavis.processors import load_processor
import torch.nn.functional as F
from transformers import get_linear_schedule_with_warmup
from torchmetrics.functional import retrieval_reciprocal_rank, retrieval_hit_rate
from transformers import BatchEncoding
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from src.data import CustomSplitLoader
from src.utils import evaluate

## Config

Paths resolution:

In [2]:
DATASET_VERSION = "v1"
PART = "train"
PATH = Path("data").resolve() / f"{PART}_{DATASET_VERSION}"
DATA_PATH = PATH / f"{PART}.data.{DATASET_VERSION}.txt"
LABELS_PATH = PATH / f"{PART}.gold.{DATASET_VERSION}.txt"
IMAGES_PATH = PATH / f"{PART}_images_{DATASET_VERSION}"
TRAIN_SPLIT_PATH = PATH / "split_train.txt"
VALIDATION_SPLIT_PATH = PATH / "split_valid.txt"
TEST_SPLIT_PATH = PATH / "split_test.txt"
SAVE_CHECKPOINT_PATH = Path("checkpoints").resolve() / "BLIP-ITM-2"
SAVE_CHECKPOINT_PATH.mkdir(parents=True, exist_ok=True)

Environment settings:

In [3]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# some images from train might not load without the following settings or warnings would be thrown
Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True

writer = SummaryWriter()

In [4]:
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on {DEVICE}")

Running on cuda


Model & training settings

In [5]:
BLIP_VARIANT = "base" # "base" | "large"
NUM_EPOCHS = 3
WARMUP_STEPS_FRAC = 0.1
STEPS_BETWEEN_EVAL = 25
GRAD_ACCUM_STEPS = 32
SAVE_CHECKPOINT_STEPS = STEPS_BETWEEN_EVAL 
LR = 1e-5

## Loading data

In [6]:
df = pd.read_csv(DATA_PATH, sep='\t', header=None)
df.columns = ["word", "context"] + [f"image{i}" for i in range(10)]
df["label"] = pd.read_csv(LABELS_PATH, sep='\t', header=None)

train_df = df.loc[pd.read_csv(TRAIN_SPLIT_PATH, sep='\t', header=None).T.values[0]]
validation_df = df.loc[pd.read_csv(VALIDATION_SPLIT_PATH, sep='\t', header=None).T.values[0]]
test_df = df.loc[pd.read_csv(TEST_SPLIT_PATH, sep='\t', header=None).T.values[0]]

## Preprocessing

In [10]:
def infinite_repeat(value):
    while True:
        yield value

def concat_iters(*iterables):
    for it in iterables:
        for value in it:
            yield value

In [11]:
class ItmDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        df: pd.DataFrame,
        images_path: Path,
        text_processor,
        vis_processor,
        use_context_as_text: bool = True,
        enable_cache: bool = True,
    ) -> None:
        self.df = df
        self.images_path = images_path
        self.text_processor = text_processor
        self.vis_processor = vis_processor
        self.tokens_cache = dict()
        self.image_batch_cache = dict()
        self.enable_cache = enable_cache
        self.text_field = "context" if use_context_as_text else "word"
        self.labels_map = self._gen_labels()

    def _gen_labels(self) -> Dict[int, int]: # index to label
        labels = self.df["label"].values
        zips = []
        for i in range(10):
            images = self.df[f"image{i}"].values
            zips.append(zip(np.argwhere(labels == images).reshape(-1), infinite_repeat(i)))
        return dict(concat_iters(*tuple(zips)))
    
    def __len__(self) -> int:
        return len(self.df)

    def _get_image(self, name: str) -> Image:
        return Image.open(self.images_path / name).convert("RGB")

    def _make_image_batch(self, idx: int) -> torch.Tensor:
        row = self.df.iloc[idx]
        return torch.stack([self.vis_processor(self._get_image(row[f"image{i}"])) for i in range(10)])

    def _get_image_batch(self, idx: int) -> torch.Tensor:
        if not self.enable_cache:
            return self._make_image_batch(idx)
        if idx in self.image_batch_cache:
            return self.image_batch_cache[idx]
        t = self._make_image_batch(idx)
        self.image_batch_cache[idx] = t
        return t
    
    def _make_tokens(self, idx: int) -> BatchEncoding:
        return self.text_processor(self.df.iloc[idx][self.text_field])
    
    def _get_tokens(self, idx: int) -> BatchEncoding:
        if not self.enable_cache:
            return self._make_tokens(idx)
        if idx in self.tokens_cache:
            return self.tokens_cache[idx]
        t = self._make_tokens(idx)
        self.tokens_cache[idx] = t
        return t
    
    
    def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, BatchEncoding, int]]:
        # makes a batch for each row!
        return {
            "text": self._get_tokens(idx),
            "images": self._get_image_batch(idx),
            "label": self.labels_map[idx],
        }

In [13]:
blip_model, vis_processors, text_processors = load_model_and_preprocess("blip_image_text_matching", BLIP_VARIANT, is_eval=True)

INFO:root:Missing keys []
INFO:root:load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth


In [14]:
train_ds = ItmDataset(
    df=train_df,
    images_path=IMAGES_PATH,
    text_processor=text_processors["eval"],
    vis_processor=vis_processors["eval"],
    enable_cache=False # eats up too much ram, whole 128GB!
)
val_ds = ItmDataset(
    df=validation_df,
    images_path=IMAGES_PATH,
    text_processor=text_processors["eval"],
    vis_processor=vis_processors["eval"],
    # here enable_cache is left as is (True), because we want fast & frequent validations
)
test_ds = ItmDataset(
    df=test_df,
    images_path=IMAGES_PATH,
    text_processor=text_processors["eval"],
    vis_processor=vis_processors["eval"],
    enable_cache=False,
)

In [15]:
def to_device(object, device):
    if not isinstance(object, dict):
        raise NotImplementedError("Implement other types than dict if needed!")
    return {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in object.items()}

def label2bool_tensor(label: int) -> torch.Tensor:
    t = torch.zeros(10, dtype=torch.bool)
    t[label] = True
    return t

## Model setup

In [16]:
class Classifier(nn.Module):
    def __init__(self, blip_model: BlipBase) -> None:
        super().__init__()
        self.blip_model = blip_model

    def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
        text_feats = inputs["text"]
        images_feats = inputs["images"]
        batch_outputs = self.blip_model({"text_input": [text_feats for _ in range(10)], "image": images_feats}, match_head="itm")
        return F.softmax(batch_outputs[:, 1], dim=0)

In [17]:
model = Classifier(blip_model).to(DEVICE)

## Training

In [18]:
metric2name = {
    "acc1": "Accuracy@Top1",
    "acc3": "Accuracy@Top3",
    "mrr": "Mean Reciprocal Rank",
}

def eval_single(model_outputs, one_hot_label_tensor):
    return {
        "acc1": retrieval_hit_rate(model_outputs, one_hot_label_tensor, 1).item(),
        "acc3": retrieval_hit_rate(model_outputs, one_hot_label_tensor, 3).item(),
        "mrr": retrieval_reciprocal_rank(model_outputs, one_hot_label_tensor).item(),
    }

def sum_scores(scores, new_scores):
    return {k: scores[k] + new_scores[k] for k in scores}

def div_scores(scores, n):
    return {k: v / n for k, v in scores.items()}

In [19]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
num_training_steps = int(NUM_EPOCHS * len(train_ds) / GRAD_ACCUM_STEPS)
num_warmup_steps = int(num_training_steps * WARMUP_STEPS_FRAC)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps,
)
print(f"{num_training_steps} training steps which include {num_warmup_steps} warmup ones")

572 training steps which include 57 warmup ones


In [22]:
step_num = 0
steps_since_last_eval = 0
grad_accum_step_cnt = 0
save_checkpoint_step_cnt = 0
progress_bar = tqdm(range(num_training_steps))

for epoch_num in range(NUM_EPOCHS):
    model.train()
    train_loss = 0.0
    train_scores = {"acc1": 0, "acc3": 0, "mrr": 0}
    for batch in train_ds:
        outputs = model(to_device(batch, DEVICE))
        loss = loss_fn(outputs, torch.tensor(batch["label"]).to(DEVICE))
        train_loss += loss.item()
        new_scores = eval_single(outputs, label2bool_tensor(batch["label"]).to(DEVICE))
        train_scores = sum_scores(train_scores, new_scores)
        loss.backward()
        grad_accum_step_cnt += 1

        if grad_accum_step_cnt == GRAD_ACCUM_STEPS:
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            writer.add_scalar("Loss/Train", float(train_loss / GRAD_ACCUM_STEPS), step_num)            
            for k, v in div_scores(train_scores, GRAD_ACCUM_STEPS).items():
                writer.add_scalar(metric2name[k] + "/Train", v, step_num)
            train_loss = 0.0
            train_scores = {"acc1": 0, "acc3": 0, "mrr": 0}
            grad_accum_step_cnt = 0
            step_num += 1
            steps_since_last_eval += 1
            save_checkpoint_step_cnt += 1
            progress_bar.update(1)

        if steps_since_last_eval == STEPS_BETWEEN_EVAL:
            model.eval()
            val_loss = 0.0
            val_scores = {"acc1": 0, "acc3": 0, "mrr": 0}
            with torch.no_grad():
                for batch in val_ds:
                    outputs = model(to_device(batch, DEVICE))
                    loss = loss_fn(outputs, torch.tensor(batch["label"]).to(DEVICE))
                    val_loss += loss.item()
                    new_scores = eval_single(outputs, label2bool_tensor(batch["label"]).to(DEVICE))
                    val_scores = sum_scores(val_scores, new_scores)
            writer.add_scalar("Loss/Validation", val_loss / len(val_ds), step_num)            
            for k, v in div_scores(val_scores, len(val_ds)).items():
                writer.add_scalar(metric2name[k] + "/Validation", v, step_num)
            model.train()
            steps_since_last_eval = 0
        
        if save_checkpoint_step_cnt == SAVE_CHECKPOINT_STEPS:
            save_checkpoint_step_cnt = 0
            p = SAVE_CHECKPOINT_PATH / f"step-{step_num}.pt"
            logging.info(f"[{epoch_num}:{step_num}] Saving checkpoint to \"{str(p)}\"")
            torch.save(model.state_dict(), p)

  9%|▊         | 50/572 [1:08:05<4:08:54, 28.61s/it] INFO:root:[0:50] Saved checkpoint at "/home/s1m00n/research/vwsd/checkpoints/BLIP-ITM-2/step-50.pt"
 17%|█▋        | 100/572 [1:58:06<3:44:47, 28.57s/it] INFO:root:[0:100] Saved checkpoint at "/home/s1m00n/research/vwsd/checkpoints/BLIP-ITM-2/step-100.pt"
 26%|██▌       | 150/572 [2:48:19<3:18:15, 28.19s/it]  INFO:root:[0:150] Saved checkpoint at "/home/s1m00n/research/vwsd/checkpoints/BLIP-ITM-2/step-150.pt"
 35%|███▍      | 200/572 [3:39:05<2:57:04, 28.56s/it]  INFO:root:[1:200] Saved checkpoint at "/home/s1m00n/research/vwsd/checkpoints/BLIP-ITM-2/step-200.pt"
 44%|████▎     | 250/572 [4:29:29<2:30:37, 28.07s/it]  INFO:root:[1:250] Saved checkpoint at "/home/s1m00n/research/vwsd/checkpoints/BLIP-ITM-2/step-250.pt"
 52%|█████▏    | 300/572 [5:19:32<2:05:54, 27.77s/it]  INFO:root:[1:300] Saved checkpoint at "/home/s1m00n/research/vwsd/checkpoints/BLIP-ITM-2/step-300.pt"
 61%|██████    | 350/572 [6:10:08<1:51:46, 30.21s/it]  INFO:roo

## Evaluation

Here, let's load the best checkpoint according to Tensorboard

In [None]:
checkpoint = Classifier(blip_model).to(DEVICE)
checkpoint.load_state_dict(torch.load(SAVE_CHECKPOINT_PATH / f"step-400.pt"))
checkpoint.eval()

In [35]:
predictions = []
with torch.no_grad():
    for (i, batch) in enumerate(tqdm(test_ds)):
        preds = checkpoint(to_device(batch, DEVICE)).numpy(force=True)
        row = test_df.iloc[i]
        predictions.append({row[f"image{j}"]: preds[j] for j in range(10)})

100%|██████████| 3356/3356 [42:21<00:00,  1.32it/s]


In [36]:
evaluate(
    test_df.iloc[:, 2:-1].values,
    test_df["label"].values.reshape(-1, 1),
    predictions,
)

{'acc1': 0.8268772348033373,
 'acc3': 0.9707985697258641,
 'mrr': 0.8989807404884878}

In [38]:
# creates a file in <project root>/data with submissions in target format
with open(PATH / "blip-itm-2-400_submission.json", 'w') as f:
    json.dump([{k: str(v) for k, v in p.items()} for p in predictions], f, indent=2)

Another checkpoint that might be great

In [39]:
checkpoint = Classifier(blip_model).to(DEVICE)
checkpoint.load_state_dict(torch.load(SAVE_CHECKPOINT_PATH / f"step-250.pt"))
checkpoint.eval()
predictions = []
with torch.no_grad():
    for (i, batch) in enumerate(tqdm(test_ds)):
        preds = checkpoint(to_device(batch, DEVICE)).numpy(force=True)
        row = test_df.iloc[i]
        predictions.append({row[f"image{j}"]: preds[j] for j in range(10)})
print(
    evaluate(
        test_df.iloc[:, 2:-1].values,
        test_df["label"].values.reshape(-1, 1),
        predictions,
    )
)
# creates a file in <project root>/data with submissions in target format
with open(PATH / "blip-itm-2-250_submission.json", 'w') as f:
    json.dump([{k: str(v) for k, v in p.items()} for p in predictions], f, indent=2)

100%|██████████| 3356/3356 [42:25<00:00,  1.32it/s]

{'acc1': 0.8247914183551848, 'acc3': 0.9687127532777116, 'mrr': 0.8965728285752124}



