# M-CLIP for VWSD

WARNING: patched ~/miniconda3/envs/lavis/lib/python3.9/site-packages/multilingual_clip/pt_multilingual_clip.py

In [1]:
from multilingual_clip import pt_multilingual_clip
import transformers
import torch
import clip
import requests
from PIL import Image
import torch.nn as nn
from typing import *
from torch.utils.data import Dataset
from transformers import DataCollatorWithPadding

from pathlib import Path
import logging
import json
from typing import *
import time
from datetime import datetime
import warnings

from tqdm import tqdm
import pandas as pd
import numpy as np
from PIL import Image, ImageFile

import torch
import torch.nn as nn
import torchvision.transforms as T
import torch.nn.functional as F
from transformers import get_cosine_schedule_with_warmup
from transformers import BatchEncoding

from src.data import CustomSplitLoader
from src.itm import ItmDataset, to_device, ITMClassifier

from src.utils import evaluate, mrr
from src.validation import Validation, sum_scores, div_scores, eval_batch, metric2name
from sklearn.metrics import top_k_accuracy_score
from torch.utils.tensorboard import SummaryWriter

## Configuration

In [2]:
# GENERAL:
MODEL_NAME: str = "clip-vitl14-xlmrl" 
MODEL_VERSION: Any = 0
DEBUG: bool = False

# MODEL:
PRETRAINED_TEXT_MODEL_NAME: str = "M-CLIP/XLM-Roberta-Large-Vit-L-14" 
PRETRAINED_CLIP_MODEL_NAME: str = "ViT-L/14"
FROM_CHECKPOINT: Optional[str] = None # WARNING: not supported yet, "$CHECKPOINT_PATH/$FROM_CHECKPOINT" is loaded

# TRAINING:
NUM_EPOCHS: int = 10
WARMUP_FRAC: float = 0.1
GRAD_ACCUM: int = 15 # >= 1, if 1 => off
LR: float = 1e-5
TRAIN_BATCH_SIZE: int = 1
TRAIN_IMG_AUG: Optional[Any] = None # augmenter

# VALIDATION:
STEPS_BETWEEN_VAL: int = 250
STEPS_BETWEEN_CHECKPOINT: int = STEPS_BETWEEN_VAL
VAL_BATCH_SIZE: int = 30

# IMAGE NEGATIVE SAMPLING:
NUM_SRC_PICS: int = 10 # number of pics in source table ("image{i}")
NUM_NS: int = 9 # total number of negative samples for one positive
NUM_RAND_NS: int = 0 # number of random negative samples
NUM_HARD_NS: int = 0 # WARNING: not supported yet
NUM_RAND_WHEN_NO_HARD_NS: int = 0 # WARNING: not supported yet
REPLACE_DEFAULT_NS: bool = False # sampling default ns with replacement or not
REPLACE_RAND_NS: bool = False # sampling rand ns with replacement or not

# PATHS
DATASET_VERSION = "v1"
PART = "train"
PATH = Path("/home/s1m00n/research/vwsd/data").resolve() / f"{PART}_{DATASET_VERSION}"
DATA_PATH = PATH / f"{PART}.data.{DATASET_VERSION}.txt"
LABELS_PATH = PATH / f"{PART}.gold.{DATASET_VERSION}.txt"
IMAGES_PATH = PATH / f"{PART}_images_{DATASET_VERSION}"
TRAIN_SPLIT_PATH = PATH / "split_train.txt"
VALIDATION_SPLIT_PATH = PATH / "split_valid.txt"
VAL2_DATA_PATH = PATH / "valid2.data.v1.txt"
VAL2_GOLD_PATH = PATH / "valid2.gold.v1.txt"
TEST_SPLIT_PATH = PATH / "split_test.txt"
TEST2_DATA_PATH = PATH / "test2.data.v1.txt"
TEST2_GOLD_PATH = PATH / "test2.gold.v1.txt"
CHECKPOINT_PATH = Path("/home/s1m00n/research/vwsd/checkpoints").resolve() / f"{MODEL_NAME}-{MODEL_VERSION}"
CHECKPOINT_PATH.mkdir(parents=True, exist_ok=True)

# SYSTEM:
RANDOM_STATE = 42
# WARNING: this is very dependent on available RAM
NUM_WORKERS = 32
PERSISTENT_WORKERS = True
# WARNING: this is specific to my setup, a more conventional way to do this is:
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = torch.device("cuda:0")

# AUTO DERIVED:
TEST_BATCH_SIZE = VAL_BATCH_SIZE
TRAIN_EFFECTIVE_BATCH_SIZE = GRAD_ACCUM * TRAIN_BATCH_SIZE
NUM_LABELS = NUM_NS + 1

In [3]:
torch.manual_seed(RANDOM_STATE)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# some images from train might not load without the following settings or warnings would be thrown
Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True
warnings.filterwarnings('ignore')
writer = SummaryWriter(f"/home/s1m00n/research/vwsd/runs/{MODEL_NAME}-{MODEL_VERSION}")
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

## Source data

In [4]:
df = pd.read_csv(DATA_PATH, sep='\t', header=None)
df.columns = ["word", "context"] + [f"image{i}" for i in range(NUM_SRC_PICS)]
df["label"] = pd.read_csv(LABELS_PATH, sep='\t', header=None)

train_df = df.loc[pd.read_csv(TRAIN_SPLIT_PATH, sep='\t', header=None).T.values[0]]
validation_df = df.loc[pd.read_csv(VALIDATION_SPLIT_PATH, sep='\t', header=None).T.values[0]]
test_df = df.loc[pd.read_csv(TEST_SPLIT_PATH, sep='\t', header=None).T.values[0]]

val2_df = pd.read_csv(VAL2_DATA_PATH, sep = '\t', header = None)
val2_df.columns = ["word", "context"] + [f"image{i}" for i in range(10)]
val2_df["label"] = pd.read_csv(VAL2_GOLD_PATH, sep = "\t", header = None)

test2_df = pd.read_csv(TEST2_DATA_PATH, sep = '\t', header = None)
test2_df.columns = ["word", "context"] + [f"image{i}" for i in range(10)]
test2_df["label"] = pd.read_csv(TEST2_GOLD_PATH, sep = "\t", header = None)

## Model

In [5]:
class MCLIPClassifier(nn.Module):
    def __init__(
        self,
        text_model,
        clip_model,
        tokenizer,
    ) -> None:
        super().__init__()
        self.text_model = text_model
        self.clip_model = clip_model
        self.sim = nn.CosineSimilarity(dim=1)
        self.tokenizer = tokenizer
    
    def forward(self, inputs: Dict[str, Any]) -> torch.Tensor:
        text_out = self.text_model(inputs["text"], self.tokenizer) # => bs, dim hidden repr
        images = inputs["images"]
        bs = images.shape[0]
        n = images.shape[1]
        c = images.shape[2]
        h = images.shape[3]
        w = images.shape[4]
        len_flat = bs * n
        imgs_out = self.clip_model.encode_image(images.reshape(len_flat, c, h, w)).reshape(bs, n, -1) # => bs, n, dim hidden repr
        if DEBUG:
            print("text out:", text_out, "text out shape:", text_out.shape)
            print("img out:", imgs_out, "img out shape:", imgs_out.shape)
        sims_by_batch = []
        for i in range(bs):
            s = torch.softmax(self.sim(text_out[i].reshape(1, -1), imgs_out[i]) * 6, dim=0) 
            if DEBUG:
                print(f"[batch {i}] sims:", s)
            sims_by_batch.append(s)
        # sims = (self.sim(text_out.reshape(1, -1), imgs_out) * 6).reshape(bs, n)
        stack = torch.stack(sims_by_batch)
        if DEBUG:
            print("stacked batches:", stack, "shape:", stack.shape)
        return stack

text_model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(PRETRAINED_TEXT_MODEL_NAME)
tokenizer = transformers.AutoTokenizer.from_pretrained(PRETRAINED_TEXT_MODEL_NAME)
clip_model, preprocess = clip.load(PRETRAINED_CLIP_MODEL_NAME)
clip_model = clip_model.float()
model = MCLIPClassifier(text_model, clip_model, tokenizer).to(DEVICE)

## Data preprocessing

In [6]:
class ImageSet:
    def __init__(
        self,
        images_path: Path,
        image_processor: Callable[[Image], torch.Tensor],
        similarity_measure: Callable[[torch.Tensor], torch.Tensor] = nn.CosineSimilarity(dim=1),
        enable_cache: bool = True,
    ) -> None:
        self.images_path = images_path
        self.image_processor = image_processor
        self.enable_cache = enable_cache
        self.similarity_measure = similarity_measure
        self.tensor_cache: Dict[str, torch.Tensor] = dict() # <file name> -> <data>
        self.embedding_cache: Dict[str, torch.Tensor] = dict() # <file name> -> <embedding>
        self.similarities_cache: Dict[str, Dict[str, float]] = dict() # fn1 -> fn2 -> sim(fn1, fn2)

    def __getitem__(self, file_name: Union[str, List[str]]) -> torch.Tensor:
        if isinstance(file_name, list):
            return torch.stack([self[n] for n in file_name])

        if file_name in self.tensor_cache:
            return self.tensor_cache[file_name]
        loaded = self.image_processor(Image.open(self.images_path / file_name))
        if self.enable_cache:
            self.tensor_cache[file_name] = loaded
        return loaded

    @property
    def known_embs(self) -> List[str]:
        return list(self.embedding_cache.keys())

    def update_emb(self, file_name: str, vec: torch.Tensor):
        self.embedding_cache[file_name] = vec

    def get_emb(self, file_name: str) -> Optional[torch.Tensor]:
        try:
            return self.embedding_cache[file_name]
        except:
            return None
    
    def get_sims(self, file_names: List[str]) -> Optional[torch.Tensor]:
        embeddings = []
        for name in file_names:
            emb = self.get_emb(name)
            if emb is None:
                return None
            embeddings.append(emb)
        embeddings = torch.stack(embeddings)
        return self.similarity_measure(embeddings)

class VWSDDataset(Dataset):
    def __init__(
        self,
        df: pd.DataFrame,
        image_set: ImageSet,
        text_preprocessor,
        use_context_as_text: bool = True,
        num_src_pics: int = 10,

        num_ns: int = 9,
        num_any_ns: int = 0,
        replace_any_ns: bool = False,
        replace_default_ns: bool = False,
        num_hard_ns: int = 0,
        num_any_when_no_hard_ns: int = 0,
    ) -> None:
        self.df = df
        self.image_set = image_set
        self.text_preprocessor = text_preprocessor
        self.text_field = "context" if use_context_as_text else "word"
        self.num_src_pics = num_src_pics
        self.num_ns = num_ns
        self.num_any_ns = num_any_ns
        self.replace_any_ns = replace_any_ns
        self.replace_default_ns = replace_default_ns
        self.num_hard_ns = num_hard_ns
        self.num_any_when_no_hard_ns = num_any_when_no_hard_ns

        self.all_image_names: np.ndarray = np.unique(
            self.df[[f"image{i}" for i in range(self.num_src_pics)]].values.ravel("K")
        )

        self.num_default_ns = self.num_ns - self.num_any_ns - self.num_hard_ns
        log.info(f"Total pics in sample: 1 positive, {self.num_any_ns} random from all dataset, {self.num_hard_ns} hard negative samples, {self.num_default_ns} from default rows = {self.num_ns + 1} total samples")

    def __len__(self) -> int:
        return len(self.df)

    def _sample_hard_names(self, pos_img_name: str) -> Optional[List[str]]:
        known_embs = self.image_set.known_embs
        try:
            pos_index = known_embs.index(pos_img_name)
        except ValueError:
            return None
        else:
            sim_mat = self.image_set.get_sims(known_embs)
            if sim_mat is None:
                return None
            top_indices = torch.argsort(sim_mat[pos_index], descending=True)[:self.num_hard_ns]
            return [known_embs[i] for i in top_indices]

    def __getitem__(self, index: int) -> Dict:
        row = self.df.iloc[index]
        pos_img_name = row["label"]

        negative_row_indices = []
        for i in range(self.num_src_pics):
            name = row[f"image{i}"]
            if name != pos_img_name:
                negative_row_indices.append(i)
        negative_row_indices = np.array(negative_row_indices)

        # making hard negatives & preparing replacements if not available
        mb_hard_ns_names = self._sample_hard_names(pos_img_name)
        if mb_hard_ns_names is None:
            add_alt_ns_num = self.num_any_when_no_hard_ns
            add_default_ns_num = self.num_hard_ns - add_alt_ns_num
            hard_ns_names = []
        else:
            add_default_ns_num = 0
            add_alt_ns_num = 0
            hard_ns_names = mb_hard_ns_names
        
        # default & alt names 
        default_ns_names = [row[f"image{i}"] for i in np.random.choice(
            negative_row_indices,
            self.num_default_ns + add_default_ns_num,
            replace = self.replace_default_ns
        )]
        alt_ns_names = list(np.random.choice(
            self.all_image_names[self.all_image_names != pos_img_name],
            self.num_any_ns + add_alt_ns_num,
            replace=self.replace_any_ns,
        ))
        
        # combine, shuffle, patch with positive
        names = default_ns_names + alt_ns_names + hard_ns_names
        assert len(names) == self.num_ns
        random.shuffle(names)
        label = random.randint(0, self.num_ns)
        names.insert(label, pos_img_name)

        return {
            "text": self.text_preprocessor(row[self.text_field]),
            "images": self.image_set[names], 
            "label": label,
            "image_names": names,
        }

In [7]:
# text_collator = DataCollatorWithPadding(tokenizer)

# def collate(samples):
#     texts = [s.pop("text") for s in samples]
#     collated = torch.utils.data.default_collate(samples)
#     collated["text"] = text_collator(texts)
#     if DEBUG:
#         for k, v in collated.items():
#             print(k, v.shape if isinstance(v, torch.Tensor) else len(v))
#         print("COLLATED TEXTS:", collated["text"])
#     return collated

In [8]:
train_image_set = ImageSet(
    images_path = IMAGES_PATH, image_processor = preprocess,
    similarity_measure = nn.CosineSimilarity(dim=1),
    enable_cache = False,
)
train_ds = ItmDataset(
    df = train_df,
    image_set = train_image_set,
    text_preprocessor = lambda x: x,
    # text_preprocessor = tokenizer,
    use_context_as_text = True,
    num_src_pics = NUM_SRC_PICS, num_ns = NUM_NS, num_any_ns = NUM_RAND_NS, replace_any_ns = REPLACE_RAND_NS, replace_default_ns = REPLACE_DEFAULT_NS, num_hard_ns = NUM_HARD_NS, num_any_when_no_hard_ns = NUM_RAND_WHEN_NO_HARD_NS,
)
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=TRAIN_BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)
# train_dl = torch.utils.data.DataLoader(train_ds, batch_size=TRAIN_BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True, collate_fn=collate)
train_l = len(train_dl)
train_l

INFO:root:Total pics in sample: 1 positive, 0 random from all dataset, 0 hard negative samples, 9 from default rows = 10 total samples


6103

In [9]:
val_image_set = ImageSet(
    images_path = IMAGES_PATH, image_processor = preprocess,
    similarity_measure = nn.CosineSimilarity(dim=1),
    enable_cache = False,
)
val_ds = ItmDataset(
    df = validation_df,
    image_set = val_image_set,
    text_preprocessor = lambda x: x,
    use_context_as_text = True,
    num_src_pics = NUM_SRC_PICS, num_ns = NUM_NS, num_any_ns = NUM_RAND_NS, replace_any_ns = REPLACE_RAND_NS, replace_default_ns = REPLACE_DEFAULT_NS, num_hard_ns = NUM_HARD_NS, num_any_when_no_hard_ns = NUM_RAND_WHEN_NO_HARD_NS,
)


val2_image_set = ImageSet(
    images_path = IMAGES_PATH, image_processor = preprocess,
    similarity_measure = nn.CosineSimilarity(dim=1),
    enable_cache = False,
)
val2_ds = ItmDataset(
    df = val2_df,
    image_set = val2_image_set,
    text_preprocessor = lambda x: x,
    # text_preprocessor = tokenizer,
    use_context_as_text = True,
    num_src_pics = NUM_SRC_PICS, num_ns = NUM_NS, num_any_ns = NUM_RAND_NS, replace_any_ns = REPLACE_RAND_NS, replace_default_ns = REPLACE_DEFAULT_NS, num_hard_ns = NUM_HARD_NS, num_any_when_no_hard_ns = NUM_RAND_WHEN_NO_HARD_NS,
)

INFO:root:Total pics in sample: 1 positive, 0 random from all dataset, 0 hard negative samples, 9 from default rows = 10 total samples
INFO:root:Total pics in sample: 1 positive, 0 random from all dataset, 0 hard negative samples, 9 from default rows = 10 total samples


## Training
### Train settings

In [10]:
model.train()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
num_training_steps = int(NUM_EPOCHS * (train_l / GRAD_ACCUM))
num_warmup_steps = int(num_training_steps * WARMUP_FRAC)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps,
)
print(f"{num_training_steps} training steps which include {num_warmup_steps} warmup ones")

4068 training steps which include 406 warmup ones


### Validation settings

In [11]:
labels_range = np.arange(NUM_SRC_PICS)
if DEBUG:
    print("labels range:", labels_range)

def get_batch_scores(model, batch, dev, env):
    batch = to_device(batch, dev)
    outputs = model(batch)
    np_labels = batch["label"].numpy(force=True)
    np_preds = outputs.numpy(force=True)
    if DEBUG:
        # print("batch:", batch)
        print("outputs:", outputs)
        print("np labels:", np_labels)
        print("np preds:", np_preds)
    return {
        "Loss": loss_fn(outputs, F.one_hot(batch["label"], NUM_SRC_PICS).float().to(dev)),
        "Accuracy@Top1": top_k_accuracy_score(np_labels, np_preds, k=1, labels=labels_range),
        "Accuracy@Top3": top_k_accuracy_score(np_labels, np_preds, k=3, labels=labels_range),
        "Mean Reciprocal Rank": mrr(np_labels, np_preds),
    }

def log_score(train_step, name, metric_name, metric_value):
    writer.add_scalar(f"{metric_name}/{name}", metric_value, train_step)
    print(f"[{train_step}][{name}]", f"{metric_name}: {metric_value}")


validation = Validation(
    common = {
        "device": DEVICE,
        "get_batch_scores": get_batch_scores,
        "step_cond": lambda s: (s % STEPS_BETWEEN_VAL == 0) or (s == 1),
        "log_score": log_score,
    },
    configs = {
        "Validation": { "dl": torch.utils.data.DataLoader(val_ds, batch_size=VAL_BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True), },
        # "Validation (augmented)": { "dl": torch.utils.data.DataLoader(val_aug_ds, batch_size=VALIDATION_BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True), },
        "Validation 2": { "dl": torch.utils.data.DataLoader(val2_ds, batch_size=VAL_BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True), },
        # "Validation 2 (augmented)": { "dl": torch.utils.data.DataLoader(val2_aug_ds, batch_size=VALIDATION_BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True), },
    },
)

### Train-validation loop

In [12]:
step_num = 0
grad_accum_step_cnt = 0
save_checkpoint_step_cnt = 0
progress_bar = tqdm(range(num_training_steps))

for epoch_num in range(NUM_EPOCHS):
    model.train()
    train_loss = 0.0
    train_scores = {"acc1": 0, "acc3": 0, "mrr": 0}
    for batch in train_dl:
        batch = to_device(batch, DEVICE)
        outputs = model(batch)
        loss = loss_fn(outputs, F.one_hot(batch["label"], NUM_LABELS).float().to(DEVICE))

        train_loss += loss.item()
        train_scores = sum_scores(train_scores, eval_batch(batch["label"], outputs, num_labels = NUM_LABELS))

        loss.backward()
        grad_accum_step_cnt += 1

        if grad_accum_step_cnt == GRAD_ACCUM: 
            writer.add_scalar("Learning Rate", lr_scheduler.get_last_lr()[0], step_num)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            writer.add_scalar("Loss/Train", train_loss / GRAD_ACCUM, step_num)            
            for k, v in div_scores(train_scores, GRAD_ACCUM).items():
                writer.add_scalar(metric2name[k] + "/Train", v, step_num)
            train_loss = 0.0
            train_scores = {"acc1": 0, "acc3": 0, "mrr": 0}
            grad_accum_step_cnt = 0
            step_num += 1
            save_checkpoint_step_cnt += 1
            progress_bar.update(1)
            validation(step_num, model)

        if save_checkpoint_step_cnt == STEPS_BETWEEN_CHECKPOINT:
            save_checkpoint_step_cnt = 0
            p = CHECKPOINT_PATH / f"step-{step_num}.pt"
            torch.save(model.state_dict(), p)

  0%|          | 1/4068 [00:08<9:51:32,  8.73s/it]

[1][Validation] Loss: 2.2632877826690674
[1][Validation] Accuracy@Top1: 0.5709064327485384
[1][Validation] Accuracy@Top3: 0.8036549707602337
[1][Validation] Mean Reciprocal Rank: 0.7077473777035181
[1][Validation 2] Loss: 2.2597692012786865
[1][Validation 2] Accuracy@Top1: 0.6276923076923077
[1][Validation 2] Accuracy@Top3: 0.8220512820512822
[1][Validation 2] Mean Reciprocal Rank: 0.7459452584452584


  6%|▌         | 250/4068 [33:33<5:57:31,  5.62s/it] 

[250][Validation] Loss: 1.845352053642273
[250][Validation] Accuracy@Top1: 0.7786549707602337
[250][Validation] Accuracy@Top3: 0.9573099415204682
[250][Validation] Mean Reciprocal Rank: 0.8694843590457625
[250][Validation 2] Loss: 1.9382201433181763
[250][Validation 2] Accuracy@Top1: 0.6939743589743588
[250][Validation 2] Accuracy@Top3: 0.9119230769230764
[250][Validation 2] Mean Reciprocal Rank: 0.808325702075702


 12%|█▏        | 500/4068 [1:07:10<5:34:22,  5.62s/it] 

[500][Validation] Loss: 1.815487027168274
[500][Validation] Accuracy@Top1: 0.7549707602339183
[500][Validation] Accuracy@Top3: 0.948976608187135
[500][Validation] Mean Reciprocal Rank: 0.8537425740276621
[500][Validation 2] Loss: 1.9150879383087158
[500][Validation 2] Accuracy@Top1: 0.6623076923076924
[500][Validation 2] Accuracy@Top3: 0.8947435897435894
[500][Validation 2] Mean Reciprocal Rank: 0.7856789275539271


 18%|█▊        | 750/4068 [1:40:43<5:11:20,  5.63s/it]   

[750][Validation] Loss: 1.769071340560913
[750][Validation] Accuracy@Top1: 0.7894736842105265
[750][Validation] Accuracy@Top3: 0.9656432748538016
[750][Validation] Mean Reciprocal Rank: 0.8784939199851476
[750][Validation 2] Loss: 1.8763978481292725
[750][Validation 2] Accuracy@Top1: 0.6984615384615384
[750][Validation 2] Accuracy@Top3: 0.9182051282051277
[750][Validation 2] Mean Reciprocal Rank: 0.8124615893365894


 25%|██▍       | 1000/4068 [2:14:28<4:47:31,  5.62s/it]  

[1000][Validation] Loss: 1.7807267904281616
[1000][Validation] Accuracy@Top1: 0.7779239766081866
[1000][Validation] Accuracy@Top3: 0.9599415204678367
[1000][Validation] Mean Reciprocal Rank: 0.8692169080107672
[1000][Validation 2] Loss: 1.883785367012024
[1000][Validation 2] Accuracy@Top1: 0.6967948717948717
[1000][Validation 2] Accuracy@Top3: 0.8998717948717947
[1000][Validation 2] Mean Reciprocal Rank: 0.8077988909238907


 31%|███       | 1250/4068 [2:48:01<4:23:48,  5.62s/it]   

[1250][Validation] Loss: 1.7630964517593384
[1250][Validation] Accuracy@Top1: 0.7871345029239766
[1250][Validation] Accuracy@Top3: 0.9631578947368425
[1250][Validation] Mean Reciprocal Rank: 0.8755490578297601
[1250][Validation 2] Loss: 1.861405372619629
[1250][Validation 2] Accuracy@Top1: 0.7046153846153845
[1250][Validation 2] Accuracy@Top3: 0.9111538461538462
[1250][Validation 2] Mean Reciprocal Rank: 0.8151474867724867


 37%|███▋      | 1500/4068 [3:21:29<4:00:38,  5.62s/it]   

[1500][Validation] Loss: 1.7696207761764526
[1500][Validation] Accuracy@Top1: 0.7748538011695904
[1500][Validation] Accuracy@Top3: 0.9578947368421058
[1500][Validation] Mean Reciprocal Rank: 0.867789380859556
[1500][Validation 2] Loss: 1.8642746210098267
[1500][Validation 2] Accuracy@Top1: 0.7089743589743589
[1500][Validation 2] Accuracy@Top3: 0.9189743589743589
[1500][Validation 2] Mean Reciprocal Rank: 0.8197591066341066


 43%|████▎     | 1750/4068 [3:55:07<3:37:23,  5.63s/it]   

[1750][Validation] Loss: 1.7805321216583252
[1750][Validation] Accuracy@Top1: 0.7666666666666668
[1750][Validation] Accuracy@Top3: 0.9523391812865503
[1750][Validation] Mean Reciprocal Rank: 0.8612544091710762
[1750][Validation 2] Loss: 1.8668131828308105
[1750][Validation 2] Accuracy@Top1: 0.6988461538461539
[1750][Validation 2] Accuracy@Top3: 0.9194871794871792
[1750][Validation 2] Mean Reciprocal Rank: 0.8133238705738706


 49%|████▉     | 2000/4068 [4:28:43<3:14:01,  5.63s/it]   

[2000][Validation] Loss: 1.7599775791168213
[2000][Validation] Accuracy@Top1: 0.7922514619883035
[2000][Validation] Accuracy@Top3: 0.9589181286549715
[2000][Validation] Mean Reciprocal Rank: 0.8778649169219344
[2000][Validation 2] Loss: 1.8498072624206543
[2000][Validation 2] Accuracy@Top1: 0.7203846153846153
[2000][Validation 2] Accuracy@Top3: 0.9221794871794871
[2000][Validation 2] Mean Reciprocal Rank: 0.8266691595441594


 55%|█████▌    | 2250/4068 [5:02:35<2:50:28,  5.63s/it]   

[2250][Validation] Loss: 1.7746481895446777
[2250][Validation] Accuracy@Top1: 0.7728070175438597
[2250][Validation] Accuracy@Top3: 0.9516081871345038
[2250][Validation] Mean Reciprocal Rank: 0.8650011603081781
[2250][Validation 2] Loss: 1.859097957611084
[2250][Validation 2] Accuracy@Top1: 0.7057692307692307
[2250][Validation 2] Accuracy@Top3: 0.9146153846153845
[2250][Validation 2] Mean Reciprocal Rank: 0.8169820919820917


 61%|██████▏   | 2500/4068 [5:36:12<2:26:18,  5.60s/it]  

[2500][Validation] Loss: 1.7911733388900757
[2500][Validation] Accuracy@Top1: 0.7492690058479526
[2500][Validation] Accuracy@Top3: 0.9409356725146203
[2500][Validation] Mean Reciprocal Rank: 0.8488746170983014
[2500][Validation 2] Loss: 1.8671339750289917
[2500][Validation 2] Accuracy@Top1: 0.6911538461538461
[2500][Validation 2] Accuracy@Top3: 0.9107692307692302
[2500][Validation 2] Mean Reciprocal Rank: 0.8094856532356531


 68%|██████▊   | 2750/4068 [6:09:43<2:03:15,  5.61s/it]  

[2750][Validation] Loss: 1.7849836349487305
[2750][Validation] Accuracy@Top1: 0.7567251461988301
[2750][Validation] Accuracy@Top3: 0.9442982456140355
[2750][Validation] Mean Reciprocal Rank: 0.8538307574491788
[2750][Validation 2] Loss: 1.8607451915740967
[2750][Validation 2] Accuracy@Top1: 0.6993589743589746
[2750][Validation 2] Accuracy@Top3: 0.9123076923076923
[2750][Validation 2] Mean Reciprocal Rank: 0.8139547212047213


 74%|███████▎  | 3000/4068 [6:43:21<1:39:47,  5.61s/it]  

[3000][Validation] Loss: 1.8049863576889038
[3000][Validation] Accuracy@Top1: 0.7299707602339186
[3000][Validation] Accuracy@Top3: 0.9318713450292401
[3000][Validation] Mean Reciprocal Rank: 0.8357177666388188
[3000][Validation 2] Loss: 1.8718235492706299
[3000][Validation 2] Accuracy@Top1: 0.6889743589743593
[3000][Validation 2] Accuracy@Top3: 0.911153846153846
[3000][Validation 2] Mean Reciprocal Rank: 0.8069339641839641


 80%|███████▉  | 3250/4068 [7:16:51<1:16:25,  5.61s/it]  

[3250][Validation] Loss: 1.794877290725708
[3250][Validation] Accuracy@Top1: 0.7418128654970758
[3250][Validation] Accuracy@Top3: 0.9343567251461993
[3250][Validation] Mean Reciprocal Rank: 0.8432189269469968
[3250][Validation 2] Loss: 1.8612120151519775
[3250][Validation 2] Accuracy@Top1: 0.6957692307692306
[3250][Validation 2] Accuracy@Top3: 0.9164102564102563
[3250][Validation 2] Mean Reciprocal Rank: 0.8123857855107856


 86%|████████▌ | 3500/4068 [7:50:17<53:02,  5.60s/it]    

[3500][Validation] Loss: 1.8153852224349976
[3500][Validation] Accuracy@Top1: 0.7175438596491233
[3500][Validation] Accuracy@Top3: 0.921052631578948
[3500][Validation] Mean Reciprocal Rank: 0.8266316253596951
[3500][Validation 2] Loss: 1.8746826648712158
[3500][Validation 2] Accuracy@Top1: 0.6778205128205129
[3500][Validation 2] Accuracy@Top3: 0.905
[3500][Validation 2] Mean Reciprocal Rank: 0.7997234940984939


 92%|█████████▏| 3750/4068 [8:23:52<29:45,  5.61s/it]    

[3750][Validation] Loss: 1.815322756767273
[3750][Validation] Accuracy@Top1: 0.7144736842105265
[3750][Validation] Accuracy@Top3: 0.9211988304093568
[3750][Validation] Mean Reciprocal Rank: 0.8248077369349299
[3750][Validation 2] Loss: 1.8746932744979858
[3750][Validation 2] Accuracy@Top1: 0.6808974358974363
[3750][Validation 2] Accuracy@Top3: 0.9071794871794867
[3750][Validation 2] Mean Reciprocal Rank: 0.8011983109483111


 98%|█████████▊| 4000/4068 [8:57:28<06:21,  5.61s/it]    

[4000][Validation] Loss: 1.8159812688827515
[4000][Validation] Accuracy@Top1: 0.7143274853801173
[4000][Validation] Accuracy@Top3: 0.9200292397660822
[4000][Validation] Mean Reciprocal Rank: 0.8244319131161233
[4000][Validation 2] Loss: 1.8752110004425049
[4000][Validation 2] Accuracy@Top1: 0.6802564102564103
[4000][Validation 2] Accuracy@Top3: 0.9071794871794869
[4000][Validation 2] Mean Reciprocal Rank: 0.8009365588115589


100%|██████████| 4068/4068 [9:13:51<00:00,  5.59s/it]   