In [None]:
import os

import random

from pathlib import Path

from tqdm import tqdm

import numpy as np

from beir import util
from beir.datasets.data_loader import GenericDataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from peft import get_peft_model, LoraConfig, TaskType

from transformers import AutoModel, AutoTokenizer

from sklearn.feature_extraction.text import TfidfVectorizer

In [None]:
project_directory = os.path.join("results", "triplet-loss-fine-tuning")
input_filepath = os.path.join(project_directory, "inputs", "beir_data")
output_filepath = project_directory

Path(output_filepath).mkdir(parents=True, exist_ok=True)


dataset = "fiqa"
url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset}.zip"
output_directory = util.download_and_unzip(url, input_filepath)

train_corpus, train_queries, train_qrels = GenericDataLoader(output_directory).load(split="train")
test_corpus, test_queries, test_qrels = GenericDataLoader(output_directory).load(split="train")

In [None]:
class TripletTextDataset(Dataset):
    def __init__(self, triplets, tokenizer, max_length=128):
        self.triplets = triplets # list of (anchor, positive, negative)
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, idx):
        anchor, pos, neg = self.triplets[idx]
        encoded = self.tokenizer([anchor, pos, neg], 
                                 padding="max_length", 
                                 truncation=True,
                                 max_length=self.max_length,
                                 return_tensors="pt")
        return {key: val for key, val in encoded.items()}


class TextDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=128):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        encoded = self.tokenizer(text, 
                                 padding="max_length", 
                                 truncation=True,
                                 max_length=self.max_length,
                                 return_tensors="pt")
        return encoded


def average_precision(pred, true):
    if not true:
        return 0.0
    hits = 0
    score = 0.0
    for pred_idx, pred_id in enumerate(pred):
        if pred_id in true:
            hits += 1
            score += hits / (pred_idx + 1)
    return score / len(true)


def mean_average_precision(predictions, trues):
    scores = [average_precision(pred, true) for pred, true in zip(predictions, trues)]
    return sum(scores) / len(scores)


def mean_reciprocal_rank(predictions, trues):
    score = 0.0
    for pred, true in zip(predictions, trues):
        for pred_idx, pred_id in enumerate(pred):
            if pred_id in true:
                score += 1 / (pred_idx + 1)
                break
    return score / len(predictions)

In [None]:
batch_size = 16
accumulation_steps = 4
num_epochs = 5 * accumulation_steps
margin = 0.3
max_length = 128

device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

torch.cuda.empty_cache()

model_name = "microsoft/deberta-v3-base"

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

model = AutoModel.from_pretrained(model_name)
# base_model = AutoModel.from_pretrained(model_name)
# lora_config = LoraConfig(
#     r=8,
#     lora_alpha=16,
#     target_modules=["query_proj", "value_proj"],
#     lora_dropout=0.1,
#     bias="none",
#     task_type=TaskType.FEATURE_EXTRACTION
# )
# model = get_peft_model(base_model, lora_config)
# model.print_trainable_parameters()
model = model.to(device)

In [None]:
eval_fraction = 0.1
eval_query_ids = random.sample(sorted(test_queries), int(len(test_queries) * eval_fraction))
eval_queries = {query_id: test_queries[query_id] for query_id in eval_query_ids}
eval_documents = {query_id: {document_id: (test_corpus[document_id]["title"] + " " + test_corpus[document_id]["text"]).strip() for document_id in test_qrels[query_id].keys()} for query_id in eval_query_ids}

eval_query_lst = []
eval_trues = []
eval_document_lst = []
eval_document_id_lst = []
for query_id in eval_query_ids:
    eval_trues.append(list(test_qrels[query_id].keys()))
    eval_query_lst.append(eval_queries[query_id])
    for document_id, document in eval_documents[query_id].items():
        eval_document_id_lst.append(document_id)
        eval_document_lst.append(document)

eval_query_dataset = TextDataset(eval_query_lst, tokenizer, max_length=max_length)
eval_query_dataloader = DataLoader(eval_query_dataset, batch_size=batch_size, shuffle=False)
eval_document_dataset = TextDataset(eval_document_lst, tokenizer, max_length=max_length)
eval_document_dataloader = DataLoader(eval_document_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# premining_query_lst = list(train_queries.values())
# premining_query_id_lst = list(train_queries.keys())
# premining_document_lst = []
# premining_document_id_lst = []
# for document_id, document in train_corpus.items():
#     premining_document_id_lst.append(document_id)
#     premining_document_lst.append((document["title"] + " " + document["text"]).strip())

# vectorizer = TfidfVectorizer()

# vectorizer.fit(premining_document_lst + premining_query_lst)
# corpus_tf_idf = vectorizer.transform(premining_document_lst)
# query_tf_idf = vectorizer.transform(premining_query_lst)

# premining_cosine_similarities = query_tf_idf @ corpus_tf_idf.T

###############################################################################################

#     triplets = []
#     progress = tqdm(enumerate(premining_cosine_similarities), total=len(premining_query_lst))
#     progress.set_description("Create dataset")
#     for query_idx, cosine_similarity in progress:
#         query_text = premining_query_lst[query_idx]
#         relevant_docs = list(train_qrels[premining_query_id_lst[query_idx]].keys())
#         if not relevant_docs:
#             continue
#         pos_id = random.choice(relevant_docs)
#         pos_text = (train_corpus[pos_id]["title"] + " " + train_corpus[pos_id]["text"]).strip()
    
#         ranked_document_idxs = np.argsort(cosine_similarity.toarray()).squeeze()[::-1]
#         trues_ids = train_qrels[premining_query_id_lst[query_idx]].keys()
#         neg_text = None
#         for document_idx in ranked_document_idxs:
#             if premining_document_id_lst[document_idx] not in trues_ids:
#                 neg_text = premining_document_lst[document_idx]
#                 break
#         if not neg_text:
#             continue
    
#         triplets.append((query_text, pos_text, neg_text))

In [None]:
def mean_pooling(output, attention_mask, normalized=True):
    token_embeddings = output.last_hidden_state # [batch, seq_len, hidden]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1)
    sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
    mean_embeddings = sum_embeddings / sum_mask
    if normalized:
        return F.normalize(mean_embeddings, p=2, dim=1)
    return mean_embeddings


criterion = torch.nn.TripletMarginWithDistanceLoss(
    distance_function=lambda x, y: 1.0 - torch.nn.functional.cosine_similarity(x, y),
    margin=margin
)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

for epoch in range(num_epochs):
    model.train()
    
    # Create dataset
    triplets = []
    progress = tqdm(train_queries.items())
    progress.set_description("Create dataset")
    for query_id, query_text in progress:
        relevant_docs = list(train_qrels[query_id].keys())
        if not relevant_docs:
            continue
        pos_id = random.choice(relevant_docs)
        pos_text = (train_corpus[pos_id]["title"] + " " + train_corpus[pos_id]["text"]).strip()
  
        # Sample a random document not relevant to this query
        neg_id = random.choice([doc_id for doc_id in train_corpus if doc_id not in train_qrels[query_id]])
        neg_text = (train_corpus[neg_id]["title"] + " " + train_corpus[neg_id]["text"]).strip()
  
        triplets.append((query_text, pos_text, neg_text))

    dataset = TripletTextDataset(triplets, tokenizer, max_length=max_length)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Training loop
    running_loss = 0
    accumulation_step = 0
    progress = tqdm(dataloader)
    for batch in progress:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)        
        # Tokens and attention masks
        anchor_ids, pos_ids, neg_ids = input_ids[:,0,:], input_ids[:,1,:], input_ids[:,2,:]
        anchor_mask, pos_mask, neg_mask = attention_mask[:,0,:], attention_mask[:,1,:], attention_mask[:,2,:]
        # Anchor
        anchor_output = model(input_ids=anchor_ids, attention_mask=anchor_mask)
        anchor_embedding = mean_pooling(anchor_output, anchor_mask, normalized=True)
        # Positive
        pos_output = model(input_ids=pos_ids, attention_mask=pos_mask)
        pos_embed = mean_pooling(pos_output, pos_mask, normalized=True)
        # Negative
        neg_output = model(input_ids=neg_ids, attention_mask=neg_mask)
        neg_embed = mean_pooling(neg_output, neg_mask, normalized=True)
        # Triplet loss
        loss = criterion(anchor_embedding, pos_embed, neg_embed) / accumulation_steps
        loss.backward()
        # Gradient accumulation
        accumulation_step += 1
        if accumulation_step == accumulation_steps:
            accumulation_step = 0
            # Step optimizer
            optimizer.step()
            optimizer.zero_grad()
        # Record loss
        running_loss += loss.item() * accumulation_steps
        progress.set_description(f"Loss: {loss.item():.4f}")
        
    ## Evalueate model
    print(f"Epoch Loss: {running_loss/len(dataloader):.4f}")
    model.eval()
    
    # Queries
    query_embeddings = []
    progress = tqdm(eval_query_dataloader)
    progress.set_description("Compute eval query embeddings")
    with torch.no_grad():
        for batch in progress:
            input_ids = batch["input_ids"].squeeze(dim=1).to(device)
            attention_mask = batch["attention_mask"].squeeze(dim=1).to(device)
            output = model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = mean_pooling(output, attention_mask, normalized=True)
            query_embeddings.append(embeddings.detach().cpu().numpy())
    query_embeddings = np.concatenate(query_embeddings)
    
    # Documents
    document_embeddings = []
    progress = tqdm(eval_document_dataloader)
    progress.set_description("Compute eval document embeddings")
    with torch.no_grad():
        for batch in progress:
            input_ids = batch["input_ids"].squeeze(dim=1).to(device)
            attention_mask = batch["attention_mask"].squeeze(dim=1).to(device)
            output = model(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = mean_pooling(output, attention_mask, normalized=True)
            document_embeddings.append(embeddings.detach().cpu().numpy())
    document_embeddings = np.concatenate(document_embeddings)
    
    # Cosine similatities
    k = 5
    cosine_similarities = query_embeddings @ document_embeddings.T
    prediction_ids = []
    for cosine_similarity in cosine_similarities:
        top_k = np.argsort(cosine_similarity).squeeze()[-k:][::-1]
        prediction_ids.append(top_k.tolist())
    eval_predictions = [[eval_document_id_lst[document_prediction_ids] for document_prediction_ids in query_document_prediction_ids]for query_document_prediction_ids in prediction_ids]
    
    # Score model
    map_score = mean_average_precision(eval_predictions, eval_trues)
    mrr_score = mean_reciprocal_rank(eval_predictions, eval_trues)
    print(f"Mean average precision: {map_score}")
    print(f"Mean reciprocal rank: {mrr_score}")
    
    # Save checkpoint
    model_path = os.path.join(output_filepath, f"deberta_fiqa_fine_tuned_epoch_{epoch}.pt")
    torch.save(model.state_dict(), model_path)