In [1]:
from pathlib import Path
import logging
import json
from typing import *
import time
from datetime import datetime
import warnings

from tqdm import tqdm
import pandas as pd
import numpy as np
from PIL import Image, ImageFile

import torch
import torch.nn as nn
import torchvision.transforms as T
from lavis.models import load_model_and_preprocess, BlipBase
from lavis.processors import load_processor
import torch.nn.functional as F
from transformers import get_cosine_schedule_with_warmup
from transformers import BatchEncoding

from src.data import CustomSplitLoader, ImageSet
from src.itm import DefaultDataset, AltNSDataset, to_device, ITMClassifier

from src.utils import evaluate, mrr
from src.validation import Validation, sum_scores, div_scores, eval_batch, metric2name
from sklearn.metrics import top_k_accuracy_score
from torch.utils.tensorboard import SummaryWriter

## Config

Versioning

Paths resolution:

In [2]:
DATASET_VERSION = "v1"
PART = "train"
PATH = Path("/home/s1m00n/research/vwsd/data").resolve() / f"{PART}_{DATASET_VERSION}"
DATA_PATH = PATH / f"{PART}.data.{DATASET_VERSION}.txt"
LABELS_PATH = PATH / f"{PART}.gold.{DATASET_VERSION}.txt"
IMAGES_PATH = PATH / f"{PART}_images_{DATASET_VERSION}"
TRAIN_SPLIT_PATH = PATH / "split_train.txt"
VALIDATION_SPLIT_PATH = PATH / "split_valid.txt"
VAL2_DATA_PATH = PATH / "valid2.data.v1.txt"
VAL2_GOLD_PATH = PATH / "valid2.gold.v1.txt"
TEST_SPLIT_PATH = PATH / "split_test.txt"
TEST2_DATA_PATH = PATH / "test2.data.v1.txt"
TEST2_GOLD_PATH = PATH / "test2.gold.v1.txt"
NUM_PICS = 10

Environment settings:

In [3]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# some images from train might not load without the following settings or warnings would be thrown
Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True
warnings.filterwarnings('ignore')

In [4]:
RANDOM_STATE = 42
torch.manual_seed(RANDOM_STATE)
DEVICE = torch.device("cuda:0")
# a more conventional way to do this is:
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on {DEVICE}")

Running on cuda:0


Model & training settings

In [5]:
BLIP_VARIANT = "base" # "base" | "large"

In [6]:
VALIDATION_BATCH_SIZE = 420

## Loading data

In [7]:
df = pd.read_csv(DATA_PATH, sep='\t', header=None)
df.columns = ["word", "context"] + [f"image{i}" for i in range(NUM_PICS)]
df["label"] = pd.read_csv(LABELS_PATH, sep='\t', header=None)

train_df = df.loc[pd.read_csv(TRAIN_SPLIT_PATH, sep='\t', header=None).T.values[0]]
validation_df = df.loc[pd.read_csv(VALIDATION_SPLIT_PATH, sep='\t', header=None).T.values[0]]
test_df = df.loc[pd.read_csv(TEST_SPLIT_PATH, sep='\t', header=None).T.values[0]]

val2_df = pd.read_csv(VAL2_DATA_PATH, sep = '\t', header = None)
val2_df.columns = ["word", "context"] + [f"image{i}" for i in range(10)]
val2_df["label"] = pd.read_csv(VAL2_GOLD_PATH, sep = "\t", header = None)

test2_df = pd.read_csv(TEST2_DATA_PATH, sep = '\t', header = None)
test2_df.columns = ["word", "context"] + [f"image{i}" for i in range(10)]
test2_df["label"] = pd.read_csv(TEST2_GOLD_PATH, sep = "\t", header = None)

## Preprocessing

In [8]:
blip_model, vis_processors, text_processors = load_model_and_preprocess("blip_image_text_matching", BLIP_VARIANT, is_eval=True)

INFO:root:Missing keys []
INFO:root:load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth


In [9]:
model = ITMClassifier(blip_model).to(DEVICE)

In [10]:
itm = model.blip_model.to(DEVICE)

In [11]:
vis_proc = vis_processors["eval"]
text_proc = text_processors["eval"]

In [12]:
class SimpleBinaryDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        df: pd.DataFrame,
        images_path: Path,
        text_processor,
        vis_processor,
        use_context_as_text: bool = True,
    ) -> None:
        self.df = df
        self.images_path = images_path
        self.text_processor = text_processor
        self.vis_processor = vis_processor
        self.text_field = "context" if use_context_as_text else "word"

    def __len__(self) -> int:
        return 10 * len(self.df)
    
    def __getitem__(self, idx: int) -> Dict:
        pic_idx = idx % 10
        row_idx = int((idx - pic_idx) / 10)
        row = df.iloc[row_idx]
        img_name = row[f"image{pic_idx}"]
        return {
            "text_input": self.text_processor(row[self.text_field]),
            "image": self.vis_processor(Image.open(self.images_path / img_name).convert("RGB")),
            "label": img_name == row["label"],
        }
        

In [13]:
class SimpleBinaryDatasetFrom10DS(torch.utils.data.Dataset):
    def __init__(self, ds: torch.utils.data.Dataset) -> None:
        self.ds = ds

    def __len__(self) -> int:
        return 10 * len(self.ds)

    def __getitem__(self, idx: int) -> Dict:
        pic_idx = idx % 10
        item_idx = int((idx - pic_idx) / 10)
        item = self.ds[item_idx]
        return {
            "text_input": item["text"],
            "image": item["images"][pic_idx],
            "label": pic_idx == item["label"],
        }

In [14]:
NUM_WORKERS = 32
train_ds = SimpleBinaryDataset(
    df=train_df,
    images_path=IMAGES_PATH,
    text_processor=text_processors["eval"],
    vis_processor=vis_processors["eval"],
)
train_ds2 = SimpleBinaryDatasetFrom10DS(AltNSDataset(
    df=train_df,
    images_path=IMAGES_PATH,
    text_processor=text_processors["eval"],
    vis_processor=vis_processors["eval"],
    num_negatives=9,
    num_pics=NUM_PICS,
))
val_ds = SimpleBinaryDataset(
    df=validation_df,
    images_path=IMAGES_PATH,
    text_processor=text_processors["eval"],
    vis_processor=vis_processors["eval"],
)
val2_ds = SimpleBinaryDataset(
    df=val2_df,
    images_path=IMAGES_PATH,
    text_processor=text_processors["eval"],
    vis_processor=vis_processors["eval"],
)
test_ds = SimpleBinaryDataset(
    df=test_df,
    images_path=IMAGES_PATH,
    text_processor=text_proc,
    vis_processor=vis_proc,
)
test2_ds = SimpleBinaryDataset(
    df=test2_df,
    images_path=IMAGES_PATH,
    text_processor=text_proc,
    vis_processor=vis_proc,
)

dls = {
    "train_dl": torch.utils.data.DataLoader(train_ds, batch_size=VALIDATION_BATCH_SIZE, num_workers=NUM_WORKERS),
    "train_dl2": torch.utils.data.DataLoader(train_ds, batch_size=VALIDATION_BATCH_SIZE, num_workers=NUM_WORKERS),
    "val_dl": torch.utils.data.DataLoader(val_ds, batch_size=VALIDATION_BATCH_SIZE, num_workers=NUM_WORKERS),
    "val2_dl": torch.utils.data.DataLoader(val2_ds, batch_size=VALIDATION_BATCH_SIZE, num_workers=NUM_WORKERS),
    "test_dl": torch.utils.data.DataLoader(test_ds, batch_size=VALIDATION_BATCH_SIZE, num_workers=NUM_WORKERS),
    "test2_dl": torch.utils.data.DataLoader(test2_ds, batch_size=VALIDATION_BATCH_SIZE, num_workers=NUM_WORKERS),
}

In [15]:
negative_data = dict()
positive_data = dict()

In [16]:
itm.eval()
for split_name, dl in dls.items():
    print(f"Predicting for split: {split_name}")
    negative_data[split_name] = {"pi_t": [], "ni_t": []}
    positive_data[split_name] = {"pi_t": [], "ni_t": []}
    for batch in tqdm(dl):
        with torch.no_grad():
           mixed_preds = itm(to_device(batch, DEVICE)) 
           negative_preds = mixed_preds[:, 0]
           positive_preds = mixed_preds[:, 1]
           for i in range(len(negative_preds)):
                cat = "pi_t" if batch["label"][i] == True else "ni_t"
                negative_data[split_name][cat].append(negative_preds[i].item())
                positive_data[split_name][cat].append(positive_preds[i].item())


Predicting for split: train_dl


100%|██████████| 146/146 [10:38<00:00,  4.37s/it]


Predicting for split: train_dl2


100%|██████████| 146/146 [10:37<00:00,  4.36s/it]


Predicting for split: val_dl


100%|██████████| 82/82 [06:17<00:00,  4.61s/it] 


Predicting for split: val2_dl


100%|██████████| 38/38 [03:13<00:00,  5.10s/it]


Predicting for split: test_dl


100%|██████████| 80/80 [06:12<00:00,  4.66s/it] 


Predicting for split: test2_dl


100%|██████████| 38/38 [03:13<00:00,  5.08s/it]


In [None]:
with open(PATH / "itm0-neg-preds.json", "w") as f:
    json.dump(negative_data, f)

In [None]:
with open(PATH / "itm0-pos-preds.json", "w") as f:
    json.dump(positive_data, f)