# SPLADE for Portuguese

Inspired by https://github.com/naver/splade

(Até o momento, tentei poder parte do código.  Falta refatorar a parte relativa a configurações e entender o código em mais detallhes - ex.: a parte relativa aos dados, avaliação e indexação.  Pode ser útil também quebrar este notebook em etapas, ou então mover as classes para arquivos .py)

## Libraries installation

In [3]:
!pip install transformers -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m71.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m19.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m97.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [19]:
from abc import ABC

import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModel
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
import contextlib
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader

## Training loop

In [5]:
class TransformerRep(torch.nn.Module):

  def __init__(self, model_type_or_dir, output, fp16=False):
    """
    output indicates which representation(s) to output from transformer ("MLM" for MLM model)
    model_type_or_dir is either the name of a pre-trained model (e.g. bert-base-uncased), or the path to
    directory containing model weights, vocab etc.
    """
    super().__init__()
    assert output in ("mean", "cls", "hidden_states", "MLM"), "provide valid output"
    model_class = AutoModel if output != "MLM" else AutoModelForMaskedLM
    self.transformer = model_class.from_pretrained(model_type_or_dir)
    self.tokenizer = AutoTokenizer.from_pretrained(model_type_or_dir)
    self.output = output
    self.fp16 = fp16

  def forward(self, **tokens):
    #with torch.cuda.amp.autocast() if self.fp16 else NullContextManager():
    with torch.cuda.amp.autocast() if self.fp16 else contextlib.nullcontext():
      # tokens: output of HF tokenizer
      out = self.transformer(**tokens)
      if self.output == "MLM":
          return out
      hidden_states = self.transformer(**tokens)[0]
      # => forward from AutoModel returns a tuple, first element is hidden states, shape (bs, seq_len, hidden_dim)
      if self.output == "mean":
        return torch.sum(hidden_states * tokens["attention_mask"].unsqueeze(-1),
                          dim=1) / torch.sum(tokens["attention_mask"], dim=-1, keepdim=True)
      elif self.output == "cls":
        return hidden_states[:, 0, :]  # returns [CLS] representation
      else:
        return hidden_states, tokens["attention_mask"]
        # no pooling, we return all the hidden states (+ the attention mask)

In [11]:
def normalize(tensor, eps=1e-9):
  """normalize input tensor on last dimension
  """
  return tensor / (torch.norm(tensor, dim=-1, keepdim=True) + eps)

In [12]:
class SiameseBase(torch.nn.Module, ABC):

  def __init__(self, model_type_or_dir, output, match="dot_product", model_type_or_dir_q=None, freeze_d_model=False,
                fp16=False):
    super().__init__()
    self.output = output
    assert match in ("dot_product", "cosine_sim"), "specify right match argument"
    self.cosine = True if match == "cosine_sim" else False
    self.match = match
    self.fp16 = fp16
    self.transformer_rep = TransformerRep(model_type_or_dir, output, fp16)
    self.transformer_rep_q = TransformerRep(model_type_or_dir_q,
                                            output, fp16) if model_type_or_dir_q is not None else None
    assert not (freeze_d_model and model_type_or_dir_q is None)
    self.freeze_d_model = freeze_d_model
    if freeze_d_model:
        self.transformer_rep.requires_grad_(False)

  def encode(self, kwargs, is_q):
    raise NotImplementedError

  def encode_(self, tokens, is_q=False):
    transformer = self.transformer_rep
    if is_q and self.transformer_rep_q is not None:
        transformer = self.transformer_rep_q
    return transformer(**tokens)

  def train(self, mode=True):
    if self.transformer_rep_q is None:  # only one model, life is simple
      self.transformer_rep.train(mode)
    else:  # possibly freeze d model
      self.transformer_rep_q.train(mode)
      mode_d = False if not mode else not self.freeze_d_model
      self.transformer_rep.train(mode_d)

  def forward(self, **kwargs):
    """forward takes as inputs 1 or 2 dict
    "d_kwargs" => contains all inputs for document encoding
    "q_kwargs" => contains all inputs for query encoding ([OPTIONAL], e.g. for indexing)
    """
    #with torch.cuda.amp.autocast() if self.fp16 else NullContextManager():
    with torch.cuda.amp.autocast() if self.fp16 else contextlib.nullcontext():
      out = {}
      do_d, do_q = "d_kwargs" in kwargs, "q_kwargs" in kwargs
      if do_d:
        d_rep = self.encode(kwargs["d_kwargs"], is_q=False)
        if self.cosine:  # normalize embeddings
          d_rep = normalize(d_rep)
        out.update({"d_rep": d_rep})
      if do_q:
        q_rep = self.encode(kwargs["q_kwargs"], is_q=True)
        if self.cosine:  # normalize embeddings
          q_rep = normalize(q_rep)
        out.update({"q_rep": q_rep})
      if do_d and do_q:
        if "nb_negatives" in kwargs:
          # in the cas of negative scoring, where there are several negatives per query
          bs = q_rep.shape[0]
          d_rep = d_rep.reshape(bs, kwargs["nb_negatives"], -1)  # shape (bs, nb_neg, out_dim)
          q_rep = q_rep.unsqueeze(1)  # shape (bs, 1, out_dim)
          score = torch.sum(q_rep * d_rep, dim=-1)  # shape (bs, nb_neg)
        else:
          if "score_batch" in kwargs:
            score = torch.matmul(q_rep, d_rep.t())  # shape (bs_q, bs_d)
          else:
            score = torch.sum(q_rep * d_rep, dim=1, keepdim=True)  # shape (bs, )
        out.update({"score": score})
    return out

In [7]:
class Splade(SiameseBase):
  """SPLADE model
  """

  def __init__(self, model_type_or_dir, model_type_or_dir_q=None, freeze_d_model=False, agg="max", fp16=True):
    super().__init__(model_type_or_dir=model_type_or_dir,
                      output="MLM",
                      match="dot_product",
                      model_type_or_dir_q=model_type_or_dir_q,
                      freeze_d_model=freeze_d_model,
                      fp16=fp16)
    self.output_dim = self.transformer_rep.transformer.config.vocab_size  # output dim = vocab size = 30522 for BERT
    assert agg in ("sum", "max")
    self.agg = agg

  def encode(self, tokens, is_q):
    out = self.encode_(tokens, is_q)["logits"]  # shape (bs, pad_len, voc_size)
    if self.agg == "sum":
        return torch.sum(torch.log(1 + torch.relu(out)) * tokens["attention_mask"].unsqueeze(-1), dim=1)
    else:
        values, _ = torch.max(torch.log(1 + torch.relu(out)) * tokens["attention_mask"].unsqueeze(-1), dim=1)
        return values
        # 0 masking also works with max because all activations are positive

In [8]:
random_seed = 123

In [13]:
def init_simple_bert_optim(model, lr, weight_decay, warmup_steps, num_training_steps):
  """
  inspired from https://github.com/ArthurCamara/bert-axioms/blob/master/scripts/bert.py
  """
  optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
  scheduler = get_linear_schedule_with_warmup(optimizer=optimizer,
                                              num_warmup_steps=warmup_steps,
                                              num_training_steps=num_training_steps)
  return optimizer, scheduler


In [14]:

class InBatchPairwiseNLL:
  """in batch negatives version
  """

  def __init__(self):
    self.logsoftmax = torch.nn.LogSoftmax(dim=1)

  def __call__(self, out_d):
    in_batch_scores, neg_scores = out_d["pos_score"], out_d["neg_score"]
    # here in_batch_scores is a matrix of size bs * (bs / nb_gpus)
    nb_columns = in_batch_scores.shape[1]
    nb_gpus = int(in_batch_scores.shape[0] / nb_columns)
    temp = torch.cat([in_batch_scores, neg_scores], dim=1)  # concat neg score from BM25 sampling
    # shape (batch_size, batch_size/nb_gpus + 1)
    scores = self.logsoftmax(temp)
    return torch.mean(-scores[torch.arange(in_batch_scores.shape[0]),
                              torch.arange(nb_columns).repeat(nb_gpus)])

In [15]:
def get_loss(config):
  if config["loss"] == "InBatchPairwiseNLL":
      loss = InBatchPairwiseNLL()
  else:
      raise NotImplementedError("provide valid loss")
  return loss

In [16]:
class FLOPS:
  """constraint from Minimizing FLOPs to Learn Efficient Sparse Representations
  https://arxiv.org/abs/2004.05665
  """

  def __call__(self, batch_rep):
      return torch.sum(torch.mean(torch.abs(batch_rep), dim=0) ** 2)

In [17]:
class RegWeightScheduler:
  """same scheduling as in: Minimizing FLOPs to Learn Efficient Sparse Representations
  https://arxiv.org/abs/2004.05665
  """

  def __init__(self, lambda_, T):
      self.lambda_ = lambda_
      self.T = T
      self.t = 0
      self.lambda_t = 0

  def step(self):
      """quadratic increase until time T
      """
      if self.t >= self.T:
          pass
      else:
          self.t += 1
          self.lambda_t = self.lambda_ * (self.t / self.T) ** 2
      return self.lambda_t

  def get_lambda(self):
      return self.lambda_t


In [None]:
# initialize regularizer dict
if "regularizer" in config and regularizer is None:  # else regularizer is loaded
  output_dim = model.module.output_dim if hasattr(model, "module") else model.output_dim
  regularizer = {"eval": {"L0": {"loss": init_regularizer("L0")},
                          "sparsity_ratio": {"loss": init_regularizer("sparsity_ratio",
                                                                      output_dim=output_dim)}},
                  "train": {}}
  if config["regularizer"] == "eval_only":
    # just in the case we train a model without reg but still want the eval metrics like L0
    pass
  else:
    for reg in config["regularizer"]:
      temp = {"loss": init_regularizer(config["regularizer"][reg]["reg"]),
              "targeted_rep": config["regularizer"][reg]["targeted_rep"]}
      d_ = {}
      if "lambda_q" in config["regularizer"][reg]:
          d_["lambda_q"] = RegWeightScheduler(config["regularizer"][reg]["lambda_q"],
                                              config["regularizer"][reg]["T"])
      if "lambda_d" in config["regularizer"][reg]:
          d_["lambda_d"] = RegWeightScheduler(config["regularizer"][reg]["lambda_d"],
                                              config["regularizer"][reg]["T"])
      temp["lambdas"] = d_  # it is possible to have reg only on q or d if e.g. you only specify lambda_q
      # in the reg config
      # targeted_rep is just used to indicate which rep to constrain (if e.g. the model outputs several
      # representations)
      # the common case: model outputs "rep" (in forward) and this should be the value for this targeted_rep
      regularizer["train"][reg] = temp


Estou assumindo por enquanto que vamos iniciar o treinamento apenas com triplets.  Ainda não olhei os datasets.

In [None]:
class PairsDatasetPreLoad(Dataset):
  """
  dataset to iterate over a collection of pairs, format per line: q \t d_pos \t d_neg
  we preload everything in memory at init
  """

  def __init__(self, data_dir):
    self.data_dir = data_dir
    self.id_style = "row_id"

    self.data_dict = {}  # => dict that maps the id to the line offset (position of pointer in the file)
    print("Preloading dataset")
    self.data_dir = os.path.join(self.data_dir, "raw.tsv")
    with open(self.data_dir) as reader:
        for i, line in enumerate(tqdm(reader)):
            if len(line) > 1:
                query, pos, neg = line.split("\t")  # first column is id
                self.data_dict[i] = (query.strip(), pos.strip(), neg.strip())
    self.nb_ex = len(self.data_dict)

  def __len__(self):
    return self.nb_ex

  def __getitem__(self, idx):
    return self.data_dict[idx]

In [20]:
def rename_keys(d, prefix):
  return {prefix + "_" + k: v for k, v in d.items()}


class DataLoaderWrapper(DataLoader):
  def __init__(self, tokenizer_type, max_length, **kwargs):
    self.max_length = max_length
    self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_type)
    super().__init__(collate_fn=self.collate_fn, **kwargs, pin_memory=True)

  def collate_fn(self, batch):
    raise NotImplementedError("must implement this method")

class SiamesePairsDataLoader(DataLoaderWrapper):
  """Siamese encoding (query and document independent)
  train mode (pairs)
  """

  def collate_fn(self, batch):
    """
    batch is a list of tuples, each tuple has 3 (text) items (q, d_pos, d_neg)
    """
    q, d_pos, d_neg = zip(*batch)
    q = self.tokenizer(list(q),
                        add_special_tokens=True,
                        padding="longest",  # pad to max sequence length in batch
                        truncation="longest_first",  # truncates to self.max_length
                        max_length=self.max_length,
                        return_attention_mask=True)
    d_pos = self.tokenizer(list(d_pos),
                            add_special_tokens=True,
                            padding="longest",  # pad to max sequence length in batch
                            truncation="longest_first",  # truncates to self.max_length
                            max_length=self.max_length,
                            return_attention_mask=True)
    d_neg = self.tokenizer(list(d_neg),
                            add_special_tokens=True,
                            padding="longest",  # pad to max sequence length in batch
                            truncation="longest_first",  # truncates to self.max_length
                            max_length=self.max_length,
                            return_attention_mask=True)
    sample = {**rename_keys(q, "q"), **rename_keys(d_pos, "pos"), **rename_keys(d_neg, "neg")}
    return {k: torch.tensor(v) for k, v in sample.items()}



In [None]:
data_train = PairsDatasetPreLoad(data_dir=exp_dict["data"]["TRAIN_DATA_DIR"])

In [None]:
val_loss_loader = None  # default
if "VALIDATION_SIZE_FOR_LOSS" in exp_dict["data"]:
  print("initialize loader for validation loss")
  print("split train, originally {} pairs".format(len(data_train)))
  data_train, data_val = torch.utils.data.random_split(data_train, lengths=[
      len(data_train) - exp_dict["data"]["VALIDATION_SIZE_FOR_LOSS"],
      exp_dict["data"]["VALIDATION_SIZE_FOR_LOSS"]])
  print("train: {} pairs ~~ val: {} pairs".format(len(data_train), len(data_val)))
  val_loss_loader = SiamesePairsDataLoader(dataset=data_val, batch_size=config["eval_batch_size"],
                                                shuffle=False,
                                                num_workers=4,
                                                tokenizer_type=config["tokenizer_type"],
                                                max_length=config["max_length"], drop_last=drop_last)


In [None]:
train_loader = SiamesePairsDataLoader(dataset=data_train, batch_size=config["train_batch_size"], shuffle=True,
                                              num_workers=4,
                                              tokenizer_type=config["tokenizer_type"],
                                              max_length=config["max_length"], drop_last=drop_last)

OBS.: Ainda não olhei o que o código abaixo está fazendo (o que seria full ranking?)

In [24]:
import os
import h5py
from tqdm.auto import tqdm
import numpy as np

class IndexDictOfArray:
  def __init__(self, index_path=None, force_new=False, filename="array_index.h5py", dim_voc=None):
    if index_path is not None:
      self.index_path = index_path
      if not os.path.exists(index_path):
          os.makedirs(index_path)
      self.filename = os.path.join(self.index_path, filename)
      if os.path.exists(self.filename) and not force_new:
        print("index already exists, loading...")
        self.file = h5py.File(self.filename, "r")
        if dim_voc is not None:
            dim = dim_voc
        else:
            dim = self.file["dim"][()]
        self.index_doc_id = dict()
        self.index_doc_value = dict()
        for key in tqdm(range(dim)):
            try:
                self.index_doc_id[key] = np.array(self.file["index_doc_id_{}".format(key)],
                                                  dtype=np.int32)
                # ideally we would not convert to np.array() but we cannot give pool an object with hdf5
                self.index_doc_value[key] = np.array(self.file["index_doc_value_{}".format(key)],
                                                      dtype=np.float32)
            except:
                self.index_doc_id[key] = np.array([], dtype=np.int32)
                self.index_doc_value[key] = np.array([], dtype=np.float32)
        self.file.close()
        del self.file
        print("done loading index...")
        doc_ids = pickle.load(open(os.path.join(self.index_path, "doc_ids.pkl"), "rb"))
        self.n = len(doc_ids)
      else:
        self.n = 0
        print("initializing new index...")
        self.index_doc_id = defaultdict(lambda: array.array("I"))
        self.index_doc_value = defaultdict(lambda: array.array("f"))
    else:
      self.n = 0
      print("initializing new index...")
      self.index_doc_id = defaultdict(lambda: array.array("I"))
      self.index_doc_value = defaultdict(lambda: array.array("f"))

    def add_batch_document(self, row, col, data, n_docs=-1):
      """add a batch of documents to the index
      """
      if n_docs < 0:
          self.n += len(set(row))
      else:
          self.n += n_docs
      for doc_id, dim_id, value in zip(row, col, data):
          self.index_doc_id[dim_id].append(doc_id)
          self.index_doc_value[dim_id].append(value)

    def __len__(self):
      return len(self.index_doc_id)

    def nb_docs(self):
      return self.n

    def save(self, dim=None):
      print("converting to numpy")
      for key in tqdm(list(self.index_doc_id.keys())):
        self.index_doc_id[key] = np.array(self.index_doc_id[key], dtype=np.int32)
        self.index_doc_value[key] = np.array(self.index_doc_value[key], dtype=np.float32)
      print("save to disk")
      with h5py.File(self.filename, "w") as f:
        if dim:
            f.create_dataset("dim", data=int(dim))
        else:
            f.create_dataset("dim", data=len(self.index_doc_id.keys()))
        for key in tqdm(self.index_doc_id.keys()):
            f.create_dataset("index_doc_id_{}".format(key), data=self.index_doc_id[key])
            f.create_dataset("index_doc_value_{}".format(key), data=self.index_doc_value[key])
        f.close()
      print("saving index distribution...")  # => size of each posting list in a dict
      index_dist = {}
      for k, v in self.index_doc_id.items():
        index_dist[int(k)] = len(v)
      json.dump(index_dist, open(os.path.join(self.index_path, "index_dist.json"), "w"))


In [25]:
class L0:
  """non-differentiable
  """

  def __call__(self, batch_rep):
    return torch.count_nonzero(batch_rep, dim=-1).float().mean()

In [26]:
def to_list(tensor):
  return tensor.detach().cpu().tolist()

In [None]:
from collections import defaultdict
import pickle
import json

class Evaluator:
  def __init__(self, model, config=None, restore=True):
    """base class for model evaluation (inference)
    """
    self.model = model
    self.config = config
    self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    if restore:
      if self.device == torch.device("cuda"):
        if "pretrained_no_yamlconfig" not in config or not config["pretrained_no_yamlconfig"]:
          checkpoint = torch.load(os.path.join(config["checkpoint_dir"], "model/model.tar"))
          restore_model(model, checkpoint["model_state_dict"])
          print(
              "restore model on GPU at {}".format(os.path.join(config["checkpoint_dir"], "model/model.tar")))
        self.model.eval()
        if torch.cuda.device_count() > 1:
          print(" --- use {} GPUs --- ".format(torch.cuda.device_count()))
          self.model = torch.nn.DataParallel(self.model)
        self.model.to(self.device)

      else:  # CPU
        if "pretrained_no_yamlconfig" not in config or not config["pretrained_no_yamlconfig"]:
          checkpoint = torch.load(os.path.join(config["checkpoint_dir"], "model/model.tar"),
                                  map_location=self.device)
          restore_model(model, checkpoint["model_state_dict"])
          print(
              "restore model on CPU at {}".format(os.path.join(config["checkpoint_dir"], "model/model.tar")))
    else:
        print("WARNING: init evaluator, NOT restoring the model, NOT placing on device")
    self.model.eval()  # => put in eval mode


#Seria essa a parte relativa ao índice invertido? TODO: Ainda não olhei
class SparseIndexing(Evaluator):
  """sparse indexing
  """

  def __init__(self, model, config, compute_stats=False, dim_voc=None, is_query=False, force_new=True,**kwargs):
    super().__init__(model, config, **kwargs)
    self.index_dir = config["index_dir"] if config is not None else None
    self.sparse_index = IndexDictOfArray(self.index_dir, dim_voc=dim_voc, force_new=force_new)
    self.compute_stats = compute_stats
    self.is_query = is_query
    if self.compute_stats:
        self.l0 = L0()

  def index(self, collection_loader, id_dict=None):
    doc_ids = []
    if self.compute_stats:
        stats = defaultdict(float)
    count = 0
    with torch.no_grad():
      for t, batch in enumerate(tqdm(collection_loader)):
        inputs = {k: v.to(self.device) for k, v in batch.items() if k not in {"id"}}
        if self.is_query:
            batch_documents = self.model(q_kwargs=inputs)["q_rep"]
        else:
            batch_documents = self.model(d_kwargs=inputs)["d_rep"]
        if self.compute_stats:
            stats["L0_d"] += self.l0(batch_documents).item()
        row, col = torch.nonzero(batch_documents, as_tuple=True)
        data = batch_documents[row, col]
        row = row + count
        batch_ids = to_list(batch["id"])
        if id_dict:
            batch_ids = [id_dict[x] for x in batch_ids]
        count += len(batch_ids)
        doc_ids.extend(batch_ids)
        self.sparse_index.add_batch_document(row.cpu().numpy(), col.cpu().numpy(), data.cpu().numpy(),
                                              n_docs=len(batch_ids))
    if self.compute_stats:
      stats = {key: value / len(collection_loader) for key, value in stats.items()}
    if self.index_dir is not None:
      self.sparse_index.save()
      pickle.dump(doc_ids, open(os.path.join(self.index_dir, "doc_ids.pkl"), "wb"))
      print("done iterating over the corpus...")
      print("index contains {} posting lists".format(len(self.sparse_index)))
      print("index contains {} documents".format(len(doc_ids)))
      if self.compute_stats:
        with open(os.path.join(self.index_dir, "index_stats.json"), "w") as handler:
          json.dump(stats, handler)
    else:
      # if no index_dir, we do not write the index to disk but return it
      for key in list(self.sparse_index.index_doc_id.keys()):
          # convert to numpy
          self.sparse_index.index_doc_id[key] = np.array(self.sparse_index.index_doc_id[key], dtype=np.int32)
          self.sparse_index.index_doc_value[key] = np.array(self.sparse_index.index_doc_value[key],
                                                            dtype=np.float32)
      out = {"index": self.sparse_index, "ids_mapping": doc_ids}
      if self.compute_stats:
          out["stats"] = stats
      return out


In [None]:
class CollectionDatasetPreLoad(Dataset):
  """
  dataset to iterate over a document/query collection, format per line: format per line: doc_id \t doc
  we preload everything in memory at init
  """

  def __init__(self, data_dir, id_style):
    self.data_dir = data_dir
    assert id_style in ("row_id", "content_id"), "provide valid id_style"
    # id_style indicates how we access the doc/q (row id or doc/q id)
    self.id_style = id_style
    self.data_dict = {}
    self.line_dict = {}
    print("Preloading dataset")
    with open(os.path.join(self.data_dir, "raw.tsv")) as reader:
        for i, line in enumerate(tqdm(reader)):
            if len(line) > 1:
                id_, *data = line.split("\t")  # first column is id
                data = " ".join(" ".join(data).splitlines())
                if self.id_style == "row_id":
                    self.data_dict[i] = data
                    self.line_dict[i] = id_.strip()
                else:
                    self.data_dict[id_] = data.strip()
    self.nb_ex = len(self.data_dict)

  def __len__(self):
    return self.nb_ex

  def __getitem__(self, idx):
    if self.id_style == "row_id":
        return self.line_dict[idx], self.data_dict[idx]
    else:
        return str(idx), self.data_dict[str(idx)]

In [21]:
class CollectionDataLoader(DataLoaderWrapper):
  """
  """

  def collate_fn(self, batch):
    """
    batch is a list of tuples, each tuple has 2 (text) items (id_, doc)
    """
    id_, d = zip(*batch)
    processed_passage = self.tokenizer(list(d),
                                        add_special_tokens=True,
                                        padding="longest",  # pad to max sequence length in batch
                                        truncation="longest_first",  # truncates to self.max_length
                                        max_length=self.max_length,
                                        return_attention_mask=True)
    return {**{k: torch.tensor(v) for k, v in processed_passage.items()},
            "id": torch.tensor([int(i) for i in id_], dtype=torch.long)}

In [27]:
def makedir(dir_):
    if not os.path.exists(dir_):
        os.makedirs(dir_)

In [28]:
def restore_model(model, state_dict):
    missing_keys, unexpected_keys = model.load_state_dict(state_dict=state_dict, strict=False)
    # strict = False => it means that we just load the parameters of layers which are present in both and
    # ignores the rest
    if len(missing_keys) > 0:
        print("~~ [WARNING] MISSING KEYS WHILE RESTORING THE MODEL ~~")
        print(missing_keys)
    if len(unexpected_keys) > 0:
        print("~~ [WARNING] UNEXPECTED KEYS WHILE RESTORING THE MODEL ~~")
        print(unexpected_keys)
    print("restoring model:", model.__class__.__name__)

In [29]:
class Evaluator:
  def __init__(self, model, config=None, restore=True):
    """base class for model evaluation (inference)
    """
    self.model = model
    self.config = config
    self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    if restore:
      if self.device == torch.device("cuda"):
        if "pretrained_no_yamlconfig" not in config or not config["pretrained_no_yamlconfig"]:
          checkpoint = torch.load(os.path.join(config["checkpoint_dir"], "model/model.tar"))
          restore_model(model, checkpoint["model_state_dict"])
          print(
              "restore model on GPU at {}".format(os.path.join(config["checkpoint_dir"], "model/model.tar")))
        self.model.eval()
        if torch.cuda.device_count() > 1:
          print(" --- use {} GPUs --- ".format(torch.cuda.device_count()))
          self.model = torch.nn.DataParallel(self.model)
        self.model.to(self.device)

      else:  # CPU
        if "pretrained_no_yamlconfig" not in config or not config["pretrained_no_yamlconfig"]:
          checkpoint = torch.load(os.path.join(config["checkpoint_dir"], "model/model.tar"),
                                  map_location=self.device)
          restore_model(model, checkpoint["model_state_dict"])
          print(
              "restore model on CPU at {}".format(os.path.join(config["checkpoint_dir"], "model/model.tar")))
    else:
        print("WARNING: init evaluator, NOT restoring the model, NOT placing on device")
    self.model.eval()  # => put in eval mode


In [30]:
import numba

class SparseRetrieval(Evaluator):
  """retrieval from SparseIndexing
  """

  @staticmethod
  def select_topk(filtered_indexes, scores, k):
    if len(filtered_indexes) > k:
        sorted_ = np.argpartition(scores, k)[:k]
        filtered_indexes, scores = filtered_indexes[sorted_], -scores[sorted_]
    else:
        scores = -scores
    return filtered_indexes, scores

  @staticmethod
  @numba.njit(nogil=True, parallel=True, cache=True)
  def numba_score_float(inverted_index_ids: numba.typed.Dict,
                        inverted_index_floats: numba.typed.Dict,
                        indexes_to_retrieve: np.ndarray,
                        query_values: np.ndarray,
                        threshold: float,
                        size_collection: int):
      scores = np.zeros(size_collection, dtype=np.float32)  # initialize array with size = size of collection
      n = len(indexes_to_retrieve)
      for _idx in range(n):
          local_idx = indexes_to_retrieve[_idx]  # which posting list to search
          query_float = query_values[_idx]  # what is the value of the query for this posting list
          retrieved_indexes = inverted_index_ids[local_idx]  # get indexes from posting list
          retrieved_floats = inverted_index_floats[local_idx]  # get values from posting list
          for j in numba.prange(len(retrieved_indexes)):
              scores[retrieved_indexes[j]] += query_float * retrieved_floats[j]
      filtered_indexes = np.argwhere(scores > threshold)[:, 0]  # ideally we should have a threshold to filter
      # unused documents => this should be tuned, currently it is set to 0
      return filtered_indexes, -scores[filtered_indexes]

  def __init__(self, model, config, dim_voc, dataset_name=None, index_d=None, compute_stats=False, is_beir=False,
                **kwargs):
    super().__init__(model, config, **kwargs)
    assert ("index_dir" in config and index_d is None) or (
            "index_dir" not in config and index_d is not None)
    if "index_dir" in config:
      self.sparse_index = IndexDictOfArray(config["index_dir"], dim_voc=dim_voc)
      self.doc_ids = pickle.load(open(os.path.join(config["index_dir"], "doc_ids.pkl"), "rb"))
    else:
      self.sparse_index = index_d["index"]
      self.doc_ids = index_d["ids_mapping"]
      for i in range(dim_voc):
          # missing keys (== posting lists), causing issues for retrieval => fill with empty
          if i not in self.sparse_index.index_doc_id:
              self.sparse_index.index_doc_id[i] = np.array([], dtype=np.int32)
              self.sparse_index.index_doc_value[i] = np.array([], dtype=np.float32)
    # convert to numba
    self.numba_index_doc_ids = numba.typed.Dict()
    self.numba_index_doc_values = numba.typed.Dict()
    for key, value in self.sparse_index.index_doc_id.items():
      self.numba_index_doc_ids[key] = value
    for key, value in self.sparse_index.index_doc_value.items():
      self.numba_index_doc_values[key] = value
    self.out_dir = os.path.join(config["out_dir"], dataset_name) if (dataset_name is not None and not is_beir) \
        else config["out_dir"]
    self.doc_stats = index_d["stats"] if (index_d is not None and compute_stats) else None
    self.compute_stats = compute_stats
    if self.compute_stats:
      self.l0 = L0()

  def retrieve(self, q_loader, top_k, name=None, return_d=False, id_dict=False, threshold=0):
    makedir(self.out_dir)
    if self.compute_stats:
      makedir(os.path.join(self.out_dir, "stats"))
    res = defaultdict(dict)
    if self.compute_stats:
      stats = defaultdict(float)
    with torch.no_grad():
      for t, batch in enumerate(tqdm(q_loader)):
        q_id = to_list(batch["id"])[0]
        if id_dict:
          q_id = id_dict[q_id]
        inputs = {k: v for k, v in batch.items() if k not in {"id"}}
        for k, v in inputs.items():
          inputs[k] = v.to(self.device)
        query = self.model(q_kwargs=inputs)["q_rep"]  # we assume ONE query per batch here
        if self.compute_stats:
          stats["L0_q"] += self.l0(query).item()
        # TODO: batched version for retrieval
        row, col = torch.nonzero(query, as_tuple=True)
        values = query[to_list(row), to_list(col)]
        filtered_indexes, scores = self.numba_score_float(self.numba_index_doc_ids,
                                                          self.numba_index_doc_values,
                                                          col.cpu().numpy(),
                                                          values.cpu().numpy().astype(np.float32),
                                                          threshold=threshold,
                                                          size_collection=self.sparse_index.nb_docs())
        # threshold set to 0 by default, could be better
        filtered_indexes, scores = self.select_topk(filtered_indexes, scores, k=top_k)
        for id_, sc in zip(filtered_indexes, scores):
          res[str(q_id)][str(self.doc_ids[id_])] = float(sc)
    if self.compute_stats:
      stats = {key: value / len(q_loader) for key, value in stats.items()}
    if self.compute_stats:
      with open(os.path.join(self.out_dir, "stats",
                              "q_stats{}.json".format("_iter_{}".format(name) if name is not None else "")),
                "w") as handler:
        json.dump(stats, handler)
      if self.doc_stats is not None:
        with open(os.path.join(self.out_dir, "stats",
                                "d_stats{}.json".format("_iter_{}".format(name) if name is not None else "")),
                  "w") as handler:
            json.dump(self.doc_stats, handler)
    with open(os.path.join(self.out_dir, "run{}.json".format("_iter_{}".format(name) if name is not None else "")),
              "w") as handler:
      json.dump(res, handler)
    if return_d:
      out = {"retrieval": res}
      if self.compute_stats:
          out["stats"] = stats if self.doc_stats is None else {**stats, **self.doc_stats}
      return out



In [31]:
class SparseApproxEvalWrapper(Evaluator):
  """
  wrapper for sparse indexer + retriever during training
  """

  def __init__(self, model, config, collection_loader, q_loader, **kwargs):
    super().__init__(model, config, **kwargs)
    self.collection_loader = collection_loader
    self.q_loader = q_loader
    self.model_output_dim = self.model.module.output_dim if hasattr(self.model, "module") else self.model.output_dim

  def index_and_retrieve(self, i):
    indexer = SparseIndexing(self.model, config=None, restore=False, compute_stats=True)
    sparse_index_d = indexer.index(self.collection_loader)
    retriever = SparseRetrieval(self.model, self.config, dim_voc=self.model_output_dim, index_d=sparse_index_d,
                                restore=False, compute_stats=True)
    return retriever.retrieve(self.q_loader, top_k=self.config["top_k"], name=i, return_d=True)

In [None]:
val_evaluator = None
if "VALIDATION_FULL_RANKING" in exp_dict["data"]:
    with open_dict(config):
        config["val_full_rank_qrel_path"] = exp_dict["data"]["VALIDATION_FULL_RANKING"]["QREL_PATH"]
    full_ranking_d_collection = CollectionDatasetPreLoad(
        data_dir=exp_dict["data"]["VALIDATION_FULL_RANKING"]["D_COLLECTION_PATH"], id_style="row_id")
    full_ranking_d_loader = CollectionDataLoader(dataset=full_ranking_d_collection,
                                                  tokenizer_type=config["tokenizer_type"],
                                                  max_length=config["max_length"],
                                                  batch_size=config["eval_batch_size"],
                                                  shuffle=False, num_workers=4)
    full_ranking_q_collection = CollectionDatasetPreLoad(
        data_dir=exp_dict["data"]["VALIDATION_FULL_RANKING"]["Q_COLLECTION_PATH"], id_style="row_id")
    full_ranking_q_loader = CollectionDataLoader(dataset=full_ranking_q_collection,
                                                  tokenizer_type=config["tokenizer_type"],
                                                  max_length=config["max_length"], batch_size=1,
                                                  # TODO fix: bs currently set to 1
                                                  shuffle=False, num_workers=4)
    val_evaluator = SparseApproxEvalWrapper(model,
                                            config={"top_k": exp_dict["data"]["VALIDATION_FULL_RANKING"]["TOP_K"],
                                                    "out_dir": os.path.join(config["checkpoint_dir"],
                                                                            "val_full_ranking")
                                                    },
                                            collection_loader=full_ranking_d_loader,
                                            q_loader=full_ranking_q_loader,
                                            restore=False)

In [None]:
# #################################################################
# # TRAIN
# #################################################################
print("+++++ BEGIN TRAINING +++++")
trainer = SiameseTransformerTrainer(model=model, iterations=iterations, loss=loss, optimizer=optimizer,
                                    config=config, scheduler=scheduler,
                                    train_loader=train_loader, validation_loss_loader=val_loss_loader,
                                    validation_evaluator=val_evaluator,
                                    regularizer=regularizer)
trainer.train()