# SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking

This notebook gives a minimal example usage of our SPLADE model. In a nutshell, SPLADE learns **sparse**, **expansion-based** query/doc representations for efficient first-stage retrieval.

Sparsity is induced via a regularization applied on representations, whose strength can be adjusted; it is thus possible to control the trade-off between effectiveness and efficiency. For more details, check our papers, and don't hesitate to reach out ! 
* v1 (SIGIR21 short paper): **SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking**, https://arxiv.org/abs/2107.05720
* v2 (arxiv) new pooling + distillation: **SPLADE v2: Sparse Lexical and Expansion Model for Information Retrieval**, https://arxiv.org/abs/2109.10086

We provide weights for 4 models (in the `weights` folder):

| model | MRR@10 (MS MARCO dev) | recall@1000 (MS MARCO dev) | expected FLOPS | ~ avg q length | ~ avg d length | 
| --- | --- | --- | --- | --- | --- |
| `flops_best` (**v1**) | 32.2 | 95.5 | 0.73 | 15 | 58 |
| `flops_efficient` (**v1**) | 29.6 | 93.3 | 0.05 | 6 | 18 |
| `splade_max` (**v2**) | 34.0 | 96.5 | 1.32 | 18 | 92 |
| `distilsplade_max` (**v2**) | 36.8 | 97.9 | 3.82 | 25 | 232 |

In [1]:
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer

In [2]:
class Splade(torch.nn.Module):

    def __init__(self, model_type_or_dir, agg="max"):
        super().__init__()
        self.transformer = AutoModelForMaskedLM.from_pretrained(model_type_or_dir)
        assert agg in ("sum", "max")
        self.agg = agg
    
    def forward(self, **kwargs):
        out = self.transformer(**kwargs)["logits"] # output (logits) of MLM head, shape (bs, pad_len, voc_size)
        if self.agg == "max":
            values, _ = torch.max(torch.log(1 + torch.relu(out)) * kwargs["attention_mask"].unsqueeze(-1), dim=1)
            return values
            # 0 masking also works with max because all activations are positive
        else:
            return torch.sum(torch.log(1 + torch.relu(out)) * kwargs["attention_mask"].unsqueeze(-1), dim=1)

In [3]:
# set the dir for trained weights 
# NOTE: because between v1 and v2 we switched the pooling mechanism (better results with max), we need to prodive
# the agg argument depending on the set of weights we want to use

#### v1
# agg = "sum"
# model_type_or_dir = "weights/flops_efficient"
# model_type_or_dir = "weights/flops_best"

##### v2
agg = "max"
model_type_or_dir = "weights/splade_max"
# model_type_or_dir = "weights/distilsplade_max"

In [4]:
# loading model and tokenizer

model = Splade(model_type_or_dir, agg=agg)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_type_or_dir)
reverse_voc = {v: k for k, v in tokenizer.vocab.items()}

In [5]:
# example document from MS MARCO passage collection (doc_id = 8003157)

doc = "Glass and Thermal Stress. Thermal Stress is created when one area of a glass pane gets hotter than an adjacent area. If the stress is too great then the glass will crack. The stress level at which the glass will break is governed by several factors."

In [6]:
# now compute the document representation
with torch.no_grad():
    doc_rep = model(**tokenizer(doc, return_tensors="pt")).squeeze()  # (sparse) doc rep in voc space, shape (30522,)

# get the number of non-zero dimensions in the rep:
col = torch.nonzero(doc_rep).squeeze().cpu().tolist()
print("number of actual dimensions: ", len(col))

# now let's inspect the bow representation:
weights = doc_rep[col].cpu().tolist()
d = {k: v for k, v in zip(col, weights)}
sorted_d = {k: v for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)}
bow_rep = []
for k, v in sorted_d.items():
    bow_rep.append((reverse_voc[k], round(v, 2)))
print("SPLADE BOW rep:\n", bow_rep)

number of actual dimensions:  78
SPLADE BOW rep:
 [('glass', 14.96), ('stress', 14.23), ('cause', 9.77), ('thermal', 6.1), ('stressed', 5.82), ('window', 5.46), ('glasses', 4.68), ('crack', 3.99), ('happen', 3.44), ('why', 3.42), ('material', 3.34), ('break', 2.36), ('shatter', 2.27), ('meaning', 2.24), ('materials', 1.98), ('heat', 1.95), ('caused', 1.74), ('do', 1.49), ('pan', 1.43), ('when', 1.39), ('strike', 1.37), ('too', 1.2), ('create', 1.18), ('it', 1.12), ('temperature', 1.09), ('created', 1.05), ('collapse', 1.04), ('generated', 1.02), ('result', 1.0), ('hot', 0.99), ('area', 0.97), ('formed', 0.92), ('fracture', 0.82), ('later', 0.8), ('factor', 0.7), ('produced', 0.69), ('hotter', 0.67), ('adjacent', 0.67), ('cooler', 0.65), ('occur', 0.65), ('determined', 0.64), ('because', 0.64), ('level', 0.63), ('difference', 0.63), ('if', 0.56), ('and', 0.53), ('than', 0.51), ('one', 0.51), ('factors', 0.51), ('pain', 0.49), ('problem', 0.49), ('related', 0.49), ('form', 0.48), ('gener