# SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking

This notebook gives a minimal example usage of SPLADE.

* We provide models via Hugging Face (https://huggingface.co/naver)
* See [Naver Labs Europe website](https://europe.naverlabs.com/research/machine-learning-and-optimization/splade-models/) for other intermediate models.

| model | MRR@10 (MS MARCO dev) | recall@1000 (MS MARCO dev) | expected FLOPS | ~ avg q length | ~ avg d length | 
| --- | --- | --- | --- | --- | --- |
| `naver/splade_v2_max` (**v2** [HF](https://huggingface.co/naver/splade_v2_max)) | 34.0 | 96.5 | 1.32 | 18 | 92 |
| `naver/splade_v2_distil` (**v2** [HF](https://huggingface.co/naver/splade_v2_distil)) | 36.8 | 97.9 | 3.82 | 25 | 232 |
| `naver/splade-cocondenser-selfdistil` (**v2bis**, [HF](https://huggingface.co/naver/splade-cocondenser-selfdistil))| 37.6 | 98.4 | 2.32 | 56 | 134 |
| `naver/splade-cocondenser-ensembledistil` (**v2bis**, [HF](https://huggingface.co/naver/splade-cocondenser-ensembledistil)) | 38.3 | 98.3  | 1.85 | 44 | 120 |

In [1]:
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
from splade.models.transformer_rep import Splade

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# set the dir for trained weights

##### v2
# model_type_or_dir = "naver/splade_v2_max"
# model_type_or_dir = "naver/splade_v2_distil"

### v2bis, directly download from Hugging Face
model_type_or_dir = "naver/splade-cocondenser-selfdistil"
# model_type_or_dir = "/scratch/lamdo/splade_checkpoints/experiments_combined_references_v6/debug/checkpoint/model"
# model_type_or_dir = "/home/lamdo/keyphrase_informativeness_test/splade/experiments_unarxive_intro_relatedwork_1citationpersentence+scirepeval_search_v2/debug/checkpoint/model"

In [None]:
# loading model and tokenizer

model = Splade(model_type_or_dir, agg="max")
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_type_or_dir)
reverse_voc = {v: k for k, v in tokenizer.vocab.items()}

In [None]:
# example document from MS MARCO passage collection (doc_id = 8003157)

doc = """How much will it cost to go to college to become a detective"""

In [None]:
# now compute the document representation
doc_tokens = tokenizer(doc, return_tensors="pt")
with torch.no_grad():
    doc_rep = model(d_kwargs=doc_tokens)["d_rep"].squeeze()  # (sparse) doc rep in voc space, shape (30522,)

# get the number of non-zero dimensions in the rep:
col = torch.nonzero(doc_rep).squeeze().cpu().tolist()
print("number of actual dimensions: ", len(col))

# now let's inspect the bow representation:
weights = doc_rep[col].cpu().tolist()
d = {k: v for k, v in zip(col, weights)}
sorted_d = {k: v for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)}
bow_rep = []
for k, v in sorted_d.items():
    bow_rep.append((reverse_voc[k], round(v, 2)))
print("SPLADE BOW rep:\n", bow_rep)

number of actual dimensions:  57
SPLADE BOW rep:
 [('detective', 2.27), ('college', 1.66), ('police', 1.39), ('cost', 1.37), ('detectives', 1.28), ('price', 1.17), ('degree', 1.13), ('crime', 1.08), ('tuition', 0.93), ('fee', 0.92), ('become', 0.88), ('salary', 0.85), ('money', 0.8), ('interview', 0.79), ('training', 0.68), ('hunter', 0.68), ('lawyer', 0.63), ('colleges', 0.58), ('education', 0.56), ('go', 0.53), ('pay', 0.53), ('university', 0.51), ('hire', 0.51), ('student', 0.44), ('qualification', 0.43), ('late', 0.43), ('investigation', 0.41), ('research', 0.4), ('career', 0.39), ('officer', 0.32), ('investigator', 0.29), ('attend', 0.23), ('murder', 0.2), ('ask', 0.2), ('clerk', 0.19), ('investment', 0.19), ('young', 0.19), ('will', 0.16), ('psychiatrist', 0.14), ('graduate', 0.14), ('graduation', 0.13), ('recruit', 0.13), ('doctor', 0.11), ('required', 0.11), ('audition', 0.1), ('inspector', 0.09), ('priest', 0.09), ('funding', 0.08), ('much', 0.08), ('spending', 0.07), ('certif

In [None]:
len(set([int(x) for x in doc_tokens.input_ids[0]]).intersection(set(sorted_d.keys()))) / len(set([int(x) for x in doc_tokens.input_ids[0]]))

0.5384615384615384

In [None]:
[reverse_voc[k] for k in set([int(x) for x in doc_tokens.input_ids[0]]) - set(sorted_d.keys())]

['[CLS]', '[SEP]', 'a', 'to', 'how', 'it']