In [23]:
from pathlib import Path
import logging
import json
from typing import *
import time
from datetime import datetime
import warnings

from tqdm import tqdm
import pandas as pd
import numpy as np
from PIL import Image, ImageFile

import torch
import torch.nn as nn
import torchvision.transforms as T
from lavis.models import load_model_and_preprocess, BlipBase
from lavis.processors import load_processor
import torch.nn.functional as F
from transformers import get_cosine_schedule_with_warmup
from transformers import BatchEncoding

from src.data import CustomSplitLoader, ImageSet
from src.itm import DefaultDataset, AltNSDataset, to_device, ITMClassifier

from src.utils import evaluate, mrr
from src.validation import Validation, sum_scores, div_scores, eval_batch, metric2name
from sklearn.metrics import top_k_accuracy_score
from torch.utils.tensorboard import SummaryWriter

## Config

Versioning

Paths resolution:

In [24]:
DATASET_VERSION = "v1"
PART = "train"
PATH = Path("/home/s1m00n/research/vwsd/data").resolve() / f"{PART}_{DATASET_VERSION}"
DATA_PATH = PATH / f"{PART}.data.{DATASET_VERSION}.txt"
LABELS_PATH = PATH / f"{PART}.gold.{DATASET_VERSION}.txt"
IMAGES_PATH = PATH / f"{PART}_images_{DATASET_VERSION}"
TRAIN_SPLIT_PATH = PATH / "split_train.txt"
VALIDATION_SPLIT_PATH = PATH / "split_valid.txt"
VAL2_DATA_PATH = PATH / "valid2.data.v1.txt"
VAL2_GOLD_PATH = PATH / "valid2.gold.v1.txt"
TEST_SPLIT_PATH = PATH / "split_test.txt"
TEST2_DATA_PATH = PATH / "test2.data.v1.txt"
TEST2_GOLD_PATH = PATH / "test2.gold.v1.txt"
NUM_PICS = 10

Environment settings:

In [25]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# some images from train might not load without the following settings or warnings would be thrown
Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True
warnings.filterwarnings('ignore')

In [26]:
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE)
DEVICE = torch.device("cuda:0")
# a more conventional way to do this is:
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on {DEVICE}")

Running on cuda:0


Model & training settings

In [27]:
BLIP_VARIANT = "base" # "base" | "large"

In [28]:
VALIDATION_BATCH_SIZE = 400

## Loading data

In [29]:
df = pd.read_csv(DATA_PATH, sep='\t', header=None)
df.columns = ["word", "context"] + [f"image{i}" for i in range(NUM_PICS)]
df["label"] = pd.read_csv(LABELS_PATH, sep='\t', header=None)

train_df = df.loc[pd.read_csv(TRAIN_SPLIT_PATH, sep='\t', header=None).T.values[0]]
validation_df = df.loc[pd.read_csv(VALIDATION_SPLIT_PATH, sep='\t', header=None).T.values[0]]
test_df = df.loc[pd.read_csv(TEST_SPLIT_PATH, sep='\t', header=None).T.values[0]]

val2_df = pd.read_csv(VAL2_DATA_PATH, sep = '\t', header = None)
val2_df.columns = ["word", "context"] + [f"image{i}" for i in range(10)]
val2_df["label"] = pd.read_csv(VAL2_GOLD_PATH, sep = "\t", header = None)

test2_df = pd.read_csv(TEST2_DATA_PATH, sep = '\t', header = None)
test2_df.columns = ["word", "context"] + [f"image{i}" for i in range(10)]
test2_df["label"] = pd.read_csv(TEST2_GOLD_PATH, sep = "\t", header = None)

## Preprocessing

In [30]:
blip_model, vis_processors, text_processors = load_model_and_preprocess("blip_image_text_matching", BLIP_VARIANT, is_eval=True)

INFO:root:Missing keys []
INFO:root:load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth


In [31]:
model = ITMClassifier(blip_model).to(DEVICE)
model.load_state_dict(torch.load("/home/s1m00n/research/vwsd/checkpoints/BLIP-itm-28/step-1000.pt"))

<All keys matched successfully>

In [32]:
def cmp_img(itm_cls: ITMClassifier, data: Dict[str, torch.Tensor]):
    enc = itm_cls.blip_model.visual_encoder.forward_features
    proj = itm_cls.blip_model.vision_proj
    img1_feats = F.normalize(proj(enc(data["img1"])[:, 0, :]), dim=-1)
    img2_feats = F.normalize(proj(enc(data["img2"])[:, 0, :]), dim=-1)
    return img1_feats @ img2_feats.t()

In [33]:
vis_proc = vis_processors["eval"]
text_proc = text_processors["eval"]

In [34]:
class ContrastiveDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        df: pd.DataFrame,
        images_path: Path,
        vis_processor,
    ) -> None:
        self.images_path = images_path
        self.vis_processor = vis_processor
        self.data = []
        for _, row in df.iterrows():
            pos_pic_idx = None
            for i in range(10):
                if row[f"image{i}"] == row["label"]:
                    pos_pic_idx = i
                    break
            for i in range(10):
                if i != pos_pic_idx:
                    self.data.append({
                        "img1": row[f"image{i}"],
                        "img2": row["label"],
                    })

    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> Dict:
        return {k: self.vis_processor(Image.open(self.images_path / v).convert("RGB"))
            for k, v in self.data[idx].items()}

In [35]:
class ContrastiveDatasetFrom10DS(torch.utils.data.Dataset):
    def __init__(self, ds: torch.utils.data.Dataset) -> None:
        self.ds = ds

    def __len__(self) -> int:
        return 9 * len(self.ds)

    def __getitem__(self, idx: int) -> Dict:
        pic_idx = idx % 9
        item_idx = int((idx - pic_idx) / 9)
        item = self.ds[item_idx]
        possible_indices = list(range(10))
        img2_idx = item["label"]
        possible_indices.remove(img2_idx)
        img1_idx = possible_indices[pic_idx]
        return {
            "img1": item["images"][img1_idx],
            "img2": item["images"][img2_idx],
        }

In [36]:
NUM_WORKERS = 16
train_ds = ContrastiveDataset(
    df=train_df,
    images_path=IMAGES_PATH,
    vis_processor=vis_processors["eval"],
)
train_ds2 = ContrastiveDatasetFrom10DS(AltNSDataset(
    df=train_df,
    images_path=IMAGES_PATH,
    text_processor=text_processors["eval"],
    vis_processor=vis_processors["eval"],
    num_negatives=9,
    num_pics=NUM_PICS,
))
val_ds = ContrastiveDataset(
    df=validation_df,
    images_path=IMAGES_PATH,
    vis_processor=vis_processors["eval"],
)
val2_ds = ContrastiveDataset(
    df=val2_df,
    images_path=IMAGES_PATH,
    vis_processor=vis_processors["eval"],
)
test_ds = ContrastiveDataset(
    df=test_df,
    images_path=IMAGES_PATH,
    vis_processor=vis_proc,
)
test2_ds = ContrastiveDataset(
    df=test2_df,
    images_path=IMAGES_PATH,
    vis_processor=vis_proc,
)

dls = {
    "train_dl": torch.utils.data.DataLoader(train_ds, batch_size=VALIDATION_BATCH_SIZE, num_workers=NUM_WORKERS),
    "train_dl2": torch.utils.data.DataLoader(train_ds, batch_size=VALIDATION_BATCH_SIZE, num_workers=NUM_WORKERS),
    "val_dl": torch.utils.data.DataLoader(val_ds, batch_size=VALIDATION_BATCH_SIZE, num_workers=NUM_WORKERS),
    "val2_dl": torch.utils.data.DataLoader(val2_ds, batch_size=VALIDATION_BATCH_SIZE, num_workers=NUM_WORKERS),
    "test_dl": torch.utils.data.DataLoader(test_ds, batch_size=VALIDATION_BATCH_SIZE, num_workers=NUM_WORKERS),
    "test2_dl": torch.utils.data.DataLoader(test2_ds, batch_size=VALIDATION_BATCH_SIZE, num_workers=NUM_WORKERS),
}

In [37]:
negative_data = dict()
positive_data = dict()

In [38]:
model.eval()
for split_name, dl in dls.items():
    print(f"Predicting for split: {split_name}")
    negative_data[split_name] = []
    positive_data[split_name] = []
    for batch in tqdm(dl):
        with torch.no_grad():
           mixed_preds = cmp_img(model, to_device(batch, DEVICE)) 
           negative_preds = mixed_preds[:, 0]
           positive_preds = mixed_preds[:, 1]
           for i in range(len(negative_preds)):
                negative_data[split_name].append(negative_preds[i].item())
                positive_data[split_name].append(positive_preds[i].item())

Predicting for split: train_dl


100%|██████████| 138/138 [16:10<00:00,  7.03s/it]


Predicting for split: train_dl2


100%|██████████| 138/138 [16:08<00:00,  7.02s/it]


Predicting for split: val_dl


100%|██████████| 77/77 [09:20<00:00,  7.28s/it] 


Predicting for split: val2_dl


100%|██████████| 35/35 [04:48<00:00,  8.25s/it]


Predicting for split: test_dl


100%|██████████| 76/76 [09:14<00:00,  7.30s/it] 


Predicting for split: test2_dl


100%|██████████| 36/36 [04:44<00:00,  7.90s/it]


In [39]:
with open(PATH / "itm28-pi-ni-neg-preds.json", "w") as f:
    json.dump(negative_data, f)

In [40]:
with open(PATH / "itm28-pi-ni-pos-preds.json", "w") as f:
    json.dump(positive_data, f)