In [2]:
import json, torch
import pandas as pd
from pyserini.search.lucene import LuceneImpactSearcher
from pyserini.pyclass import JFloat, JInt, JHashMap
from scipy.sparse import csr_matrix, vstack, save_npz, load_npz
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

SPLADE_VOCAB = tokenizer.vocab
MAX_LENGTH = 256
VOCAB_LENGTH = len(SPLADE_VOCAB)

In [4]:
SPLADE_VOCAB["health"]

2740

In [36]:
def create_sparse_matrix(vector, token_indices):
    row_col_val = [[v, SPLADE_VOCAB[k], vector[k]] for k,v in token_indices.items()]
    row = [item[0] for item in row_col_val]
    col = [item[1] for item in row_col_val]
    val = [item[2] for item in row_col_val]

    sparse_matrix = csr_matrix((val, (row, col)), shape=(MAX_LENGTH, VOCAB_LENGTH))
    return sparse_matrix


def compute_score(query_line, doc_line, tokenized_query = None, tokenized_doc = None):
    query_vector = query_line["vector"]
    query_token_indices = query_line["token_indices"]

    doc_vector = doc_line["vector"]
    doc_token_indices = doc_line["token_indices"]

    query_full_rep = create_sparse_matrix(query_vector, query_token_indices)
    doc_full_rep = create_sparse_matrix(doc_vector, doc_token_indices).transpose()
    
    max_scores = (query_full_rep @ doc_full_rep).todense()
    score = max_scores.max(1).sum()

    max_scores = torch.tensor(max_scores)
    values, indices = torch.max(max_scores, dim = 1)
    
    if tokenized_doc and not tokenized_query:
        for v, i in zip(values, indices):
            if not v > 0: continue
            print(tokenized_doc[i], v)

    if tokenized_doc and tokenized_query:
        for j, (v, i) in enumerate(zip(values, indices)):
            if not v > 0: continue
            if j == len(tokenized_query): break
            print(tokenized_query[j], "->", tokenized_doc[i], v)

    return score


def dot_product(dict1, dict2):
    # Calculate the dot product of two dictionaries
    print([[k, dict1[k] * dict2[k]] for k in dict1 if k in dict2])
    return sum(dict1[key] * dict2.get(key, 0) for key in dict1)

def create_jquery(encoded_query, searcher, fields = {}):
    jfields = JHashMap()
    for (field, boost) in fields.items():
        jfields.put(field, JFloat(boost))

    jquery = JHashMap()
    for (token, weight) in encoded_query.items():
        # if token in searcher.idf and searcher.idf[token] >= searcher.min_idf:
        jquery.put(token, JInt(weight))

    return jquery

In [19]:
splade_maxsim_metadata_path = "/home/lamdo/splade/pyserini_evaluation/metadata/trec_dl_2019__splade_maxsim_150k_lowregv3"
normal_splade_metadata_path = "/home/lamdo/splade/pyserini_evaluation/metadata/trec_dl_2019__splade_normal_150k_lowreg"

qrel_path = "/home/lamdo/splade/data/msmarco/trec_dl_2019/qrels/test.tsv"
queries_path = "/home/lamdo/splade/data/msmarco/trec_dl_2019/queries.jsonl"

In [7]:
splade_maxsim_searcher = LuceneImpactSearcher("/scratch/lamdo/beir_splade/indexes/msmarco__splade_maxsim_150k_lowregv3", query_encoder=None)
normal_splade_searcher = LuceneImpactSearcher("/scratch/lamdo/beir_splade/indexes/msmarco__splade_normal_150k_lowreg", query_encoder=None)

Apr 09, 2025 8:30:37 PM org.apache.lucene.store.MemorySegmentIndexInputProvider <init>
INFO: Using MemorySegmentIndexInput with Java 21; to disable start with -Dorg.apache.lucene.store.MMapDirectory.enableMemorySegments=false


In [8]:
with open(splade_maxsim_metadata_path) as f:
    splade_maxsim_metadata = json.load(f)

with open(normal_splade_metadata_path) as f:
    normal_splade_metadata = json.load(f)

In [21]:
queries = {}
with open(queries_path) as f:
    for line in f:
        jline = json.loads(line)
        id = jline["_id"]
        query_text = jline["text"]
        queries[id] = query_text

In [22]:
_qrels = pd.read_csv(qrel_path, sep='\t').to_dict("records")
    
qrel_metadata = {}
for line in _qrels:
    query_id = str(line["query-id"])
    doc_id = line["corpus-id"]
    score = line["score"]

    if query_id not in qrel_metadata:
        qrel_metadata[query_id] = []

    qrel_metadata[query_id].append({
        "docid": doc_id,
        "score": score
    })

In [23]:
# normal_splade_metadata["encoded_queries"][1], splade_maxsim_metadata["encoded_queries"][1]
chosen_index = 5

splade_maxsim_query_line = splade_maxsim_metadata["encoded_queries"][chosen_index]
normal_splade_query_line = normal_splade_metadata["encoded_queries"][chosen_index]

splade_maxsim_query_rep = splade_maxsim_query_line["vector"]
normal_splade_query_rep = normal_splade_query_line["vector"]

query_id = normal_splade_metadata["encoded_queries"][chosen_index]["query_id"]

print("QueryId", query_id, queries[query_id])
print({k: v for k, v in sorted(splade_maxsim_query_rep.items(), key=lambda item: -item[1])})
print()
print({k: v for k, v in sorted(normal_splade_query_rep.items(), key=lambda item: -item[1])})

QueryId 573724 what are the social determinants of health
{'social': 264, 'health': 243, 'deter': 218, '##mina': 174, 'sick': 133, 'are': 93, 'of': 84, '##nt': 40, 'help': 26}

{'social': 241, 'health': 223, 'deter': 185, '##mina': 166, '##nts': 144, 'are': 93, '##nt': 84, 'meaning': 81, 'of': 77, 'the': 49, 'cause': 49, 'quality': 46, 'socially': 46, 'medical': 29, 'symptoms': 29, 'emissions': 20, 'help': 17, 'disease': 14, '-': 12, '.': 4, '##ents': 2}


In [11]:
chosen_query_id = query_id
doc_ids = [str(line['docid']) for line in qrel_metadata[chosen_query_id] if line["score"]]

In [12]:
min_normal, max_normal = min(normal_splade_metadata["predictions"][chosen_query_id].values()), max(normal_splade_metadata["predictions"][chosen_query_id].values())
min_maxsim, max_maxsim = min(splade_maxsim_metadata["predictions"][chosen_query_id].values()), max(splade_maxsim_metadata["predictions"][chosen_query_id].values())

print("score range normal", min_normal, max_normal)
print("score range maxsim", min_maxsim, max_maxsim)
for docid in doc_ids:
    print(docid)
    print("normal splade", normal_splade_metadata["predictions"][chosen_query_id].get(docid, 0))
    print("maxsim splade", splade_maxsim_metadata["predictions"][chosen_query_id].get(docid, 0))

    print()

score range normal 96934.0 243662.0
score range maxsim 167151.0 242481.0
1005338
normal splade 160164.0
maxsim splade 171333.0

104856
normal splade 182289.0
maxsim splade 179668.0

1165129
normal splade 240139.0
maxsim splade 241362.0

1509940
normal splade 225304.0
maxsim splade 221122.0

1509942
normal splade 227229.0
maxsim splade 230671.0

1816038
normal splade 223807.0
maxsim splade 223254.0

2002448
normal splade 171204.0
maxsim splade 182627.0

2021252
normal splade 105352.0
maxsim splade 0

2133806
normal splade 195449.0
maxsim splade 212634.0

2819390
normal splade 0
maxsim splade 0

3059586
normal splade 0
maxsim splade 0

3059589
normal splade 101102.0
maxsim splade 0

3146238
normal splade 237371.0
maxsim splade 237352.0

3365631
normal splade 195982.0
maxsim splade 201754.0

3365633
normal splade 224299.0
maxsim splade 233094.0

3365635
normal splade 178652.0
maxsim splade 193600.0

3365636
normal splade 221929.0
maxsim splade 225223.0

3365638
normal splade 234004.0
maxs

In [52]:
chosen_doc_id = "7589690"
raw = json.loads(splade_maxsim_searcher.doc(chosen_doc_id).lucene_document().get("raw"))
chosen_doc = f"{raw['title']} | {raw['text']}"
print(chosen_doc)

splade_maxsim_doc_line = json.loads(splade_maxsim_searcher.doc(chosen_doc_id).lucene_document().get("raw"))
normal_splade_doc_line = json.loads(normal_splade_searcher.doc(chosen_doc_id).lucene_document().get("raw"))

splade_maxsim_doc_rep = splade_maxsim_doc_line["vector"]
normal_splade_doc_rep = normal_splade_doc_line["vector"]

print({k: v for k, v in sorted(splade_maxsim_doc_rep.items(), key=lambda item: -item[1])})
print()
print({k: v for k, v in sorted(normal_splade_doc_rep.items(), key=lambda item: -item[1])})

 | This raises concerns of worsening rural health disparities in the future. For additional information about the causes of health disparities in rural areas, see RHIhub's topic guide: Social Determinants of Health for Rural People.
{'rural': 232, '##spar': 210, 'di': 185, 'health': 156, 'future': 156, 'worse': 148, '##hi': 147, '##hu': 147, 'deter': 143, 'social': 138, 'r': 137, 'concern': 135, 'guide': 132, '##mina': 130, 'countryside': 125, 'agricultural': 122, 'topic': 121, 'sick': 107, '##b': 105, '##ity': 94, 'urban': 93, 'area': 91, 'people': 83, 'farm': 83, 'suburban': 74, '##ance': 70, 'likely': 66, 'cause': 64, 'past': 61, 'local': 50, 'illness': 50, 'harmful': 49, 'theme': 45, '##ning': 42, '##ities': 42, 'regarding': 41, 'case': 39, 'person': 39, 'source': 38, 'subject': 36, '|': 35, 'potential': 34, '##ting': 32, 'worry': 32, 'affect': 32, 'context': 31, 'bad': 29, 'situation': 28, 'against': 26, 'problem': 26, 'should': 24, 'poverty': 22, 'concerning': 19, 'upcoming': 18,

In [53]:
compute_score(
    splade_maxsim_query_line, 
    splade_maxsim_doc_line, 
    tokenized_doc = tokenizer.tokenize(chosen_doc, add_special_tokens=True), 
    tokenized_query = tokenizer.tokenize(queries[query_id], add_special_tokens=True))

[CLS] -> [CLS] tensor(14231)
social -> social tensor(36432)
deter -> deter tensor(31174)
##mina -> ##mina tensor(22620)
of -> of tensor(1008)
health -> health tensor(37908)


np.int64(143373)

In [51]:
dot_product(normal_splade_query_rep, normal_splade_doc_rep)

[['.', 76], ['of', 4620], ['help', 680], ['social', 47236], ['health', 27206], ['medical', 1363], ['##nt', 1260], ['cause', 5782], ['meaning', 7776], ['disease', 140], ['##nts', 6336], ['symptoms', 1160], ['##ents', 72], ['emissions', 660], ['socially', 3404], ['##mina', 23240], ['deter', 25900]]


156911

In [16]:
splade_maxsim_query_line

{'query_id': '573724',
 'vector': {'of': 84,
  'are': 93,
  'help': 26,
  'social': 264,
  'health': 243,
  '##nt': 40,
  'sick': 133,
  '##mina': 174,
  'deter': 218},
 'vector_onehot': {},
 'token_indices': {'of': 8,
  'are': 2,
  'help': 0,
  'social': 4,
  'health': 9,
  '##nt': 7,
  'sick': 0,
  '##mina': 6,
  'deter': 5},
 'token_indices_onehot': None,
 'pad_len': 15}

In [39]:
splade_maxsim_metadata["predictions"][query_id]

{'394142': 242481.0,
 '1165129': 241362.0,
 '5721829': 239515.0,
 '87405': 238580.0,
 '8833192': 237809.0,
 '3146238': 237352.0,
 '5339620': 237099.0,
 '3365638': 236113.0,
 '7674859': 234263.0,
 '8833194': 234128.0,
 '162151': 233951.0,
 '3365633': 233094.0,
 '394136': 232956.0,
 '6776717': 231955.0,
 '7199669': 231252.0,
 '1509942': 230671.0,
 '7199673': 228498.0,
 '7104825': 226558.0,
 '8441642': 226024.0,
 '7609928': 225807.0,
 '3365636': 225223.0,
 '394138': 224943.0,
 '2635262': 224865.0,
 '5310530': 224755.0,
 '1816038': 223254.0,
 '5962423': 222843.0,
 '1509940': 221122.0,
 '3617636': 220985.0,
 '7384332': 220308.0,
 '394139': 216002.0,
 '8444728': 215453.0,
 '5185671': 213362.0,
 '5047796': 212871.0,
 '2133806': 212634.0,
 '7199675': 212358.0,
 '7104827': 210060.0,
 '7964601': 206690.0,
 '5721831': 206515.0,
 '2417985': 206025.0,
 '7609927': 205853.0,
 '8444727': 205747.0,
 '87409': 205461.0,
 '8441637': 205359.0,
 '5962428': 205081.0,
 '4017883': 204410.0,
 '7609926': 203994.

In [31]:
240* 243

58320

In [32]:
splade_maxsim_doc_line["token_indices"]

{',': 37,
 '.': 1,
 'e': 11,
 '|': 1,
 'the': 5,
 'of': 7,
 'in': 10,
 'is': 24,
 'by': 26,
 'but': 2,
 'are': 33,
 'have': 33,
 'also': 42,
 'other': 29,
 '##e': 11,
 'if': 2,
 'like': 32,
 'some': 3,
 'do': 42,
 'through': 26,
 'most': 3,
 'while': 2,
 'such': 32,
 'part': 3,
 '...': 1,
 'work': 32,
 'much': 3,
 'too': 42,
 'own': 21,
 'think': 2,
 'government': 33,
 'public': 29,
 'body': 35,
 'say': 2,
 'different': 17,
 'law': 0,
 'help': 40,
 'include': 32,
 'control': 6,
 'education': 34,
 'others': 29,
 'human': 8,
 'role': 6,
 'society': 40,
 'political': 9,
 'important': 32,
 'social': 9,
 'lead': 15,
 'outside': 21,
 'person': 0,
 'rest': 20,
 'care': 30,
 'health': 23,
 'word': 0,
 'else': 21,
 'towards': 7,
 'action': 6,
 'hospital': 18,
 'changed': 25,
 'bad': 24,
 'involved': 25,
 'medical': 18,
 '##ity': 13,
 'attention': 13,
 'reason': 14,
 'covered': 25,
 'related': 25,
 'active': 32,
 'economic': 9,
 'property': 6,
 'significant': 3,
 'understand': 4,
 'particular': 

In [33]:
tokenizer.tokenize(chosen_doc, add_special_tokens=True)[23]

'is'