In [None]:
# this notebook adpats the scripts avaible at https://github.com/beir-cellar/beir/wiki/Examples-and-tutorials

In [None]:
from sentence_transformers import SentenceTransformer, models, losses, InputExample
from beir import util, LoggingHandler
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.train import TrainRetriever
from torch.utils.data import Dataset
from tqdm.autonotebook import tqdm
import pathlib, os, gzip, json
import logging
import random
dataset = "BEIR_dataset_T5AndChatgptMisinfo"
corpus, queries, _ = GenericDataLoader(dataset).load(split="train")

In [None]:
#################################
#### Parameters for Training ####
#################################
train_batch_size = 64           
max_seq_length = 350            
ce_score_margin = 0.1            
num_negs_per_system = 32       

In [None]:
triplets_filepath = "generated/" + dataset + "/hard-negatives.jsonl"
train_queries = {}
with open(triplets_filepath, 'rt', encoding='utf8') as fIn:
    for line in tqdm(fIn, total=502939):
#         print(line)
        data = json.loads(line)
        
        #Get the positive passage ids
        pos_pids = data['pos'] 
    
        #Get the hard negatives
        neg_pids = set()
        for system_negs in data['neg'].values():
            negs_added = 0
            for item in system_negs[5:]:
                pid = item #item['pid']
                if pid not in neg_pids:
                    neg_pids.add(pid)
                    negs_added += 1
                    if negs_added >= num_negs_per_system:
                        break
        
        if len(pos_pids) > 0 and len(neg_pids) > 0:
            train_queries[data['qid']] = {'query': queries[data['qid']], 'pos': pos_pids, 'hard_neg': list(neg_pids)}
        
print("Train queries: {}".format(len(train_queries)))


In [None]:
len(train_queries), len(train_queries['genQ0']['hard_neg'])

In [None]:
# We create a custom dataset that returns triplets (query, positive, negative)
# on-the-fly based on the information from the mined-hard-negatives jsonl file.
class customDataset(Dataset):
    def __init__(self, queries, corpus):
        self.queries = queries
        self.queries_ids = list(queries.keys())
        self.corpus = corpus

        for qid in self.queries:
            self.queries[qid]['pos'] = list(self.queries[qid]['pos'])
            self.queries[qid]['hard_neg'] = list(self.queries[qid]['hard_neg'])
            random.shuffle(self.queries[qid]['hard_neg'])

    def __getitem__(self, item):
        query = self.queries[self.queries_ids[item]]
        query_text = query['query']

        pos_id = query['pos'].pop(0)    #Pop positive and add at end
        pos_text = self.corpus[pos_id]["text"]
        query['pos'].append(pos_id)

        neg_id = query['hard_neg'].pop(0)    #Pop negative and add at end
        neg_text = self.corpus[neg_id]["text"]
        query['hard_neg'].append(neg_id)

        return InputExample(texts=[query_text, pos_text, neg_text])

    def __len__(self):
        return len(self.queries)

# We train the SentenceTransformer bi-encoder using MNRL loss
device="cuda:3"
use_pre_trained_model = False
model_name = "distilroberta-base"
if use_pre_trained_model:
    print("use pretrained SBERT model")
    model = SentenceTransformer(model_name, device=device)
    model.max_seq_length = max_seq_length
else:
    print("Create new SBERT model")
#     model_name = "distilbert-base-uncased" 
    word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)
    pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
    model = SentenceTransformer(modules=[word_embedding_model, pooling_model],device=device)

#### Provide a high batch-size to train better with triplets!
retriever = TrainRetriever(model=model, batch_size=train_batch_size)

# For training the SentenceTransformer model, we need a dataset, a dataloader, and a loss used for training.
train_dataset = customDataset(train_queries, corpus=corpus)
train_dataloader = retriever.prepare_train(train_dataset, shuffle=True, dataset_present=True)
train_loss = losses.MultipleNegativesRankingLoss(model=retriever.model)

#### If no dev set is present from above use dummy evaluator
ir_evaluator = retriever.load_dummy_evaluator()

#### Provide model save path
from datetime import datetime
model_save_path = os.path.join("output", "{}-v3-{}".format(model_name, dataset))
# model_save_path = os.path.join("output_v2", "{}-v3-{}".format(model_name, dataset), '-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
os.makedirs(model_save_path, exist_ok=True)


In [None]:
len(train_dataset)

In [None]:

#### Configure Train params
num_epochs = 2
evaluation_steps = 10000
warmup_steps = 1000

retriever.fit(train_objectives=[(train_dataloader, train_loss)], 
                evaluator=ir_evaluator, 
                epochs=num_epochs,
                output_path=model_save_path,
                warmup_steps=warmup_steps,
                evaluation_steps=evaluation_steps,
                optimizer_params = {'lr': 4e-5, 'eps': 1e-6, 'correct_bias': False},
                use_amp=True)