In [None]:
import pandas as pd
import torch

from accelerate import Accelerator

from src.dataloaders.vec_dataset import VecDataset

In [None]:
from config.config import CONFIG

In [None]:
from src.models.embedder import Embedder
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained(CONFIG['model'], use_fast=CONFIG['tokenizer_use_fast'])
embedder = AutoModel.from_pretrained(CONFIG['model']).to(CONFIG['device'])

In [None]:
train_df = pd.read_csv(CONFIG['train_path'])
dev_df = pd.read_csv(CONFIG['dev_path'])

In [None]:
from src.models.utils import query_embedding

train_df['query_embed'] = train_df['query'].apply(lambda x: query_embedding(x, embedder, tokenizer, CONFIG))
dev_df['query_embed'] = dev_df['query'].apply(lambda x: query_embedding(x, embedder, tokenizer, CONFIG))

In [None]:
def upsample_data(df):
    # Upsample the data to balance across languages
    lang_counts = df['lang'].value_counts()
    max_count = lang_counts.max()
    upsampled_dfs = []
    for lang in lang_counts.index:
        lang_df = df[df['lang'] == lang]
        upsampled_df = lang_df.sample(max_count - lang_counts[lang], replace=True)
        upsampled_dfs.append(pd.concat([lang_df, upsampled_df]))
    return pd.concat(upsampled_dfs)

In [None]:
train_df = upsample_data(train_df)

In [None]:
# save query embeddings
train_df.to_csv(CONFIG['train_emb_path'], index=False)
dev_df.to_csv(CONFIG['dev_emb_path'], index=False)

In [None]:
# load query embeddings
train_df = pd.read_csv(CONFIG['train_emb_path'])
dev_df = pd.read_csv(CONFIG['dev_emb_path'])

In [None]:
import pickle

# load document embeddings
with open(CONFIG['doc_embeds_path'], 'rb') as f:
    doc_embeds = pickle.load(f)

In [None]:
from src.dataloaders.utils import get_train_val_dataloaders

train_dl, val_dl = get_train_val_dataloaders(CONFIG, train_df, dev_df, doc_embeds)

In [None]:
accelerator = Accelerator(gradient_accumulation_steps=CONFIG['gradient_accumulation_steps'])

In [None]:
from src.training.trainner import Trainer
from src.models.dpr import DPRModel

embed_size = embedder.config.hidden_size 
model = DPRModel(embed_size).to(CONFIG['device'])
trainer = Trainer(model, (train_dl, val_dl), CONFIG, accelerator)

In [None]:
trainer.train()

In [None]:
losses_df = pd.DataFrame({'epoch':list(range(1, CONFIG['epochs'] + 1)),
                          'train_loss':trainer.train_losses, 
                          'val_loss': trainer.val_losses
                         })
losses_df.to_csv(CONFIG['losses_path'], index=False)

In [None]:
from matplotlib import pyplot as plt

plt.plot(trainer.train_losses, color='red')
plt.plot(trainer.val_losses, color='orange')
plt.title('Loss')
plt.legend(['Train', 'Validation'], loc='upper right')
plt.show()

In [None]:
from src.models.encoder import Encoder

embed_size = embedder.config.hidden_size
doc_encoder = Encoder(embed_size).to(CONFIG['device'])
doc_encoder.load_state_dict(torch.load(f"{CONFIG['load_path']}/doc_encoder.pth"))

In [None]:
from src.dataloaders.vec_dataset import VecDataset

docids = []
langs = []
vecs = []
for docid, embed_dict in doc_embeds.items():
    for embed in embed_dict['embeds']:
        docids.append(docid)
        langs.append(embed_dict['lang'])
        vecs.append(embed)
    
vec_ds = VecDataset(docids, langs, vecs)

In [None]:
from tqdm import tqdm

vec_dataloader = torch.utils.data.DataLoader(vec_ds, batch_size=256, collate_fn=vec_ds.collate_fn, shuffle=False, num_workers=4)

doc_encodes = []
doc_ids = []
doc_langs = []

with torch.no_grad():
    for batch in tqdm(vec_dataloader, desc="Encoding documents"):
        doc_ids.extend(batch['doc_id'])
        doc_langs.extend(batch['lang'])
        doc_encodes.extend([doc_encode for doc_encode in doc_encoder(batch['vec'].to(CONFIG['device'])).cpu().numpy()])
        

In [None]:
from tqdm import tqdm

encode_dict = {}

for docid, lang, encode_chunk in tqdm(zip(doc_ids, doc_langs, doc_encodes), desc="Creating encode dictionary"):
    if docid not in encode_dict:
        encode_dict[docid] = {'lang': lang, 'encodes': []}
    encode_dict[docid]['encodes'].append(encode_chunk)


In [None]:
import pickle

with open(CONFIG['doc_encodes_path'], 'wb') as f:
    pickle.dump(encode_dict, f)