In [1]:
from pathlib import Path
import logging
import json
from typing import *
import time
from datetime import datetime
import warnings

from tqdm import tqdm
import pandas as pd
import numpy as np
from PIL import Image, ImageFile

import torch
import torch.nn as nn
import torchvision.transforms as T
from lavis.models import load_model_and_preprocess, BlipBase
from lavis.processors import load_processor
import torch.nn.functional as F
from transformers import get_cosine_schedule_with_warmup
from transformers import BatchEncoding

from src.data import CustomSplitLoader
from src.itm import AltNSDataset, to_device, ITMClassifier, DefaultDataset, DefaultDatasetMultiaug

from src.utils import evaluate, mrr
from src.validation import Validation, sum_scores, div_scores, eval_batch, metric2name
from sklearn.metrics import top_k_accuracy_score
from torch.utils.tensorboard import SummaryWriter

## Config

Versioning

In [2]:
HEAD = "itm"
MODEL_VERSION = "monster-v1"

Paths resolution:

In [3]:
DATASET_VERSION = "v1"
PART = "train"
PATH = Path("/home/s1m00n/research/vwsd/data").resolve() / f"{PART}_{DATASET_VERSION}"
DATA_PATH = PATH / f"{PART}.data.{DATASET_VERSION}.txt"
LABELS_PATH = PATH / f"{PART}.gold.{DATASET_VERSION}.txt"
IMAGES_PATH = PATH / f"{PART}_images_{DATASET_VERSION}"
TRAIN_SPLIT_PATH = PATH / "split_train.txt"
VALIDATION_SPLIT_PATH = PATH / "split_valid.txt"
VAL2_DATA_PATH = PATH / "valid2.data.v1.txt"
VAL2_GOLD_PATH = PATH / "valid2.gold.v1.txt"
TEST_SPLIT_PATH = PATH / "split_test.txt"
TEST2_DATA_PATH = PATH / "test2.data.v1.txt"
TEST2_GOLD_PATH = PATH / "test2.gold.v1.txt"
SAVE_CHECKPOINT_PATH = Path("/home/s1m00n/research/vwsd/checkpoints").resolve() / f"BLIP-{HEAD}-{MODEL_VERSION}" # TODO: maybe add timestamp?
SAVE_CHECKPOINT_PATH.mkdir(parents=True, exist_ok=True)
NUM_PICS = 10

Environment settings:

In [4]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# some images from train might not load without the following settings or warnings would be thrown
Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True

warnings.filterwarnings('ignore')


In [5]:
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE)
# WARNING: this is specific to my setup
DEVICE = torch.device("cuda")
# a more conventional way to do this is:
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_WORKERS = 0
PERSISTENT_WORKERS = True
print(f"Running on {DEVICE}")

Running on cuda


Model & training settings

In [6]:
BLIP_VARIANT = "large" # "base" | "large"
NUM_NS = 9
NUM_EPOCHS = 15
WARMUP_STEPS_FRAC = 0.15
GRAD_ACCUM_STEPS = 16
LR = 5e-6
WEIGHT_DECAY = 1e-4
TRAIN_BATCH_SIZE = 1
IMAGE_AUGMENTATION = True
VALIDATE_UNAUGMENTED = True
VALIDATE_AUGMENTED = True
# cos lr scheduler

In [7]:
TRAIN_EFFECTIVE_BATCH_SIZE = GRAD_ACCUM_STEPS * TRAIN_BATCH_SIZE
NUM_LABELS = NUM_NS + 1
# NUM_LABELS = NUM_PICS
TRAIN_EFFECTIVE_BATCH_SIZE

16

In [8]:
STEPS_BETWEEN_VAL = 1000
SAVE_CHECKPOINT_STEPS = STEPS_BETWEEN_VAL
VALIDATION_BATCH_SIZE = 20
TEST_BATCH_SIZE = VALIDATION_BATCH_SIZE

## Loading data

In [9]:
en_df = pd.read_csv("/home/s1m00n/research/vwsd/data/test.data.v1.1/en.test.data.v1.1.txt", sep="\t", header=None)
en_df.columns = ["word", "context"] + [f"image{i}" for i in range(10)]
en_df

Unnamed: 0,word,context,image0,image1,image2,image3,image4,image5,image6,image7,image8,image9
0,goal,football goal,image.4418.jpg,image.4416.jpg,image.4417.jpg,image.4413.jpg,image.4412.jpg,image.4415.jpg,image.4419.jpg,image.4414.jpg,image.2166.jpg,image.1150.jpg
1,mustard,mustard seed,image.4429.png,image.4422.jpg,image.4423.jpg,image.4424.jpg,image.4421.jpg,image.4427.jpg,image.4426.jpg,image.4420.jpg,image.4425.jpg,image.4428.jpg
2,seat,eating seat,image.4435.jpg,image.4436.jpg,image.1166.jpg,image.4430.jpg,image.4433.jpg,image.4432.jpg,image.4438.jpg,image.4434.jpg,image.4431.jpg,image.4437.jpg
3,navigate,navigate the web,image.4439.jpg,image.4440.jpg,image.4441.jpg,image.4442.jpg,image.4444.jpg,image.4445.jpg,image.1435.jpg,image.4446.png,image.1434.jpg,image.4443.jpg
4,butterball,butterball person,image.4454.jpg,image.4450.jpg,image.4455.jpg,image.4453.jpg,image.4448.jpg,image.1253.jpg,image.4451.jpg,image.4452.jpg,image.4447.jpg,image.4449.jpg
...,...,...,...,...,...,...,...,...,...,...,...,...
458,cannabis,cannabis drug,image.8063.jpg,image.8064.jpg,image.4891.jpg,image.7450.jpg,image.8066.jpg,image.4454.jpg,image.8065.jpg,image.6775.jpg,image.1604.jpg,image.2540.jpg
459,crossroads,crossroads cars,image.8073.jpg,image.8076.jpg,image.8075.jpg,image.8070.jpg,image.8068.jpg,image.8074.jpg,image.8069.jpg,image.8071.jpg,image.8067.jpg,image.8072.jpg
460,clocks,time clocks,image.8082.jpg,image.8079.jpg,image.2094.jpg,image.8081.jpg,image.8077.jpg,image.8080.jpg,image.4995.jpg,image.8083.jpg,image.5251.jpg,image.8078.jpg
461,columba,columba stars,image.8087.jpg,image.8084.jpg,image.7279.jpg,image.192.jpg,image.93.jpg,image.8085.jpg,image.8088.jpg,image.4126.jpg,image.8086.jpg,image.8089.jpg


## Preprocessing

In [10]:
blip_model, vis_processors, text_processors = load_model_and_preprocess("blip_image_text_matching", BLIP_VARIANT, is_eval=True)
img_aug = T.TrivialAugmentWide()
vis_proc = vis_processors["eval"]
vis_proc_aug = lambda p: vis_proc(img_aug(p))
test_ds = DefaultDataset(
    df = en_df,
    images_path=Path("/home/s1m00n/research/vwsd/data/test_images/test_images").resolve(),
    text_processor=text_processors["eval"],
    vis_processor=vis_proc,
    ignore_labels=True,
)

test_aug_ds = DefaultDatasetMultiaug(
    df = en_df,
    images_path=Path("/home/s1m00n/research/vwsd/data/test_images/test_images").resolve(),
    text_processor=text_processors["eval"],
    vis_processor=vis_proc,
    vis_aug = img_aug,
    n_aug = 9,
    include_original = True,
    ignore_labels=True,
)

INFO:root:Missing keys []
INFO:root:load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth


## Model setup

In [11]:
model = ITMClassifier(blip_model).to(DEVICE)

In [12]:
model.load_state_dict(torch.load("/home/s1m00n/research/vwsd/checkpoints/BLIP-itm-monster-v1/step-12000.pt"))

<All keys matched successfully>

In [13]:
def multiaug_extract(
    imgs: torch.Tensor # (bs, 1 + n_aug, num_pics, ...)
) -> torch.Tensor: # [(bs, num_pics, ...)] with len = 1 + n_aug
    return [imgs[:, i] for i in range(imgs.shape[1])]

In [14]:
dl = torch.utils.data.DataLoader(
    test_ds,
    batch_size = 20,
    shuffle = False,
    num_workers = 16,
)

In [15]:
aug_dl = torch.utils.data.DataLoader(
    test_aug_ds,
    batch_size = 20,
    shuffle = False,
    num_workers = 16,
)

In [16]:
preds = []
model.eval()
with torch.no_grad():
    i = 0
    for batch in tqdm(aug_dl, total=len(aug_dl)):
        imgs = multiaug_extract(batch["images"])
        np_preds = None
        for bi in imgs:
            batch["images"] = bi
            out = model(to_device(batch, DEVICE))
            if np_preds is None:
                np_preds = out.numpy(force=True)
            else:
                np_preds += out.numpy(force=True)
        np_preds = np_preds / len(imgs)
        for ps in np.argsort(-np_preds, axis=1):
            row = en_df.iloc[i]
            preds.append([row[f"image{j}"] for j in ps])
            i += 1

 12%|█▎        | 3/24 [05:25<31:49, 90.92s/it]   ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm).
 17%|█▋        | 4/24 [06:55<34:36, 103.84s/it]


RuntimeError: DataLoader worker (pid 286575) is killed by signal: Bus error. It is possible that dataloader's workers are out of shared memory. Please try to raise your shared memory limit.

In [None]:
with open("blip-itm-monster-v1-real-test-9aug.en.txt", "w") as f:
    for row in preds:
        f.write("\t".join(row) + "\n")

In [None]:
labels_range = np.arange(NUM_PICS)

def get_batch_scores(model, batch, dev, env):
    batch = to_device(batch, dev)
    outputs = model(batch)
    np_labels = batch["label"].numpy(force=True)
    np_preds = outputs.numpy(force=True)
    return {
        "Loss": loss_fn(outputs, F.one_hot(batch["label"], NUM_PICS).float().to(dev)),
        "Accuracy@Top1": top_k_accuracy_score(np_labels, np_preds, k=1, labels=labels_range),
        "Accuracy@Top3": top_k_accuracy_score(np_labels, np_preds, k=3, labels=labels_range),
        "Mean Reciprocal Rank": mrr(np_labels, np_preds),
    }

def get_batch_scores_aug(model, batch, dev, env):
    np_labels = batch["label"].numpy(force=True)
    imgs = multiaug_extract(batch["images"])
    np_preds = None
    loss_accum = 0
    for bi in imgs:
        batch["images"] = bi
        out = model(to_device(batch, dev))
        if np_preds is None:
            np_preds = out.numpy(force=True)
        else:
            np_preds += out.numpy(force=True)
        loss_accum += loss_fn(out, F.one_hot(batch["label"], NUM_PICS).float().to(dev)).item()
    np_preds = np_preds / len(imgs)
    return {
        "Loss": loss_accum / len(imgs),
        "Accuracy@Top1": top_k_accuracy_score(np_labels, np_preds, k=1, labels=labels_range),
        "Accuracy@Top3": top_k_accuracy_score(np_labels, np_preds, k=3, labels=labels_range),
        "Mean Reciprocal Rank": mrr(np_labels, np_preds),
    }

def log_score(train_step, name, metric_name, metric_value):
    writer.add_scalar(f"{metric_name}/{name}", metric_value, train_step)
    # print(f"[{train_step}][{name}]", f"{metric_name}: {metric_value}")


validation = Validation(
    common = {
        "device": DEVICE,
        "get_batch_scores": get_batch_scores_aug,
        "step_cond": lambda s: (s % STEPS_BETWEEN_VAL == 0) or (s == 1),
        "log_score": log_score,
    },
    configs = {
        "Validation": {
            "dl": torch.utils.data.DataLoader(val_ds, batch_size=VALIDATION_BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True),
        },
        "Validation 2": {
            "dl": torch.utils.data.DataLoader(val2_ds, batch_size=VALIDATION_BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True),
        },
    },
)

In [None]:
step_num = 0
grad_accum_step_cnt = 0
save_checkpoint_step_cnt = 0
progress_bar = tqdm(range(num_training_steps))

for epoch_num in range(NUM_EPOCHS):
    model.train()
    train_loss = 0.0
    train_scores = {"acc1": 0, "acc3": 0, "mrr": 0}
    for batch in train_dl:
        batch = to_device(batch, DEVICE)
        outputs = model(batch)
        loss = loss_fn(outputs, F.one_hot(batch["label"], NUM_LABELS).float().to(DEVICE))

        train_loss += loss.item()
        train_scores = sum_scores(train_scores, eval_batch(batch["label"], outputs, num_labels = NUM_LABELS))

        loss.backward()
        grad_accum_step_cnt += 1

        if grad_accum_step_cnt == GRAD_ACCUM_STEPS: 
            writer.add_scalar("Learning Rate", lr_scheduler.get_last_lr()[0], step_num)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            writer.add_scalar("Loss/Train", train_loss / GRAD_ACCUM_STEPS, step_num)            
            for k, v in div_scores(train_scores, GRAD_ACCUM_STEPS).items():
                writer.add_scalar(metric2name[k] + "/Train", v, step_num)
            train_loss = 0.0
            train_scores = {"acc1": 0, "acc3": 0, "mrr": 0}
            grad_accum_step_cnt = 0
            step_num += 1
            save_checkpoint_step_cnt += 1
            progress_bar.update(1)
            validation(step_num, model)

        if save_checkpoint_step_cnt == SAVE_CHECKPOINT_STEPS:
            save_checkpoint_step_cnt = 0
            p = SAVE_CHECKPOINT_PATH / f"step-{step_num}.pt"
            torch.save(model.state_dict(), p)

## Evaluation

In [None]:
def predict_eval(
    model: ITMClassifier,
    dataframes: Dict[str, pd.DataFrame],
    images_path: Path,
    text_processor,
    vis_processors: Dict,
    batch_size: int = 1,
    num_workers: int = 0,
    persistent_workers: bool = True,
    device = torch.device("cpu"),
    preds_save_folder: Optional[Path] = None,
    preds_save_filename_prefix: str = "sample_predictions",
    preds_save_filename_add_timestamp: bool = True,
    verbose: bool = True,
) -> Tuple[Dict[str, Dict[str, np.ndarray]], Dict[str, Dict[str, float]]]:
    """
    Combines predictions for dataloader using checkpoint model with evaluation.

    Args:
        model (ITMClassifier): loaded classification model
        dataframes (pandas.DataFrame)): mapping of test set names to the dataframes
        verbose (bool): enables prints of metrics and progress tracking

    Returns:
        Tuple[Dict[str, Dict[str, np.ndarray]], Dict[str, Dict[str, float]]]: predictions and scores for the corresponding test sets
    """
    predictions = dict()
    evaluations = dict()
    for name, df in dataframes.items():
        if verbose:
            print(f"Generating predictions for \"{name}\"")
        ds = DefaultDataset(
            df=df,
            images_path=images_path,
            text_processor=text_processor,
            vis_processor=vis_processors[name],
        )
        dl = torch.utils.data.DataLoader(
            ds,
            batch_size = batch_size,
            shuffle = False,
            num_workers = num_workers,
            persistent_workers = persistent_workers,
        )
        preds = [] # list: 
        model = model.to(device)
        model.eval()
        i = 0
        with torch.no_grad():
            for batch in (tqdm(dl) if verbose else dl):
                batch = to_device(batch, device)
                for ps in model(batch).numpy(force=True): # ps - predictions for one row
                    row = df.iloc[i]
                    preds.append({row[f"image{j}"]: ps[j] for j in range(len(ps))})
                    i += 1
        predictions[name] = preds
        if preds_save_folder is not None:
            maybe_datetime = f"_at_{time.time()}_" if preds_save_filename_add_timestamp else "_"
            filename = f"{preds_save_filename_prefix}_on_{name}{maybe_datetime}submission.json"
            if verbose:
                print(f"Saving predictions for \"{name}\" as \"{filename}\"")
            with open(PATH / filename, "w") as f:
                json.dump([{k: str(v) for k, v in p.items()} for p in preds], f, indent=2)
        if verbose:
            print(f"Metrics for \"{name}\":")
        evals = evaluate(
            df.iloc[:, 2:-1].values,
            df["label"].values.reshape(-1, 1),
            preds,
        )
        if verbose:
            for metric_id, metric_value in evals.items():
                metric_name = metric2name[metric_id]
                print(f"    {metric_name}: {metric_value}")
        evaluations[name] = evals
    return predictions, evaluations

In [None]:
CHECKPOINTS = [500, 1000, 1300, 2000] # fill this out with checkpoints of interest (use Tensorboard)
TEST_DFS = {
    "Test": test_df,
    "Test 2": test2_df,
    "Test (augmented)": test_df,
    "Test 2 (augmented)": test2_df,
}
TEST_VIS_PROCS = {
    "Test": vis_proc,
    "Test 2": vis_proc,
    "Test (augmented)": vis_proc_aug,
    "Test 2 (augmented)": vis_proc_aug,
}

In [None]:
results = dict() # int -> tuple(preds, evals)
for checkpoint_num in CHECKPOINTS:
    print(f"Processing checkpoint {checkpoint_num}")
    model.load_state_dict(torch.load(SAVE_CHECKPOINT_PATH / f"step-{checkpoint_num}.pt"))
    results[checkpoint_num] = predict_eval(
        model = model,
        verbose = True,
        preds_save_filename_prefix = f"blip-{HEAD}-{MODEL_VERSION}-{checkpoint_num}",
        preds_save_folder = PATH,
        device = DEVICE,
        persistent_workers = PERSISTENT_WORKERS,
        num_workers = NUM_WORKERS,
        batch_size = TEST_BATCH_SIZE,
        text_processor=text_processors["eval"],
        vis_processors=TEST_VIS_PROCS,
        dataframes = TEST_DFS,
        images_path = IMAGES_PATH,
    )

In [None]:
sums = {c: sum([sum(ms.values()) for ms in d[1].values()]) for c, d in results.items()}
best_checkpoint = max(sums, key=sums.get)
print(f"Best checkpoint (by sum of all scores) is {best_checkpoint} with results:")
results[best_checkpoint][1]