In [1]:
from pathlib import Path
import logging
import json
from typing import *
import time
from datetime import datetime
import warnings

from tqdm import tqdm
import pandas as pd
import numpy as np
from PIL import Image, ImageFile

import torch
import torch.nn as nn
import torchvision.transforms as T
from lavis.models import load_model_and_preprocess, BlipBase
from lavis.processors import load_processor
import torch.nn.functional as F
from transformers import get_cosine_schedule_with_warmup
from transformers import BatchEncoding

from src.data import CustomSplitLoader
from src.itm import AltNSDataset, to_device, ITMClassifier, DefaultDataset

from src.utils import evaluate, mrr
from src.validation import Validation, sum_scores, div_scores, eval_batch, metric2name
from sklearn.metrics import top_k_accuracy_score
from torch.utils.tensorboard import SummaryWriter

In [None]:
# TODO: all

## Config

Versioning

In [2]:
HEAD = "itm"
MODEL_VERSION = 28

Paths resolution:

In [3]:
DATASET_VERSION = "v1"
PART = "train"
PATH = Path("/home/s1m00n/research/vwsd/data").resolve() / f"{PART}_{DATASET_VERSION}"
DATA_PATH = PATH / f"{PART}.data.{DATASET_VERSION}.txt"
LABELS_PATH = PATH / f"{PART}.gold.{DATASET_VERSION}.txt"
IMAGES_PATH = PATH / f"{PART}_images_{DATASET_VERSION}"
TRAIN_SPLIT_PATH = PATH / "split_train.txt"
VALIDATION_SPLIT_PATH = PATH / "split_valid.txt"
VAL2_DATA_PATH = PATH / "valid2.data.v1.txt"
VAL2_GOLD_PATH = PATH / "valid2.gold.v1.txt"
TEST_SPLIT_PATH = PATH / "split_test.txt"
TEST2_DATA_PATH = PATH / "test2.data.v1.txt"
TEST2_GOLD_PATH = PATH / "test2.gold.v1.txt"
SAVE_CHECKPOINT_PATH = Path("/home/s1m00n/research/vwsd/checkpoints").resolve() / f"BLIP-{HEAD}-{MODEL_VERSION}" # TODO: maybe add timestamp?
SAVE_CHECKPOINT_PATH.mkdir(parents=True, exist_ok=True)
NUM_PICS = 10

Environment settings:

In [4]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# some images from train might not load without the following settings or warnings would be thrown
Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True

warnings.filterwarnings('ignore')

writer = SummaryWriter(f"/home/s1m00n/research/vwsd/runs/blip-{HEAD}-{MODEL_VERSION} (ran at {datetime.now()})")

In [5]:
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE)
# WARNING: this is specific to my setup
DEVICE = torch.device("cuda:0")
# a more conventional way to do this is:
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_WORKERS = 32
PERSISTENT_WORKERS = True
print(f"Running on {DEVICE}")

Running on cuda:0


Model & training settings

In [6]:
BLIP_VARIANT = "base" # "base" | "large"
NUM_NS = 9
NUM_EPOCHS = 20
WARMUP_STEPS_FRAC = 0.05
GRAD_ACCUM_STEPS = 40
LR = 1e-5
WEIGHT_DECAY = 0.1
TRAIN_BATCH_SIZE = 3
IMAGE_AUGMENTATION = True
VALIDATE_UNAUGMENTED = True
VALIDATE_AUGMENTED = True
# cos lr scheduler

In [7]:
TRAIN_EFFECTIVE_BATCH_SIZE = GRAD_ACCUM_STEPS * TRAIN_BATCH_SIZE
NUM_LABELS = NUM_NS + 1
# NUM_LABELS = NUM_PICS
TRAIN_EFFECTIVE_BATCH_SIZE

120

In [8]:
STEPS_BETWEEN_VAL = 100
SAVE_CHECKPOINT_STEPS = STEPS_BETWEEN_VAL
VALIDATION_BATCH_SIZE = 40
TEST_BATCH_SIZE = VALIDATION_BATCH_SIZE

## Loading data

In [9]:
df = pd.read_csv(DATA_PATH, sep='\t', header=None)
df.columns = ["word", "context"] + [f"image{i}" for i in range(NUM_PICS)]
df["label"] = pd.read_csv(LABELS_PATH, sep='\t', header=None)

train_df = df.loc[pd.read_csv(TRAIN_SPLIT_PATH, sep='\t', header=None).T.values[0]]
validation_df = df.loc[pd.read_csv(VALIDATION_SPLIT_PATH, sep='\t', header=None).T.values[0]]
test_df = df.loc[pd.read_csv(TEST_SPLIT_PATH, sep='\t', header=None).T.values[0]]

val2_df = pd.read_csv(VAL2_DATA_PATH, sep = '\t', header = None)
val2_df.columns = ["word", "context"] + [f"image{i}" for i in range(10)]
val2_df["label"] = pd.read_csv(VAL2_GOLD_PATH, sep = "\t", header = None)

test2_df = pd.read_csv(TEST2_DATA_PATH, sep = '\t', header = None)
test2_df.columns = ["word", "context"] + [f"image{i}" for i in range(10)]
test2_df["label"] = pd.read_csv(TEST2_GOLD_PATH, sep = "\t", header = None)

In [10]:
train_df

Unnamed: 0,word,context,image0,image1,image2,image3,image4,image5,image6,image7,image8,image9,label
0,moorhen,moorhen swamphen,image.3.jpg,image.8.jpg,image.4.jpg,image.1.jpg,image.2.jpg,image.0.jpg,image.5.jpg,image.6.jpg,image.7.jpg,image.9.jpg,image.0.jpg
1,serinus,serinus genus,image.3.jpg,image.23.jpg,image.4.jpg,image.1.jpg,image.2.jpg,image.20.jpg,image.5.jpg,image.24.jpg,image.22.jpg,image.21.jpg,image.20.jpg
2,pegmatite,pegmatite igneous,image.41.jpg,image.39.jpg,image.42.jpg,image.43.jpg,image.40.jpg,image.44.jpg,image.37.jpg,image.38.jpg,image.36.jpg,image.35.jpg,image.35.jpg
4,bonxie,bonxie skua,image.3.jpg,image.77.jpg,image.78.jpg,image.4.jpg,image.1.jpg,image.2.jpg,image.5.jpg,image.79.jpg,image.76.jpg,image.75.jpg,image.75.jpg
5,ixia,ixia genus,image.90.jpg,image.3.jpg,image.91.jpg,image.4.jpg,image.92.jpg,image.1.jpg,image.2.jpg,image.94.jpg,image.93.jpg,image.5.jpg,image.90.jpg
...,...,...,...,...,...,...,...,...,...,...,...,...,...
12861,ducking,ducking hunting,image.964.jpg,image.6176.jpg,image.6742.jpg,image.12919.jpg,image.9996.jpg,image.966.jpg,image.967.jpg,image.12662.jpg,image.4312.jpg,image.965.jpg,image.12919.jpg
12862,tarnish,tarnish discoloration,image.7862.jpg,image.11086.jpg,image.11714.jpg,image.5269.jpg,image.2789.jpg,image.11230.jpg,image.3341.jpg,image.224.jpg,image.222.jpg,image.220.jpg,image.11714.jpg
12865,tragopogon,tragopogon genus,image.3.jpg,image.6250.jpg,image.15001.jpg,image.4.jpg,image.1.jpg,image.2.jpg,image.12074.jpg,image.5.jpg,image.4087.jpg,image.12806.jpg,image.12074.jpg
12866,illustrator,illustrator artist,image.10633.jpg,image.723.jpg,image.13372.jpg,image.881.jpg,image.12635.jpg,image.726.jpg,image.5985.jpg,image.722.jpg,image.724.jpg,image.725.jpg,image.10633.jpg


In [11]:
validation_df

Unnamed: 0,word,context,image0,image1,image2,image3,image4,image5,image6,image7,image8,image9,label
18,maja,maja genus,image.3.jpg,image.310.jpg,image.4.jpg,image.309.jpg,image.1.jpg,image.2.jpg,image.312.jpg,image.5.jpg,image.311.jpg,image.313.jpg,image.309.jpg
24,entoloma,entoloma genus,image.3.jpg,image.405.jpg,image.404.jpg,image.4.jpg,image.1.jpg,image.2.jpg,image.5.jpg,image.406.jpg,image.407.jpg,image.408.jpg,image.404.jpg
25,foulard,foulard fabric,image.340.jpg,image.418.jpg,image.423.jpg,image.343.jpg,image.421.jpg,image.344.jpg,image.422.jpg,image.342.jpg,image.208.jpg,image.420.jpg,image.418.jpg
27,biryani,biryani dish,image.454.jpg,image.436.jpg,image.455.jpg,image.453.jpg,image.451.jpg,image.456.jpg,image.437.jpg,image.457.jpg,image.434.jpg,image.433.jpg,image.451.jpg
28,sobriquet,sobriquet appellation,image.466.jpg,image.478.jpg,image.477.jpg,image.475.jpg,image.476.jpg,image.471.jpg,image.474.jpg,image.473.jpg,image.472.jpg,image.329.jpg,image.466.jpg
...,...,...,...,...,...,...,...,...,...,...,...,...,...
12851,marattia,marattia genus,image.3.jpg,image.11008.jpg,image.6217.jpg,image.4.jpg,image.1.jpg,image.2.jpg,image.11414.jpg,image.5.jpg,image.9223.jpg,image.9977.jpg,image.6217.jpg
12852,tragulus,tragulus genus,image.3.jpg,image.10621.jpg,image.8594.jpg,image.4.jpg,image.1.jpg,image.2.jpg,image.5.jpg,image.4061.jpg,image.14197.jpg,image.12606.jpg,image.4061.jpg
12853,barbwire,barbwire wire,image.58.jpg,image.59.jpg,image.57.jpg,image.5849.jpg,image.56.jpg,image.4112.jpg,image.2784.jpg,image.151.jpg,image.60.jpg,image.6659.jpg,image.4112.jpg
12855,sample,sample distribution,image.65.jpg,image.13362.jpg,image.3623.jpg,image.6254.jpg,image.12852.jpg,image.290.jpg,image.10966.jpg,image.3473.jpg,image.3474.jpg,image.3472.jpg,image.10966.jpg


In [12]:
test_df

Unnamed: 0,word,context,image0,image1,image2,image3,image4,image5,image6,image7,image8,image9,label
3,bangalores,bangalores torpedo,image.58.jpg,image.59.jpg,image.64.jpg,image.57.jpg,image.55.jpg,image.56.jpg,image.62.jpg,image.63.jpg,image.61.jpg,image.60.jpg,image.55.jpg
6,leucaena,leucaena genus,image.105.jpg,image.3.jpg,image.106.jpg,image.109.jpg,image.4.jpg,image.1.jpg,image.2.jpg,image.108.jpg,image.5.jpg,image.107.jpg,image.105.jpg
7,mahonia,mahonia genus,image.3.jpg,image.124.jpg,image.122.jpg,image.4.jpg,image.120.jpg,image.123.jpg,image.1.jpg,image.2.jpg,image.121.jpg,image.5.jpg,image.120.jpg
10,gangster,gangster outlaw,image.166.jpg,image.173.jpg,image.172.jpg,image.165.jpg,image.174.jpg,image.170.jpg,image.171.jpg,image.167.jpg,image.168.jpg,image.169.jpg,image.165.jpg
12,brevicipitidae,brevicipitidae family,image.3.jpg,image.207.jpg,image.206.jpg,image.4.jpg,image.1.jpg,image.2.jpg,image.5.jpg,image.205.jpg,image.208.jpg,image.209.jpg,image.205.jpg
...,...,...,...,...,...,...,...,...,...,...,...,...,...
12854,make,make persuade,image.442.jpg,image.9126.jpg,image.7574.jpg,image.7582.jpg,image.5015.jpg,image.5704.jpg,image.4933.jpg,image.1022.jpg,image.2288.jpg,image.1208.jpg,image.9126.jpg
12856,gunboat,gunboat boat,image.615.jpg,image.364.jpg,image.58.jpg,image.59.jpg,image.122.jpg,image.57.jpg,image.4149.jpg,image.56.jpg,image.680.jpg,image.60.jpg,image.615.jpg
12857,francisella,francisella bacteria,image.3.jpg,image.4.jpg,image.1147.jpg,image.2798.jpg,image.1.jpg,image.2.jpg,image.14303.jpg,image.5.jpg,image.8422.jpg,image.4973.jpg,image.1147.jpg
12863,lookout,lookout watcher,image.5338.jpg,image.11952.jpg,image.58.jpg,image.59.jpg,image.57.jpg,image.56.jpg,image.10445.jpg,image.15132.jpg,image.4060.jpg,image.60.jpg,image.4060.jpg


In [13]:
val2_df

Unnamed: 0,word,context,image0,image1,image2,image3,image4,image5,image6,image7,image8,image9,label
0,maja,maja genus,image.309.jpg,image.2499.jpg,image.9528.jpg,image.3312.jpg,image.11360.jpg,image.7933.jpg,image.13433.jpg,image.3402.jpg,image.1488.jpg,image.1206.jpg,image.309.jpg
1,sobriquet,sobriquet appellation,image.1402.jpg,image.10671.jpg,image.12672.jpg,image.7395.jpg,image.10488.jpg,image.1341.jpg,image.466.jpg,image.3816.jpg,image.6705.jpg,image.8732.jpg,image.466.jpg
2,pigiron,pigiron iron,image.485.jpg,image.3312.jpg,image.2829.jpg,image.6694.jpg,image.627.jpg,image.1724.jpg,image.5215.jpg,image.9603.jpg,image.10601.jpg,image.7395.jpg,image.485.jpg
3,paddle,paddle beat,image.2626.jpg,image.432.jpg,image.545.jpg,image.3482.jpg,image.2708.jpg,image.373.jpg,image.16015.jpg,image.9314.jpg,image.4231.jpg,image.7411.jpg,image.545.jpg
4,gourmand,gourmand feeder,image.2443.jpg,image.7923.jpg,image.9567.jpg,image.15660.jpg,image.634.jpg,image.9426.jpg,image.6064.jpg,image.4139.jpg,image.6676.jpg,image.1964.jpg,image.634.jpg
...,...,...,...,...,...,...,...,...,...,...,...,...,...
1550,caco3,caco3 carbonate,image.8703.jpg,image.2859.jpg,image.14885.jpg,image.10018.jpg,image.11685.jpg,image.271.jpg,image.15113.jpg,image.8715.jpg,image.8338.jpg,image.15640.jpg,image.8338.jpg
1551,marattia,marattia genus,image.1596.jpg,image.9510.jpg,image.6363.jpg,image.1122.jpg,image.1647.jpg,image.5751.jpg,image.3599.jpg,image.2078.jpg,image.6217.jpg,image.7841.jpg,image.6217.jpg
1552,tragulus,tragulus genus,image.7219.jpg,image.4252.jpg,image.12298.jpg,image.4061.jpg,image.13118.jpg,image.14801.jpg,image.6824.jpg,image.7813.jpg,image.9131.jpg,image.9338.jpg,image.4061.jpg
1553,barbwire,barbwire wire,image.10115.jpg,image.4112.jpg,image.11677.jpg,image.5307.jpg,image.498.jpg,image.11176.jpg,image.9644.jpg,image.2320.jpg,image.9184.jpg,image.6242.jpg,image.4112.jpg


In [14]:
test2_df

Unnamed: 0,word,context,image0,image1,image2,image3,image4,image5,image6,image7,image8,image9,label
0,leucaena,leucaena genus,image.11561.jpg,image.9445.jpg,image.10639.jpg,image.3186.jpg,image.9619.jpg,image.105.jpg,image.8520.jpg,image.14730.jpg,image.14089.jpg,image.13274.jpg,image.105.jpg
1,mahonia,mahonia genus,image.1995.jpg,image.120.jpg,image.385.jpg,image.605.jpg,image.1993.jpg,image.3456.jpg,image.4563.jpg,image.2947.jpg,image.8532.jpg,image.3763.jpg,image.120.jpg
2,breakdown,breakdown failure,image.5190.jpg,image.13274.jpg,image.2285.jpg,image.5666.jpg,image.239.jpg,image.6912.jpg,image.7786.jpg,image.9475.jpg,image.13051.jpg,image.8279.jpg,image.239.jpg
3,boletellus,boletellus genus,image.2193.jpg,image.324.jpg,image.8584.jpg,image.7562.jpg,image.7509.jpg,image.14485.jpg,image.13328.jpg,image.4791.jpg,image.13053.jpg,image.10828.jpg,image.324.jpg
4,capparis,capparis genus,image.15880.jpg,image.7367.jpg,image.8603.jpg,image.359.jpg,image.10262.jpg,image.116.jpg,image.2622.jpg,image.2896.jpg,image.481.jpg,image.7507.jpg,image.359.jpg
...,...,...,...,...,...,...,...,...,...,...,...,...,...
1554,aspergillus,aspergillus genus,image.7847.jpg,image.5827.jpg,image.7330.jpg,image.1632.jpg,image.7323.jpg,image.10614.jpg,image.2062.jpg,image.12440.jpg,image.9796.jpg,image.5521.jpg,image.10614.jpg
1555,mantophasmatodea,mantophasmatodea order,image.777.jpg,image.13482.jpg,image.1892.jpg,image.14892.jpg,image.6142.jpg,image.13245.jpg,image.3355.jpg,image.5598.jpg,image.2285.jpg,image.3602.jpg,image.3355.jpg
1556,make,make persuade,image.12886.jpg,image.1263.jpg,image.9126.jpg,image.9526.jpg,image.12047.jpg,image.2597.jpg,image.9154.jpg,image.8859.jpg,image.14239.jpg,image.6168.jpg,image.9126.jpg
1557,lookout,lookout watcher,image.9975.jpg,image.593.jpg,image.4049.jpg,image.4060.jpg,image.2628.jpg,image.7152.jpg,image.11177.jpg,image.9415.jpg,image.6207.jpg,image.10997.jpg,image.4060.jpg


## Preprocessing

In [15]:
blip_model, vis_processors, text_processors = load_model_and_preprocess("blip_image_text_matching", BLIP_VARIANT, is_eval=True)

INFO:root:Missing keys []
INFO:root:load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth


In [16]:
# img_aug = T.AugMix()
img_aug = T.AutoAugment()
vis_proc = vis_processors["eval"]
vis_proc_aug = lambda p: vis_proc(img_aug(p))

In [17]:
train_ds = DefaultDataset(
    df=train_df,
    images_path=IMAGES_PATH,
    text_processor=text_processors["eval"],
    vis_processor=vis_proc_aug if IMAGE_AUGMENTATION else vis_proc,
)
train_ds2 = AltNSDataset(
    df=train_df,
    images_path=IMAGES_PATH,
    text_processor=text_processors["eval"],
    vis_processor=vis_proc_aug if IMAGE_AUGMENTATION else vis_proc,
    num_negatives=NUM_NS,
    num_pics=NUM_PICS,
)
val_ds = DefaultDataset(
    df=validation_df,
    images_path=IMAGES_PATH,
    text_processor=text_processors["eval"],
    vis_processor=vis_proc,
)
val2_ds = DefaultDataset(
    df=val2_df,
    images_path=IMAGES_PATH,
    text_processor=text_processors["eval"],
    vis_processor=vis_proc,
)
val_aug_ds = DefaultDataset(
    df=validation_df,
    images_path=IMAGES_PATH,
    text_processor=text_processors["eval"],
    vis_processor=vis_proc_aug,
)
val2_aug_ds = DefaultDataset(
    df=val2_df,
    images_path=IMAGES_PATH,
    text_processor=text_processors["eval"],
    vis_processor=vis_proc_aug,
)

In [18]:
train_dl = torch.utils.data.DataLoader(torch.utils.data.ConcatDataset([train_ds, train_ds2]), batch_size=TRAIN_BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)
train_l = len(train_dl)
train_l

4069

## Model setup

In [19]:
model = ITMClassifier(blip_model).to(DEVICE)

## Training

In [20]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
num_training_steps = int(NUM_EPOCHS * (train_l / GRAD_ACCUM_STEPS))
num_warmup_steps = int(num_training_steps * WARMUP_STEPS_FRAC)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps,
)
print(f"{num_training_steps} training steps which include {num_warmup_steps} warmup ones")

2034 training steps which include 101 warmup ones


In [21]:
labels_range = np.arange(NUM_PICS)

def get_batch_scores(model, batch, dev, env):
    batch = to_device(batch, dev)
    outputs = model(batch)
    np_labels = batch["label"].numpy(force=True)
    np_preds = outputs.numpy(force=True)
    return {
        "Loss": loss_fn(outputs, F.one_hot(batch["label"], NUM_PICS).float().to(dev)),
        "Accuracy@Top1": top_k_accuracy_score(np_labels, np_preds, k=1, labels=labels_range),
        "Accuracy@Top3": top_k_accuracy_score(np_labels, np_preds, k=3, labels=labels_range),
        "Mean Reciprocal Rank": mrr(np_labels, np_preds),
    }

def log_score(train_step, name, metric_name, metric_value):
    writer.add_scalar(f"{metric_name}/{name}", metric_value, train_step)
    # print(f"[{train_step}][{name}]", f"{metric_name}: {metric_value}")


validation = Validation(
    common = {
        "device": DEVICE,
        "get_batch_scores": get_batch_scores,
        "step_cond": lambda s: (s % STEPS_BETWEEN_VAL == 0) or (s == 1),
        "log_score": log_score,
    },
    configs = {
        "Validation": {
            "enable": VALIDATE_UNAUGMENTED,
            "dl": torch.utils.data.DataLoader(val_ds, batch_size=VALIDATION_BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True),
        },
        "Validation (augmented)": {
            "dl": torch.utils.data.DataLoader(val_aug_ds, batch_size=VALIDATION_BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True),
        },
        "Validation 2": {
            "enable": VALIDATE_UNAUGMENTED,
            "dl": torch.utils.data.DataLoader(val2_ds, batch_size=VALIDATION_BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True),
        },
        "Validation 2 (augmented)": {
            "enable": VALIDATE_AUGMENTED,
            "dl": torch.utils.data.DataLoader(val2_aug_ds, batch_size=VALIDATION_BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True),
        },
    },
)

In [22]:
step_num = 0
grad_accum_step_cnt = 0
save_checkpoint_step_cnt = 0
progress_bar = tqdm(range(num_training_steps))

for epoch_num in range(NUM_EPOCHS):
    model.train()
    train_loss = 0.0
    train_scores = {"acc1": 0, "acc3": 0, "mrr": 0}
    for batch in train_dl:
        batch = to_device(batch, DEVICE)
        outputs = model(batch)
        loss = loss_fn(outputs, F.one_hot(batch["label"], NUM_LABELS).float().to(DEVICE))

        train_loss += loss.item()
        train_scores = sum_scores(train_scores, eval_batch(batch["label"], outputs, num_labels = NUM_LABELS))

        loss.backward()
        grad_accum_step_cnt += 1

        if grad_accum_step_cnt == GRAD_ACCUM_STEPS: 
            writer.add_scalar("Learning Rate", lr_scheduler.get_last_lr()[0], step_num)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            writer.add_scalar("Loss/Train", train_loss / GRAD_ACCUM_STEPS, step_num)            
            for k, v in div_scores(train_scores, GRAD_ACCUM_STEPS).items():
                writer.add_scalar(metric2name[k] + "/Train", v, step_num)
            train_loss = 0.0
            train_scores = {"acc1": 0, "acc3": 0, "mrr": 0}
            grad_accum_step_cnt = 0
            step_num += 1
            save_checkpoint_step_cnt += 1
            progress_bar.update(1)
            validation(step_num, model)

        if save_checkpoint_step_cnt == SAVE_CHECKPOINT_STEPS:
            save_checkpoint_step_cnt = 0
            p = SAVE_CHECKPOINT_PATH / f"step-{step_num}.pt"
            torch.save(model.state_dict(), p)

100%|██████████| 2034/2034 [26:32:18<00:00, 33.32s/it]     

## Evaluation

In [24]:
def predict_eval(
    model: ITMClassifier,
    dataframes: Dict[str, pd.DataFrame],
    images_path: Path,
    text_processor,
    vis_processors: Dict,
    batch_size: int = 1,
    num_workers: int = 0,
    persistent_workers: bool = True,
    device = torch.device("cpu"),
    preds_save_folder: Optional[Path] = None,
    preds_save_filename_prefix: str = "sample_predictions",
    preds_save_filename_add_timestamp: bool = True,
    verbose: bool = True,
) -> Tuple[Dict[str, Dict[str, np.ndarray]], Dict[str, Dict[str, float]]]:
    """
    Combines predictions for dataloader using checkpoint model with evaluation.

    Args:
        model (ITMClassifier): loaded classification model
        dataframes (pandas.DataFrame)): mapping of test set names to the dataframes
        verbose (bool): enables prints of metrics and progress tracking

    Returns:
        Tuple[Dict[str, Dict[str, np.ndarray]], Dict[str, Dict[str, float]]]: predictions and scores for the corresponding test sets
    """
    predictions = dict()
    evaluations = dict()
    for name, df in dataframes.items():
        if verbose:
            print(f"Generating predictions for \"{name}\"")
        ds = DefaultDataset(
            df=df,
            images_path=images_path,
            text_processor=text_processor,
            vis_processor=vis_processors[name],
        )
        dl = torch.utils.data.DataLoader(
            ds,
            batch_size = batch_size,
            shuffle = False,
            num_workers = num_workers,
            persistent_workers = persistent_workers,
        )
        preds = [] # list: 
        model = model.to(device)
        model.eval()
        i = 0
        with torch.no_grad():
            for batch in (tqdm(dl) if verbose else dl):
                batch = to_device(batch, device)
                for ps in model(batch).numpy(force=True): # ps - predictions for one row
                    row = df.iloc[i]
                    preds.append({row[f"image{j}"]: ps[j] for j in range(len(ps))})
                    i += 1
        predictions[name] = preds
        if preds_save_folder is not None:
            maybe_datetime = f"_at_{time.time()}_" if preds_save_filename_add_timestamp else "_"
            filename = f"{preds_save_filename_prefix}_on_{name}{maybe_datetime}submission.json"
            if verbose:
                print(f"Saving predictions for \"{name}\" as \"{filename}\"")
            with open(PATH / filename, "w") as f:
                json.dump([{k: str(v) for k, v in p.items()} for p in preds], f, indent=2)
        if verbose:
            print(f"Metrics for \"{name}\":")
        evals = evaluate(
            df.iloc[:, 2:-1].values,
            df["label"].values.reshape(-1, 1),
            preds,
        )
        if verbose:
            for metric_id, metric_value in evals.items():
                metric_name = metric2name[metric_id]
                print(f"    {metric_name}: {metric_value}")
        evaluations[name] = evals
    return predictions, evaluations

In [25]:
CHECKPOINTS = [500, 1000, 1300, 2000] # fill this out with checkpoints of interest (use Tensorboard)
TEST_DFS = {
    "Test": test_df,
    "Test 2": test2_df,
    "Test (augmented)": test_df,
    "Test 2 (augmented)": test2_df,
}
TEST_VIS_PROCS = {
    "Test": vis_proc,
    "Test 2": vis_proc,
    "Test (augmented)": vis_proc_aug,
    "Test 2 (augmented)": vis_proc_aug,
}

In [31]:
results = dict() # int -> tuple(preds, evals)
for checkpoint_num in CHECKPOINTS:
    print(f"Processing checkpoint {checkpoint_num}")
    model.load_state_dict(torch.load(SAVE_CHECKPOINT_PATH / f"step-{checkpoint_num}.pt"))
    results[checkpoint_num] = predict_eval(
        model = model,
        verbose = True,
        preds_save_filename_prefix = f"blip-{HEAD}-{MODEL_VERSION}-{checkpoint_num}",
        preds_save_folder = PATH,
        device = DEVICE,
        persistent_workers = PERSISTENT_WORKERS,
        num_workers = NUM_WORKERS,
        batch_size = TEST_BATCH_SIZE,
        text_processor=text_processors["eval"],
        vis_processors=TEST_VIS_PROCS,
        dataframes = TEST_DFS,
        images_path = IMAGES_PATH,
    )

Processing checkpoint 500
Generating predictions for "Test"


100%|██████████| 84/84 [06:23<00:00,  4.56s/it]


Saving predictions for "Test" as "blip-itm-28-500_on_Test_at_1673514398.2059329_submission.json"
Metrics for "Test":
    Accuracy@Top1: 0.833134684147795
    Accuracy@Top3: 0.968414779499404
    Mean Reciprocal Rank: 0.9012530270352082
Generating predictions for "Test 2"


100%|██████████| 39/39 [03:15<00:00,  5.00s/it]


Saving predictions for "Test 2" as "blip-itm-28-500_on_Test 2_at_1673514594.1078389_submission.json"
Metrics for "Test 2":
    Accuracy@Top1: 0.7735728030788968
    Accuracy@Top3: 0.9403463758819757
    Mean Reciprocal Rank: 0.8600776851257929
Generating predictions for "Test (augmented)"


100%|██████████| 84/84 [07:10<00:00,  5.12s/it]


Saving predictions for "Test (augmented)" as "blip-itm-28-500_on_Test (augmented)_at_1673515024.8104289_submission.json"
Metrics for "Test (augmented)":
    Accuracy@Top1: 0.8218116805721096
    Accuracy@Top3: 0.9627532777115614
    Mean Reciprocal Rank: 0.8936837836048206
Generating predictions for "Test 2 (augmented)"


100%|██████████| 39/39 [03:45<00:00,  5.79s/it]


Saving predictions for "Test 2 (augmented)" as "blip-itm-28-500_on_Test 2 (augmented)_at_1673515251.4449437_submission.json"
Metrics for "Test 2 (augmented)":
    Accuracy@Top1: 0.7626683771648493
    Accuracy@Top3: 0.9352148813341886
    Mean Reciprocal Rank: 0.854160939552216
Processing checkpoint 1000
Generating predictions for "Test"


100%|██████████| 84/84 [06:23<00:00,  4.56s/it]


Saving predictions for "Test" as "blip-itm-28-1000_on_Test_at_1673515635.405556_submission.json"
Metrics for "Test":
    Accuracy@Top1: 0.8322407628128725
    Accuracy@Top3: 0.9702026221692491
    Mean Reciprocal Rank: 0.9013119123294928
Generating predictions for "Test 2"


100%|██████████| 39/39 [03:15<00:00,  5.03s/it]


Saving predictions for "Test 2" as "blip-itm-28-1000_on_Test 2_at_1673515832.2138894_submission.json"
Metrics for "Test 2":
    Accuracy@Top1: 0.7831943553559975
    Accuracy@Top3: 0.940987812700449
    Mean Reciprocal Rank: 0.8654942627040124
Generating predictions for "Test (augmented)"


100%|██████████| 84/84 [07:11<00:00,  5.13s/it]


Saving predictions for "Test (augmented)" as "blip-itm-28-1000_on_Test (augmented)_at_1673516263.5992057_submission.json"
Metrics for "Test (augmented)":
    Accuracy@Top1: 0.8182359952324195
    Accuracy@Top3: 0.966030989272944
    Mean Reciprocal Rank: 0.8920281372382086
Generating predictions for "Test 2 (augmented)"


100%|██████████| 39/39 [03:46<00:00,  5.80s/it]


Saving predictions for "Test 2 (augmented)" as "blip-itm-28-1000_on_Test 2 (augmented)_at_1673516490.6303256_submission.json"
Metrics for "Test 2 (augmented)":
    Accuracy@Top1: 0.7684413085311097
    Accuracy@Top3: 0.9377806286080821
    Mean Reciprocal Rank: 0.8566961422157061
Processing checkpoint 1300
Generating predictions for "Test"


100%|██████████| 84/84 [06:25<00:00,  4.59s/it]


Saving predictions for "Test" as "blip-itm-28-1300_on_Test_at_1673516877.0888796_submission.json"
Metrics for "Test":
    Accuracy@Top1: 0.8292610250297974
    Accuracy@Top3: 0.9636471990464839
    Mean Reciprocal Rank: 0.8976545206878938
Generating predictions for "Test 2"


100%|██████████| 39/39 [03:15<00:00,  5.00s/it]


Saving predictions for "Test 2" as "blip-itm-28-1300_on_Test 2_at_1673517073.0814276_submission.json"
Metrics for "Test 2":
    Accuracy@Top1: 0.7806286080821039
    Accuracy@Top3: 0.935856318152662
    Mean Reciprocal Rank: 0.8624314018958021
Generating predictions for "Test (augmented)"


100%|██████████| 84/84 [07:38<00:00,  5.46s/it]


Saving predictions for "Test (augmented)" as "blip-itm-28-1300_on_Test (augmented)_at_1673517532.246607_submission.json"
Metrics for "Test (augmented)":
    Accuracy@Top1: 0.8197258641239571
    Accuracy@Top3: 0.9591775923718713
    Mean Reciprocal Rank: 0.8914533552793386
Generating predictions for "Test 2 (augmented)"


100%|██████████| 39/39 [03:54<00:00,  6.02s/it]


Saving predictions for "Test 2 (augmented)" as "blip-itm-28-1300_on_Test 2 (augmented)_at_1673517767.889199_submission.json"
Metrics for "Test 2 (augmented)":
    Accuracy@Top1: 0.7697241821680565
    Accuracy@Top3: 0.9345734445157152
    Mean Reciprocal Rank: 0.854833939134773
Processing checkpoint 2000
Generating predictions for "Test"


100%|██████████| 84/84 [06:38<00:00,  4.74s/it]


Saving predictions for "Test" as "blip-itm-28-2000_on_Test_at_1673518167.2713418_submission.json"
Metrics for "Test":
    Accuracy@Top1: 0.8325387365911799
    Accuracy@Top3: 0.9588796185935637
    Mean Reciprocal Rank: 0.8984233403333521
Generating predictions for "Test 2"


100%|██████████| 39/39 [03:23<00:00,  5.22s/it]


Saving predictions for "Test 2" as "blip-itm-28-2000_on_Test 2_at_1673518371.8233898_submission.json"
Metrics for "Test 2":
    Accuracy@Top1: 0.7831943553559975
    Accuracy@Top3: 0.9364977549711353
    Mean Reciprocal Rank: 0.8642149526049462
Generating predictions for "Test (augmented)"




 73%|███████▎  | 61/84 [05:41<02:08,  5.61s/it]


KeyboardInterrupt: 

100%|██████████| 2034/2034 [32:11:32<00:00, 56.98s/it]

In [30]:
sums = {c: sum([sum(ms.values()) for ms in d[1].values()]) for c, d in results.items()}
best_checkpoint = max(sums, key=sums.get)
print(f"Best checkpoint (by sum of all scores) is {best_checkpoint} with results:")
results[best_checkpoint][1]

Best checkpoint (by sum of all scores) is 1000 with results:


{'Test': {'acc1': 0.8322407628128725,
  'acc3': 0.9702026221692491,
  'mrr': 0.9013119123294928},
 'Test 2': {'acc1': 0.7831943553559975,
  'acc3': 0.940987812700449,
  'mrr': 0.8654942627040124},
 'Test (augmented)': {'acc1': 0.8256853396901073,
  'acc3': 0.965435041716329,
  'mrr': 0.8961898376752369},
 'Test 2 (augmented)': {'acc1': 0.7761385503527902,
  'acc3': 0.9390635022450289,
  'mrr': 0.8595882077440768}}