# SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking

This notebook gives a minimal example usage of our SPLADE model. In a nutshell, SPLADE learns **sparse**, **expansion-based** query/doc representations for efficient first-stage retrieval.

Sparsity is induced via a regularization applied on representations, whose strength can be adjusted; it is thus possible to control the trade-off between effectiveness and efficiency. For more details, check our paper, and don't hesitate to reach out ! 

We provide weights for two models (in the `weights` folder):

| model | MRR@10 (MS MARCO dev) | recall@1000 (MS MARCO dev) | expected FLOPS | ~ avg q length | ~ avg d length | 
| --- | --- | --- | --- | --- | --- |
| `flops_best` | 32.2 | 95.5 | 0.73 | 15 | 58 |
| `flops_efficient` | 29.6 | 93.3 | 0.05 | 6 | 18 |

In [1]:
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer

In [2]:
class Splade(torch.nn.Module):

    def __init__(self, model_type_or_dir):
        super().__init__()
        self.transformer = AutoModelForMaskedLM.from_pretrained(model_type_or_dir)
    
    def forward(self, **kwargs):
        out = self.transformer(**kwargs)["logits"] # output (logits) of MLM head, shape (bs, pad_len, voc_size)
        return torch.sum(torch.log(1 + torch.relu(out)) * kwargs["attention_mask"].unsqueeze(-1), dim=1)

In [3]:
# set the dir for trained weights 

model_type_or_dir = "weights/flops_efficient"
# model_type_or_dir = "weights/flops_best"

In [4]:
# loading model and tokenizer

model = Splade(model_type_or_dir)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_type_or_dir)
reverse_voc = {v: k for k, v in tokenizer.vocab.items()}

In [5]:
# example document from MS MARCO passage collection (doc_id = 8003157)

doc = "Glass and Thermal Stress. Thermal Stress is created when one area of a glass pane gets hotter than an adjacent area. If the stress is too great then the glass will crack. The stress level at which the glass will break is governed by several factors."

In [6]:
# now compute the document representation
with torch.no_grad():
    doc_rep = model(**tokenizer(doc, return_tensors="pt")).squeeze()  # (sparse) doc rep in voc space, shape (30522,)

# get the number of non-zero dimensions in the rep:
col = torch.nonzero(doc_rep).squeeze().cpu().tolist()
print("number of actual dimensions: ", len(col))

# now let's inspect the bow representation:
weights = doc_rep[col].cpu().tolist()
d = {k: v for k, v in zip(col, weights)}
sorted_d = {k: v for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)}
bow_rep = []
for k, v in sorted_d.items():
    bow_rep.append((reverse_voc[k], round(v, 2)))
print("SPLADE BOW rep:\n", bow_rep)

number of actual dimensions:  15
SPLADE BOW rep:
 [('glass', 2.22), ('stress', 2.19), ('thermal', 1.91), ('pan', 1.6), ('crack', 1.1), ('created', 0.98), ('caused', 0.77), ('hotter', 0.58), ('adjacent', 0.36), ('create', 0.35), ('area', 0.34), ('break', 0.3), ('plastic', 0.29), ('shatter', 0.11), ('hot', 0.09)]
