In [1]:
import os
import sys

In [2]:
IN_COLAB='google.colab' in sys.modules

In [3]:
if IN_COLAB:
    from google.colab import drive

    WORKING_FOLDER="/content/drive/MyDrive/unicamp/ia368v_dd/aula_08"

    drive.mount('/content/drive', force_remount=True)

    os.chdir(WORKING_FOLDER)
    
    !pip install transformers -q

### Check if pyserini is already installed in the drive folder

In [4]:
if not os.path.exists("pyserini"):
    os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-11-openjdk-amd64"

    !pip install pyserini -q
    !git clone --recurse-submodules https://github.com/castorini/pyserini.git

    os.chdir("pyserini/tools/eval")
    !tar xvfz trec_eval.9.0.4.tar.gz && cd trec_eval.9.0.4 && make && cd ../../..

    os.chdir("ndeval")
    !make && cd ../../..

    os.chdir(WORKING_FOLDER)
else:
    !chmod +x pyserini/tools/eval/trec_eval.9.0.4/trec_eval

    print("Pyserini already installed...")

Pyserini already installed...


In [5]:
import pandas as pd
import pickle
import numpy as np

import torch

from scipy import stats, sparse

from datetime import datetime

from transformers import (AutoTokenizer, 
                          AutoModel,
                          AutoModelForMaskedLM,
                          BatchEncoding
)

from tqdm.auto import tqdm

import time

In [6]:
import gc

from multiprocessing import Pool

import glob

In [7]:
TREC_COVID_MERGED_FILE="trec_covid_merged_data.tsv"
TREC_COVID_DOCUMENTS_FILE="trec_covid_original_title_text_merged.tsv"

TREC_COVID_QRELS="trec_covid_qrels.tsv"

In [8]:
PYSERINI_TEST_RUN_DPR_FILENAME_FORMAT="run.trec-covid_DPR_{}_{}.txt"

RUNS_FOLDER="runs"

In [9]:
MODEL_NAME="naver/splade_v2_distil"

In [10]:
TOKENIZED_DATA_FILE="trec_covid_tokenized_data_{}.pkl"

ENCODED_DATA_FILE="trec_covid_encoded_data_{}.pkl"
INVERTED_INDEX_FILE="trec_covid_splade_inverted_index_{}.pkl"

In [11]:
class TextToEncodeDataset(torch.utils.data.Dataset):

    def __init__(self, texts_list, tokenizer, return_overflow=True, max_length=None):

        self.max_length = max_length

        self.tokenized_texts = tokenizer(texts_list, 
                                         truncation=True, 
                                         return_overflowing_tokens=return_overflow, 
                                         max_length=max_length, 
                                         return_length=True)
        
        self.original_length = len(texts_list)
        self.length_stats = stats.describe(self.tokenized_texts['length'])

        print("Text tokens size stats:\n{}\n".format(self.length_stats))

        if (max_length is not None) and 'overflow_to_sample_mapping' in self.tokenized_texts:
            if self.original_length < len(self.tokenized_texts['overflow_to_sample_mapping']):
                print("Added {} overflowing texts...".format(len(self.tokenized_texts['overflow_to_sample_mapping']) - self.original_length))


    def __len__(self):
        return len(self.tokenized_texts['input_ids'])


    def __getitem__(self, index):
        return {'input_ids': self.tokenized_texts['input_ids'][index],
                'attention_mask': self.tokenized_texts['attention_mask'][index]}

    def get_original_index(self, tokenized_documents_indexes):
        if 'overflow_to_sample_mapping' in self.tokenized_texts:
            return np.array(self.tokenized_texts['overflow_to_sample_mapping'])[tokenized_documents_indexes]

In [12]:
class DataCollator(object):
    def __init__(self, tokenizer=None):
        self.tokenizer = tokenizer


    def __call__(self, batch):
        padded_batch = self.tokenizer.pad(batch, return_tensors='pt')

        return BatchEncoding(padded_batch)

### Check if hasn't already encoded the data

In [13]:
if os.path.exists(ENCODED_DATA_FILE.format(os.path.basename(MODEL_NAME))):
    with open(ENCODED_DATA_FILE.format(os.path.basename(MODEL_NAME)), 'rb') as inputFile:
        encoded_data = pickle.load(inputFile)
        
    encoded_queries = encoded_data['encoded_queries']
    encoded_corpus = encoded_data['encoded_corpus']
    
    has_already_encoded = True
else:
    print("Need to encode the data...")
    
    has_already_encoded = False

### Check if hasn't already created the inverted index

In [14]:
if os.path.exists(INVERTED_INDEX_FILE.format(os.path.basename(MODEL_NAME))):
    
    with open(INVERTED_INDEX_FILE.format(os.path.basename(MODEL_NAME)), "rb") as inputFile:
        inverted_index = pickle.load(inputFile)
    
    has_already_created_index = True
else:
    print("Need to create the inverted index...")
    
    has_already_created_index = False

### Load TREC COVID data

This data has already been preprocessed for the previous exercises.

In [32]:
trec_covid_merged_df = pd.read_csv(TREC_COVID_MERGED_FILE, sep='\t')
trec_covid_corpus_df = pd.read_csv(TREC_COVID_DOCUMENTS_FILE, sep='\t', header=None, names=['corpus-id', 'corpus-title-text'])

### Prepare the queries to be processed

In [33]:
valid_queries_df = trec_covid_merged_df[['query-id', 'query-text']].drop_duplicates().sort_values('query-id').reset_index(drop=True)

### Load the tokenizer and pretrained model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(device)

In [None]:
if not has_already_encoded:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

    model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME).to(device)

In [None]:
def encode(device,
           which_model, 
           which_dataloader):
    
    which_model.eval()

    encoded_values = []
    
    
    with torch.no_grad():
        for batch in tqdm(which_dataloader, mininterval=0.5, desc="Encode", disable=False):
            
            encoded_outputs = which_model(**batch.to(device))

            # print("logits.shape={}".format(encoded_outputs.logits.shape))
    
            # SPLADE v2 pooling takes the maximun value of each vocabulary token for each document
        
            sample_sparse_encodings, _ = torch.max(torch.log(1 + torch.nn.ReLU()(encoded_outputs.logits)) * batch.attention_mask.unsqueeze(-1), dim=1)
            
            # Following computation found at https://github.com/naver/splade/blob/94f941da7cf96ffdbb57758ce0f5c676136024ca/splade/models/transformer_rep.py#L192
            
            
            # print("sample_sparse_encodings.shape={}".format(sample_sparse_encodings.shape))

            # encoded_values.append(torch.nonzero(sample_sparse_encodings.cpu()))

            for document in sample_sparse_encodings.cpu():
                non_zero_vocabulary_terms_indexes = torch.nonzero(document).squeeze()
                non_zero_vocabulary_terms_scores = document[non_zero_vocabulary_terms_indexes]
                
                encoded_values.append([non_zero_vocabulary_terms_indexes, non_zero_vocabulary_terms_scores])
                
    return encoded_values

In [None]:
RETURN_OVERFLOW=False
MAX_INPUT_LENGTH=tokenizer.model_max_length
batch_size=32

### Check if has already tokenized the data

In [None]:
if not has_already_encoded:
    if os.path.exists(TOKENIZED_DATA_FILE.format(os.path.basename(MODEL_NAME))):
        with open(TOKENIZED_DATA_FILE.format(os.path.basename(MODEL_NAME)), "rb") as inputFile:
            tokenized_data = pickle.load(inputFile)

        queries_to_encode = tokenized_data['queries_to_encode']
        corpus_to_encode = tokenized_data['corpus_to_encode']

        already_tokenized = True
    else:
        already_tokenized = False

    if not already_tokenized:
        queries_to_encode = TextToEncodeDataset(valid_queries_df['query-text'].tolist(), 
                                                tokenizer,
                                                RETURN_OVERFLOW,
                                                MAX_INPUT_LENGTH)
    else:
        print("Already tokenized the data...")
else:
    print("Has already encoded the data...")

In [None]:
if not has_already_encoded:
    if not already_tokenized:
        start_time = time.time()

        corpus_to_encode = TextToEncodeDataset(trec_covid_corpus_df['corpus-title-text'].tolist(), 
                                               tokenizer, 
                                               RETURN_OVERFLOW,
                                               MAX_INPUT_LENGTH)

        print("Time to tokenize the corpus: {}".format(time.time() - start_time))

        with open(TOKENIZED_DATA_FILE.format(os.path.basename(MODEL_NAME)), "wb") as outputFile:
            pickle.dump({'queries_to_encode': queries_to_encode,
                         'corpus_to_encode': corpus_to_encode}, outputFile, pickle.HIGHEST_PROTOCOL)
    else:
        print("Already tokenized the data...")
else:
    print("Has already encoded the data...")    

In [None]:
if not has_already_encoded:
    encode_queries_dataloader = torch.utils.data.DataLoader(queries_to_encode, 
                                                            batch_size=batch_size, 
                                                            shuffle=False, 
                                                            collate_fn=DataCollator(tokenizer))

    encode_corpus_dataloader = torch.utils.data.DataLoader(corpus_to_encode, 
                                                           batch_size=batch_size, 
                                                           shuffle=False, 
                                                           collate_fn=DataCollator(tokenizer))
else:
    print("Has already encoded the data...")    

In [None]:
if not has_already_encoded:
    start_time = time.time()

    encoded_corpus = encode(device, model, encode_corpus_dataloader)

    print("Time to encode the corpus: {}".format(time.time() - start_time))
else:
    print("Has already encoded the data...")

In [None]:
if not has_already_encoded:
    encoded_queries = encode(device, model, encode_queries_dataloader)

    with open(ENCODED_DATA_FILE.format(os.path.basename(MODEL_NAME)), 'wb') as outputFile:
        pickle.dump({'encoded_queries': encoded_queries,
                     'encoded_corpus': encoded_corpus}, outputFile, pickle.HIGHEST_PROTOCOL)
else:
    print("Has already encoded the data...")    

### Create the inverted index

In [None]:
def create_inverted_index(encoded_documents, verbose=False):
    
    total_start_time = time.time()
    start_time = time.time()

    inverted_index = {}

    # Inverted index key will be the token values ― not their corresponding characters sequence.

    for doc_index, document in enumerate(encoded_documents):

        # document[0]: holds the non-zero vocabulary tokens
        # document[1]: holds the SPLADE score of the corresponding (same offset) vocabulary index for the
        #              given document.

        if verbose:
            if doc_index % 1000 == 0:
                print(time.time() - start_time)
                print(doc_index)
                start_time = time.time()        

        for term_index, token in enumerate(document[0]):

            token = token.item()

            if token not in inverted_index:
                inverted_index[token] = sparse.lil_matrix((1, len(encoded_documents)), dtype=np.float32)

            inverted_index[token][0, doc_index] = document[1][term_index].item()


    print("Time to create the inverted index: {}".format(time.time() - total_start_time))
    
    return inverted_index

In [None]:
inverted_index = create_inverted_index(encoded_corpus)

### Finally, save the inverted index

In [None]:
with open(INVERTED_INDEX_FILE.format(os.path.basename(MODEL_NAME)), "wb") as outputFile:
    pickle.dump(inverted_index, outputFile, pickle.HIGHEST_PROTOCOL)

## Process the queries

In [107]:
def find_related_documents(encoded_query, inverted_indexes):

    query_start_time = time.time()
    
    # encoded_query[0]: holds the non-zero vocabulary tokens
    # encoded_query[1]: holds the SPLADE score of the corresponding (same offset) vocabulary index for the
    #              given document.
    
    related_documents = []
    
    for token_index, token in enumerate(encoded_query[0]):
        
        token = token.item()
        
        if token in inverted_indexes:
            related_documents.append(np.array(inverted_index[token][0].todense())[0] * encoded_query[1][token_index].item())
        
    all_docs_scores = np.stack(related_documents)
    
    weighted_scores = np.sum(all_docs_scores, axis=0)
    
    print("Time to process query: {}".format(time.time() - query_start_time))
    
    return weighted_scores, time.time() - query_start_time

In [108]:
related_documents_per_query = []
time_per_query = []

for query in encoded_queries:
    
    docs, elapsed_time = find_related_documents(query, inverted_index)
    
    related_documents_per_query.append(docs)
    time_per_query.append(elapsed_time)

Time to process query: 0.4458603858947754
Time to process query: 0.40161800384521484
Time to process query: 0.564791202545166
Time to process query: 0.3619050979614258
Time to process query: 0.45990657806396484
Time to process query: 0.5195517539978027
Time to process query: 0.40600037574768066
Time to process query: 0.5381870269775391
Time to process query: 0.5310907363891602
Time to process query: 0.5816469192504883
Time to process query: 0.6061816215515137
Time to process query: 0.39930224418640137
Time to process query: 0.3310422897338867
Time to process query: 0.5846042633056641
Time to process query: 0.28842949867248535
Time to process query: 0.2158670425415039
Time to process query: 0.5996396541595459
Time to process query: 0.44867944717407227
Time to process query: 0.48566508293151855
Time to process query: 0.4152188301086426
Time to process query: 0.22468161582946777
Time to process query: 0.466876745223999
Time to process query: 0.4515516757965088
Time to process query: 0.427

In [109]:
print("Mean time per query: {}".format(np.mean(time_per_query)))

Mean time per query: 0.470242075920105


In [40]:
MAX_RESULTS_TO_SAVE=1000

TREC_RESULT_LINE_FORMAT="{}\tQ0\t{}\t{}\t{}\tSPLADE_v2\n"

PYSERINI_TEST_RUN_DPR_FILENAME_FORMAT="run.trec-covid_SPLADE_v2_{}_{}.txt"

In [88]:
test_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

run_filename = os.path.join(RUNS_FOLDER, PYSERINI_TEST_RUN_DPR_FILENAME_FORMAT.format(os.path.basename(MODEL_NAME), test_timestamp))

In [30]:
if not os.path.exists(RUNS_FOLDER):
    os.makedirs(RUNS_FOLDER)

In [89]:
def generate_trec_format(queries_matches, queries_df, documents_df, output_filename, verbose=False):
    
    with open(output_filename, 'w') as outputFile:
        for i, query_result in enumerate(queries_matches):
            
            if verbose:
                print("Saving query {} | query-id=={}:".format(i, queries_df.iloc[i]['query-id']))
            
            relevant_docs = np.where(query_result > 0)[0]
            relevant_docs_scores = query_result[relevant_docs]
            
            relevant_docs_order = np.argsort(relevant_docs_scores)[::-1]
            
            if verbose:
                print("relevant_docs.shape={}".format(relevant_docs.shape))
            
            relevant_docs_final_result = relevant_docs[relevant_docs_order]
            relevant_docs_final_score = relevant_docs_scores[relevant_docs_order]
            
            if verbose:
                print("relevant_docs_final_result: {}\n\n".format(relevant_docs_final_result))
            
            for j, document_index in enumerate(relevant_docs_final_result[:MAX_RESULTS_TO_SAVE]):
                outputFile.write(TREC_RESULT_LINE_FORMAT.format(queries_df.iloc[i]['query-id'], 
                                                                documents_df.iloc[document_index]['corpus-id'], 
                                                                j, 
                                                                relevant_docs_final_score[j]))

In [110]:
generate_trec_format(related_documents_per_query, valid_queries_df, trec_covid_corpus_df, run_filename, verbose=True)

Saving query 0 | query-id==1:
relevant_docs.shape=(171181,)
relevant_docs_final_result: [145417 106891  30254 ...   8109 166335  12034]


Saving query 1 | query-id==2:
relevant_docs.shape=(170937,)
relevant_docs_final_result: [127769 162571  76247 ...  16555  24589  10141]


Saving query 2 | query-id==3:
relevant_docs.shape=(171092,)
relevant_docs_final_result: [142069 124015 102328 ...  97154  45423  17227]


Saving query 3 | query-id==4:
relevant_docs.shape=(170977,)
relevant_docs_final_result: [ 68569 138474 116049 ...  15062  25524  18412]


Saving query 4 | query-id==5:
relevant_docs.shape=(170984,)
relevant_docs_final_result: [101920 127161  69553 ...  10174  10141  24494]


Saving query 5 | query-id==6:
relevant_docs.shape=(171192,)
relevant_docs_final_result: [ 68580  68122 111656 ...  35295  12931  86423]


Saving query 6 | query-id==7:
relevant_docs.shape=(170431,)
relevant_docs_final_result: [149708 169667  79853 ... 145410  12588  31503]


Saving query 7 | query-id==8:
rele

## Now run the metric

In [111]:
!pyserini/tools/eval/trec_eval.9.0.4/trec_eval -c -mrecall.1000 -mmap -mndcg_cut.10 -mrecip_rank.100 \
    {TREC_COVID_QRELS} {run_filename}

map                   	all	0.2191
recip_rank            	all	0.8833
recall_1000           	all	0.4290
ndcg_cut_10           	all	0.7063
