# SPLADE on MSMARCO v1 Passage Corpus using PyTerrier

This notebook demonstrates the creation of a SPLADE index using PyTerrier.

## Installation

Install using pip:

In [1]:
!pip install git+https://github.com/tonellotto/pyt_splade@naverless-branch

Collecting git+https://github.com/tonellotto/pyt_splade@naverless-branch
  Cloning https://github.com/tonellotto/pyt_splade (to revision naverless-branch) to /tmp/pip-req-build-zcd3emia
  Running command git clone --filter=blob:none --quiet https://github.com/tonellotto/pyt_splade /tmp/pip-req-build-zcd3emia
  Running command git checkout -b naverless-branch --track origin/naverless-branch
  Switched to a new branch 'naverless-branch'
  Branch 'naverless-branch' set up to track remote branch 'naverless-branch' from 'origin'.
  Resolved https://github.com/tonellotto/pyt_splade to commit 77a2ab7964e1c3297c5412bc60e3c34c87b1faae
  Preparing metadata (setup.py) ... [?25ldone


## Setup

We create a factory object `splade` that gives us access to the appropriate transformers to use SPLADE.

In [1]:
import pyterrier as pt
import pyt_splade

splade = pyt_splade.Splade(device='cuda:1')
doc_encoder = splade.doc_encoder()

## Indexing demo

Lets see what terms are generated by the SPLADE model during indexing.

In [2]:
df = doc_encoder([{'docno' : 'd1', 'text' : 'ww2'}])
df[0]['toks']

{'w': 199,
 '##2': 193,
 'war': 167,
 'wwii': 150,
 '##w': 130,
 'ii': 110,
 '2': 94,
 'germany': 86,
 'army': 76,
 'battle': 70,
 'was': 66,
 'bomb': 48,
 'event': 43,
 'wilson': 43,
 'conflict': 38,
 'marshall': 33,
 'allied': 23,
 'surrender': 22,
 'peace': 16,
 'military': 12,
 'era': 10,
 'alliance': 10,
 'weapon': 10,
 'wars': 8,
 'camp': 7,
 'were': 6,
 'france': 6,
 'invasion': 6,
 'nazi': 4,
 'zombie': 2,
 'german': 1,
 'japan': 1,
 'patton': 1}

## Indexing MSMARCO

Lets go and create an index for the MSMARCO v1 passage corpus. The following will provide access to the dataset:

In [3]:
dataset = pt.get_dataset('irds:msmarco-passage')

This is the actual indexing code. We use the SPLADE model to transform the passages into tokens and weights. It took around 4 hours to run on a RTX 4090.

In [4]:
import os

if not os.path.exists('./msmarco_psg'): # skip if already created
    indexer = pt.IterDictIndexer('./msmarco_psg', pretokenised=True)
    indexer.setProperty("termpipelines", "")
    indexer.setProperty("tokeniser", "WhitespaceTokeniser")

    indexer_pipe = doc_encoder >> indexer
    index_ref = indexer_pipe.index(dataset.get_corpus_iter())

Java started (triggered by TerrierIndexer.__init__) and loaded: pyterrier.java, pyterrier.terrier.java [version=5.11 (build: craig.macdonald 2025-01-13 21:29), helper_version=0.0.8]
msmarco-passage documents: 100%|████████████████████████████████████████████████████████████████████████████████████████| 8841823/8841823 [3:57:10<00:00, 621.34it/s]


## Retrieval

We can now conduct retrieval using PyTerrier.

In [5]:
retr = pt.terrier.Retriever('./msmarco_psg', wmodel='Tf', verbose=True)

retr_pipe = splade.query_encoder() >> retr

Let check retrieval works, and we can see the generated query.

In [6]:
retr_pipe.search('chemical reactions')

TerrierRetr(Tf): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.48s/q]


Unnamed: 0,qid,docid,docno,rank,score,query,query_toks
0,1,758284,758284,0,759.949764,chemical reactions,"{'reactions': 269.4256896972656, 'reaction': 2..."
1,1,5913794,5913794,1,758.851035,chemical reactions,"{'reactions': 269.4256896972656, 'reaction': 2..."
2,1,742206,742206,2,757.623774,chemical reactions,"{'reactions': 269.4256896972656, 'reaction': 2..."
3,1,8572191,8572191,3,750.944456,chemical reactions,"{'reactions': 269.4256896972656, 'reaction': 2..."
4,1,129901,129901,4,748.763385,chemical reactions,"{'reactions': 269.4256896972656, 'reaction': 2..."
...,...,...,...,...,...,...,...
995,1,5094605,5094605,995,543.317372,chemical reactions,"{'reactions': 269.4256896972656, 'reaction': 2..."
996,1,2226467,2226467,996,543.290851,chemical reactions,"{'reactions': 269.4256896972656, 'reaction': 2..."
997,1,7866969,7866969,997,543.282898,chemical reactions,"{'reactions': 269.4256896972656, 'reaction': 2..."
998,1,3732415,3732415,998,543.242431,chemical reactions,"{'reactions': 269.4256896972656, 'reaction': 2..."


Finally, lets run the experiment and see the resulting performance.

In [7]:
from pyterrier.measures import *

pt.Experiment(
    [retr_pipe],
    pt.get_dataset('irds:msmarco-passage/trec-dl-2019/judged').get_topics(),
    pt.get_dataset('irds:msmarco-passage/trec-dl-2019/judged').get_qrels(),
    eval_metrics=[RR(rel=2), nDCG@10, nDCG@100, AP(rel=2)],
    names=['splade']
)        

TerrierRetr(Tf): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 43/43 [00:43<00:00,  1.01s/q]


Unnamed: 0,name,RR(rel=2),nDCG@10,nDCG@100,AP(rel=2)
0,splade,0.918605,0.731091,0.672569,0.504771


## Exploring the Index

In [8]:
index = pt.java.cast("org.terrier.querying.LocalManager", retr.manager).index

Lets explore the lexicon - what tokens were used? (First 100)

In [9]:
for i, entry in enumerate(index.getLexicon()):
    if i == 100:
        break
    print(entry.getKey() + " " + entry.getValue().toString())

! term8768 Nt=35737 TF=1147883 maxTF=2147483647 @{0 0 0}
" term908 Nt=861221 TF=35398179 maxTF=2147483647 @{0 194287 4}
# term5228 Nt=69467 TF=3962679 maxTF=2147483647 @{0 5071126 4}
##0 term5242 Nt=68206 TF=4326264 maxTF=2147483647 @{0 5658804 6}
##00 term19501 Nt=14675 TF=972519 maxTF=2147483647 @{0 6285192 0}
##01 term12382 Nt=7860 TF=535730 maxTF=2147483647 @{0 6431703 6}
##0s term26590 Nt=390 TF=27064 maxTF=2147483647 @{0 6513717 2}
##1 term5497 Nt=105146 TF=5876691 maxTF=2147483647 @{0 6518228 4}
##10 term17384 Nt=21166 TF=1574624 maxTF=2147483647 @{0 7370202 5}
##100 term12383 Nt=13688 TF=807672 maxTF=2147483647 @{0 7601374 3}
##11 term9506 Nt=17113 TF=1083192 maxTF=2147483647 @{0 7725680 3}
##12 term8684 Nt=8396 TF=695098 maxTF=2147483647 @{0 7889352 4}
##13 term12419 Nt=13620 TF=714544 maxTF=2147483647 @{0 7992063 4}
##14 term17856 Nt=6107 TF=427526 maxTF=2147483647 @{0 8105092 2}
##15 term13479 Nt=14618 TF=926535 maxTF=2147483647 @{0 8170719 5}
##16 term8683 Nt=6142 TF=441899

In [10]:
print(index.getCollectionStatistics().toString())

Number of documents: 8841823
Number of terms: 28679
Number of postings: 1038252055
Number of fields: 0
Number of tokens: 51858349642
Field names: []
Positions:   false



We can even look into particular document in the index.

In [11]:
di = index.getDirectIndex()
doi = index.getDocumentIndex()
lex = index.getLexicon()
docid = 7700000 #docids are 0-based
#NB: postings will be null if the document is empty
dictrep = {}
for posting in di.getPostings(doi.getDocumentEntry(docid)):
    termid = posting.getId()
    lee = lex.getLexiconEntry(termid)
    dictrep[lee.getKey()] = posting.getFrequency()

for k in sorted(dictrep.keys()):
    print(k, dictrep[k])

" 13
##uation 37
000 59
25 43
30 70
35 72
40 35
accountant 39
accounting 70
advice 11
amount 71
amounts 88
applicants 12
ask 108
asked 86
asking 33
asks 6
assessment 19
average 38
bargaining 48
bart 20
bottom 94
briggs 20
burke 10
business 13
businesses 28
calculate 62
calculated 2
candidacy 30
candidate 110
candidates 115
chart 26
companies 43
company 15
considered 36
corporate 10
dave 17
davis 5
desk 7
difference 1
diversity 2
employee 95
employees 17
employment 40
engineer 20
example 107
examples 98
excel 14
executive 5
finance 43
fisher 8
flat 134
flex 50
flexibility 129
flexible 92
gage 6
give 98
given 17
giving 29
highest 1
hr 62
improvisation 9
include 57
included 28
income 15
interview 19
job 80
jobs 33
kelly 1
letter 9
low 76
management 4
marketing 14
matching 23
math 56
max 44
maximum 84
median 101
mid 102
middle 32
money 32
murray 21
negotiate 98
negotiating 49
negotiation 82
negotiations 59
normal 13
numbers 9
often 19
pay 139
payroll 7
point 59
points 50
post 50
posted 72
