In [1]:
from pathlib import Path
import logging
import json
from typing import *
import time
from datetime import datetime

import pandas as pd
import numpy as np
import torch
from PIL import Image, ImageFile
import torch.nn as nn
from lavis.models import load_model_and_preprocess, BlipBase
from lavis.processors import load_processor
import torch.nn.functional as F
from transformers import get_cosine_schedule_with_warmup
from transformers import BatchEncoding
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from sklearn.metrics import top_k_accuracy_score

from src.data import CustomSplitLoader
from src.utils import evaluate, mrr
# from src.itc import ClsITC, ClsITCBatchData, Temperature
from src.itm import AltNSDataset, to_device, ITMClassifier, DefaultDataset

## Config

Versioning

In [2]:
HEAD = "itm"
MODEL_VERSION = 19

Paths resolution:

In [3]:
DATASET_VERSION = "v1"
PART = "train"
PATH = Path("data").resolve() / f"{PART}_{DATASET_VERSION}"
DATA_PATH = PATH / f"{PART}.data.{DATASET_VERSION}.txt"
LABELS_PATH = PATH / f"{PART}.gold.{DATASET_VERSION}.txt"
IMAGES_PATH = PATH / f"{PART}_images_{DATASET_VERSION}"
TRAIN_SPLIT_PATH = PATH / "split_train.txt"
VALIDATION_SPLIT_PATH = PATH / "split_valid.txt"
VAL2_DATA_PATH = PATH / "valid2.data.v1.txt"
VAL2_GOLD_PATH = PATH / "valid2.gold.v1.txt"
TEST_SPLIT_PATH = PATH / "split_test.txt"
TEST2_DATA_PATH = PATH / "test2.data.v1.txt"
TEST2_GOLD_PATH = PATH / "test2.gold.v1.txt"
SAVE_CHECKPOINT_PATH = Path("checkpoints").resolve() / f"BLIP-{HEAD}-{MODEL_VERSION}"
SAVE_CHECKPOINT_PATH.mkdir(parents=True, exist_ok=True)
NUM_PICS = 10

Environment settings:

In [4]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# some images from train might not load without the following settings or warnings would be thrown
Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True

import warnings
warnings.filterwarnings('ignore')

writer = SummaryWriter(f"runs/blip-{HEAD}-{MODEL_VERSION} (ran at {datetime.now()})")

In [5]:
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE)
# WARNING: this is specific to my setup
DEVICE = torch.device("cuda:0")
# a more conventional way to do this is:
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_WORKERS = 32
PERSISTENT_WORKERS = True
print(f"Running on {DEVICE}")

Running on cuda:0


Model & training settings

In [6]:
BLIP_VARIANT = "base" # "base" | "large"
# NUM_NS = 5
NUM_EPOCHS = 20
WARMUP_STEPS_FRAC = 0.1
GRAD_ACCUM_STEPS = 30
LR = 1e-5
WEIGHT_DECAY = 0.1
TRAIN_BATCH_SIZE = 3
# cos lr scheduler

In [7]:
TRAIN_EFFECTIVE_BATCH_SIZE = GRAD_ACCUM_STEPS * TRAIN_BATCH_SIZE
# NUM_LABELS = NUM_NS + 1
NUM_LABELS = NUM_PICS
TRAIN_EFFECTIVE_BATCH_SIZE

90

In [8]:
STEPS_BETWEEN_VAL = 100
STEPS_BETWEEN_VAL2 = 100
SAVE_CHECKPOINT_STEPS = STEPS_BETWEEN_VAL
VALIDATION_BATCH_SIZE = 40
TEST_BATCH_SIZE = VALIDATION_BATCH_SIZE

## Loading data

In [9]:
df = pd.read_csv(DATA_PATH, sep='\t', header=None)
df.columns = ["word", "context"] + [f"image{i}" for i in range(NUM_PICS)]
df["label"] = pd.read_csv(LABELS_PATH, sep='\t', header=None)

train_df = df.loc[pd.read_csv(TRAIN_SPLIT_PATH, sep='\t', header=None).T.values[0]]
validation_df = df.loc[pd.read_csv(VALIDATION_SPLIT_PATH, sep='\t', header=None).T.values[0]]
test_df = df.loc[pd.read_csv(TEST_SPLIT_PATH, sep='\t', header=None).T.values[0]]

val2_df = pd.read_csv(VAL2_DATA_PATH, sep = '\t', header = None)
val2_df.columns = ["word", "context"] + [f"image{i}" for i in range(10)]
val2_df["label"] = pd.read_csv(VAL2_GOLD_PATH, sep = "\t", header = None)

test2_df = pd.read_csv(TEST2_DATA_PATH, sep = '\t', header = None)
test2_df.columns = ["word", "context"] + [f"image{i}" for i in range(10)]
test2_df["label"] = pd.read_csv(TEST2_GOLD_PATH, sep = "\t", header = None)

## Preprocessing

In [10]:
blip_model, vis_processors, text_processors = load_model_and_preprocess("blip_image_text_matching", BLIP_VARIANT, is_eval=True)

INFO:root:Missing keys []
INFO:root:load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth


In [11]:
train_ds = DefaultDataset(
    df=train_df,
    images_path=IMAGES_PATH,
    text_processor=text_processors["eval"],
    vis_processor=vis_processors["eval"],
)
val_ds = DefaultDataset(
    df=validation_df,
    images_path=IMAGES_PATH,
    text_processor=text_processors["eval"],
    vis_processor=vis_processors["eval"],
)
val2_ds = DefaultDataset(
    df=val2_df,
    images_path=IMAGES_PATH,
    text_processor=text_processors["eval"],
    vis_processor=vis_processors["eval"],
)

train_dl = torch.utils.data.DataLoader(train_ds, batch_size=TRAIN_BATCH_SIZE, num_workers=NUM_WORKERS, persistent_workers=True, shuffle=True)
train_l = len(train_dl)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=VALIDATION_BATCH_SIZE, num_workers=NUM_WORKERS, persistent_workers=True, shuffle=True)
val_l = len(val_dl)
val2_dl = torch.utils.data.DataLoader(val2_ds, batch_size=VALIDATION_BATCH_SIZE, num_workers=NUM_WORKERS, persistent_workers=True, shuffle=True)
val2_l = len(val2_dl)

train_l, val_l, val2_l

(2035, 86, 39)

## Model setup

In [12]:
model = ITMClassifier(blip_model).to(DEVICE)

## Training

In [16]:
metric2name = {
    "acc1": "Accuracy@Top1",
    "acc3": "Accuracy@Top3",
    "mrr": "Mean Reciprocal Rank",
}

def eval_batch(labels, preds, num_labels = NUM_PICS):
    labels_range = np.arange(num_labels)
    labels = labels.numpy(force=True)
    preds = preds.numpy(force=True)
    return {
        "acc1": top_k_accuracy_score(labels, preds, k=1, labels=labels_range), 
        "acc3": top_k_accuracy_score(labels, preds, k=3, labels=labels_range),
        "mrr": mrr(labels, preds),
    }

def sum_scores(scores, new_scores):
    return {k: scores[k] + new_scores[k] for k in scores}

def div_scores(scores, n):
    return {k: v / n for k, v in scores.items()}

In [20]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
num_training_steps = int(NUM_EPOCHS * (train_l / GRAD_ACCUM_STEPS))
num_warmup_steps = int(num_training_steps * WARMUP_STEPS_FRAC)
# lr_scheduler = get_linear_schedule_with_warmup(
#     optimizer=optimizer,
#     num_warmup_steps=num_warmup_steps,
#     num_training_steps=num_training_steps,
# )
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps,
)
print(f"{num_training_steps} training steps which include {num_warmup_steps} warmup ones")

1356 training steps which include 135 warmup ones


In [21]:
step_num = 0
steps_since_last_val = 0
steps_since_last_val2 = 0
grad_accum_step_cnt = 0
save_checkpoint_step_cnt = 0
progress_bar = tqdm(range(num_training_steps))

for epoch_num in range(NUM_EPOCHS):
    model.train()
    train_loss = 0.0
    train_scores = {"acc1": 0, "acc3": 0, "mrr": 0}
    for batch in train_dl:
        batch = to_device(batch, DEVICE)
        outputs = model(batch)
        loss = loss_fn(outputs, F.one_hot(batch["label"], NUM_LABELS).float().to(DEVICE))

        train_loss += loss.item()
        train_scores = sum_scores(train_scores, eval_batch(batch["label"], outputs, num_labels = NUM_LABELS))

        loss.backward()
        grad_accum_step_cnt += 1

        if grad_accum_step_cnt == GRAD_ACCUM_STEPS: 
            writer.add_scalar("Learning Rate", lr_scheduler.get_last_lr()[0], step_num)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            writer.add_scalar("Loss/Train", train_loss / GRAD_ACCUM_STEPS, step_num)            
            for k, v in div_scores(train_scores, GRAD_ACCUM_STEPS).items():
                writer.add_scalar(metric2name[k] + "/Train", v, step_num)
            train_loss = 0.0
            train_scores = {"acc1": 0, "acc3": 0, "mrr": 0}
            grad_accum_step_cnt = 0
            step_num += 1
            steps_since_last_val += 1
            steps_since_last_val2 += 1
            save_checkpoint_step_cnt += 1
            progress_bar.update(1)

        if steps_since_last_val == STEPS_BETWEEN_VAL:
            model.eval()
            val_loss = 0.0
            val_scores = {"acc1": 0, "acc3": 0, "mrr": 0}
            with torch.no_grad():
                for batch in val_dl:
                    batch = to_device(batch, DEVICE)
                    outputs = model(batch)
                    loss = loss_fn(outputs, F.one_hot(batch["label"], NUM_PICS).float().to(DEVICE))
                    val_loss += loss.item()
                    val_scores = sum_scores(val_scores, eval_batch(batch["label"], outputs))
            writer.add_scalar("Loss/Validation", val_loss / val_l, step_num) 
            for k, v in div_scores(val_scores, val_l).items():
                writer.add_scalar(metric2name[k] + "/Validation", v, step_num)
            model.train()
            steps_since_last_val = 0
        
        if steps_since_last_val2 == STEPS_BETWEEN_VAL2:
            model.eval()
            val_loss = 0.0
            val_scores = {"acc1": 0, "acc3": 0, "mrr": 0}
            with torch.no_grad():
                for batch in val2_dl:
                    batch = to_device(batch, DEVICE)
                    outputs = model(batch)
                    loss = loss_fn(outputs, F.one_hot(batch["label"], NUM_PICS).float().to(DEVICE))
                    val_loss += loss.item()
                    val_scores = sum_scores(val_scores, eval_batch(batch["label"], outputs))
            writer.add_scalar("Loss/Validation 2", val_loss / val2_l, step_num)            
            for k, v in div_scores(val_scores, val2_l).items():
                writer.add_scalar(metric2name[k] + "/Validation 2", v, step_num)
            model.train()
            steps_since_last_val2 = 0
        
        if save_checkpoint_step_cnt == SAVE_CHECKPOINT_STEPS:
            save_checkpoint_step_cnt = 0
            p = SAVE_CHECKPOINT_PATH / f"step-{step_num}.pt"
            torch.save(model.state_dict(), p)

 31%|███▏      | 424/1356 [3:39:00<6:31:35, 25.21s/it]  

## Evaluation

In [13]:
def predict_eval(
    model: ITMClassifier,
    dataframes: Dict[str, pd.DataFrame],
    images_path: Path,
    text_processor,
    vis_processor,
    batch_size: int = 1,
    num_workers: int = 0,
    persistent_workers: bool = True,
    device = torch.device("cpu"),
    preds_save_folder: Optional[Path] = None,
    preds_save_filename_prefix: str = "sample_predictions",
    preds_save_filename_add_timestamp: bool = True,
    verbose: bool = True,
) -> Tuple[Dict[str, Dict[str, np.ndarray]], Dict[str, Dict[str, float]]]:
    """
    Combines predictions for dataloader using checkpoint model with evaluation.

    Args:
        model (ITMClassifier): loaded classification model
        dataframes (pandas.DataFrame)): mapping of test set names to the dataframes
        verbose (bool): enables prints of metrics and progress tracking

    Returns:
        Tuple[Dict[str, Dict[str, np.ndarray]], Dict[str, Dict[str, float]]]: predictions and scores for the corresponding test sets
    """
    predictions = dict()
    evaluations = dict()
    for name, df in dataframes.items():
        if verbose:
            print(f"Generating predictions for \"{name}\"")
        ds = DefaultDataset(
            df=df,
            images_path=images_path,
            text_processor=text_processor,
            vis_processor=vis_processor,
        )
        dl = torch.utils.data.DataLoader(
            ds,
            batch_size = batch_size,
            shuffle = False,
            num_workers = num_workers,
            persistent_workers = persistent_workers,
        )
        preds = [] # list: 
        model = model.to(device)
        model.eval()
        i = 0
        with torch.no_grad():
            for batch in (tqdm(dl) if verbose else dl):
                batch = to_device(batch, device)
                for ps in model(batch).numpy(force=True): # ps - predictions for one row
                    row = df.iloc[i]
                    preds.append({row[f"image{j}"]: ps[j] for j in range(len(ps))})
                    i += 1
        predictions[name] = preds
        if preds_save_folder is not None:
            maybe_datetime = f"_at_{time.time()}_" if preds_save_filename_add_timestamp else "_"
            filename = f"{preds_save_filename_prefix}_on_{name}{maybe_datetime}submission.json"
            if verbose:
                print(f"Saving predictions for \"{name}\" as \"{filename}\"")
            with open(PATH / filename, "w") as f:
                json.dump([{k: str(v) for k, v in p.items()} for p in preds], f, indent=2)
        if verbose:
            print(f"Metrics for \"{name}\":")
        evals = evaluate(
            df.iloc[:, 2:-1].values,
            df["label"].values.reshape(-1, 1),
            preds,
        )
        if verbose:
            for metric_id, metric_value in evals.items():
                metric_name = metric2name[metric_id]
                print(f"    {metric_name}: {metric_value}")
        evaluations[name] = evals
    return predictions, evaluations

In [14]:
CHECKPOINTS = [200, 600, 700, 1000, 1200] # fill this out with checkpoints of interest (use Tensorboard)
TEST_DFS = {
    "test": test_df,
    "test2": test2_df,
}

In [17]:
results = dict() # int -> tuple(preds, evals)
for checkpoint_num in CHECKPOINTS:
    print(f"Processing checkpoint {checkpoint_num}")
    model.load_state_dict(torch.load(SAVE_CHECKPOINT_PATH / f"step-{checkpoint_num}.pt"))
    results[checkpoint_num] = predict_eval(
        model = model,
        verbose = True,
        preds_save_filename_prefix = f"blip-{HEAD}-{MODEL_VERSION}-{checkpoint_num}",
        preds_save_folder = PATH,
        device = DEVICE,
        persistent_workers = PERSISTENT_WORKERS,
        num_workers = NUM_WORKERS,
        batch_size = TEST_BATCH_SIZE,
        text_processor=text_processors["eval"],
        vis_processor=vis_processors["eval"],
        dataframes = TEST_DFS,
        images_path = IMAGES_PATH,
    )

Processing checkpoint 200
Generating predictions for "test"


100%|██████████| 84/84 [06:24<00:00,  4.58s/it] 


Saving predictions for "test" as "blip-itm-19-200_on_test_at_1671282958.672083_submission.json"
Metrics for "test":
    Accuracy@Top1: 0.8224076281287247
    Accuracy@Top3: 0.966626936829559
    Mean Reciprocal Rank: 0.8949462701250543
Generating predictions for "test2"


100%|██████████| 39/39 [03:15<00:00,  5.02s/it]


Saving predictions for "test2" as "blip-itm-19-200_on_test2_at_1671283155.266982_submission.json"
Metrics for "test2":
    Accuracy@Top1: 0.7626683771648493
    Accuracy@Top3: 0.9236690186016677
    Mean Reciprocal Rank: 0.8499302564729121
Processing checkpoint 600
Generating predictions for "test"


100%|██████████| 84/84 [06:24<00:00,  4.57s/it] 


Saving predictions for "test" as "blip-itm-19-600_on_test_at_1671283540.8092375_submission.json"
Metrics for "test":
    Accuracy@Top1: 0.8295589988081049
    Accuracy@Top3: 0.9693087008343266
    Mean Reciprocal Rank: 0.8998445097148912
Generating predictions for "test2"


100%|██████████| 39/39 [03:17<00:00,  5.07s/it]


Saving predictions for "test2" as "blip-itm-19-600_on_test2_at_1671283739.4501984_submission.json"
Metrics for "test2":
    Accuracy@Top1: 0.7562540089801154
    Accuracy@Top3: 0.9236690186016677
    Mean Reciprocal Rank: 0.8467963794455136
Processing checkpoint 700
Generating predictions for "test"


100%|██████████| 84/84 [06:25<00:00,  4.59s/it] 


Saving predictions for "test" as "blip-itm-19-700_on_test_at_1671284126.3496537_submission.json"
Metrics for "test":
    Accuracy@Top1: 0.8253873659117997
    Accuracy@Top3: 0.9681168057210966
    Mean Reciprocal Rank: 0.8970350426622774
Generating predictions for "test2"


100%|██████████| 39/39 [03:17<00:00,  5.07s/it]


Saving predictions for "test2" as "blip-itm-19-700_on_test2_at_1671284325.2299027_submission.json"
Metrics for "test2":
    Accuracy@Top1: 0.7626683771648493
    Accuracy@Top3: 0.9211032713277743
    Mean Reciprocal Rank: 0.8509624097661301
Processing checkpoint 1000
Generating predictions for "test"


100%|██████████| 84/84 [06:28<00:00,  4.62s/it] 


Saving predictions for "test" as "blip-itm-19-1000_on_test_at_1671284714.8536747_submission.json"
Metrics for "test":
    Accuracy@Top1: 0.8238974970202623
    Accuracy@Top3: 0.9657330154946365
    Mean Reciprocal Rank: 0.8957359006375694
Generating predictions for "test2"


100%|██████████| 39/39 [03:17<00:00,  5.05s/it]


Saving predictions for "test2" as "blip-itm-19-1000_on_test2_at_1671284913.0068724_submission.json"
Metrics for "test2":
    Accuracy@Top1: 0.7556125721616421
    Accuracy@Top3: 0.9211032713277743
    Mean Reciprocal Rank: 0.8462516672266512
Processing checkpoint 1200
Generating predictions for "test"


100%|██████████| 84/84 [06:27<00:00,  4.61s/it] 


Saving predictions for "test" as "blip-itm-19-1200_on_test_at_1671285302.0396414_submission.json"
Metrics for "test":
    Accuracy@Top1: 0.8238974970202623
    Accuracy@Top3: 0.9663289630512515
    Mean Reciprocal Rank: 0.8959423538982537
Generating predictions for "test2"


 79%|███████▉  | 31/39 [02:45<00:42,  5.35s/it]


KeyboardInterrupt: 

In [None]:
sums = {c: sum([sum(ms.values()) for ms in d[1].values()]) for c, d in results.items()}
best_checkpoint = max(sums, key=sums.get)
print(f"Best checkpoint (by sum of all scores) is {best_checkpoint} with results:")
results[best_checkpoint][1]