# SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking

This notebook gives a minimal example usage of our SPLADE model with beir. 

We will soon provide the weights and details for the distilled model. Here are the BEIR performance of that model (weights/splade_distil_v2)

|      dataset      | NDCG@10 | Recall@100 |
|:-----------------:|:-------:|:----------:|
|      arguana      |  0.479  |   97.23%   |
|   climate-fever   |  0.235  |   52.43%   |
|      DBPedia      |  0.435  |   57.52%   |
|       fever       |  0.786  |   95.14%   |
|        fiqa       |  0.336  |   62.10%   |
|      hotpotqa     |  0.684  |   82.03%   |
|      nfcorpus     |  0.334  |   27.71%   |
|         nq        |  0.521  |   93.05%   |
|       quora       |  0.838  |   98.69%   |
|      scidocs      |  0.158  |   36.43%   |
|      scifact      |  0.693  |   92.03%   |
|     trec-covid    |  0.710  |   54.98%   |
|  webis-touche2020 |  0.364  |   35.39%   |
| Average zero shot |  0.506  |   66.89%   |

## Versions:

* Transformers: 4.2.2
* PyTorch: 1.7.0
* Beir: 0.1.8

In [1]:
from models import Splade, BEIRSpladeModel
from transformers import AutoModelForMaskedLM, AutoTokenizer


In [2]:
# set the dir for trained weights 
model_type_or_dir = "weights/splade_distil_v2"
# model_type_or_dir = "weights/flops_best"

In [3]:
# loading model and tokenizer

model = Splade(model_type_or_dir)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_type_or_dir)
beir_splade = BEIRSpladeModel(model,tokenizer)

In [4]:
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir import util, LoggingHandler

dataset = "nfcorpus"

url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
out_dir = "dataset".format(dataset)
data_path = util.download_and_unzip(url, out_dir)


#### Provide the data path where nfcorpus has been downloaded and unzipped to the data loader
# data folder would contain these files: 
# (1) nfcorpus/corpus.jsonl  (format: jsonlines)
# (2) nfcorpus/queries.jsonl (format: jsonlines)
# (3) nfcorpus/qrels/test.tsv (format: tsv ("\t"))

corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")



In [5]:
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from beir.retrieval.evaluation import EvaluateRetrieval

dres = DRES(beir_splade)

retriever = EvaluateRetrieval(dres, score_function="dot") # or "dot" for dot-product
results = retriever.retrieve(corpus, queries)
ndcg, map_, recall, p = EvaluateRetrieval.evaluate(qrels, results, [1,10,100,1000]) 
results2 = EvaluateRetrieval.evaluate_custom(qrels, results, [1,10,100,1000], metric = "r_cap")   
res = {
    "NDCG@10":ndcg["NDCG@10"],
    "Recall@100": recall["Recall@100"],
    "R_cap@100": results2["R_cap@100"]
}
print(res,flush=True)


Batches:   0%|          | 0/11 [00:00<?, ?it/s]

Batches:   0%|          | 0/114 [00:00<?, ?it/s]

{'NDCG@10': 0.33409, 'Recall@100': 0.27705, 'R_cap@100': 0.29269}
