# SPLADE on MSMARCO v1 Passage Corpus using PyTerrier.

This notebook demonstrates the creation of a SPLADE index using PyTerrier.

## Installation

Installation is using PIP.

In [None]:
# we use a github version of PyTerrier
%pip install --force-reinstall --upgrade git+https://github.com/terrier-org/pyterrier.git#egg=python-terrier

# we also install SPLADE and the PyTerrier repo
%pip install -q git+https://github.com/naver/splade.git
%pip install -q git+https://github.com/cmacdonald/pyt_splade.git

In [1]:
import pyterrier as pt

# we use the github version of Terrier too
pt.init(tqdm='notebook', version='snapshot')

Downloading terrier-assemblies 5.x-SNAPSHOT jar-with-dependencies to /home/me/.pyterrier...
Done


PyTerrier 0.8.1 has loaded Terrier 5.6 (built by jitpack on 2022-07-30 20:16)

No etc/terrier.properties, using terrier.default.properties for bootstrap configuration.


## SPLADE setup

We create a factory object that gives us access to the appropriate transformers to use SPLADE. 

In [2]:
import pyt_splade
factory = pyt_splade.SpladeFactory()
doc_encoder = factory.indexing()

## Indexing demonstration

Lets see what terms are generated by the SPLADE model during indexing.

In [4]:
df = (doc_encoder >> pyt_splade.toks2doc()).transform_iter([{'docno' : 'd1', 'text' : 'ww2'}])
df.iloc[0].text

'2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w w was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was was were were were were were were war war war war war war war war war war war war war war war war war war war war war war war war war war war war war war

## Indexing MSMARCO

Lets go and create an index for the MSMARCO v1 passage corpus. 

In [6]:
# this will provide access to the corpus
dataset = pt.get_dataset('irds:msmarco-passage')

This is the actual indexing code. We use the SPLADE model to transform the passages into tokens and weights. We then finally regenerate documents using the `toks2doc()` transformer.

Indexing the 8.8M passages of MSMARCO using a GeForce RTX 3090 took 10 hours. `batch_size=128` worked for a 24GB GPU. If you have a GPU with more RAM, you could probably increase the batch size.

In [None]:
indexer = pt.IterDictIndexer('./msmarco_psg', overwrite=True)
indexer.setProperty("termpipelines", "")
indexer.setProperty("tokeniser", "WhitespaceTokeniser")

indxr_pipe = (doc_encoder >> pyt_splade.toks2doc() >> indexer)
index_ref = indxr_pipe.index(dataset.get_corpus_iter(), batch_size=128)

## Retrieval

We can now conduct retrieval using PyTerrier.

In [3]:
br = pt.BatchRetrieve('./msmarco_psg', wmodel='Tf', verbose=True)

# query_splade is our query encoder. 
query_splade = factory.query()

retr_pipe = query_splade >> br

Lets check retrieval works, and we can see the generated query.

The output is in Terrier's matchop query language (which is very similar to the Indri QL). This allows us to:

1. Set weights on query terms, using `#combine()`.
2. Prevent terms being tokenised by Terrier's normal (end-user-facing) query parser.
3. Tokens such as '##2' encoded to `#base64()`. 

In [5]:
retr_pipe.search('chemical reactions').iloc[0]['query']

BR(Tf):   0%|          | 0/1 [00:00<?, ?q/s]

'#combine:0=38.44116926193237(science) #combine:0=60.64908504486084(process) #combine:0=44.81603503227234(#base64(IyNhdGlvbg==)) #combine:0=5.139094591140747(#base64(IyN0aW9u)) #combine:0=3.4552380442619324(adam) #combine:0=37.64009773731232(brian) #combine:0=257.92694091796875(reaction) #combine:0=233.56268405914307(chemical) #combine:0=9.495500475168228(element) #combine:0=185.97100973129272(chemistry) #combine:0=4.767346754670143(owen) #combine:0=55.40198087692261(mechanism) #combine:0=60.75552701950073(experiment) #combine:0=35.60888469219208(equation) #combine:0=23.633737862110138(hammer) #combine:0=21.45126610994339(enzyme) #combine:0=269.3887948989868(reactions) #combine:0=6.886265426874161(synthesis) #combine:0=2.412080205976963(darwin) #combine:0=6.337817758321762(spark) #combine:0=182.8196883201599(chemicals)'

Finally, lets run the experiment and see the resulting performance. 

In [6]:
from pyterrier.measures import *
pt.Experiment(
    [retr_pipe],
    pt.get_dataset('msmarco_passage').get_topics('test-2019'),
    pt.get_dataset('msmarco_passage').get_qrels('test-2019'),
    batch_size=200,
    filter_by_qrels=True,
    eval_metrics=[RR(rel=2), nDCG@10, nDCG@100, AP(rel=2)],
    names=['splade']
)        

BR(Tf):   0%|          | 0/43 [00:00<?, ?q/s]

Unnamed: 0,name,RR(rel=2),nDCG@10,nDCG@100,AP(rel=2)
0,splade,0.918605,0.730178,0.671292,0.503885


## Exploring the Index

In [9]:
index = pt.cast("org.terrier.querying.LocalManager", br.manager).index

Lets explore the lexicon - what tokens were used? 

In [10]:
for me in index.getLexicon():
    print(me.getKey() + " " + me.getValue().toString())

! term8764 Nt=35502 TF=1143953 maxTF=2147483647 @{0 0 0} TFf=1143953
" term907 Nt=860874 TF=35394676 maxTF=2147483647 @{0 340959 0} TFf=35394676
# term5228 Nt=69437 TF=3961583 maxTF=2147483647 @{0 9749193 6} TFf=3961583
##0 term5251 Nt=68105 TF=4323510 maxTF=2147483647 @{0 10840565 4} TFf=4323510
##00 term19503 Nt=14646 TF=971558 maxTF=2147483647 @{0 12015450 6} TFf=971558
##01 term12374 Nt=7844 TF=535424 maxTF=2147483647 @{0 12285070 0} TFf=535424
##0s term26588 Nt=390 TF=27045 maxTF=2147483647 @{0 12434922 6} TFf=27045
##1 term5496 Nt=105063 TF=5872766 maxTF=2147483647 @{0 12442861 0} TFf=5872766
##10 term17384 Nt=21134 TF=1573542 maxTF=2147483647 @{0 14041515 0} TFf=1573542
##100 term12376 Nt=13623 TF=806120 maxTF=2147483647 @{0 14471838 4} TFf=806120
##11 term9503 Nt=17088 TF=1082035 maxTF=2147483647 @{0 14698320 2} TFf=1082035
##12 term8679 Nt=8391 TF=695027 maxTF=2147483647 @{0 14999209 2} TFf=695027
##13 term12411 Nt=13611 TF=714057 maxTF=2147483647 @{0 15189831 4} TFf=714057
##

In [11]:
print(index.getCollectionStatistics().toString())

Number of documents: 8841823
Number of terms: 28679
Number of postings: 1037528393
Number of fields: 1
Number of tokens: 51850433834
Field names: [text]
Positions:   false



We can even look into particular document in the index. 

In [12]:
di = index.getDirectIndex()
doi = index.getDocumentIndex()
lex = index.getLexicon()
docid = 7700000 #docids are 0-based
#NB: postings will be null if the document is empty
dictrep = {}
for posting in di.getPostings(doi.getDocumentEntry(docid)):
    termid = posting.getId()
    lee = lex.getLexiconEntry(termid)
    dictrep[lee.getKey()] = posting.getFrequency()

for k in sorted(dictrep.keys()):
    print(k, dictrep[k])

" 13
##uation 37
000 59
25 43
30 70
35 72
40 35
accountant 39
accounting 70
advice 11
amount 71
amounts 88
applicants 12
ask 108
asked 86
asking 33
asks 6
assessment 19
average 38
bargaining 48
bart 20
bottom 94
briggs 21
burke 10
business 13
businesses 28
calculate 62
calculated 2
candidacy 30
candidate 110
candidates 115
chart 26
companies 43
company 15
considered 36
corporate 10
dave 17
davis 5
desk 7
diversity 2
employee 95
employees 17
employment 40
engineer 20
example 107
examples 98
excel 14
executive 5
finance 43
fisher 8
flat 134
flex 50
flexibility 129
flexible 92
gage 6
give 97
given 17
giving 29
highest 1
hr 62
improvisation 9
include 57
included 28
income 15
interview 19
job 80
jobs 33
letter 9
low 76
management 4
marketing 14
matching 23
math 56
max 44
maximum 84
median 101
mid 101
middle 32
money 31
murray 21
negotiate 98
negotiating 49
negotiation 82
negotiations 59
normal 13
numbers 9
often 19
pay 139
payroll 7
point 59
points 50
post 50
posted 72
posting 95
practical 