# Import

In [None]:
import pandas as pd
import torch as tc
import os
import torch.nn.functional as F
from torch.utils.data import Dataset

MAX_LENGTH = 10240

import pyterrier as pt
if not pt.started():
    pt.init()


if tc.cuda.is_available():
  device = 'cuda'
else:
  device = 'cpu'


In [None]:
def batch(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        if(ndx % (10*n)) == 0:
          print(ndx)
        yield iterable[ndx:min(ndx + n, l)]
def flatten(t):
    return [item for sublist in t for item in sublist]

# Dataset

In [None]:
class TextsDataset(Dataset):
  '''
  Requires a tuple (queries, documents), where
  a query at index idx and a document at index idx
  is a query-document pair to be scored. 
  '''
  def __init__(self, query_document_pairs, text_path):
    self.qids = query_document_pairs[0]
    self.docnos = query_document_pairs[1]
    self.text_path = text_path

  def __len__(self):
    return len(self.qids)

  def __getitem__(self, idx):
    def read(docno):
      p = self.text_path + docno + '.txt'
      with open(p) as f:
        text = f.read()
        return text
    if type(idx) == int:
      docno = self.docnos[idx]
      return read(docno), self.qids[idx]
    else:
      docnos = self.docnos[idx]
      texts = [read(d) for d in docnos]
      return texts, self.qids[idx]

In [None]:
def get_queries_text(topics):
  q_text = dict()
  for _, row in topics.iterrows():
    if 'query' in row:
      text = row['query']
      qid = row['qid']
      q_text[qid] = text
    else:
      qid = row['identifier']
      text = row['title'] + '. ' +  row['description']
      q_text[qid] = text
  return q_text

In [None]:
def encode_queries(qids, queries_dict):
  qs = [queries_dict[q] for q in qids]
  return tc.cat(qs)

# Encode and score texts 

In [None]:
def get_score_seq_cls(tokenizer, model, queries, texts):
  features = tokenizer(queries, texts, truncation = True, padding='max_length', max_length = 512, return_tensors="pt").to(device)
  with tc.no_grad():
    scores = model(**features).logits
    scores = scores.softmax(dim=-1).detach().cpu().numpy().tolist()
    return [s[1] for s in scores]


In [None]:
def encode_score_texts(tokenizer, model, queries_dict, qids, texts):
  queries = [queries_dict[qid] for qid in qids]
  return get_score_seq_cls(tokenizer, model, queries, texts)

# Experiment

In [None]:
def run_experiment(tokenizer, model, dataset, topics, result_path, qrels):
  queries = get_queries_text(topics)
  scores = [encode_score_texts(tokenizer, model, queries, qs, ts) for (ts, qs) in batch(dataset, 30)]
  scores = flatten(scores)
  sim_scores = pd.DataFrame(data = {"docno": dataset.docnos, "qid": dataset.qids, "score" :  scores})
  sim_scores.to_csv(result_path)
  print(pt.Utils.evaluate(sim_scores, qrels, ["map", "ndcg_cut_5", "ndcg_cut_10"]))

# Mbert passage reranking

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("amberoad/bert-multilingual-passage-reranking-msmarco")

model = AutoModelForSequenceClassification.from_pretrained("amberoad/bert-multilingual-passage-reranking-msmarco").to(device)

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

fine_tuned_model_path = ""

ft_tokenizer = AutoTokenizer.from_pretrained(fine_tuned_model_path)

ft_model = AutoModelForSequenceClassification.from_pretrained(fine_tuned_model_path).to(device)