In [None]:
from pathlib import Path
import logging
import json
from typing import *
import time
from datetime import datetime
import warnings

from tqdm import tqdm
import pandas as pd
import numpy as np
from PIL import Image, ImageFile

import torch
import torch.nn as nn
import torchvision.transforms as T
from lavis.models import load_model_and_preprocess, BlipBase
from lavis.processors import load_processor
import torch.nn.functional as F
from transformers import get_cosine_schedule_with_warmup
from transformers import BatchEncoding

from src.data import CustomSplitLoader, ImageSet
from src.itm import ItmDataset, to_device, ITMClassifier

from src.utils import evaluate, mrr
from src.validation import Validation, sum_scores, div_scores, eval_batch, metric2name
from sklearn.metrics import top_k_accuracy_score
from torch.utils.tensorboard import SummaryWriter

## Config

Versioning

In [None]:
HEAD = "itm"
MODEL_VERSION = 32

Paths resolution:

In [None]:
DATASET_VERSION = "v1"
PART = "train"
PATH = Path("/home/s1m00n/research/vwsd/data").resolve() / f"{PART}_{DATASET_VERSION}"
DATA_PATH = PATH / f"{PART}.data.{DATASET_VERSION}.txt"
LABELS_PATH = PATH / f"{PART}.gold.{DATASET_VERSION}.txt"
IMAGES_PATH = PATH / f"{PART}_images_{DATASET_VERSION}"
TRAIN_SPLIT_PATH = PATH / "split_train.txt"
VALIDATION_SPLIT_PATH = PATH / "split_valid.txt"
VAL2_DATA_PATH = PATH / "valid2.data.v1.txt"
VAL2_GOLD_PATH = PATH / "valid2.gold.v1.txt"
TEST_SPLIT_PATH = PATH / "split_test.txt"
TEST2_DATA_PATH = PATH / "test2.data.v1.txt"
TEST2_GOLD_PATH = PATH / "test2.gold.v1.txt"
SAVE_CHECKPOINT_PATH = Path("/home/s1m00n/research/vwsd/checkpoints").resolve() / f"BLIP-{HEAD}-{MODEL_VERSION}" # TODO: maybe add timestamp?
SAVE_CHECKPOINT_PATH.mkdir(parents=True, exist_ok=True)
NUM_PICS = 10

Environment settings:

In [None]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# some images from train might not load without the following settings or warnings would be thrown
Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True

warnings.filterwarnings('ignore')

writer = SummaryWriter(f"/home/s1m00n/research/vwsd/runs/blip-{HEAD}-{MODEL_VERSION} (ran at {datetime.now()})")

In [None]:
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE)
# WARNING: this is specific to my setup
DEVICE = torch.device("cuda:0")
# a more conventional way to do this is:
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_WORKERS = 16
PERSISTENT_WORKERS = True
print(f"Running on {DEVICE}")

Model & training settings

In [None]:
BLIP_VARIANT = "base" # "base" | "large"
NUM_NS = 9
NUM_EPOCHS = 40
WARMUP_STEPS_FRAC = 0.05
GRAD_ACCUM_STEPS = 40
LR = 1e-5
WEIGHT_DECAY = 0.05
TRAIN_BATCH_SIZE = 3
IMAGE_AUGMENTATION = True
# cos lr scheduler

In [None]:
TRAIN_EFFECTIVE_BATCH_SIZE = GRAD_ACCUM_STEPS * TRAIN_BATCH_SIZE
NUM_LABELS = NUM_NS + 1
# NUM_LABELS = NUM_PICS
TRAIN_EFFECTIVE_BATCH_SIZE

In [None]:
STEPS_BETWEEN_VAL = 100
SAVE_CHECKPOINT_STEPS = STEPS_BETWEEN_VAL
VALIDATION_BATCH_SIZE = 40
TEST_BATCH_SIZE = VALIDATION_BATCH_SIZE

## Loading data

In [None]:
df = pd.read_csv(DATA_PATH, sep='\t', header=None)
df.columns = ["word", "context"] + [f"image{i}" for i in range(NUM_PICS)]
df["label"] = pd.read_csv(LABELS_PATH, sep='\t', header=None)

train_df = df.loc[pd.read_csv(TRAIN_SPLIT_PATH, sep='\t', header=None).T.values[0]]
validation_df = df.loc[pd.read_csv(VALIDATION_SPLIT_PATH, sep='\t', header=None).T.values[0]]
test_df = df.loc[pd.read_csv(TEST_SPLIT_PATH, sep='\t', header=None).T.values[0]]

val2_df = pd.read_csv(VAL2_DATA_PATH, sep = '\t', header = None)
val2_df.columns = ["word", "context"] + [f"image{i}" for i in range(10)]
val2_df["label"] = pd.read_csv(VAL2_GOLD_PATH, sep = "\t", header = None)

test2_df = pd.read_csv(TEST2_DATA_PATH, sep = '\t', header = None)
test2_df.columns = ["word", "context"] + [f"image{i}" for i in range(10)]
test2_df["label"] = pd.read_csv(TEST2_GOLD_PATH, sep = "\t", header = None)

In [None]:
train_df

In [None]:
validation_df

In [None]:
test_df

In [None]:
val2_df

In [None]:
test2_df

## Preprocessing

In [None]:
blip_model, vis_processors, text_processors = load_model_and_preprocess("blip_image_text_matching", BLIP_VARIANT, is_eval=True)

In [None]:
# img_aug = T.AugMix()
img_aug = T.AutoAugment()
vis_proc = vis_processors["eval"]
vis_proc_aug = lambda p: vis_proc(img_aug(p))
text_proc = text_processors["eval"]

In [None]:
train_image_set = ImageSet(
    images_path = IMAGES_PATH,
    image_processor = vis_proc_aug if IMAGE_AUGMENTATION else vis_proc,
    similarity_measure = nn.CosineSimilarity(dim=1),
    enable_cache = False,
)
train_ds = ItmDataset(
    df = train_df,
    image_set = train_image_set,
    text_preprocessor = text_proc,
    use_context_as_text = True,
    num_src_pics = NUM_PICS,
    num_ns = NUM_NS,
    num_any_ns = 3,
    replace_any_ns = False,
    replace_default_ns = False,
    num_hard_ns = 3,
    num_any_when_no_hard_ns = 1,
)
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=TRAIN_BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)
train_l = len(train_dl)
train_l

In [None]:
val_image_set = ImageSet(
    images_path = IMAGES_PATH,
    image_processor = vis_proc,
    similarity_measure = nn.CosineSimilarity(dim=1),
    enable_cache = False,
)
val_ds = ItmDataset(
    df = validation_df,
    image_set = val_image_set,
    text_preprocessor = text_proc,
    use_context_as_text = True,
    num_src_pics = NUM_PICS,
    num_ns = NUM_NS,
    num_any_ns = 0,
    replace_any_ns = False,
    replace_default_ns = False,
    num_hard_ns = 0,
    num_any_when_no_hard_ns = 0,
)

In [None]:
val2_image_set = ImageSet(
    images_path = IMAGES_PATH,
    image_processor = vis_proc,
    similarity_measure = nn.CosineSimilarity(dim=1),
    enable_cache = False,
)
val2_ds = ItmDataset(
    df = val2_df,
    image_set = val2_image_set,
    text_preprocessor = text_proc,
    use_context_as_text = True,
    num_src_pics = NUM_PICS,
    num_ns = NUM_NS,
    num_any_ns = 0,
    replace_any_ns = False,
    replace_default_ns = False,
    num_hard_ns = 0,
    num_any_when_no_hard_ns = 0,
)

## Model setup

In [None]:
model = ITMClassifier(blip_model).to(DEVICE)

## Training

### Train env setup

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
num_training_steps = int(NUM_EPOCHS * (train_l / GRAD_ACCUM_STEPS))
num_warmup_steps = int(num_training_steps * WARMUP_STEPS_FRAC)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps,
)
print(f"{num_training_steps} training steps which include {num_warmup_steps} warmup ones")

### Validation config

In [None]:
labels_range = np.arange(NUM_PICS)

def get_batch_scores(model, batch, dev, env):
    batch = to_device(batch, dev)
    outputs = model(batch)
    np_labels = batch["label"].numpy(force=True)
    np_preds = outputs.numpy(force=True)
    return {
        "Loss": loss_fn(outputs, F.one_hot(batch["label"], NUM_PICS).float().to(dev)),
        "Accuracy@Top1": top_k_accuracy_score(np_labels, np_preds, k=1, labels=labels_range),
        "Accuracy@Top3": top_k_accuracy_score(np_labels, np_preds, k=3, labels=labels_range),
        "Mean Reciprocal Rank": mrr(np_labels, np_preds),
    }

def log_score(train_step, name, metric_name, metric_value):
    writer.add_scalar(f"{metric_name}/{name}", metric_value, train_step)
    print(f"[{train_step}][{name}]", f"{metric_name}: {metric_value}")


validation = Validation(
    common = {
        "device": DEVICE,
        "get_batch_scores": get_batch_scores,
        "step_cond": lambda s: (s % STEPS_BETWEEN_VAL == 0) or (s == 1),
        "log_score": log_score,
    },
    configs = {
        "Validation": { "dl": torch.utils.data.DataLoader(val_ds, batch_size=VALIDATION_BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True), },
        # "Validation (augmented)": { "dl": torch.utils.data.DataLoader(val_aug_ds, batch_size=VALIDATION_BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True), },
        "Validation 2": { "dl": torch.utils.data.DataLoader(val2_ds, batch_size=VALIDATION_BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True), },
        # "Validation 2 (augmented)": { "dl": torch.utils.data.DataLoader(val2_aug_ds, batch_size=VALIDATION_BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True), },
    },
)

In [None]:
step_num = 0
grad_accum_step_cnt = 0
save_checkpoint_step_cnt = 0
progress_bar = tqdm(range(num_training_steps))

for epoch_num in range(NUM_EPOCHS):
    model.train()
    train_loss = 0.0
    train_scores = {"acc1": 0, "acc3": 0, "mrr": 0}
    for batch in train_dl:
        batch = to_device(batch, DEVICE)
        outputs = model(batch)
        # TODO: retrieve 
        loss = loss_fn(outputs, F.one_hot(batch["label"], NUM_LABELS).float().to(DEVICE))

        train_loss += loss.item()
        train_scores = sum_scores(train_scores, eval_batch(batch["label"], outputs, num_labels = NUM_LABELS))

        loss.backward()
        grad_accum_step_cnt += 1

        if grad_accum_step_cnt == GRAD_ACCUM_STEPS: 
            writer.add_scalar("Learning Rate", lr_scheduler.get_last_lr()[0], step_num)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            writer.add_scalar("Loss/Train", train_loss / GRAD_ACCUM_STEPS, step_num)            
            for k, v in div_scores(train_scores, GRAD_ACCUM_STEPS).items():
                writer.add_scalar(metric2name[k] + "/Train", v, step_num)
            train_loss = 0.0
            train_scores = {"acc1": 0, "acc3": 0, "mrr": 0}
            grad_accum_step_cnt = 0
            step_num += 1
            save_checkpoint_step_cnt += 1
            progress_bar.update(1)
            validation(step_num, model)

        if save_checkpoint_step_cnt == SAVE_CHECKPOINT_STEPS:
            save_checkpoint_step_cnt = 0
            p = SAVE_CHECKPOINT_PATH / f"step-{step_num}.pt"
            torch.save(model.state_dict(), p)

## Evaluation

In [None]:
def predict_eval(
    model: ITMClassifier,
    dataframes: Dict[str, pd.DataFrame],
    images_path: Path,
    text_processor,
    vis_processors: Dict,
    batch_size: int = 1,
    num_workers: int = 0,
    persistent_workers: bool = True,
    device = torch.device("cpu"),
    preds_save_folder: Optional[Path] = None,
    preds_save_filename_prefix: str = "sample_predictions",
    preds_save_filename_add_timestamp: bool = True,
    verbose: bool = True,
) -> Tuple[Dict[str, Dict[str, np.ndarray]], Dict[str, Dict[str, float]]]:
    """
    Combines predictions for dataloader using checkpoint model with evaluation.

    Args:
        model (ITMClassifier): loaded classification model
        dataframes (pandas.DataFrame)): mapping of test set names to the dataframes
        verbose (bool): enables prints of metrics and progress tracking

    Returns:
        Tuple[Dict[str, Dict[str, np.ndarray]], Dict[str, Dict[str, float]]]: predictions and scores for the corresponding test sets
    """
    predictions = dict()
    evaluations = dict()
    for name, df in dataframes.items():
        if verbose:
            print(f"Generating predictions for \"{name}\"")
        ds = DefaultDataset(
            df=df,
            images_path=images_path,
            text_processor=text_processor,
            vis_processor=vis_processors[name],
        )
        dl = torch.utils.data.DataLoader(
            ds,
            batch_size = batch_size,
            shuffle = False,
            num_workers = num_workers,
            persistent_workers = persistent_workers,
        )
        preds = [] # list: 
        model = model.to(device)
        model.eval()
        i = 0
        with torch.no_grad():
            for batch in (tqdm(dl) if verbose else dl):
                batch = to_device(batch, device)
                for ps in model(batch).numpy(force=True): # ps - predictions for one row
                    row = df.iloc[i]
                    preds.append({row[f"image{j}"]: ps[j] for j in range(len(ps))})
                    i += 1
        predictions[name] = preds
        if preds_save_folder is not None:
            maybe_datetime = f"_at_{time.time()}_" if preds_save_filename_add_timestamp else "_"
            filename = f"{preds_save_filename_prefix}_on_{name}{maybe_datetime}submission.json"
            if verbose:
                print(f"Saving predictions for \"{name}\" as \"{filename}\"")
            with open(PATH / filename, "w") as f:
                json.dump([{k: str(v) for k, v in p.items()} for p in preds], f, indent=2)
        if verbose:
            print(f"Metrics for \"{name}\":")
        evals = evaluate(
            df.iloc[:, 2:-1].values,
            df["label"].values.reshape(-1, 1),
            preds,
        )
        if verbose:
            for metric_id, metric_value in evals.items():
                metric_name = metric2name[metric_id]
                print(f"    {metric_name}: {metric_value}")
        evaluations[name] = evals
    return predictions, evaluations

In [None]:
CHECKPOINTS = [500, 1000, 1300, 2000] # fill this out with checkpoints of interest (use Tensorboard)
TEST_DFS = {
    "Test": test_df,
    "Test 2": test2_df,
    "Test (augmented)": test_df,
    "Test 2 (augmented)": test2_df,
}
TEST_VIS_PROCS = {
    "Test": vis_proc,
    "Test 2": vis_proc,
    "Test (augmented)": vis_proc_aug,
    "Test 2 (augmented)": vis_proc_aug,
}

In [None]:
results = dict() # int -> tuple(preds, evals)
for checkpoint_num in CHECKPOINTS:
    print(f"Processing checkpoint {checkpoint_num}")
    model.load_state_dict(torch.load(SAVE_CHECKPOINT_PATH / f"step-{checkpoint_num}.pt"))
    results[checkpoint_num] = predict_eval(
        model = model,
        verbose = True,
        preds_save_filename_prefix = f"blip-{HEAD}-{MODEL_VERSION}-{checkpoint_num}",
        preds_save_folder = PATH,
        device = DEVICE,
        persistent_workers = PERSISTENT_WORKERS,
        num_workers = NUM_WORKERS,
        batch_size = TEST_BATCH_SIZE,
        text_processor=text_processors["eval"],
        vis_processors=TEST_VIS_PROCS,
        dataframes = TEST_DFS,
        images_path = IMAGES_PATH,
    )

In [None]:
sums = {c: sum([sum(ms.values()) for ms in d[1].values()]) for c, d in results.items()}
best_checkpoint = max(sums, key=sums.get)
print(f"Best checkpoint (by sum of all scores) is {best_checkpoint} with results:")
results[best_checkpoint][1]