<a href="https://colab.research.google.com/github/eduseiti/ia368v_dd_class_07/blob/main/DPR_TREC_COVID.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os

from google.colab import drive

In [2]:
WORKING_FOLDER="drive/MyDrive/unicamp/ia368v_dd/aula_07"

In [3]:
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [4]:
os.chdir(WORKING_FOLDER)

### Check if pyserini is already installed in the drive folder

In [5]:
if not os.path.exists("pyserini"):
    os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-11-openjdk-amd64"

    !pip install pyserini -q
    !git clone --recurse-submodules https://github.com/castorini/pyserini.git
    !cd pyserini
    !cd tools/eval && tar xvfz trec_eval.9.0.4.tar.gz && cd trec_eval.9.0.4 && make && cd ../../..
    !cd tools/eval/ndeval && make && cd ../../..
else:
    !chmod +x pyserini/tools/eval/trec_eval.9.0.4/trec_eval

    print("Pyserini already installed...")

Pyserini already installed...


In [6]:
!pip install transformers -q
!pip install faiss-gpu -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m55.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m200.1/200.1 kB[0m [31m26.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m102.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.5/85.5 MB[0m [31m15.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [7]:
import pandas as pd
import pickle
import numpy as np

import torch

from scipy import stats

from datetime import datetime

from transformers import (AutoTokenizer, 
                          AutoModel,
                          BatchEncoding
)

from tqdm.auto import tqdm

import time

import faiss

In [8]:
TREC_COVID_MERGED_FILE="trec_covid_merged_data.tsv"
TREC_COVID_DOCUMENTS_FILE="trec_covid_original_title_text_merged.tsv"

TREC_COVID_QRELS="trec_covid_qrels.tsv"

TRAIN_OUTPUT_FOLDER="./trained_model"

ENCODED_DATA_FILE="trec_covid_encoded_data_{}.pkl"

In [9]:
PYSERINI_TEST_RUN_DPR_FILENAME_FORMAT="run.trec-covid_DPR_{}.txt"

In [10]:
BEST_MODEL_CHECKPOINT="checkpoint_20230414_170604_0.0772"
# BEST_MODEL_CHECKPOINT="checkpoint_20230415_124055_2.7061"

MODEL_NAME='microsoft/MiniLM-L12-H384-uncased'

MODEL_MAX_INPUT_LENGTH=512

## Explore TREC COVID merged data

This dataframe contains the merge of the TREC COVID qrels and the document corpus. Hence, it only lists the documents which are indeed mentioned in the qrels.

In [11]:
trec_covid_merged_df = pd.read_csv(TREC_COVID_MERGED_FILE, sep='\t')

In [12]:
trec_covid_merged_df

Unnamed: 0,query-id,corpus-id,score,query-text,corpus-title,corpus-text,query-metadata,corpus-metadata
0,1,005b2j4b,2,what is the origin of COVID-19,Monophyletic Relationship between Severe Acute...,Although primary genomic analysis has revealed...,"{'query': 'coronavirus origin', 'narrative': ""...",{'url': 'https://www.ncbi.nlm.nih.gov/pubmed/1...
1,16,005b2j4b,0,how long does coronavirus remain stable on su...,Monophyletic Relationship between Severe Acute...,Although primary genomic analysis has revealed...,{'query': 'how long does coronavirus survive o...,{'url': 'https://www.ncbi.nlm.nih.gov/pubmed/1...
2,32,005b2j4b,0,"Does SARS-CoV-2 have any subtypes, and if so w...",Monophyletic Relationship between Severe Acute...,Although primary genomic analysis has revealed...,"{'query': 'coronavirus subtypes', 'narrative':...",{'url': 'https://www.ncbi.nlm.nih.gov/pubmed/1...
3,37,005b2j4b,0,What is the result of phylogenetic analysis of...,Monophyletic Relationship between Severe Acute...,Although primary genomic analysis has revealed...,"{'query': 'SARS-CoV-2 phylogenetic analysis', ...",{'url': 'https://www.ncbi.nlm.nih.gov/pubmed/1...
4,1,00fmeepz,1,what is the origin of COVID-19,Comprehensive overview of COVID-19 based on cu...,"In December 2019, twenty-seven pneumonia patie...","{'query': 'coronavirus origin', 'narrative': ""...","{'url': '', 'pubmed_id': ''}"
...,...,...,...,...,...,...,...,...
66331,50,zn10rnrm,1,what is known about an mRNA vaccine for the SA...,Characterization of RNA in Saliva,Background: We have previously shown that huma...,"{'query': 'mRNA vaccine coronavirus', 'narrati...",{'url': 'https://www.ncbi.nlm.nih.gov/pmc/arti...
66332,50,zstmdt4n,0,what is known about an mRNA vaccine for the SA...,Coordinate induction of IFN-α and -γ by SARS-C...,Abstract Background: Severe acute respiratory ...,"{'query': 'mRNA vaccine coronavirus', 'narrati...",{'url': 'https://api.elsevier.com/content/arti...
66333,50,zth8ffy3,0,what is known about an mRNA vaccine for the SA...,Vasculopathy and Coagulopathy Associated with ...,The emergence of severe acute respiratory synd...,"{'query': 'mRNA vaccine coronavirus', 'narrati...","{'url': '', 'pubmed_id': ''}"
66334,50,zv4nbz9p,2,what is known about an mRNA vaccine for the SA...,"Emerging Technologies for Use in the Study, Di...",INTRODUCTION: The COVID-19 pandemic has caused...,"{'query': 'mRNA vaccine coronavirus', 'narrati...",{'url': 'https://doi.org/10.1007/s12195-020-00...


#### Check the number of documents

In [13]:
trec_covid_merged_df.drop_duplicates(['corpus-id'])

Unnamed: 0,query-id,corpus-id,score,query-text,corpus-title,corpus-text,query-metadata,corpus-metadata
0,1,005b2j4b,2,what is the origin of COVID-19,Monophyletic Relationship between Severe Acute...,Although primary genomic analysis has revealed...,"{'query': 'coronavirus origin', 'narrative': ""...",{'url': 'https://www.ncbi.nlm.nih.gov/pubmed/1...
4,1,00fmeepz,1,what is the origin of COVID-19,Comprehensive overview of COVID-19 based on cu...,"In December 2019, twenty-seven pneumonia patie...","{'query': 'coronavirus origin', 'narrative': ""...","{'url': '', 'pubmed_id': ''}"
8,1,g7dhmyyo,2,what is the origin of COVID-19,"The SARS, MERS and novel coronavirus (COVID-19...",OBJECTIVES: To provide an overview of the thre...,"{'query': 'coronavirus origin', 'narrative': ""...","{'url': '', 'pubmed_id': ''}"
17,1,0194oljo,1,what is the origin of COVID-19,Evidence for zoonotic origins of Middle East r...,Middle East respiratory syndrome (MERS) is an ...,"{'query': 'coronavirus origin', 'narrative': ""...",{'url': 'https://www.ncbi.nlm.nih.gov/pubmed/2...
18,1,021q9884,1,what is the origin of COVID-19,Deadly virus effortlessly hops species,Genetic engineering helps reveal origin of dea...,"{'query': 'coronavirus origin', 'narrative': ""...",{'url': 'https://www.ncbi.nlm.nih.gov/pmc/arti...
...,...,...,...,...,...,...,...,...
66331,50,zn10rnrm,1,what is known about an mRNA vaccine for the SA...,Characterization of RNA in Saliva,Background: We have previously shown that huma...,"{'query': 'mRNA vaccine coronavirus', 'narrati...",{'url': 'https://www.ncbi.nlm.nih.gov/pmc/arti...
66332,50,zstmdt4n,0,what is known about an mRNA vaccine for the SA...,Coordinate induction of IFN-α and -γ by SARS-C...,Abstract Background: Severe acute respiratory ...,"{'query': 'mRNA vaccine coronavirus', 'narrati...",{'url': 'https://api.elsevier.com/content/arti...
66333,50,zth8ffy3,0,what is known about an mRNA vaccine for the SA...,Vasculopathy and Coagulopathy Associated with ...,The emergence of severe acute respiratory synd...,"{'query': 'mRNA vaccine coronavirus', 'narrati...","{'url': '', 'pubmed_id': ''}"
66334,50,zv4nbz9p,2,what is known about an mRNA vaccine for the SA...,"Emerging Technologies for Use in the Study, Di...",INTRODUCTION: The COVID-19 pandemic has caused...,"{'query': 'mRNA vaccine coronavirus', 'narrati...",{'url': 'https://doi.org/10.1007/s12195-020-00...


#### Check the number of documents referred by each question

In [14]:
trec_covid_merged_df.groupby('query-id').count()

Unnamed: 0_level_0,corpus-id,score,query-text,corpus-title,corpus-text,query-metadata,corpus-metadata
query-id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
1,1565,1565,1565,1565,1354,1565,1565
2,1251,1251,1251,1251,1091,1251,1251
3,1649,1649,1649,1649,1457,1649,1649
4,1793,1793,1793,1793,1511,1793,1793
5,1643,1643,1643,1643,1479,1643,1643
6,1562,1562,1562,1562,1372,1562,1562
7,1331,1331,1331,1331,1156,1331,1331
8,1837,1837,1837,1837,1607,1837,1837
9,1606,1606,1606,1606,1327,1606,1606
10,1100,1100,1100,1100,941,1100,1100


#### Check the available document scores

In [15]:
trec_covid_merged_df['score'].unique()

array([ 2,  0,  1, -1])

## Explore TREC COVID documents data

In [16]:
trec_covid_corpus_df = pd.read_csv(TREC_COVID_DOCUMENTS_FILE, sep='\t', header=None, names=['corpus-id', 'corpus-title-text'])

In [17]:
trec_covid_corpus_df.head()

Unnamed: 0,corpus-id,corpus-title-text
0,ug7v899j,Clinical features of culture-proven Mycoplasma...
1,02tnwd4m,Nitric oxide: a pro-inflammatory mediator in l...
2,ejv2xln0,Surfactant protein-D and pulmonary host defens...
3,2b73a28n,Role of endothelin-1 in lung diseaseEndothelin...
4,9785vg6d,Gene expression in epithelial cells in respons...


In [18]:
trec_covid_corpus_df.shape

(171325, 2)

## Create dense representations for the documents and requires

In [19]:
class TextToEncodeDataset(torch.utils.data.Dataset):

    def __init__(self, texts_list, tokenizer, max_length=None):

        self.max_length = max_length

        self.tokenized_texts = tokenizer(texts_list, 
                                         truncation=True, 
                                         return_overflowing_tokens=True, 
                                         max_length=max_length, 
                                         return_length=True)
        
        self.original_length = len(texts_list)
        self.length_stats = stats.describe(self.tokenized_texts['length'])

        print("Text tokens size stats:\n{}\n".format(self.length_stats))

        if (max_length is not None) and 'overflow_to_sample_mapping' in self.tokenized_texts:
            if self.original_length < len(self.tokenized_texts['overflow_to_sample_mapping']):
                print("Added {} overflowing texts...".format(len(self.tokenized_texts['overflow_to_sample_mapping']) - self.original_length))


    def __len__(self):
        return len(self.tokenized_texts['input_ids'])


    def __getitem__(self, index):
        return {'input_ids': self.tokenized_texts['input_ids'][index],
                'attention_mask': self.tokenized_texts['attention_mask'][index]}

    def get_original_index(self, tokenized_documents_indexes):
        if 'overflow_to_sample_mapping' in self.tokenized_texts:
            return np.array(self.tokenized_texts['overflow_to_sample_mapping'])[tokenized_documents_indexes]

In [20]:
class DPRCollator(object):
    def __init__(self, tokenizer=None):
        self.tokenizer = tokenizer


    def __call__(self, batch):
        padded_batch = self.tokenizer.pad(batch, return_tensors='pt')

        return BatchEncoding(padded_batch)

In [21]:
def encode(device,
           which_model, 
           which_dataloader):
    
    encoded_text = []

    which_model.eval()

    with torch.no_grad():
        for batch in tqdm(which_dataloader, mininterval=0.5, desc="Encode", disable=False):
            
            encoded_outputs = which_model(**batch.to(device))

            encoded_cls = encoded_outputs.last_hidden_state[:, 0, :]

            encoded_text.append(encoded_cls.cpu().numpy())

    return np.concatenate(encoded_text, axis=0)

### Instantiate the required elements

In [22]:
batch_size = 256

In [23]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(device)

cuda


In [24]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [25]:
passages_model = AutoModel.from_pretrained(os.path.join(TRAIN_OUTPUT_FOLDER, BEST_MODEL_CHECKPOINT, "_passages")).to(device)
topics_model = AutoModel.from_pretrained(os.path.join(TRAIN_OUTPUT_FOLDER, BEST_MODEL_CHECKPOINT, "_topics")).to(device)

### Prepare the queries to be tokenized

In [26]:
valid_queries_df = trec_covid_merged_df[['query-id', 'query-text']].drop_duplicates().sort_values('query-id').reset_index(drop=True)

### Check if has already computed the encoded values

In [27]:
if os.path.exists(ENCODED_DATA_FILE.format(BEST_MODEL_CHECKPOINT)):
    with open(ENCODED_DATA_FILE.format(BEST_MODEL_CHECKPOINT), 'rb') as inputFile:
        encoded_data = pickle.load(inputFile)

    encoded_queries = encoded_data['encoded_queries']
    encoded_corpus = encoded_data['encoded_corpus']
    queries_to_encode = encoded_data['queries_dataset']
    corpus_to_encode = encoded_data['corpus_dataset']

    already_encoded = True
else:
    already_encoded = False

    print("Need to encode the data...")

### Tokenize queries and corpus

In [28]:
if not already_encoded:
    queries_to_encode = TextToEncodeDataset(valid_queries_df['query-text'].tolist(), tokenizer, MODEL_MAX_INPUT_LENGTH)
else:
    print("Already encoded the data...")    

Already encoded the data...


In [29]:
if not already_encoded:
    start_time = time.time()

    corpus_to_encode = TextToEncodeDataset(trec_covid_corpus_df['corpus-title-text'].tolist(), tokenizer, MODEL_MAX_INPUT_LENGTH)

    print("Time to tokenize the corpus: {}".format(time.time() - start_time))
else:
    print("Already encoded the data...")    

Already encoded the data...


In [30]:
encode_queries_dataloader = torch.utils.data.DataLoader(queries_to_encode, 
                                                        batch_size=batch_size, 
                                                        shuffle=False, 
                                                        collate_fn=DPRCollator(tokenizer))

encode_corpus_dataloader = torch.utils.data.DataLoader(corpus_to_encode, 
                                                       batch_size=batch_size, 
                                                       shuffle=False, 
                                                       collate_fn=DPRCollator(tokenizer))

In [31]:
if not already_encoded:
    encoded_queries = encode(device, topics_model, encode_queries_dataloader)
else:
    print("Already encoded the data...")    

Already encoded the data...


In [32]:
if not already_encoded:
    start_time = time.time()

    encoded_corpus = encode(device, passages_model, encode_corpus_dataloader)

    print("Time to encode the corpus: {}".format(time.time() - start_time))
else:
    print("Already encoded the data...")    

Already encoded the data...


In [33]:
if not already_encoded:
    with open(ENCODED_DATA_FILE.format(BEST_MODEL_CHECKPOINT), 'wb') as outputFile:
        pickle.dump({'encoded_queries': encoded_queries,
                    'encoded_corpus': encoded_corpus,
                    'queries_dataset': queries_to_encode,
                    'corpus_dataset': corpus_to_encode}, outputFile, pickle.HIGHEST_PROTOCOL)
else:
    print("Already encoded the data...")        

Already encoded the data...


## Now, create an search index

In [34]:
corpus_index = faiss.IndexFlatIP(384)

In [35]:
start_time = time.time()

corpus_index.add(encoded_corpus)

print("FAISS index creation time: {}".format(time.time() - start_time))

FAISS index creation time: 0.20330166816711426


## Search the queries in the index

In [36]:
search_results = corpus_index.search(encoded_queries, 1000)

In [37]:
search_results[0].shape

(50, 1000)

### Save the results in the TREC file format

In [38]:
TREC_RESULT_LINE_FORMAT="{}\tQ0\t{}\t{}\t{}\tminiLM_DPR\n"

In [39]:
test_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

In [40]:
run_filename = PYSERINI_TEST_RUN_DPR_FILENAME_FORMAT.format(test_timestamp)

In [41]:
# with open(run_filename, 'w') as outputFile:
#     for query_index, query_scores in enumerate(search_results[0]):
#         document_ascending_order = np.argsort(query_scores)
#         tokenized_documents_ordered_indexes = search_results[1][query_index][document_ascending_order]

#         original_documents_ordered_indexes = corpus_to_encode.get_original_index(tokenized_documents_ordered_indexes)

#         max_score = np.max(query_scores)

#         included_docs = set()

#         for i, document_index in enumerate(original_documents_ordered_indexes):
#             if document_index not in included_docs:
#                 included_docs.add(document_index)

#                 outputFile.write(TREC_RESULT_LINE_FORMAT.format(valid_queries_df.iloc[query_index]['query-id'], 
#                                                                 trec_covid_corpus_df.iloc[document_index]['corpus-id'], 
#                                                                 i + 1,
#                                                                 np.abs(query_scores[document_ascending_order[i]] - max_score)))
#             else:
#                 print("Ignoring document={} as it is already in the answers set for query={}".format(document_index, query_index))

In [53]:
COMPUTE_MEAN_OF_REPEATED_DOCS=True

In [56]:
with open(run_filename, 'w') as outputFile:
    for query_index, query_scores in enumerate(search_results[0]):

        # print("\n\nQUERY {}\n".format(query_index))

        document_descending_order = np.argsort(query_scores)[::-1]
        tokenized_documents_ordered_indexes = search_results[1][query_index][document_descending_order]

        original_documents_ordered_indexes = corpus_to_encode.get_original_index(tokenized_documents_ordered_indexes)

        # print("original_documents_ordered_indexes={}".format(original_documents_ordered_indexes))

        remaining_query_scores = []
        remaining_original_documents_indexes = []

        if COMPUTE_MEAN_OF_REPEATED_DOCS:
            doc_counts = np.unique(original_documents_ordered_indexes, return_counts=True)

            # print(doc_counts)

            if np.sum(doc_counts[1] > 1) > 0:
                for document_index, document_count in zip(doc_counts[0], doc_counts[1]):

                    # print("document_index={}, document_count={}".format(document_index, document_count))

                    document_pos = np.where(original_documents_ordered_indexes == document_index)[0]

                    # print(document_pos)

                    if document_count > 1:
                        # print("{} = {}".format(document_index, document_pos))
                        # print(query_scores[document_descending_order[document_pos]])
                        # print(np.mean(query_scores[document_descending_order[document_pos]]))

                        remaining_query_scores.append(np.mean(query_scores[document_descending_order[document_pos]]))
                    else:
                        remaining_query_scores.append(query_scores[document_descending_order[document_pos]][0])

                    remaining_original_documents_indexes.append(document_index)

                # print(remaining_original_documents_indexes)
                # print(remaining_query_scores)

                query_scores = remaining_query_scores
                document_descending_order = np.argsort(remaining_query_scores)[::-1]
                original_documents_ordered_indexes = np.array(remaining_original_documents_indexes)[document_descending_order]

                # print("original_documents_ordered_indexes={}".format(original_documents_ordered_indexes))

        included_docs = set()

        for i, document_index in enumerate(original_documents_ordered_indexes):
            if document_index not in included_docs:
                included_docs.add(document_index)

                outputFile.write(TREC_RESULT_LINE_FORMAT.format(valid_queries_df.iloc[query_index]['query-id'], 
                                                                trec_covid_corpus_df.iloc[document_index]['corpus-id'], 
                                                                i + 1,
                                                                query_scores[document_descending_order[i]]))
            else:
                print("Ignoring document={} as it is already in the answers set for query={}".format(document_index, query_index))

## Now run the metric

In [57]:
!pyserini/tools/eval/trec_eval.9.0.4/trec_eval -c -mrecall.1000 -mmap -mndcg_cut.10 -mrecip_rank.100 \
    {TREC_COVID_QRELS} {run_filename}

map                   	all	0.0439
recip_rank            	all	0.5149
recall_1000           	all	0.1732
ndcg_cut_10           	all	0.3197
