# Import

In [None]:
import pandas as pd
import torch as tc
import torch.nn.functional as F
from torch.utils.data import Dataset, 
from transformers import BertTokenizer, BertModel

import pyterrier as pt
if not pt.started():
    pt.init()


if tc.cuda.is_available():
  device = 'cuda'
else:
  device = 'cpu'

BATCH_SIZE = 30

In [None]:
def batch(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        if(ndx % (10*n)) == 0:
          print(ndx)
        yield iterable[ndx:min(ndx + n, l)]
def flatten(t):
    return [item for sublist in t for item in sublist]

# Dataset

In [None]:
class TextsDataset(Dataset):
  '''
  Requires a tuple (queries, documents), where
  a query at index idx and a document at index idx
  is a query-document pair to be scored. 
  '''
  def __init__(self, query_document_pairs, text_path):
    self.qids = query_document_pairs[0]
    self.docnos = query_document_pairs[1]
    self.text_path = text_path

  def __len__(self):
    return len(self.qids)

  def __getitem__(self, idx):
    def read(docno):
      p = self.text_path + docno + '.txt'
      with open(p) as f:
        text = f.read()
        return text
    if type(idx) == int:
      docno = self.docnos[idx]
      return read(docno), self.qids[idx]
    else:
      docnos = self.docnos[idx]
      texts = [read(d) for d in docnos]
      return texts, self.qids[idx]



# Encode text with bert

In [None]:
def z_norm(inputs):
    mean = inputs.mean(0, keepdim=True)
    var = inputs.var(0, unbiased=False, keepdim=True)
    return (inputs - mean) / tc.sqrt(var + 1e-9)

In [None]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return tc.sum(token_embeddings * input_mask_expanded, 1) / tc.clamp(input_mask_expanded.sum(1), min=1e-9)

In [None]:
def encode(tokenizer, model, text, norm = False):
  with tc.no_grad():
    encoded_input = tokenizer(text, truncation=True, return_tensors="pt", max_length=512, padding = 'max_length')
    encoded_input = encoded_input.to(device)
    output = model(**encoded_input).last_hidden_state
    if norm:
      output = z_norm(output)
    embeddings = mean_pooling(output, encoded_input['attention_mask'])
    embeddings = F.normalize(embeddings, p=2, dim=1)
    return embeddings

# Encode queries

In [None]:
def get_queries_representations(tokenizer, model, topics):
  q_embs = dict()
  for _, row in topics.iterrows():
    qid = row['identifier']
    text = '. '.join([row['title'], row['description']])
    q_embs[qid] = encode(tokenizer, model, text)
  return q_embs

In [None]:
def encode_queries(qids, queries_embeddings):
  qs = [queries_embeddings[q] for q in qids]
  return tc.cat(qs)

# Encode and score texts batch

In [None]:
def encode_score_texts(tokenizer, model, query_embs, texts):
  texts_embs = encode(tokenizer, model, texts)
  sim_scores = tc.mm(query_embs, texts_embs.transpose(0, 1)).diagonal() 
  return sim_scores.detach().cpu().tolist()


# Import models

In [None]:
mbert_tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
mbert_model = BertModel.from_pretrained("bert-base-multilingual-cased").to(device)


In [None]:
aligned_mbert_tokenizer = BertTokenizer.from_pretrained('bert-base-m-cased_align')
aligned_mbert_model = BertModel.from_pretrained("bert-base-m-cased_align").to(device)


In [None]:
german_tokenizer = BertTokenizer.from_pretrained("bert-base-german-cased")
german_model = BertModel.from_pretrained("bert-base-german-cased").to(device)

In [None]:
from transformers import FlaubertModel, FlaubertTokenizer
french_model = FlaubertModel.from_pretrained('flaubert/flaubert_base_cased').to(device)
french_tokenizer = FlaubertTokenizer.from_pretrained('flaubert/flaubert_base_cased', do_lowercase=False)

In [None]:
english_tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
english_model = BertModel.from_pretrained("bert-base-cased").to(device)

In [None]:
from transformers import BertTokenizer, BertModel
spanish_tokenizer = BertTokenizer.from_pretrained("dccuchile/bert-base-spanish-wwm-cased")
spanish_model = BertModel.from_pretrained("dccuchile/bert-base-spanish-wwm-cased").to(device)

# Experiment

In [None]:
def run_experiment(tokenizer, model, dataset, topics, result_path, qrels):
  queries_embeddings = get_queries_representations(tokenizer, model, topics)
  scores = [encode_score_texts(tokenizer, model, encode_queries(qs, queries_embeddings), ts) for (ts, qs) in batch(dataset, BATCH_SIZE)]
  scores = flatten(scores)
  sim_scores = pd.DataFrame(data = {"docno": dataset.docnos, "qid": dataset.qids, "score" :  scores})
  sim_scores.to_csv(result_path)
  print(pt.Utils.evaluate(sim_scores, qrels, ["map", "ndcg_cut_5", "ndcg_cut_10"]))