<a href="https://colab.research.google.com/github/eduseiti/ia368v_dd_class_09/blob/main/DL_reranking_fine_tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TREC-COVID DL reranking fine tuning

This notebook perform reranking model fine tuning using the TREC COVID LLM queries.

## Prepare the environment

In [1]:
import os
import sys

import tqdm

In [2]:
IN_COLAB='google.colab' in sys.modules
LINK_WITH_COMET=True

In [3]:
if IN_COLAB:
    from google.colab import drive

    WORKING_FOLDER="/content/drive/MyDrive/unicamp/ia368v_dd/aula_09"

    drive.mount('/content/drive', force_remount=True)

    os.chdir(WORKING_FOLDER)
    
    !pip install transformers -q

    if LINK_WITH_COMET:
        !pip install comet_ml -q
else:
    WORKING_FOLDER="/mnt/0060f889-4c27-409b-b0de-47f5427515e3/unicamp/ia368v_dd/ia368v_dd_class_09/"
    PYSERINI_FOLDER="/mnt/0060f889-4c27-409b-b0de-47f5427515e3/unicamp/ia368v_dd/pyserini/"
    
    TREC_EVAL_FULLPATH=PYSERINI_FOLDER+"tools/eval/trec_eval.9.0.4/trec_eval"
    
    os.environ["ANSERINI_CLASSPATH"]="/media/eduseiti/bigdata01/unicamp/ia368v_dd/anserini/target"

Mounted at /content/drive
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m96.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m28.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m110.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m506.5/506.5 kB[0m [31m36.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m201.7/201.7 kB[0m [31m29.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.1/510.1 kB[0m [31m52.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m137.9/137.9 kB[0m [31m21.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.3/54.3 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[2K     

In [4]:
import pandas as pd
import pickle
import numpy as np

import json

import time

import re

from datetime import datetime

from scipy import stats

if LINK_WITH_COMET:
    from comet_ml import Experiment

In [5]:
TREC_COVID_MERGED_FILE="trec_covid_merged_data.tsv"
TREC_COVID_DOCUMENTS_FILE="trec_covid_original_title_text_merged.tsv"

TREC_COVID_QUERIES="trec_covid_queries.tsv"
TREC_COVID_QRELS="trec_covid_qrels.tsv"

API_KEYS_FILE="../api_keys_20230324.json"

pd.set_option('display.max_colwidth', None)
pd.set_option('display.max_rows', 200)

In [6]:
TREC_COVID_ORIGINAL_FOLDER="trec_covid_original"
TREC_COVID_ORIGINAL_INDEX_FOLDER="trec_covid_original/index"
TREC_COVID_ORIGINAL_RUNS_FOLDER="trec_covid_original/runs"

In [7]:
TREC_COVID_LLM_0100_QUERIES="eduseiti_100_queries_expansion_20230501_01.jsonl"
TREC_COVID_LLM_1000_QUERIES="eduseiti_1000_queries_expansion_20230502_02.jsonl"

In [8]:
import torch

from transformers import get_linear_schedule_with_warmup, get_constant_schedule
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.utils import data
from transformers import BatchEncoding

## Set the random seed

In [9]:
RANDOM_SEED = 6

rng = np.random.default_rng(RANDOM_SEED)

### Link with COMET

In [10]:
if LINK_WITH_COMET:
    with open(API_KEYS_FILE) as inputFile:
        api_keys = json.load(inputFile)

    experiment = Experiment(api_key=api_keys['comet_ml'], 
                            project_name="InPars reraking",
                            workspace="eduseiti")

[1;38;5;39mCOMET INFO:[0m Couldn't find a Git repository in '/content/drive/MyDrive/unicamp/ia368v_dd/aula_09' nor in any parent directory. Set `COMET_GIT_DIRECTORY` if your Git Repository is elsewhere.
[1;38;5;39mCOMET INFO:[0m Experiment is live on comet.com https://www.comet.com/eduseiti/inpars-reraking/a26c235ed0254c2eb71e1876151d549d



### Initialize reranking model parameters

In [11]:
MODEL_NAME='microsoft/MiniLM-L12-H384-uncased'
MS_MARCO_PRETRAINED_MODEL="pretrain_20230315_180741"

MAX_TOKENS_LENGTH=512

In [12]:
TRAIN_OUTPUT_FOLDER="trained_models"

In [13]:
TREC_COVID_TOKENIZED_LLM_EXPANSION="trec_covid_tokenized_expansion_{}.pkl"

In [14]:
PYSERINI_TEST_RUN_RERANKED_FILENAME_FORMAT="run.trec_covid_reranking_{}_{}_{}.txt"

In [15]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [17]:
device

device(type='cuda')

## Prepare the fine tuning dataset

### Load the LLM generated questions

In [18]:
expanded_queries = []

with open(TREC_COVID_LLM_0100_QUERIES, "r") as inputFile:
    for line in inputFile:
        expanded_queries.append(json.loads(line))

print(len(expanded_queries))

463


In [19]:
expanded_queries_100_df = pd.DataFrame(expanded_queries)

In [20]:
expanded_queries_100_df

Unnamed: 0,query,positive_doc_id,negative_doc_ids
0,How can chatbots be designed to effectively share up-to-date information during a pandemic?,70hskj1o,"[mt00852w, x7ol32mz, b54dymlu, h5vh6px7, bza9agzo, eumithb4, zygepxd7, i44s4vqr, 6u1zo0f9, m3a6kl8s, z9r5i0ky, c8s0jn2z, 4fb5xnil, nt7c31ft, f1ch030o, nm30wct0, qcgc2bo3, azkamnpa, enit4rki, e2g1iu39]"
1,What strategies can be used to encourage desired health impacting behaviors through chatbots?,70hskj1o,"[et84j0qi, xsfolppr, 5t2o287y, kj2tnw8q, j68x0yd1, 1lobeca0, qwdjb7vk, ue5v55l8, 3eovj63c, 411qyubx, 4qvv1hsq, j8iawzp8, 7bh268mb, gp6gz0bw, 6gc7smqf, ur9t45yc, rgeeld8q, qqsiv6r6, m7cqlzbh, hlymyzcq]"
2,What are the risks associated with amplifying misinformation with chatbot technology?,70hskj1o,"[2c1m04je, rd93y7hu, vlmvi0tf, dbq3z982, 848fswtv, uveezi7s, pat3t7ne, aimm65cr, c45feko6, gl6ozx2o, t63ni1qn, rc65rv6r, 27kfciro, pmuo5qpf, t7tjvpxv, ak97kgj5, e0nxkyhc, rh0x9gxf, idhr2upe, u75hks4k]"
3,What research has been conducted on the effectiveness of chatbots during pandemics?,70hskj1o,"[49zlztqu, amjqr9hr, hpx4723v, e790rxq9, 95bsoea2, k41xro7c, ysa8vb9x, fkv395t5, u4di2tk7, 2swzr52p, oo0z5pb2, lkzo4y8b, i6vfr6um, 8fhpsn4n, au2je08j, mi0pmyo4, fdkbuw6e, 74joo4yr, 6lrawta5, l864lrhx]"
4,"How can collaborations between healthcare workers, companies, academics and governments help prepare for future pandemics?",70hskj1o,"[eg2lj9zc, prmf9yob, ara8bsws, zjmshwl3, apvc5mml, ridgctn4, 4dv6954b, 1k168vv0, dc6jtcz0, lt67jwyv, 7ftq02ev, hbalyfy3, v9wynk5x, 22ioujwl, 6o50m9si, ti75rrwk, 3dswdn6p, jo7ty7v4, tolikanw, t1hwh3o8]"
...,...,...,...
458,What are some common symptoms of MERS-CoV infection?,24lzevco,"[jv3425w1, yz7goivp, heui8rox, optngtwu, n9pqd30o, hkvurb2k, s95ryhiu, fnrir5nh, 4hk736ev, fxck2ain, sphk023v, 345fmq8h, futlnw88, p8luczyk, fa6kbjif, zbqfs77n, 70jg65o2, jvplobgy, 0ejs05e8, m00gcci4]"
459,Are there any existing effective anti-MERS-CoV antiviral agents or therapeutics?,24lzevco,"[rigxrvhn, gxyk9fgj, 1mpov118, 08d5cdf4, eld5svt2, mg1eg740, 0wh7x410, inibtytf, rq5gh710, cmor0wkp, mpv025c6, 0lzapk68, 3mh63vjj, qg73804g, cjimyfu4, v7asfaxc, wicc796j, z0h32jyu, klx95l1j, 3fp46sov]"
460,What are some potential Host-Directed Therapies for MERS-CoV infected patients?,24lzevco,"[wtvjjc7p, mha7zs08, uwwih8v3, 1intktsf, ysbnv3fb, s8fitxwd, n3f5pihh, 4u0appfs, qexn0nuy, do9r8q84, ipmyfxk5, dkreswvk, eisfz30c, k50qvr4w, vzyrcmu4, bwnvfs8l, 6dzo97ze, n9k0ctn6, lntg6yb8, 1mowsbjy]"
461,Could Host-Directed Therapies improve treatment outcomes for patients with MERS-CoV infection?,24lzevco,"[g4oku7wp, xnjpe1ss, uqykia6i, o9uk0y2n, j5zlismf, gxi3iwb0, g2phfpbd, 5ftql1b9, k3rqx1x0, hi3fjne4, bm0ldeue, uwwih8v3, c51eyqpi, elmrvpxd, yorqoyn9, wfcyaumm, bnh65bqg, 73xuhvll, pnp8flc3, nmdko4nl]"


In [21]:
expanded_queries = []

with open(TREC_COVID_LLM_1000_QUERIES, "r") as inputFile:
    for line in inputFile:
        expanded_queries.append(json.loads(line))

print(len(expanded_queries))

4914


In [22]:
expanded_queries_1000_df = pd.DataFrame(expanded_queries)

In [23]:
expanded_queries_1000_df

Unnamed: 0,query,positive_doc_id,negative_doc_ids
0,What are the benefits of international cooperation in the field of medicine for mass gatherings?,pxniqk3i,"[nku844kt, xilenqax, cgcuzxdt, ing711rk, dgvhzxlu, fzthd8c0, zvdblh8r, jh34klbg, 79l7wsc0, mvh4ig2g, za6x4reh, do111e5s, 4igvc039, oqmt78e8, gerhoy8w, obph3gup, rd6cqdsf, cakplpzq, fsly6gph, fcuwzpfy]"
1,How has progress in mass gathering medicine been achieved through international cooperation?,pxniqk3i,"[esggkw4u, y3fhubnc, 60u7muhc, jzyn6swh, j00jj2og, 1e4dzy64, qmcbsqse, uq80ybb0, 3n890yha, 41nebtwm, 4uzf5uue, j1h7b44j, 7lyccvm5, gqjcvxln, 6jgao58w, v3vxhqi4, r71awxo0, qhay7qao, ljjry7gt, hq8dg87u]"
2,What challenges have been overcome with the help of international cooperation in the field of mass gatherings medicine?,pxniqk3i,"[271xb1ls, s5b9q5n0, w933g6mm, efkjyd1k, 8ynk9k8r, vltscf9v, q2k2krnm, bljjcfd6, dpsgeko0, hltiva9i, tzgg27ge, s2q6g91b, 22eneg79, qh63b6sy, u592mbw6, blzjj06v, fk7k84n7, sdtiyrab, j2sq9kof, vb4eyjyu]"
3,What have been the most successful initiatives in international cooperation for mass gatherings medicine?,pxniqk3i,"[45k6tp0b, aqgauumt, iec4mvh7, e8yxnh27, 99b09xpi, l346qqwp, z6vddid9, fsojj7m5, vb4eyjyu, 2jd7aa2d, ptfhjlav, mgn5x4f5, j2sq9kof, 8ldkbl2g, aso63o3p, ncr8i4z3, fk7k84n7, 74txzvou, 6qz7bqnb, m5ptobvz]"
4,How can international cooperation be further improved to promote progress in mass gatherings medicine?,pxniqk3i,"[emnln2ix, pux58dut, aza0pzud, zaz1e2cf, ax3i3d7f, rxzsty8t, 4bh701xz, 1v4frt83, dfb3512i, lvm5ej0c, 7xxcl2ug, hnrlem31, v9y7xpnj, 294aqph9, 74qmy02h, td5xxbsl, 2z32ln7g, jxwyyv9o, trtikw3s, t35n7bk9]"
...,...,...,...
4909,What is the role of GBF1 in poliovirus replication?,lxb2otpr,"[h2cm3cge, d441jam3, 3kmqy07w, tg0tczni, j6zirpyz, dlmx12vt, 9v0z2chz, ue15waq7, kjyuxc3g, 7oat5jiw, 2ec3arfc, 4s0keus5, 4zphlpks, 0u62j7nj, ngxy58tg, 11m9uz3n, jomutdxf, frkb8iso, 00a19z5i, 3q9yr6np]"
4910,How does the presence of BFA affect polio replication?,lxb2otpr,"[0u62j7nj, 8jnrc01n, k9dl79kc, jc4ckqy7, c3g3hrm2, u8otoxn4, ahrrphm8, gzdtdmx6, ync7cwq2, g9ykr72y, m22h669g, jjocpkh9, w1q46y63, w4emoao6, xyo2l9u4, m3dt5kuy, nm8gdeyc, rebll0rk, vori28fm, j6i1hpbj]"
4911,Does the N-terminus domain of GBF1 have any role in Arf GTPase activation?,lxb2otpr,"[3jxi0dj2, bpnqz1cn, pujlr8uf, ywanv2la, u9u0y5ju, 3tfqtdbg, wrkl39sf, rnkawxrq, 3ury4hnv, 2s3x6sj8, firlsrg9, lfmtmqju, oshi71yw, d2algx16, e2qq4ngq, 0bnfugdm, 5qnayf72, gxe3kwtu, x04w3l3f, 7nh3m8l8]"
4912,How does the absence of p115 and Rab1b influence poliovirus replication?,lxb2otpr,"[dcmz4wcn, 413m7czk, qnntyqud, 8ut2p9p6, 5gp8wg9w, btbv4rq0, r01ii1bu, qsg18lbc, 73sxbpng, 1n127fkf, rsycde29, j2bb6eq8, ejtvvbck, 6vj79mqh, 8detpwvu, ghqtxrvt, 5hm0jr7u, 788zjhkn, 1fp4ck88, ixnafpv7]"


### Load the TREC COVID documents

In [24]:
trec_covid_docs_df = pd.read_csv(TREC_COVID_DOCUMENTS_FILE, sep='\t', header=None, names=['corpus-id', 'corpus-title-text'])

display(trec_covid_docs_df.head())

print(trec_covid_docs_df.shape)

Unnamed: 0,corpus-id,corpus-title-text
0,ug7v899j,"Clinical features of culture-proven Mycoplasma pneumoniae infections at King Abdulaziz University Hospital, Jeddah, Saudi ArabiaOBJECTIVE: This retrospective chart review describes the epidemiology and clinical features of 40 patients with culture-proven Mycoplasma pneumoniae infections at King Abdulaziz University Hospital, Jeddah, Saudi Arabia. METHODS: Patients with positive M. pneumoniae cultures from respiratory specimens from January 1997 through December 1998 were identified through the Microbiology records. Charts of patients were reviewed. RESULTS: 40 patients were identified, 33 (82.5%) of whom required admission. Most infections (92.5%) were community-acquired. The infection affected all age groups but was most common in infants (32.5%) and pre-school children (22.5%). It occurred year-round but was most common in the fall (35%) and spring (30%). More than three-quarters of patients (77.5%) had comorbidities. Twenty-four isolates (60%) were associated with pneumonia, 14 (35%) with upper respiratory tract infections, and 2 (5%) with bronchiolitis. Cough (82.5%), fever (75%), and malaise (58.8%) were the most common symptoms, and crepitations (60%), and wheezes (40%) were the most common signs. Most patients with pneumonia had crepitations (79.2%) but only 25% had bronchial breathing. Immunocompromised patients were more likely than non-immunocompromised patients to present with pneumonia (8/9 versus 16/31, P = 0.05). Of the 24 patients with pneumonia, 14 (58.3%) had uneventful recovery, 4 (16.7%) recovered following some complications, 3 (12.5%) died because of M pneumoniae infection, and 3 (12.5%) died due to underlying comorbidities. The 3 patients who died of M pneumoniae pneumonia had other comorbidities. CONCLUSION: our results were similar to published data except for the finding that infections were more common in infants and preschool children and that the mortality rate of pneumonia in patients with comorbidities was high."
1,02tnwd4m,"Nitric oxide: a pro-inflammatory mediator in lung disease?Inflammatory diseases of the respiratory tract are commonly associated with elevated production of nitric oxide (NO•) and increased indices of NO• -dependent oxidative stress. Although NO• is known to have anti-microbial, anti-inflammatory and anti-oxidant properties, various lines of evidence support the contribution of NO• to lung injury in several disease models. On the basis of biochemical evidence, it is often presumed that such NO• -dependent oxidations are due to the formation of the oxidant peroxynitrite, although alternative mechanisms involving the phagocyte-derived heme proteins myeloperoxidase and eosinophil peroxidase might be operative during conditions of inflammation. Because of the overwhelming literature on NO• generation and activities in the respiratory tract, it would be beyond the scope of this commentary to review this area comprehensively. Instead, it focuses on recent evidence and concepts of the presumed contribution of NO• to inflammatory diseases of the lung."
2,ejv2xln0,"Surfactant protein-D and pulmonary host defenseSurfactant protein-D (SP-D) participates in the innate response to inhaled microorganisms and organic antigens, and contributes to immune and inflammatory regulation within the lung. SP-D is synthesized and secreted by alveolar and bronchiolar epithelial cells, but is also expressed by epithelial cells lining various exocrine ducts and the mucosa of the gastrointestinal and genitourinary tracts. SP-D, a collagenous calcium-dependent lectin (or collectin), binds to surface glycoconjugates expressed by a wide variety of microorganisms, and to oligosaccharides associated with the surface of various complex organic antigens. SP-D also specifically interacts with glycoconjugates and other molecules expressed on the surface of macrophages, neutrophils, and lymphocytes. In addition, SP-D binds to specific surfactant-associated lipids and can influence the organization of lipid mixtures containing phosphatidylinositol in vitro. Consistent with these diverse in vitro activities is the observation that SP-D-deficient transgenic mice show abnormal accumulations of surfactant lipids, and respond abnormally to challenge with respiratory viruses and bacterial lipopolysaccharides. The phenotype of macrophages isolated from the lungs of SP-D-deficient mice is altered, and there is circumstantial evidence that abnormal oxidant metabolism and/or increased metalloproteinase expression contributes to the development of emphysema. The expression of SP-D is increased in response to many forms of lung injury, and deficient accumulation of appropriately oligomerized SP-D might contribute to the pathogenesis of a variety of human lung diseases."
3,2b73a28n,"Role of endothelin-1 in lung diseaseEndothelin-1 (ET-1) is a 21 amino acid peptide with diverse biological activity that has been implicated in numerous diseases. ET-1 is a potent mitogen regulator of smooth muscle tone, and inflammatory mediator that may play a key role in diseases of the airways, pulmonary circulation, and inflammatory lung diseases, both acute and chronic. This review will focus on the biology of ET-1 and its role in lung disease."
4,9785vg6d,"Gene expression in epithelial cells in response to pneumovirus infectionRespiratory syncytial virus (RSV) and pneumonia virus of mice (PVM) are viruses of the family Paramyxoviridae, subfamily pneumovirus, which cause clinically important respiratory infections in humans and rodents, respectively. The respiratory epithelial target cells respond to viral infection with specific alterations in gene expression, including production of chemoattractant cytokines, adhesion molecules, elements that are related to the apoptosis response, and others that remain incompletely understood. Here we review our current understanding of these mucosal responses and discuss several genomic approaches, including differential display reverse transcription-polymerase chain reaction (PCR) and gene array strategies, that will permit us to unravel the nature of these responses in a more complete and systematic manner."


(171325, 2)


### Tokenization functions

In [25]:
def select_trec_covid_documents(expanded_queries_df, trec_covid_docs_df):

    selected_docs = expanded_queries_df['positive_doc_id'].unique()

    for i, row in expanded_queries_df.iterrows():
        selected_docs = np.union1d(selected_docs, row['negative_doc_ids'])

    print("selected_docs.shape: {}".format(selected_docs.shape))

    trec_covid_selected_docs_df = trec_covid_docs_df.merge(pd.DataFrame(selected_docs, columns=['doc-id']), 
                                                           left_on='corpus-id', 
                                                           right_on='doc-id', how='inner')[['corpus-id', 'corpus-title-text']]

    print("trec_covid_selected_docs_df.shape: {}".format(trec_covid_selected_docs_df.shape))

    return trec_covid_selected_docs_df

In [26]:
def tokenize_queries_and_selected_docs(expanded_queries_df, trec_covid_selected_docs_df, tokenizer, tokenized_data_filename):

    trec_queries_tokens = tokenizer(expanded_queries_df['query'].tolist(), 
                                    truncation=True, 
                                    max_length=MAX_TOKENS_LENGTH, 
                                    return_length=True)

    print(stats.describe(trec_queries_tokens['length']))

    trec_docs_tokens = tokenizer(trec_covid_selected_docs_df['corpus-title-text'].tolist(), 
                                 truncation=True,
#                                  return_overflowing_tokens=True, 
                                 max_length=MAX_TOKENS_LENGTH - np.max(trec_queries_tokens['length']), 
                                 return_length=True)

    print(stats.describe(trec_docs_tokens['length']))

    #### Check if has truncated documents

    if 'overflow_to_sample_mapping' in trec_docs_tokens:    
        original_length = trec_covid_selected_docs_df.shape[0]

        if original_length < len(trec_docs_tokens['overflow_to_sample_mapping']):
            print("Added {} overflowing texts...".format(len(trec_docs_tokens['overflow_to_sample_mapping']) - original_length))


    #### Save the tokenized data

    with open(tokenized_data_filename, "wb") as outputFile:
        pickle.dump({'trec_queries_tokens': trec_queries_tokens,
                     'trec_docs_tokens': trec_docs_tokens,
                     'trec_covid_selected_docs_df': trec_covid_selected_docs_df}, outputFile, pickle.HIGHEST_PROTOCOL)    

    return trec_queries_tokens, trec_docs_tokens

### Check if hasn't already tokenized the data

In [27]:
if os.path.exists(TREC_COVID_TOKENIZED_LLM_EXPANSION.format(TREC_COVID_LLM_0100_QUERIES)):
    with open(TREC_COVID_TOKENIZED_LLM_EXPANSION.format(TREC_COVID_LLM_0100_QUERIES), "rb") as inputFile:
        
        tokenized_data = pickle.load(inputFile)

    eval_trec_queries_tokens = tokenized_data['trec_queries_tokens']
    eval_trec_docs_tokens = tokenized_data['trec_docs_tokens']
    eval_trec_covid_selected_docs_df = tokenized_data['trec_covid_selected_docs_df']
    
    tokenized_validation_data_read=True
else:
    tokenized_validation_data_read=False
    
    print("Need to create the tokenized LLM queries validation data...")

In [28]:
if os.path.exists(TREC_COVID_TOKENIZED_LLM_EXPANSION.format(TREC_COVID_LLM_1000_QUERIES)):
    with open(TREC_COVID_TOKENIZED_LLM_EXPANSION.format(TREC_COVID_LLM_1000_QUERIES), "rb") as inputFile:
        
        tokenized_data = pickle.load(inputFile)

    train_trec_queries_tokens = tokenized_data['trec_queries_tokens']
    train_trec_docs_tokens = tokenized_data['trec_docs_tokens']
    train_trec_covid_selected_docs_df = tokenized_data['trec_covid_selected_docs_df']
    
    tokenized_train_data_read=True
else:
    tokenized_train_data_read=False
    
    print("Need to create the tokenized LLM queries train data...")

### Build the test data to be tokenized, if needed

In [29]:
if not tokenized_validation_data_read:
    eval_trec_covid_selected_docs_df = select_trec_covid_documents(expanded_queries_100_df, trec_covid_docs_df)

    eval_trec_queries_tokens, eval_trec_docs_tokens = tokenize_queries_and_selected_docs(expanded_queries_100_df, 
                                                                                         eval_trec_covid_selected_docs_df, 
                                                                                         tokenizer, 
                                                                                         TREC_COVID_TOKENIZED_LLM_EXPANSION.format(TREC_COVID_LLM_0100_QUERIES))
else:
    print("Validation data already tokenized...")

Validation data already tokenized...


In [30]:
if not tokenized_train_data_read:
    train_trec_covid_selected_docs_df = select_trec_covid_documents(expanded_queries_1000_df, trec_covid_docs_df)

    train_trec_queries_tokens, train_trec_docs_tokens = tokenize_queries_and_selected_docs(expanded_queries_1000_df, 
                                                                                           train_trec_covid_selected_docs_df, 
                                                                                           tokenizer, 
                                                                                           TREC_COVID_TOKENIZED_LLM_EXPANSION.format(TREC_COVID_LLM_1000_QUERIES))
else:
    print("Train data already tokenized...")

Train data already tokenized...


### Build the training dataset

In [31]:
class InParsTrainingDataset(data.Dataset):
    def __init__(self, generated_queries_df, selected_docs_df, tokenized_queries, tokenized_documents, rng, max_pairs=None, verbose=True):
        self.generated_queries_df = generated_queries_df
        self.selected_docs_df = selected_docs_df
        self.tokenized_queries = tokenized_queries
        self.tokenized_documents = tokenized_documents
        
        self.rng = rng

        self.max_pairs = max_pairs
        
        self.rebuild_dataset(verbose)
        
        
    def rebuild_dataset(self, verbose=True):
        
        self.test_input_ids = []
        self.test_token_type_ids = []
        self.test_attention_mask = []
        self.labels =[]
        
        for i, row in self.generated_queries_df.iterrows():
            
            if verbose:
                print(i)
            
            selected_negative = self.rng.choice(row['negative_doc_ids'], 1)[0]

            if verbose:
                print("negative_doc_ids: {}".format(row['negative_doc_ids']))
                print("selected_negative: {}".format(selected_negative))
            
            positive_doc_index = self.selected_docs_df[self.selected_docs_df['corpus-id'] == row['positive_doc_id']].index[0]
            negative_doc_index = self.selected_docs_df[self.selected_docs_df['corpus-id'] == selected_negative].index[0]
            
            if verbose:
                print("positive_doc_index={}".format(positive_doc_index))
                print("negative_doc_index={}".format(negative_doc_index))
            
            
            self.test_input_ids.append(self.tokenized_queries['input_ids'][i] + self.tokenized_documents['input_ids'][positive_doc_index][1:])
            self.test_token_type_ids.append(self.tokenized_queries['token_type_ids'][i] + self.tokenized_documents['token_type_ids'][positive_doc_index][1:])
            self.test_attention_mask.append(self.tokenized_queries['attention_mask'][i] + self.tokenized_documents['attention_mask'][positive_doc_index][1:])
            
            self.test_input_ids.append(self.tokenized_queries['input_ids'][i] + self.tokenized_documents['input_ids'][negative_doc_index][1:])
            self.test_token_type_ids.append(self.tokenized_queries['token_type_ids'][i] + self.tokenized_documents['token_type_ids'][negative_doc_index][1:])
            self.test_attention_mask.append(self.tokenized_queries['attention_mask'][i] + self.tokenized_documents['attention_mask'][negative_doc_index][1:])
            
            self.labels.append(True)
            self.labels.append(False)

            if (self.max_pairs is not None) and (i >= self.max_pairs):
                break
        
    
    def __len__(self):
        return len(self.test_input_ids)
    
    
    def __getitem__(self, idx):
        return {'input_ids': self.test_input_ids[idx],
                'attention_mask': self.test_attention_mask[idx],
                'labels': int(self.labels[idx])}

In [32]:
def collate_fn(batch):

    # print(len(batch[0]['input_ids']))

    r1 = tokenizer.pad(batch, return_tensors='pt')

    # print(len(r1['input_ids'][0]))

    return BatchEncoding(r1)

### Create the dataset and the dataloader

In [33]:
train_dataset = InParsTrainingDataset(expanded_queries_1000_df,
                                      train_trec_covid_selected_docs_df,
                                      train_trec_queries_tokens,
                                      train_trec_docs_tokens,
                                      rng,
                                      verbose=False)

In [34]:
eval_dataset = InParsTrainingDataset(expanded_queries_100_df,
                                     eval_trec_covid_selected_docs_df,
                                     eval_trec_queries_tokens,
                                     eval_trec_docs_tokens,
                                     rng,
                                     max_pairs=500,
                                     verbose=False)

In [35]:
def evaluate(model, dataloader, set_name, min_valid_loss, current_training_step, current_epoch):
    losses = []
    correct = 0
    model.eval()
    with torch.no_grad():

        tqdm_batches = tqdm.tqdm(dataloader, mininterval=0.5, desc=set_name, disable=False, bar_format='{l_bar}{bar:20}{r_bar}{bar:-20b}', colour='GREEN', file=sys.stdout, position=0, leave=True)

        for batch in tqdm_batches:
            outputs = model(**batch.to(device))
            loss_val = outputs.loss
            losses.append(loss_val.cpu().item())
            preds = outputs.logits.argmax(dim=1)
            correct += (preds == batch['labels']).sum().item()

            tqdm_batches.set_description("Loss {:0.4f}".format(losses[-1]))

    print("Eval loss: {:0.4f}; accuracy: {}".format(np.mean(losses), correct / len(dataloader.dataset)))

    if LINK_WITH_COMET:
        experiment.log_metrics({'eval loss': np.mean(losses)},
                                step=current_training_step)


    if min_valid_loss['loss'] > np.mean(losses):
        print("New minimal validation loss; saving model...")

        training_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

        checkpoint_name = "checkpoint_{}_{}_{}_eval_{:.4f}".format(TREC_COVID_LLM_1000_QUERIES, "{:02d}_epoch".format(current_epoch), training_timestamp, np.mean(losses))
        model.save_pretrained(os.path.join(TRAIN_OUTPUT_FOLDER, checkpoint_name))

        min_valid_loss['checkpoint_name'] = checkpoint_name
        min_valid_loss['loss'] = np.mean(losses)

In [36]:
hyperparameters = {
    'batch_size': 16,
    'epochs': 50,
    'num_warmup_steps': 0,
    'learning_rate': 1e-8,
}

In [37]:
hyperparameters['num_training_steps'] = hyperparameters['epochs'] * int(len(train_dataset) // hyperparameters['batch_size'])

In [38]:
train_dataloader = data.DataLoader(train_dataset, batch_size=hyperparameters['batch_size'], shuffle=False, collate_fn=collate_fn)

In [39]:
eval_dataloader = data.DataLoader(eval_dataset, batch_size=hyperparameters['batch_size'], shuffle=False, collate_fn=collate_fn)

In [40]:
model = AutoModelForSequenceClassification.from_pretrained(os.path.join(TRAIN_OUTPUT_FOLDER, "checkpoint_eduseiti_1000_queries_expansion_20230502_02.jsonl_08_epoch_20230504_000728_0.2839")).to(device)
# model = AutoModelForSequenceClassification.from_pretrained(os.path.join(TRAIN_OUTPUT_FOLDER, MS_MARCO_PRETRAINED_MODEL)).to(device)
# model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME).to(device)

print('Parameters', model.num_parameters())

Parameters 33360770


In [41]:
optimizer = torch.optim.AdamW(model.parameters(), lr=hyperparameters['learning_rate'])
scheduler = get_constant_schedule(optimizer)

In [42]:
min_valid_loss = {"loss": 1000,
                  "checkpoint_name": None}

min_training_loss = {"loss": 1000,
                     "checkpoint_name": None}                  

In [43]:
if LINK_WITH_COMET:
    experiment.log_parameters(hyperparameters)

In [44]:
current_training_step = 0

for epoch in tqdm.tqdm(range(hyperparameters['epochs']), desc='Epochs', bar_format='{l_bar}{bar:20}{r_bar}{bar:-20b}', colour='GREEN', file=sys.stdout):
    model.train()
    train_losses = []
    
    tqdm_batches = tqdm.tqdm(train_dataloader, mininterval=0.5, desc='Train', disable=False, bar_format='{l_bar}{bar:20}{r_bar}{bar:-20b}', colour='GREEN', file=sys.stdout, position=0, leave=True)
    # tqdm_batches = tqdm.tqdm(train_dataloader, mininterval=0.5, desc='Train', disable=False)
    
    for batch in tqdm_batches:
        optimizer.zero_grad()
        outputs = model(**batch.to(device))
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        scheduler.step()
        train_losses.append(loss.cpu().item())

        tqdm_batches.set_description("Loss {:0.4f}".format(train_losses[-1]))

        if LINK_WITH_COMET:
            experiment.log_metrics({'train loss': train_losses[-1],
                                    'learning_rate': scheduler.get_last_lr()},
                                    step=current_training_step)
            
        current_training_step += 1
        
        
    print("Epoch: {}, Training loss: {:0.4f}".format(epoch, np.mean(train_losses)))

    if min_training_loss['loss'] > np.mean(train_losses):
        print("New minimal training loss; saving model...")

        training_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

        checkpoint_name = "checkpoint_{}_{}_{}_{:.4f}".format(TREC_COVID_LLM_1000_QUERIES, "{:02d}_epoch".format(epoch), training_timestamp, np.mean(train_losses))
        model.save_pretrained(os.path.join(TRAIN_OUTPUT_FOLDER, checkpoint_name))

        min_training_loss['checkpoint_name'] = checkpoint_name
        min_training_loss['loss'] = np.mean(train_losses)

    if LINK_WITH_COMET:
        experiment.log_metrics({'train loss': np.mean(train_losses)},
                               epoch=epoch)

    evaluate(model=model, 
             dataloader=eval_dataloader, 
             set_name='Eval', 
             min_valid_loss=min_valid_loss, 
             current_training_step=current_training_step, 
             current_epoch=epoch)

    train_dataset.rebuild_dataset(verbose=False)

Train:   0%|[32m                    [0m| 0/615 [00:00<?, ?it/s][32m[0m

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Loss 0.0502: 100%|[32m████████████████████[0m| 615/615 [01:36<00:00,  6.39it/s][32m[0m
Epoch: 0, Training loss: 0.2834
New minimal training loss; saving model...
Loss 0.2636: 100%|[32m████████████████████[0m| 58/58 [00:02<00:00, 19.70it/s][32m[0m
Eval loss: 0.3104; accuracy: 0.8725701943844493
New minimal validation loss; saving model...
Loss 0.0505: 100%|[32m████████████████████[0m| 615/615 [01:34<00:00,  6.54it/s][32m[0m
Epoch: 1, Training loss: 0.2785
New minimal training loss; saving model...
Loss 0.2625: 100%|[32m████████████████████[0m| 58/58 [00:02<00:00, 19.72it/s][32m[0m
Eval loss: 0.3080; accuracy: 0.8736501079913607
New minimal validation loss; saving model...
Loss 0.0484: 100%|[32m████████████████████[0m| 615/615 [01:34<00:00,  6.53it/s][32m[0m
Epoch: 2, Training loss: 0.2775
New minimal training loss; saving model...
Loss 0.2608: 100%|[32m████████████████████[0m| 58/58 [00:02<00:00, 19.64it/s][32m[0m
Eval loss: 0.3054; accuracy: 0.8736501079913607
N

KeyboardInterrupt: ignored

In [45]:
if LINK_WITH_COMET:
    experiment.end() 

[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m Comet.ml Experiment Summary
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m   Data:
[1;38;5;39mCOMET INFO:[0m     display_summary_level : 1
[1;38;5;39mCOMET INFO:[0m     url                   : https://www.comet.com/eduseiti/inpars-reraking/a26c235ed0254c2eb71e1876151d549d
[1;38;5;39mCOMET INFO:[0m   Metrics [count] (min, max):
[1;38;5;39mCOMET INFO:[0m     eval loss [30]     : (0.2542881483918634, 0.3104252854426359)
[1;38;5;39mCOMET INFO:[0m     learning_rate      : 1e-08
[1;38;5;39mCOMET INFO:[0m     loss [1831]        : (0.015924062579870224, 1.668872594833374)
[1;38;5;39mCOMET INFO:[0m     train loss [18480] : (0.012990033254027367, 1.67180335521698)
[1;38;5;39mCOMET INFO:[0m   Parameters:
[1;38;5;39mCOMET INFO:[0m     batch_si

In [None]:
training_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

In [None]:
checkpoint_name = "checkpoint_{}_{}_{}_{:.4f}".format(TREC_COVID_LLM_1000_QUERIES, "{:02d}_epochs".format(hyperparameters['epochs']), training_timestamp, np.mean(train_losses))
model.save_pretrained(os.path.join(TRAIN_OUTPUT_FOLDER, checkpoint_name))