In [51]:
import torch
import torch.nn.functional as F
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification #Can use the texar equivalent here
from elasticsearch import Elasticsearch
import csv
import os

### Configuration - Goes into config.yaml

In [72]:
input_file = 'data/collectionandqueries/queries.dev.small.tsv'
ground_truth_file = 'data/collectionandqueries/qrels.dev.small.tsv'
output_file = 'output/results_dev.tsv'

host = 'localhost:9200'
index_name = 'elastic_index'
size = 100 # For testing purposes - Use 1000 for full-ranking

model_name = 'amberoad/bert-multilingual-passage-reranking-msmarco'
max_seq_length = 512

### Full-ranking

In [42]:
es = Elasticsearch()

In [43]:
def es_search(query_text):
    
    query_body = {"query": {"match": {'content': query_text}}, "size": size}
    results = es.search(index="elastic_index", body=query_body)
    hits = results['hits']['hits']

    return hits

### Re-ranking

In [44]:
# Tokenizer and Model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

In [73]:
def process_query(query_pack):
    
    # Get query id and text
    query_id = query_pack[0]
    query_text = query_pack[1]
    
    # Get the top-1000 results from es bm25 and keep in doc_packs
    hits = es_search(query_text)
    
    doc_packs = [[hit['_source']['doc_id'], hit['_source']['content']] for hit in hits]
    
    # ===== Iterate through doc_packs - following the Forte pipeline - Very slow hence vectorized ====
#     doc_scores = []
#     for doc_pack in doc_packs:
        
#         doc_id = doc_pack[0]
#         doc_text = doc_pack[1]
        
#         # Bert Inference
#         encodings = tokenizer(query_text, doc_text, padding = True, max_length=max_seq_length, return_tensors= 'pt')
        
#         model.eval()
#         with torch.no_grad():
#             logits = model(**encodings)
        
#         pt_predictions = F.softmax(logits[0], dim=1)
#         score = pt_predictions.tolist()[0][1]
        
#         doc_scores.append([doc_id, score])
    
    # Vectorization - Still similarly slow ============================================================
    
    docs_id = list(list(zip(*doc_packs))[0])
    docs_content = list(list(zip(*doc_packs))[1])
    
    # Bert Inference
    encodings = tokenizer([query_text] * len(docs_content), docs_content, 
                          padding = True, max_length=max_seq_length, return_tensors= 'pt')
    
    
    model.eval()
    with torch.no_grad():
        logits = model(**encodings)
    
    pt_predictions = F.softmax(logits[0], dim=1)
    scores = pt_predictions[:,1]
    
    doc_scores = list(zip(docs_id, scores))
    
    doc_scores = sorted(doc_scores, key = lambda x: x[1], reverse=True)
    doc_ranks = [[query_id, row[0], idx+1] for idx, row in enumerate(doc_scores)]
    
    return doc_ranks

### Main function

In [74]:
with open(input_file, 'r', encoding='utf-8') as file:
    
    counter = 0
    
    if not os.path.exists('output'):
        os.makedirs('output')
    open(output_file, "w").close()
    
    for line in file:
        
        query_pack = line.split('\t', 1)
        
        # Get the ranks after full-ranker and re-ranker
        doc_ranks = process_query(query_pack)
        
        # Append the results to tsv
        with open(output_file, 'a', newline='') as f:
            tsv_writer = csv.writer(f, delimiter='\t')
            [tsv_writer.writerow(row) for row in doc_ranks]
        
        counter += 1
        
        if counter % 1 == 0:
            print(f'Ranked {counter} queries')
            
        # Removing below break will run for all 7k queries
        if counter==10:
            break
        
print(f'Completed ranking {counter} queries')

Ranked 1 queries
Ranked 2 queries
Ranked 3 queries
Ranked 4 queries
Ranked 5 queries
Ranked 6 queries
Ranked 7 queries
Ranked 8 queries
Ranked 9 queries
Ranked 10 queries
Completed ranking 10 queries


### Evaluation

In [75]:
from ms_marco_eval import compute_metrics_from_files
metrics = compute_metrics_from_files(path_to_reference = ground_truth_file, path_to_candidate = output_file)

In [76]:
metrics

{'MRR @10': 0.0005027971073816346, 'QueriesRanked': 10}