# SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking

This notebook gives a minimal example usage of SPLADE.

* We provide models via Hugging Face (https://huggingface.co/naver)
* See [Naver Labs Europe website](https://europe.naverlabs.com/research/machine-learning-and-optimization/splade-models/) for other intermediate models.

| model | MRR@10 (MS MARCO dev) | recall@1000 (MS MARCO dev) | expected FLOPS | ~ avg q length | ~ avg d length |
| --- | --- | --- | --- | --- | --- |
| `naver/splade_v2_max` (**v2** [HF](https://huggingface.co/naver/splade_v2_max)) | 34.0 | 96.5 | 1.32 | 18 | 92 |
| `naver/splade_v2_distil` (**v2** [HF](https://huggingface.co/naver/splade_v2_distil)) | 36.8 | 97.9 | 3.82 | 25 | 232 |
| `naver/splade-cocondenser-selfdistil` (**v2bis**, [HF](https://huggingface.co/naver/splade-cocondenser-selfdistil))| 37.6 | 98.4 | 2.32 | 56 | 134 |
| `naver/splade-cocondenser-ensembledistil` (**v2bis**, [HF](https://huggingface.co/naver/splade-cocondenser-ensembledistil)) | 38.3 | 98.3  | 1.85 | 44 | 120 |

In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

Mounted at /content/gdrive


In [None]:
!nvidia-smi

Fri Jun 16 05:26:31 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-SXM...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P0    46W / 400W |      0MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
%%shell
pip install pytrec_eval -q
pip install transformers -q
pip install torch -q
pip install datasets -q
pip install evaluate -q
pip install trectools -q
pip install faiss-cpu -q
pip install sentence-transformers -q
pip install git+https://github.com/naver/splade.git -q
#pip install pyserini -q

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for pytrec_eval (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m95.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m236.8/236.8 kB[0m [31m31.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m119.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m88.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.6/485.6 kB[0m [31m17.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m16.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.5/212.5 kB[0m [31m25.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━



In [None]:
user = "leonardo"
if user == "monique":
    main_dir = '/content/gdrive/MyDrive/Unicamp-projeto-final/'
else:
    main_dir = '/content/gdrive/MyDrive/Unicamp/IA368-DD/'

In [None]:
import os
import json
import numpy as np
import pandas as pd
import random
import torch
import collections
import evaluate
import shutil
import pickle
import numba

from collections import defaultdict, Counter
from datasets import load_dataset
from tqdm import tqdm
from operator import itemgetter
from time import time
from torch import nn, optim
from transformers import BatchEncoding, get_linear_schedule_with_warmup, AutoModelForMaskedLM, AutoTokenizer
from splade.models.transformer_rep import Splade

In [None]:
shutil.copyfile(f"{main_dir}Projeto Final/experiments_10m.zip", "/content/experiments.zip")

'/content/experiments.zip'

In [None]:
!unzip /content/experiments.zip
!mv /content/content/splade/experiments /content/experiments

Archive:  /content/experiments.zip
   creating: content/splade/experiments/
   creating: content/splade/experiments/pt/
   creating: content/splade/experiments/pt/checkpoint/
  inflating: content/splade/experiments/pt/checkpoint/training_perf.txt  
   creating: content/splade/experiments/pt/checkpoint/model/
  inflating: content/splade/experiments/pt/checkpoint/model/special_tokens_map.json  
  inflating: content/splade/experiments/pt/checkpoint/model/vocab.txt  
  inflating: content/splade/experiments/pt/checkpoint/model/model.tar  
  inflating: content/splade/experiments/pt/checkpoint/model/pytorch_model.bin  
  inflating: content/splade/experiments/pt/checkpoint/model/tokenizer.json  
  inflating: content/splade/experiments/pt/checkpoint/model/config.json  
  inflating: content/splade/experiments/pt/checkpoint/model/tokenizer_config.json  
   creating: content/splade/experiments/pt/checkpoint/val_full_ranking/
  inflating: content/splade/experiments/pt/checkpoint/val_full_ranking/ru

In [None]:
def restore_model(model, state_dict):
    missing_keys, unexpected_keys = model.load_state_dict(state_dict=state_dict, strict=False)
    # strict = False => it means that we just load the parameters of layers which are present in both and
    # ignores the rest
    if len(missing_keys) > 0:
        print("~~ [WARNING] MISSING KEYS WHILE RESTORING THE MODEL ~~")
        print(missing_keys)
    if len(unexpected_keys) > 0:
        print("~~ [WARNING] UNEXPECTED KEYS WHILE RESTORING THE MODEL ~~")
        print(unexpected_keys)
    print("restoring model:", model.__class__.__name__)

In [None]:
# loading model and tokenizer
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

ckpt = torch.load("/content/experiments/pt/checkpoint/model_ckpt/model_final_checkpoint.tar", map_location=device)
model = Splade("/content/experiments/pt/checkpoint/model", agg="max")
restore_model(model, ckpt["model_state_dict"])

model.eval()
tokenizer = AutoTokenizer.from_pretrained("neuralmind/bert-base-portuguese-cased")
reverse_voc = {v: k for k, v in tokenizer.vocab.items()}

restoring model: Splade


Downloading:   0%|          | 0.00/43.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/647 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/205k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

### PT

In [None]:
# example document from MS MARCO passage collection (doc_id = 8003157)
doc = "Vidro e Estresse Térmico. O estresse térmico é criado quando uma área de um painel de vidro fica mais quente do que uma área adjacente. Se a tensão for muito grande, o vidro rachará. O nível de tensão no qual o vidro quebrará é governado por vários fatores."

In [None]:
# now compute the document representation
with torch.no_grad():
    doc_rep = model(d_kwargs=tokenizer(doc, return_tensors="pt"))["d_rep"].squeeze()  # (sparse) doc rep in voc space, shape (30522,)

# get the number of non-zero dimensions in the rep:
col = torch.nonzero(doc_rep).squeeze().cpu().tolist()
print("number of actual dimensions: ", len(col))

# now let's inspect the bow representation:
weights = doc_rep[col].cpu().tolist()
d = {k: v for k, v in zip(col, weights)}
sorted_d = {k: v for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)}
bow_rep = []
for k, v in sorted_d.items():
    bow_rep.append((reverse_voc[k], round(v, 2)))
print("SPLADE BOW rep:\n", bow_rep)

number of actual dimensions:  192
SPLADE BOW rep:
 [('vidro', 1.65), ('estre', 1.39), ('porque', 1.37), ('quebra', 1.33), ('ra', 1.29), ('##ér', 1.14), ('painel', 1.13), ('clima', 1.11), ('agressivo', 1.09), ('causas', 1.09), ('quebrar', 1.06), ('temperatura', 1.03), ('tér', 1.02), ('tensão', 0.99), ('quente', 0.91), ('##mico', 0.9), ('##se', 0.87), ('pressão', 0.85), ('gerado', 0.8), (';', 0.8), ('T', 0.8), ('térmica', 0.79), ('causado', 0.78), ('calor', 0.78), ('queda', 0.76), ('energia', 0.76), ('edifício', 0.75), ('cria', 0.72), ('desgas', 0.7), ('significa', 0.66), ('t', 0.65), ('frio', 0.65), ('##ro', 0.64), ('devido', 0.62), ('vi', 0.62), ('painéis', 0.61), ('e', 0.58), ('quebrado', 0.58), ('##s', 0.57), ('emocional', 0.57), ('impacto', 0.55), ('criado', 0.55), ('derrubar', 0.55), ('plástico', 0.51), ('gerar', 0.5), ('nível', 0.5), ('ruptura', 0.49), ('recuperação', 0.47), ('aler', 0.47), ('perder', 0.45), ('resulta', 0.45), ('influenciado', 0.44), ('ambiental', 0.43), ('ocorre'