## This notebook enables the user to train their own model using our pipeline or load our already trained model.
## The notebook is divided into two parts.
### First part
Download the necessary data and scripts to run training
### Second part
Load in the pretrained model and run evaluation with some example queries we have used.

### Part 1

In [8]:
import sys
import json
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, LoggingHandler, util, models, evaluation, losses, InputExample
import logging
from datetime import datetime
import gzip
import os
import tarfile
import tqdm
from torch.utils.data import Dataset
import random
from shutil import copyfile
import pickle
import argparse
import torch, numpy
from preprocessing import preprocess_text

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
#### /print debug information to stdout

# Used when not in jupyter shitbook

"""parser = argparse.ArgumentParser()
parser.add_argument("--train_batch_size", default=64, type=int)
parser.add_argument("--max_seq_length", default=300, type=int)
parser.add_argument("--model_name", required=True)
parser.add_argument("--max_passages", default=0, type=int)
parser.add_argument("--epochs", default=30, type=int)
parser.add_argument("--pooling", default="mean")
parser.add_argument("--negs_to_use", default=None, help="From which systems should negatives be used? Multiple systems seperated by comma. None = all")
parser.add_argument("--warmup_steps", default=500, type=int)
parser.add_argument("--lr", default=2e-5, type=float)
parser.add_argument("--num_negs_per_system", default=5, type=int)
parser.add_argument("--use_pre_trained_model", default=False, action="store_true")
parser.add_argument("--use_all_queries", default=False, action="store_true")
parser.add_argument("--eval_steps", default=500, type=int)
args = parser.parse_args()

logging.info(str(args))"""



# The  model we want to fine-tune
train_batch_size = 64 #args.train_batch_size          #Increasing the train batch size improves the model performance, but requires more GPU memory
model_name = 'bert-base-uncased' #args.model_name
max_passages = 0 #args.max_passages
max_seq_length = 300 #args.max_seq_length            #Max length for passages. Increasing it, requires more GPU memory

num_negs_per_system = 5 #args.num_negs_per_system  # We used different systems to mine hard negatives. Number of hard negatives to add from each system
num_epochs = 50 #args.epochs         # Number of epochs we want to train
use_pre_trained_model = False
pooling = "mean"
warmup_steps = 500 
lr = 2e-5

# Load our embedding model
if use_pre_trained_model:
    logging.info("use pretrained SBERT model")
    model = SentenceTransformer(model_name, device='cuda')
    model.max_seq_length = max_seq_length
else:
    logging.info("Create new SBERT model")
    word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)
    pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling)
    model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device='cuda')

model_save_path = f'output/train_bi-encoder-margin_mse-{model_name.replace("/", "-")}-batch_size_{train_batch_size}-{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}'


# Write self to path
os.makedirs(model_save_path, exist_ok=True)

train_script_path = os.path.join(model_save_path, 'train_script.py')
#copyfile(__file__, train_script_path)
#with open(train_script_path, 'a') as fOut:
   # fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))



"""
DATA VALIDATION LOADING
"""
valid_data_folder = 'data/valid'

#### Read the corpus files, that contain all the passages. Store them in the corpus dict
val_corpus = {}         #dict in the format: passage_id -> passage. Stores all existent passages
val_corpus_path = os.path.join(valid_data_folder, 'valid_corpus.csv')

logging.info("Loading validation: validation corpus")
with open(val_corpus_path, 'r', encoding='utf8') as fIn:
    for line in fIn:
        qid, passage = line.strip().split(";")
        qid = int(qid)
        val_corpus[qid] = passage
        
#### Read the corpus files, that contain all the passages. Store them in the corpus dict
val_queries = {}         #dict in the format: passage_id -> passage. Stores all existent passages
val_queries_path =  os.path.join(valid_data_folder, 'valid_queries.csv')

logging.info("Loading validation: validation queries")
with open(val_queries_path, 'r', encoding='utf8') as fIn:
    for line in fIn:
        try:
            qid, passage, q = line.strip().split(";")
            qid = int(qid)
            val_queries[qid] = passage
        except:
            continue
      

val_keywords = {}
keywords_filepath =  os.path.join(valid_data_folder, 'valid_keywords.csv')
logging.info("Loading validation: validation keywords")
with open(keywords_filepath, 'r', encoding='utf8') as fIn:
    for line in fIn:
        row = line.strip().split(";")
        pid, qid, keywordss = row[0], row[1], row[2:]
        qid = int(qid)
        val_keywords[qid] = keywordss


"""
DATA TRAINING LOADING
"""

train_data_folder = 'data/train'

#### Read the corpus files, that contain all the passages. Store them in the corpus dict
train_corpus = {}         #dict in the format: passage_id -> passage. Stores all existent passages
collection_filepath = os.path.join(train_data_folder, 'train_corpus.csv')

logging.info("Loading trianing: training corpus")
with open(collection_filepath, 'r', encoding='utf8') as fIn:
    for line in fIn:
        pid, passage = line.strip().split(";")
        pid = int(pid)
        train_corpus[pid] = passage


### Read the train queries, store in queries dict
queries = {}        #dict in the format: query_id -> query. Stores all training queries
queries_filepath = os.path.join(train_data_folder, 'train_queries.csv')
logging.info("Loading training: training queries")
with open(queries_filepath, 'r', encoding='utf8') as fIn:
    for line in fIn:
        try:
            qid, passage, q = line.strip().split(";")
            qid = int(qid)
            queries[qid] = passage
        except:
            continue
       
        

train_keywords = {}
keywords_filepath = os.path.join(train_data_folder, 'train_keywords.csv')
logging.info("Loading training: keywords corpus")
with open(keywords_filepath, 'r', encoding='utf8') as fIn:
    for line in fIn:
        row = line.strip().split(";")
        pid, qid, keywordss = row[0], row[1], row[2:]
        qid = int(qid)
        train_keywords[qid] = keywordss

    
train_queries = {}
hard_negatives_filepath = os.path.join('data', 'hard_negs.jsonl.gz')
logging.info("Read hard negatives train file")

train_keys = queries.keys()
use_all_queries = False # These two are normally sat in argparser
negs_to_use = None
with gzip.open(hard_negatives_filepath, 'rt') as fIn:
    for line in tqdm.tqdm(fIn):
        if max_passages > 0 and len(train_queries) >= max_passages:
            break
        
        data = json.loads(line)
        
        if not data['qid'] in queries:
            continue
        
        #Get the positive passage ids
        pos_pids = data['pos']
        
        #Get the hard negatives
        neg_pids = set()
        if negs_to_use is None:
            if negs_to_use is not None:    #Use specific system for negatives
                negs_to_use = negs_to_use.split(",")
            else:   #Use all systems
                negs_to_use = list(data['neg'].keys())
            logging.info("Using negatives from the following systems:", negs_to_use)

        for system_name in negs_to_use:
            if system_name not in data['neg']:
                continue

            system_negs = data['neg'][system_name]
            negs_added = 0
            for pid in system_negs:
                if pid not in neg_pids:
                    neg_pids.add(pid)
                    negs_added += 1
                    if negs_added >= num_negs_per_system:
                        break
                    
        if use_all_queries or (len(pos_pids) > 0 and len(neg_pids) > 0):   
            train_queries[data['qid']] = {'qid': data['qid'], 'query': queries[data['qid']], 'pos': pos_pids, 'neg': neg_pids}
                

"""
SETUP DATASET
"""

logging.info("Train queries: {}".format(len(train_queries)))

# We create a custom MSMARCO dataset that returns triplets (query, positive, negative)
# on-the-fly based on the information from the mined-hard-negatives jsonl file.
class MSMARCODataset(Dataset):
    def __init__(self, queries, corpus, keywords):
        self.queries = queries
        self.queries_ids = list(queries.keys())
        self.keywords = keywords
        self.corpus = corpus
        
        for qid in self.queries:
            self.queries[qid]['pos'] = list(self.queries[qid]['pos'])
            self.queries[qid]['neg'] = list(self.queries[qid]['neg'])
            random.shuffle(self.queries[qid]['neg'])
            
    def __getitem__(self, item):
        query = self.queries[self.queries_ids[item]]
        query_text = query['query']
        qid = query['qid']
        keywords = ' '.join(self.keywords[qid])
        query_text = query_text + keywords.replace('"',' ')
        if len(query['pos']) > 0:
            pos_id = query['pos'].pop(0)    #Pop positive and add at end
            pos_text = self.corpus[pos_id]
            query['pos'].append(pos_id)
        else:   #We only have negatives, use two negs
            pos_id = query['neg'].pop(0)    #Pop negative and add at end
            pos_text = self.corpus[pos_id]
            query['neg'].append(pos_id)

        #Get a negative passage
        neg_id = query['neg'].pop(0)    #Pop negative and add at end
        neg_text = self.corpus[neg_id]
        query['neg'].append(neg_id)
       
        return InputExample(texts=[query_text, pos_text, neg_text], label=1)
    

    def __len__(self):
        return len(self.queries)


"""
SETUP VALIDATION
"""
valid_queries = {}
hard_negatives_filepath = os.path.join('data', 'valid_hard_negs.jsonl.gz')
logging.info("Read hard negatives valid file")
use_all_queries = False
negs_to_use = None
with gzip.open(hard_negatives_filepath, 'rt') as fIn:
    for line in tqdm.tqdm(fIn):
        if max_passages > 0 and len(valid_queries) >= max_passages:
            break
        
        data = json.loads(line)
        
        if not data['qid'] in val_queries:
            continue
        
        #Get the positive passage ids
        pos_pids = data['pos']
        
        #Get the hard negatives
        neg_pids = set()
        if negs_to_use is None:
            if negs_to_use is not None:    #Use specific system for negatives
                negs_to_use = negs_to_use.split(",")
            else:   #Use all systems
                negs_to_use = list(data['neg'].keys())
            logging.info("Using negatives from the following systems:", negs_to_use)

        for system_name in negs_to_use:
            if system_name not in data['neg']:
                continue

            system_negs = data['neg'][system_name]
            negs_added = 0
            for pid in system_negs:
                if pid not in neg_pids:
                    neg_pids.add(pid)
                    negs_added += 1
                    if negs_added >= num_negs_per_system:
                        break
                    
        if use_all_queries or (len(pos_pids) > 0 and len(neg_pids) > 0):
            valid_queries[data['qid']] = {'qid': data['qid'], 'query': val_queries[data['qid']], 'pos': pos_pids, 'neg': neg_pids}
                


class GenerateValidationTriplets():
    def __init__(self, validation_queries, validation_corpus, validation_keywords):
        self.queries = validation_queries
        self.queries_ids = list(self.queries.keys())
        self.keywords = validation_keywords
        self.corpus = validation_corpus
   
        
        for qid in self.queries:
            self.queries[qid]['pos'] = list(self.queries[qid]['pos'])
            self.queries[qid]['neg'] = list(self.queries[qid]['neg'])
            random.shuffle(self.queries[qid]['neg'])
            
    def getdata(self):
        val_anchor = []
        pos_sentence = []
        neg_sentence = []
        for item in range(len(self.queries)):
            query = self.queries[self.queries_ids[item]]
            query_text = query['query']
            qid = query['qid']
            keywords = ' '.join(self.keywords[qid])
            query_text = query_text + keywords.replace('"',' ')
            if len(query['pos']) > 0:
                pos_id = query['pos'].pop(0)    #Pop positive and add at end
                pos_text = self.corpus[pos_id]
                query['pos'].append(pos_id)
            else:   #We only have negatives, use two negs
                pos_id = query['neg'].pop(0)    #Pop negative and add at end
                pos_text = self.corpus[pos_id]
                query['neg'].append(pos_id)

            #Get a negative passage
            neg_id = query['neg'].pop(0)    #Pop negative and add at end
            neg_text = self.corpus[neg_id]
            query['neg'].append(neg_id)
            
            val_anchor.append(query_text)
            pos_sentence.append(pos_text)
            neg_sentence.append(neg_text)
        
        return val_anchor, pos_sentence, neg_sentence
    
    
val_anchor, pos_sentence, neg_sentence = GenerateValidationTriplets(valid_queries, val_corpus, val_keywords).getdata()
"""
TRAIN MODEL
"""
# For training the SentenceTransformer model, we need a dataset, a dataloader, and a loss used for training.
train_dataset = MSMARCODataset(queries=train_queries, corpus=train_corpus, keywords=train_keywords)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size, drop_last=True)
train_loss = losses.TripletLoss(model=model)

# Train the model
model.fit(train_objectives=[(train_dataloader, train_loss)],
          evaluator=evaluation.TripletEvaluator(anchors=val_anchor, positives=pos_sentence, negatives=neg_sentence),
          #evaluation_steps=args.eval_steps,
          epochs=num_epochs,
          warmup_steps=warmup_steps,
          use_amp=True,
          #checkpoint_path=model_save_path,
          #checkpoint_save_steps=500,
          optimizer_params = {'lr': lr},
          output_path='test/'
          )

# Train latest model
model.save(model_save_path)

2023-12-21 19:42:46 - Create new SBERT model
2023-12-21 19:42:47 - Loading validation: validation corpus
2023-12-21 19:42:47 - Loading validation: validation queries
2023-12-21 19:42:47 - Loading validation: validation keywords
2023-12-21 19:42:47 - Loading trianing: training corpus
2023-12-21 19:42:47 - Loading training: training queries
2023-12-21 19:42:47 - Loading training: keywords corpus
2023-12-21 19:42:47 - Read hard negatives train file


0it [00:00, ?it/s]--- Logging error ---
Traceback (most recent call last):
  File "c:\Users\hasse\Skrivebord\02456_DL_SBERT\venv\lib\site-packages\sentence_transformers\LoggingHandler.py", line 10, in emit
    msg = self.format(record)
  File "C:\Users\hasse\AppData\Local\Programs\Python\Python39\lib\logging\__init__.py", line 927, in format
    return fmt.format(record)
  File "C:\Users\hasse\AppData\Local\Programs\Python\Python39\lib\logging\__init__.py", line 663, in format
    record.message = record.getMessage()
  File "C:\Users\hasse\AppData\Local\Programs\Python\Python39\lib\logging\__init__.py", line 367, in getMessage
    msg = msg % self.args
TypeError: not all arguments converted during string formatting
Call stack:
  File "C:\Users\hasse\AppData\Local\Programs\Python\Python39\lib\runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "C:\Users\hasse\AppData\Local\Programs\Python\Python39\lib\runpy.py", line 87, in _run_code
    exe

2023-12-21 19:42:47 - Train queries: 19409
2023-12-21 19:42:47 - Read hard negatives valid file


0it [00:00, ?it/s]--- Logging error ---
Traceback (most recent call last):
  File "c:\Users\hasse\Skrivebord\02456_DL_SBERT\venv\lib\site-packages\sentence_transformers\LoggingHandler.py", line 10, in emit
    msg = self.format(record)
  File "C:\Users\hasse\AppData\Local\Programs\Python\Python39\lib\logging\__init__.py", line 927, in format
    return fmt.format(record)
  File "C:\Users\hasse\AppData\Local\Programs\Python\Python39\lib\logging\__init__.py", line 663, in format
    record.message = record.getMessage()
  File "C:\Users\hasse\AppData\Local\Programs\Python\Python39\lib\logging\__init__.py", line 367, in getMessage
    msg = msg % self.args
TypeError: not all arguments converted during string formatting
Call stack:
  File "C:\Users\hasse\AppData\Local\Programs\Python\Python39\lib\runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "C:\Users\hasse\AppData\Local\Programs\Python\Python39\lib\runpy.py", line 87, in _run_code
    exe

KeyboardInterrupt: 

### Part 2


Download the model using this url: https://dtudk-my.sharepoint.com/:f:/g/personal/hbype_dtu_dk/EgskgXe0X1BKnLEbKdJ7Iv4BFI7jySRs-VZNZn76340aGQ?e=FrWpfi

If you dont have access for some weird reason, write an email to hbype@space.dtu.dk

In [2]:
import os 
from sentence_transformers import SentenceTransformer
from postprocessing import embed
import torch
import numpy
from main import search_articles

  from .autonotebook import tqdm as notebook_tqdm


In [3]:

corpus_path = 'data/test/test_corpus.csv' ### You can change to train/train_corpus.csv or data/corpus.csv to try with the train set or the whole dataset - remember to encode!
model_path = 'model' ### Insert model 
corpus_embeddings_path = os.path.join(os.getcwd(), 'data/embeddings')
corpus_path = os.path.join(os.getcwd(), corpus_path)
model = SentenceTransformer(model_path)
corpus_ids = []
with open(corpus_path, 'r', encoding='utf8') as fIn:
    for line in fIn:
        pid, passage = line.strip().split(";")
        corpus_ids.append(pid)
        
if not os.path.exists(corpus_embeddings_path + '.npy'):
    embed(corpus_path, model, corpus_embeddings_path)

corpus_embeddings = torch.from_numpy(numpy.load('data/embeddings.npy'))


while True:
    user_query = input("Enter your research query (or 'exit' to quit): ")
    if user_query.lower() == 'exit':
        break
    search_articles(user_query, model, corpus_embeddings,corpus_ids)



Encoding sentences: 100%|██████████| 4164/4164 [01:10<00:00, 59.32it/s]



Top 5 most similar articles in the corpus:
1187
Title: Surfing Injuries: A Review for the Orthopaedic Surgeon 
 (Score: 0.2391) 
 url: https://findit.dtu.dk/en/catalog/606d9336d9001d01c01386f8 

824
Title: Clinical and Histologic Characterization of Co-infection with Astrovirus and Goose Parvovirus in Goslings 
 (Score: 0.2286) 
 url: https://findit.dtu.dk/en/catalog/5e01ffded9001d0457169e29 

2859
Title: Metagenomic analysis of oral microbiota among oral cancer patients and tobacco chewers in Rajasthan, India 
 (Score: 0.2171) 
 url: https://findit.dtu.dk/en/catalog/64ac9dc7bcb69d910bcfd479 

1208
Title: Eumycetoma, A Neglected Tropical Disease in the United States 
 (Score: 0.2122) 
 url: https://findit.dtu.dk/en/catalog/623477dfbc81d1c155980319 

4060
Title: Head injuries in professional football (soccer): Results of video analysis verified by an accident insurance registry 
 (Score: 0.2091) 
 url: https://findit.dtu.dk/en/catalog/61165320d9001d01b32ff15d 

