<a href="https://colab.research.google.com/github/eduseiti/ia368v_dd_class_07/blob/main/exhaustive_search.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import numpy as np
import pickle

import os

In [None]:
from datetime import datetime

In [None]:
import pandas as pd

In [None]:
# RESULT_FILE="trec_covid_encoded_data_checkpoint_20230416_230224_0.0495.pkl"
# RESULT_FILE="trec_covid_encoded_data_checkpoint_20230417_153108_0.0491.pkl"
RESULT_FILE="trec_covid_encoded_data_checkpoint_20230418_050008_0.0294.pkl"

In [None]:
class TextToEncodeDataset(torch.utils.data.Dataset):

    def __init__(self, texts_list, tokenizer, max_length=None):

        self.max_length = max_length

        self.tokenized_texts = tokenizer(texts_list, 
                                         truncation=True, 
                                         return_overflowing_tokens=True, 
                                         max_length=max_length, 
                                         return_length=True)
        
        self.original_length = len(texts_list)
        self.length_stats = stats.describe(self.tokenized_texts['length'])

        print("Text tokens size stats:\n{}\n".format(self.length_stats))

        if (max_length is not None) and 'overflow_to_sample_mapping' in self.tokenized_texts:
            if self.original_length < len(self.tokenized_texts['overflow_to_sample_mapping']):
                print("Added {} overflowing texts...".format(len(self.tokenized_texts['overflow_to_sample_mapping']) - self.original_length))


    def __len__(self):
        return len(self.tokenized_texts['input_ids'])


    def __getitem__(self, index):
        return {'input_ids': self.tokenized_texts['input_ids'][index],
                'attention_mask': self.tokenized_texts['attention_mask'][index]}

    def get_original_index(self, tokenized_documents_indexes):
        if 'overflow_to_sample_mapping' in self.tokenized_texts:
            return np.array(self.tokenized_texts['overflow_to_sample_mapping'])[tokenized_documents_indexes]

In [None]:
with open(RESULT_FILE, "rb") as inputFile:
    encoded_data = pickle.load(inputFile)

In [None]:
encoded_data.keys()

dict_keys(['encoded_queries', 'encoded_corpus', 'queries_dataset', 'corpus_dataset'])

In [None]:
encoded_data['encoded_queries'].shape

(50, 384)

In [None]:
encoded_data['encoded_corpus'].shape

(182372, 384)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(device)

cuda


In [None]:
encoded_queries = torch.from_numpy(encoded_data['encoded_queries'])
encoded_corpus = torch.from_numpy(encoded_data['encoded_corpus'])
corpus_to_encode = encoded_data['corpus_dataset']

In [None]:
encoded_queries.to(device)
encoded_corpus.to(device)

tensor([[-0.7183,  0.1911, -0.5156,  ..., -0.0823,  0.0640, -0.3524],
        [-0.1416,  0.5256, -0.0711,  ...,  0.4958, -0.1396,  0.1306],
        [-0.6621,  0.3241, -0.1063,  ...,  0.6668, -0.1499,  0.0558],
        ...,
        [-0.6558,  0.1820, -0.5903,  ...,  0.1376,  0.0290,  0.0267],
        [-0.2587,  0.0220, -0.6069,  ...,  0.0166,  0.0696, -0.1629],
        [-0.1113,  0.0995, -0.0934,  ...,  0.3468, -0.1816, -0.0892]],
       device='cuda:0')

In [None]:
encoded_queries.shape

torch.Size([50, 384])

In [None]:
encoded_corpus.shape

torch.Size([182372, 384])

In [None]:
all_scores = torch.mm(encoded_queries, encoded_corpus.t())

In [None]:
all_scores.shape

torch.Size([50, 182372])

In [None]:
all_scores[0]

tensor([56.7614, 53.4069, 56.4023,  ..., 58.6083, 63.8573, 67.1564])

In [None]:
all_scores_descending_indexes = torch.argsort(all_scores, dim=1, descending=True)

In [None]:
all_scores_descending_indexes.shape

torch.Size([50, 182372])

In [None]:
all_scores[0][all_scores_descending_indexes[0][:10]]

tensor([78.1097, 77.9512, 77.7971, 77.6878, 77.6312, 77.6312, 77.6091, 77.5471,
        77.5236, 77.2872])

In [None]:
PYSERINI_FOLDER="/mnt/0060f889-4c27-409b-b0de-47f5427515e3/unicamp/ia368v_dd/pyserini"

In [None]:
PYSERINI_TEST_RUN_DPR_FILENAME_FORMAT="run.trec-covid_DPR_exhaustive_{}_{}.txt"

RUNS_FOLDER="runs"

# BEST_MODEL_CHECKPOINT="checkpoint_20230416_230224_0.0495"
BEST_MODEL_CHECKPOINT="checkpoint_20230417_153108_0.0491"

COMPUTE_MEAN_OF_REPEATED_DOCS=True

In [None]:
TREC_RESULT_LINE_FORMAT="{}\tQ0\t{}\t{}\t{}\tminiLM_DPR\n"

test_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

run_filename = os.path.join(RUNS_FOLDER, PYSERINI_TEST_RUN_DPR_FILENAME_FORMAT.format(BEST_MODEL_CHECKPOINT, test_timestamp))

In [None]:
TREC_COVID_MERGED_FILE="trec_covid_merged_data.tsv"
TREC_COVID_DOCUMENTS_FILE="trec_covid_original_title_text_merged.tsv"
TREC_COVID_QRELS="trec_covid_qrels.tsv"

In [None]:
trec_covid_merged_df = pd.read_csv(TREC_COVID_MERGED_FILE, sep='\t')

valid_queries_df = trec_covid_merged_df[['query-id', 'query-text']].drop_duplicates().sort_values('query-id').reset_index(drop=True)

In [None]:
trec_covid_corpus_df = pd.read_csv(TREC_COVID_DOCUMENTS_FILE, sep='\t', header=None, names=['corpus-id', 'corpus-title-text'])

In [None]:
len(corpus_to_encode.tokenized_texts['input_ids'])

182372

In [None]:
all_scores = all_scores.cpu().numpy()
all_scores_descending_indexes.cpu().numpy()

array([[100049,  31107, 153966, ...,  19362,  20271,  20379],
       [ 96503, 123459,  65673, ...,  66621,   8942,  22943],
       [132658, 119805, 111487, ...,  21539,   8942,  99309],
       ...,
       [ 47128, 110295,  23656, ...,  75902, 104825, 104824],
       [140987, 118241,  58782, ...,  26673,  21205,  99309],
       [106647, 102087, 158532, ..., 165673,  20379,  21205]])

In [None]:
with open(run_filename, 'w') as outputFile:
    for query_index in range(all_scores.shape[0]):
        
        query_scores = all_scores[query_index]

#         print("\n\nQUERY {}\n".format(query_index))

        document_descending_order = all_scores_descending_indexes[query_index][:1000]
        tokenized_documents_ordered_indexes = document_descending_order

        original_documents_ordered_indexes = corpus_to_encode.get_original_index(tokenized_documents_ordered_indexes)

#         print("document_descending_order={}".format(document_descending_order))
#         print("original_documents_ordered_indexes={}".format(original_documents_ordered_indexes))

        remaining_query_scores = []
        remaining_original_documents_indexes = []

        if COMPUTE_MEAN_OF_REPEATED_DOCS:
            doc_counts = np.unique(original_documents_ordered_indexes, return_counts=True)

#             print(doc_counts)

            if np.sum(doc_counts[1] > 1) > 0:
                for document_index, document_count in zip(doc_counts[0], doc_counts[1]):

#                     print("document_index={}, document_count={}".format(document_index, document_count))

                    document_pos = np.where(original_documents_ordered_indexes == document_index)[0]

#                     print(document_pos)

                    if document_count > 1:
#                         print("{} = {}".format(document_index, document_pos))
#                         print(query_scores[document_descending_order[document_pos]])
#                         print(np.mean(query_scores[document_descending_order[document_pos]]))

                        remaining_query_scores.append(np.mean(query_scores[document_descending_order[document_pos]]))
                    else:
                        remaining_query_scores.append(query_scores[document_descending_order[document_pos]])

                    remaining_original_documents_indexes.append(document_index)

                # print(remaining_original_documents_indexes)
                # print(remaining_query_scores)

                query_scores = remaining_query_scores
                document_descending_order = np.argsort(remaining_query_scores)[::-1]
                original_documents_ordered_indexes = np.array(remaining_original_documents_indexes)[document_descending_order]

                # print("original_documents_ordered_indexes={}".format(original_documents_ordered_indexes))

        included_docs = set()

        for i, document_index in enumerate(original_documents_ordered_indexes):
            if document_index not in included_docs:
                included_docs.add(document_index)

                outputFile.write(TREC_RESULT_LINE_FORMAT.format(valid_queries_df.iloc[query_index]['query-id'], 
                                                                trec_covid_corpus_df.iloc[document_index]['corpus-id'], 
                                                                i + 1,
                                                                query_scores[document_descending_order[i]]))
            else:
                print("Ignoring document={} as it is already in the answers set for query={}".format(document_index, query_index))

In [None]:
!{PYSERINI_FOLDER}/tools/eval/trec_eval.9.0.4/trec_eval -c -mrecall.1000 -mmap -mndcg_cut.10 -mrecip_rank.100 \
    {TREC_COVID_QRELS} {run_filename}

map                   	all	0.0532
recip_rank            	all	0.5791
recall_1000           	all	0.1842
ndcg_cut_10           	all	0.3104
