# SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking

This notebook gives a minimal example usage of SPLADE.

* We provide models via Hugging Face (https://huggingface.co/naver)
* See [Naver Labs Europe website](https://europe.naverlabs.com/research/machine-learning-and-optimization/splade-models/) for other intermediate models.

| model | MRR@10 (MS MARCO dev) | recall@1000 (MS MARCO dev) | expected FLOPS | ~ avg q length | ~ avg d length |
| --- | --- | --- | --- | --- | --- |
| `naver/splade_v2_max` (**v2** [HF](https://huggingface.co/naver/splade_v2_max)) | 34.0 | 96.5 | 1.32 | 18 | 92 |
| `naver/splade_v2_distil` (**v2** [HF](https://huggingface.co/naver/splade_v2_distil)) | 36.8 | 97.9 | 3.82 | 25 | 232 |
| `naver/splade-cocondenser-selfdistil` (**v2bis**, [HF](https://huggingface.co/naver/splade-cocondenser-selfdistil))| 37.6 | 98.4 | 2.32 | 56 | 134 |
| `naver/splade-cocondenser-ensembledistil` (**v2bis**, [HF](https://huggingface.co/naver/splade-cocondenser-ensembledistil)) | 38.3 | 98.3  | 1.85 | 44 | 120 |

In [1]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

Mounted at /content/gdrive


In [2]:
!nvidia-smi

Sun Jun 18 22:16:52 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-SXM...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   31C    P0    46W / 400W |      0MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
%%shell
pip install pytrec_eval -q
pip install transformers -q
pip install pyserini -q
pip install torch -q
pip install datasets -q
pip install evaluate -q
pip install trectools -q
pip install faiss-cpu -q
pip install sentence-transformers -q
pip install git+https://github.com/naver/splade.git -q

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for pytrec_eval (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m102.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m236.8/236.8 kB[0m [31m33.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m111.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m89.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.6/485.6 kB[0m [31m21.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m18.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.5/212.5 kB[0m [31m27.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━



In [67]:
%%shell
apt-get install maven -qq
git clone --recurse-submodules https://github.com/castorini/pyserini.git
cd pyserini
cd tools/eval && tar xvfz trec_eval.9.0.4.tar.gz && cd trec_eval.9.0.4 && make && cd ../../..
cd tools/eval/ndeval && make && cd ../../..

Extracting templates from packages: 100%
Selecting previously unselected package libapache-pom-java.
(Reading database ... 123069 files and directories currently installed.)
Preparing to unpack .../00-libapache-pom-java_18-1_all.deb ...
Unpacking libapache-pom-java (18-1) ...
Selecting previously unselected package libatinject-jsr330-api-java.
Preparing to unpack .../01-libatinject-jsr330-api-java_1.0+ds1-5_all.deb ...
Unpacking libatinject-jsr330-api-java (1.0+ds1-5) ...
Selecting previously unselected package libgeronimo-interceptor-3.0-spec-java.
Preparing to unpack .../02-libgeronimo-interceptor-3.0-spec-java_1.0.1-4fakesync_all.deb ...
Unpacking libgeronimo-interceptor-3.0-spec-java (1.0.1-4fakesync) ...
Selecting previously unselected package libcdi-api-java.
Preparing to unpack .../03-libcdi-api-java_1.2-2_all.deb ...
Unpacking libcdi-api-java (1.2-2) ...
Selecting previously unselected package libcommons-cli-java.
Preparing to unpack .../04-libcommons-cli-java_1.4-1_all.deb ...



In [68]:
!pip install pyserini -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.1/154.1 MB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m63.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m188.7/188.7 kB[0m [31m22.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.9/5.9 MB[0m [31m90.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for nmslib (setup.py) ... [?25l[?25hdone


In [4]:
user = "leonardo"
if user == "monique":
    main_dir = '/content/gdrive/MyDrive/Unicamp-projeto-final/'
else:
    main_dir = '/content/gdrive/MyDrive/Unicamp/IA368-DD/'

In [5]:
import os
import json
import numpy as np
import pandas as pd
import random
import torch
import collections
import evaluate
import shutil
import pickle
import numba

from collections import defaultdict, Counter
from datasets import load_dataset
from tqdm import tqdm
from operator import itemgetter
from time import time
from torch import nn, optim
from transformers import BatchEncoding, get_linear_schedule_with_warmup, AutoModelForMaskedLM, AutoTokenizer
from splade.models.transformer_rep import Splade

In [6]:
shutil.copyfile(f"{main_dir}Projeto Final/experiments_10m.zip", "/content/experiments.zip")

'/content/experiments.zip'

In [7]:
!unzip /content/experiments.zip
!mv /content/content/splade/experiments /content/experiments

Archive:  /content/experiments.zip
   creating: content/splade/experiments/
   creating: content/splade/experiments/pt/
   creating: content/splade/experiments/pt/checkpoint/
  inflating: content/splade/experiments/pt/checkpoint/training_perf.txt  
   creating: content/splade/experiments/pt/checkpoint/model/
  inflating: content/splade/experiments/pt/checkpoint/model/special_tokens_map.json  
  inflating: content/splade/experiments/pt/checkpoint/model/vocab.txt  
  inflating: content/splade/experiments/pt/checkpoint/model/model.tar  
  inflating: content/splade/experiments/pt/checkpoint/model/pytorch_model.bin  
  inflating: content/splade/experiments/pt/checkpoint/model/tokenizer.json  
  inflating: content/splade/experiments/pt/checkpoint/model/config.json  
  inflating: content/splade/experiments/pt/checkpoint/model/tokenizer_config.json  
   creating: content/splade/experiments/pt/checkpoint/val_full_ranking/
  inflating: content/splade/experiments/pt/checkpoint/val_full_ranking/ru

In [8]:
def restore_model(model, state_dict):
    missing_keys, unexpected_keys = model.load_state_dict(state_dict=state_dict, strict=False)
    # strict = False => it means that we just load the parameters of layers which are present in both and
    # ignores the rest
    if len(missing_keys) > 0:
        print("~~ [WARNING] MISSING KEYS WHILE RESTORING THE MODEL ~~")
        print(missing_keys)
    if len(unexpected_keys) > 0:
        print("~~ [WARNING] UNEXPECTED KEYS WHILE RESTORING THE MODEL ~~")
        print(unexpected_keys)
    print("restoring model:", model.__class__.__name__)

## mRobust

In [9]:
!export PYTORCH_CUDA_ALLOC_CONF=garbage_collection_threshold:0.6
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-11-openjdk-amd64"

In [10]:
random.seed(10)
np.random.seed(10)
torch.manual_seed(10)

<torch._C.Generator at 0x7fbb6e18c430>

In [11]:
collection = load_dataset('unicamp-dl/mrobust', 'collection-portuguese')
queries = load_dataset('unicamp-dl/mrobust', 'queries-portuguese')
!wget https://huggingface.co/datasets/unicamp-dl/mrobust/raw/main/qrels.robust04.txt

Downloading builder script:   0%|          | 0.00/4.03k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/21.2k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/2.27k [00:00<?, ?B/s]

Downloading and preparing dataset mrobust/collection-portuguese (download: 1.78 GiB, generated: 1.79 GiB, post-processed: Unknown size, total: 3.57 GiB) to /root/.cache/huggingface/datasets/unicamp-dl___mrobust/collection-portuguese/1.0.0/a91c748e4ba08987678ac4eebaf08238bf6ee876e4ac74d4b996cbbc7b10014d...


Downloading data:   0%|          | 0.00/1.91G [00:00<?, ?B/s]

Generating collection split:   0%|          | 0/528032 [00:00<?, ? examples/s]

Dataset mrobust downloaded and prepared to /root/.cache/huggingface/datasets/unicamp-dl___mrobust/collection-portuguese/1.0.0/a91c748e4ba08987678ac4eebaf08238bf6ee876e4ac74d4b996cbbc7b10014d. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

Downloading and preparing dataset mrobust/queries-portuguese (download: 27.75 KiB, generated: 29.22 KiB, post-processed: Unknown size, total: 56.98 KiB) to /root/.cache/huggingface/datasets/unicamp-dl___mrobust/queries-portuguese/1.0.0/a91c748e4ba08987678ac4eebaf08238bf6ee876e4ac74d4b996cbbc7b10014d...


Downloading data:   0%|          | 0.00/28.4k [00:00<?, ?B/s]

Generating queries split:   0%|          | 0/250 [00:00<?, ? examples/s]

Dataset mrobust downloaded and prepared to /root/.cache/huggingface/datasets/unicamp-dl___mrobust/queries-portuguese/1.0.0/a91c748e4ba08987678ac4eebaf08238bf6ee876e4ac74d4b996cbbc7b10014d. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

--2023-06-18 22:22:18--  https://huggingface.co/datasets/unicamp-dl/mrobust/raw/main/qrels.robust04.txt
Resolving huggingface.co (huggingface.co)... 18.155.68.116, 18.155.68.121, 18.155.68.44, ...
Connecting to huggingface.co (huggingface.co)|18.155.68.116|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 6543541 (6.2M) [text/plain]
Saving to: ‘qrels.robust04.txt’


2023-06-18 22:22:19 (7.43 MB/s) - ‘qrels.robust04.txt’ saved [6543541/6543541]



In [12]:
corpus = dict(zip(collection["collection"]['id'], collection["collection"]['text']))
topics = dict(zip(queries["queries"]['id'], queries["queries"]['text']))

print(collection)
print(queries)
print(f"corpus length: {len(corpus)}, topics length: {len(topics)}")

DatasetDict({
    collection: Dataset({
        features: ['id', 'text'],
        num_rows: 528032
    })
})
DatasetDict({
    queries: Dataset({
        features: ['id', 'text'],
        num_rows: 250
    })
})
corpus length: 528032, topics length: 250


In [13]:
output_tsv = f'/content/corpus.tsv'
output_tsv = f'/content/queries.tsv'

with open(output_tsv,'w') as f_out:
    for line in tqdm(collection["collection"], desc=f'Writing file on {output_tsv}'):
        f_out.write(f'{line["id"]}\t{line["text"]}\n')

with open(output_tsv,'w') as f_out:
    for line in tqdm(queries["queries"], desc=f'Writing file on {output_tsv}'):
        f_out.write(f'{line["id"]}\t{line["text"]}\n')

Writing file on /content/queries.tsv: 100%|██████████| 528032/528032 [00:22<00:00, 23049.15it/s]
Writing file on /content/queries.tsv: 100%|██████████| 250/250 [00:00<00:00, 30410.25it/s]


In [14]:
!wc -l /content/corpus.tsv

wc: /content/corpus.tsv: No such file or directory


In [15]:
!wc -l /content/queries.tsv

250 /content/queries.tsv


In [16]:
# Open the input text file for reading
with open('/content/qrels.robust04.txt', 'r') as input_file:
    # Open the output TSV file for writing
    with open('/content/qrels.tsv', 'w', newline='') as output_file:
        # Read each line from the input file
        for line in input_file:
            output_file.write(line.replace(" ", "\t"))

### SPLADE Indexing

In [17]:
from torch.utils.data import Dataset, DataLoader

In [18]:
class BeirDataset(Dataset):
    """
    dataset to iterate over a TREC-COVID
    everything is preloaded in memory at init
    """
    def __init__(self, value_dictionary):
        self.value_dictionary = value_dictionary
        self.ids = list(value_dictionary.keys())

    def __len__(self):
        return len(self.value_dictionary)

    def __getitem__(self, idx):
        id = self.ids[idx]
        return id, self.value_dictionary.get(id)

In [19]:
class BeirDataLoader(DataLoader):
    def __init__(self, tokenizer_type, max_length, **kwargs):
        self.max_length = max_length
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_type)
        super().__init__(collate_fn=self.collate_fn, **kwargs, pin_memory=True)

    def collate_fn(self, batch):
        """
        batch is a list of tuples, each tuple has 2 (text) items (id_, doc)
        """
        id_, d = zip(*batch)
        processed_passage = self.tokenizer(list(d),
                                           add_special_tokens=True,
                                           padding="longest",
                                           truncation="longest_first",
                                           max_length=self.max_length,
                                           return_attention_mask=True)
        return {**{k: torch.tensor(v) for k, v in processed_passage.items()}, "ids": [str(i) for i in id_]}

In [20]:
class SpladeInvertedIndex():
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.inverted_index = defaultdict(list)
        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    def __call__(self, d_loader):
        count = 0
        doc_ids = []
        with torch.no_grad():
            for t, batch in enumerate(tqdm(d_loader)):
                ## inputs: tokens_ids; att_mask
                inputs = {k: v.to(self.device) for k, v in batch.items() if k not in {"ids"}}
                batch_documents = self.model(d_kwargs=inputs)["d_rep"]

                # batch_documents = (batch_size, vocab_size)
                # row = batch_size, col = vocab_id, ......vocab_n=vocab_size
                row, col = torch.nonzero(batch_documents, as_tuple=True)
                data = batch_documents[row, col]
                batch_ids = batch["ids"]
                self.add_batch_document(row.cpu().numpy(), col.cpu().numpy(), data.cpu().numpy(), batch_ids)
                if count % 10000 == 0:
                  print(f' {count} documents indexed!\n')
                count+=len(batch_ids)
        print("acabou")
        return self.inverted_index

    def add_batch_document(self, row, col, data, batch_ids):
        """add a batch of documents to the index
        """
        for doc_id, dim_id, value in zip(row, col, data):
            # For eatch dim_id == token_id, added doc_id and score as a tuple
            self.inverted_index[dim_id].append((batch_ids[doc_id], value))

In [21]:
!mkdir /content/indexes/
!mkdir /content/indexes/splade-index-mrobust

In [22]:
# loading model and tokenizer
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

ckpt = torch.load("/content/experiments/pt/checkpoint/model_ckpt/model_final_checkpoint.tar", map_location=device)
model = Splade("/content/experiments/pt/checkpoint/model", agg="max").to(device)
restore_model(model, ckpt["model_state_dict"])

model.eval()
tokenizer = AutoTokenizer.from_pretrained("neuralmind/bert-base-portuguese-cased")
reverse_voc = {v: k for k, v in tokenizer.vocab.items()}

restoring model: Splade


Downloading:   0%|          | 0.00/43.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/647 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/205k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [23]:
# get dataset
d_collection = BeirDataset(corpus)
q_collection = BeirDataset(topics)

# index BEIR collection
d_loader = BeirDataLoader(dataset=d_collection, tokenizer_type="neuralmind/bert-base-portuguese-cased", max_length=512, batch_size=32, shuffle=False, num_workers=4)

# Create Splade Indexes
splade_inverted_index = SpladeInvertedIndex(model)
splade_ii = splade_inverted_index(d_loader)

  0%|          | 2/16501 [00:03<7:08:35,  1.56s/it] 

 0 documents indexed!



  4%|▍         | 628/16501 [00:59<23:16, 11.37it/s]

 20000 documents indexed!



  8%|▊         | 1252/16501 [01:53<22:21, 11.36it/s]

 40000 documents indexed!



 11%|█▏        | 1877/16501 [02:51<28:24,  8.58it/s]

 60000 documents indexed!



 15%|█▌        | 2503/16501 [03:46<20:32, 11.36it/s]

 80000 documents indexed!



 19%|█▉        | 3127/16501 [04:41<19:46, 11.28it/s]

 100000 documents indexed!



 23%|██▎       | 3753/16501 [05:36<18:48, 11.29it/s]

 120000 documents indexed!



 27%|██▋       | 4377/16501 [06:31<18:03, 11.19it/s]

 140000 documents indexed!



 30%|███       | 5002/16501 [07:32<16:55, 11.32it/s]

 160000 documents indexed!



 34%|███▍      | 5628/16501 [08:27<16:05, 11.26it/s]

 180000 documents indexed!



 38%|███▊      | 6252/16501 [09:22<15:14, 11.21it/s]

 200000 documents indexed!



 42%|████▏     | 6878/16501 [10:17<13:54, 11.54it/s]

 220000 documents indexed!



 45%|████▌     | 7503/16501 [11:12<13:07, 11.42it/s]

 240000 documents indexed!



 49%|████▉     | 8127/16501 [12:19<12:25, 11.23it/s]

 260000 documents indexed!



 53%|█████▎    | 8752/16501 [13:22<11:22, 11.36it/s]

 280000 documents indexed!



 57%|█████▋    | 9378/16501 [14:17<10:25, 11.38it/s]

 300000 documents indexed!



 61%|██████    | 10002/16501 [15:12<09:26, 11.48it/s]

 320000 documents indexed!



 64%|██████▍   | 10628/16501 [16:07<08:37, 11.35it/s]

 340000 documents indexed!



 68%|██████▊   | 11253/16501 [17:15<07:43, 11.32it/s]

 360000 documents indexed!



 72%|███████▏  | 11877/16501 [18:10<06:45, 11.40it/s]

 380000 documents indexed!



 76%|███████▌  | 12503/16501 [19:06<05:58, 11.16it/s]

 400000 documents indexed!



 80%|███████▉  | 13127/16501 [20:01<05:03, 11.12it/s]

 420000 documents indexed!



 83%|████████▎ | 13753/16501 [20:58<04:09, 11.04it/s]

 440000 documents indexed!



 87%|████████▋ | 14377/16501 [22:13<03:09, 11.20it/s]

 460000 documents indexed!



 91%|█████████ | 15003/16501 [23:09<02:13, 11.26it/s]

 480000 documents indexed!



 95%|█████████▍| 15627/16501 [24:05<01:18, 11.20it/s]

 500000 documents indexed!



 98%|█████████▊| 16253/16501 [25:01<00:22, 11.23it/s]

 520000 documents indexed!



100%|██████████| 16501/16501 [25:23<00:00, 10.83it/s]

acabou





### Análise do índice invertido

In [40]:
tok_dist = [(tok, len(splade_ii.get(tok))) for tok in splade_ii]
tok_dist_df = pd.DataFrame(tok_dist, columns = ["token", "docs"])
tok_dist_df[["docs"]].describe()

Unnamed: 0,docs
count,20605.0
mean,14212.762922
std,43816.252325
min,1.0
25%,119.0
50%,869.0
75%,5479.0
max,482577.0


In [43]:
tok_dist_df[["docs"]].quantile([0.5, 0.75, 0.9, 0.95, 0.99])

Unnamed: 0,docs
0.5,869.0
0.75,5479.0
0.9,32675.8
0.95,80245.4
0.99,247833.96


In [None]:
#with open(f'/content/splade-inverted-index.pkl', 'wb') as f:
#    pickle.dump(splade_ii, f, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
#shutil.copyfile("/content/indexes/splade-index-mrobust/splade-inverted-index.pkl", f"{main_dir}Projeto Final/splade-inverted-index.pkl")

## Seacher

In [44]:
class Searcher():
    def __init__(self, inverted_index, doc_dictionary, model, tokenizer):
      super().__init__()
      self.inverted_index = inverted_index
      self.doc_dictionary = {y: x for x, y in doc_dictionary.items()}
      self.model = model
      self.tokenizer = tokenizer

    def __call__(self, text = None, k = 10):
      self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

      with torch.no_grad():
        tokens_input = self.tokenizer(text, return_tensors="pt")
        tokens_input = {k: v.to(self.device) for k, v in tokens_input.items()}
        query = self.model(q_kwargs=tokens_input)["q_rep"]  # we assume ONE query per batch here

        row, col = torch.nonzero(query, as_tuple=True) #!=0
        values = query[row.cpu().numpy().tolist(), col.cpu().numpy().tolist()]
        result_retrieval = self.retrieval(col, values, n_docs=len(self.doc_dictionary.keys()))
        # sorting...
        return self.select_topk(result_retrieval, k)

    def retrieval(self, tokens_query, query_scores,n_docs=0):
      score_matriz = np.zeros((len(tokens_query), n_docs), dtype=np.float16)
      #score_matriz = #len(tokens) x docs_ids (171332)
      # A ideia é apos a indexaçaõ dos documentos
      #fazer um produto escalar entre tokens scores x docs
      # Nessa redução, voltaremos a ter um array 1D (171332)

      for _idx, (q_id, q_score) in enumerate(zip(tokens_query, query_scores)):
        retrieved_indexes = self.inverted_index[int(q_id)]

        q_score = q_score.cpu().item() # convert to numpy
        if len(retrieved_indexes) > 1:
          #retrieved_indexes = [(doc_id_1, score_1), ... (doc_id_n, score_n)]
          docs_ids_retrived = np.array([i[0] for i in retrieved_indexes])
          scores_retrived = [i[1] for i in retrieved_indexes]

          #fill
          score_matriz[_idx,docs_ids_retrived] =  scores_retrived

      score_docs = np.dot(query_scores.cpu().numpy(), score_matriz)
      return score_docs

    def select_topk(self, docs_indexes, k_hits=100):
      # O argsort traz o index da ordenação
      filtered_scores = -np.sort(-docs_indexes)[0:k_hits]
      ordered_index = np.argsort(-docs_indexes)
      #remonta o dicionary
      return {self.doc_dictionary[id]:score \
              for id, score in zip(ordered_index, filtered_scores)}

In [45]:
#convert doc-string in int
def doc_id_to_int(index, docs_ids):
  # Kudos to Mirelle Bueno
  docs_ids = {id:i for i, id in enumerate(docs_ids)}
  count = 0

  index_format = defaultdict(list)

  for key, value in index.items():
    for doc_tuple in value:
      index_format[key].append((docs_ids[doc_tuple[0]], doc_tuple[1]))
  return index_format, docs_ids

In [48]:
def splade_retrieval(splade_searcher, queries, k = 1000):
    # settings
    hits = collections.OrderedDict()

    # splade retrieval
    for query_id, query_text in tqdm(queries.items(), desc="SPLADE Retrieval"):
        docs_rel = splade_searcher(query_text, k = k)
        hits[query_id] = collections.OrderedDict(docs_rel)

    return hits

In [49]:
# init
trec_eval = evaluate.load("trec_eval")

# Initialize device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# set the dir for trained weights
ckpt = torch.load("/content/experiments/pt/checkpoint/model_ckpt/model_final_checkpoint.tar", map_location=device)
model = Splade("/content/experiments/pt/checkpoint/model", agg="max").to(device)
restore_model(model, ckpt["model_state_dict"])
tokenizer = AutoTokenizer.from_pretrained("neuralmind/bert-base-portuguese-cased")
model.eval()

# get index_format and docs_id_registry
index_format, docs_id_registry = doc_id_to_int(splade_ii, list(corpus.keys()))

# splade settings
splade_searcher = Searcher(inverted_index = index_format, doc_dictionary = docs_id_registry, model = model, tokenizer = tokenizer)
splade_run = splade_retrieval(splade_searcher, topics, k = 1000)

restoring model: Splade


SPLADE Retrieval: 100%|██████████| 250/250 [1:07:33<00:00, 16.21s/it]


In [50]:
len(list(splade_run.keys()))

250

In [51]:
with open(f'/content/splade_run.pkl', 'wb') as f:
    pickle.dump(splade_run, f, protocol=pickle.HIGHEST_PROTOCOL)

In [52]:
shutil.copyfile("/content/splade_run.pkl", f"{main_dir}Projeto Final/splade_run.pkl")

'/content/gdrive/MyDrive/Unicamp/IA368-DD/Projeto Final/splade_run.pkl'

In [63]:
run_filename = f"/content/run.mrobust.splade.pt.txt"
with open(run_filename,'w') as f_out:
    for query_id in splade_run:
        rank = 1
        hits = splade_run.get(query_id)
        for doc_id, score in hits.items():
            f_out.write(f'{query_id} Q0 {doc_id} {rank} {score} SPLADE-PT\n')
            rank += 1

In [74]:
!/content/pyserini/tools/eval/trec_eval.9.0.4/trec_eval -m map -m ndcg_cut -m recall.1000 /content/qrels.tsv /content/run.mrobust.splade.pt.txt

map                   	all	0.1291
recall_1000           	all	0.4659
ndcg_cut_5            	all	0.3451
ndcg_cut_10           	all	0.3110
ndcg_cut_15           	all	0.2910
ndcg_cut_20           	all	0.2827
ndcg_cut_30           	all	0.2724
ndcg_cut_100          	all	0.2687
ndcg_cut_200          	all	0.2844
ndcg_cut_500          	all	0.3192
ndcg_cut_1000         	all	0.3477
