In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import normalize
from sklearn.model_selection import train_test_split
import json
import pandas as pd
import os.path
from os.path import join, expanduser
from torch.utils.data import Dataset
from torch import nn
import torch.optim
from fit_mteb_pytorch import *

In [None]:
# Load embeddings for corpus and queries
dset = MiniMarcoDataset()
query_ids_train, query_ids_test = train_test_split(
    dset.query_ids, random_state=1, test_size=0.2)

In [None]:
# tfidf
embs_qa, embs_tfidf, labels = dset[query_ids_test]
mrr, mtop1 = evaluate_retrieval(
    embs_tfidf, dset.embs_tfidf_corpus_df.values,
    labels=labels, corpus_ids=dset.corpus_ids)
print(f'TF-IDF Test: {mrr=:.2f}, {mtop1=:.2f}')
mrr, mtop1 = evaluate_retrieval(
    embs_qa, dset.embs_qa_corpus_df.values,
    labels=labels, corpus_ids=dset.corpus_ids)
print(f'QA Test: {mrr=:.2f}, {mtop1=:.2f}')

In [None]:
embs_qa, _, labels = dset[query_ids_train]
embs_qa_similar = np.vstack(
    [dset.embs_qa_corpus_df.loc[lab[0]].values for lab in labels])


# put all data on GPU
device = 'cuda'
embs_qa = torch.tensor(embs_qa, dtype=torch.float).to(device)
embs_qa_similar = torch.tensor(embs_qa_similar, dtype=torch.float).to(device)
embs_qa_test, _, labels_test = dset[query_ids_test]
embs_qa_test = torch.tensor(embs_qa_test, dtype=torch.float).to(device)


model = LinearMapping(embs_qa.shape[1], embs_qa_similar.shape[1]).to('cuda')
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(100):
    with torch.no_grad():
        # evaluate train
        output = model(embs_qa)
        mrr, mtop1 = evaluate_retrieval(
            output.cpu().detach().numpy(), dset.embs_qa_corpus_df.values,
            labels=labels, corpus_ids=dset.corpus_ids)
        print(f'\tQA Train: {mrr=:.3f}, {mtop1=:.3f}')

        # evaluate test
        output = model(embs_qa_test)
        mrr, mtop1 = evaluate_retrieval(
            output.cpu().detach().numpy(), dset.embs_qa_corpus_df.values,
            labels=labels_test, corpus_ids=dset.corpus_ids)
        print(f'\tQA Test: {mrr=:.3f}, {mtop1=:.3f}')

    # sample new neg examples
    embs_qa_dissimilar = np.vstack(
        [dset.embs_qa_corpus_df.loc[dset.get_random_neg_corpus_id(q)].values for q in query_ids_train])
    embs_qa_dissimilar = torch.tensor(
        embs_qa_dissimilar, dtype=torch.float).to(device)

    model.train()
    optimizer.zero_grad()
    output = model(embs_qa)
    loss = criterion(output, embs_qa_similar) - 0.1 * \
        criterion(output, embs_qa_dissimilar)
    loss.backward()
    optimizer.step()
    print(f'Epoch {epoch} Loss {loss.item()}')