# SPLADE on TREC COVID Corpus using PyTerrier

This notebook demonstrates the creation of a SPLADE index using PyTerrier.

## Installation

Install using pip:

In [None]:
!pip install -q git+https://github.com/tonellotto/pyt_splade@naverless-branch

## Setup

We create a factory object `splade` that gives us access to the appropriate transformers to use SPLADE.

In [None]:
import pyterrier as pt
import pyt_splade

splade = pyt_splade.Splade(device='cuda:0')
doc_encoder = splade.doc_encoder()

## Indexing demo

Lets see what terms are generated by the SPLADE model during indexing.

In [None]:
df = doc_encoder([{'docno' : 'd1', 'text' : 'ww2'}])
df[0]['toks']

## Indexing TREC COVID

Lets go and create an index for the TREC COVID corpus. The following will provide access to the dataset:

In [None]:
dataset = pt.get_dataset('irds:beir/trec-covid')

This is the actual indexing code. We use the SPLADE model to transform the passages into tokens and weights. The following code took approx. 1 hour to run on Google Colab.

In [None]:
import os

if not os.path.exists('./trec_covid'): # skip if already created
    indexer = pt.IterDictIndexer('./trec_covid', pretokenised=True)
    indexer.setProperty("termpipelines", "")
    indexer.setProperty("tokeniser", "WhitespaceTokeniser")

    indexer_pipe = doc_encoder >> indexer
    index_ref = indexer_pipe.index(dataset.get_corpus_iter())

## Retrieval

We can now conduct retrieval using PyTerrier.

In [None]:
retr = pt.terrier.Retriever('./trec_covid', wmodel='Tf', verbose=True)

retr_pipe = splade.query_encoder() >> retr

Lets check retrieval works, and we can see the generated query.

In [None]:
retr_pipe.search('chemical reactions')

Finally, lets run the experiment and see the resulting performance.

In [None]:
from pyterrier.measures import *
pt.Experiment(
    [retr_pipe],
    dataset.get_topics(),
    dataset.get_qrels(),
    eval_metrics=[RR(rel=2), nDCG@10, nDCG@100, AP(rel=2)],
    names=['splade']
)

## Exploring the Index

In [None]:
index = pt.java.cast("org.terrier.querying.LocalManager", retr.manager).index

Lets explore the lexicon - what tokens were used? (First 100)

In [None]:
for i, entry in enumerate(index.getLexicon()):
    if i == 100:
        break
    print(entry.getKey() + " " + entry.getValue().toString())

In [None]:
print(index.getCollectionStatistics().toString())

We can even look into particular document in the index.

In [None]:
di = index.getDirectIndex()
doi = index.getDocumentIndex()
lex = index.getLexicon()
docid = 77_000 #docids are 0-based
#NB: postings will be null if the document is empty
dictrep = {}
for posting in di.getPostings(doi.getDocumentEntry(docid)):
    termid = posting.getId()
    lee = lex.getLexiconEntry(termid)
    dictrep[lee.getKey()] = posting.getFrequency()

for k in sorted(dictrep.keys()):
    print(k, dictrep[k])