<a href="https://colab.research.google.com/github/juliatessler/1s2023-unicamp-dl-for-search-systems/blob/main/8-splade/8_splade.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SPLADE
by Júlia Ferreira Tessler

I have used several things shared in [this YouTube video](https://www.youtube.com/watch?v=0FQ2WmM0t3w&ab_channel=JamesBriggs) that I wasn't able to find the person who shared ):

In [None]:
!pip install transformers -q
!pip install datasets -q
!pip install pyserini -q
!pip install faiss-cpu -q
!pip install ftfy -q
!pip install evaluate

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting evaluate
  Downloading evaluate-0.4.0-py3-none-any.whl (81 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.4/81.4 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: evaluate
Successfully installed evaluate-0.4.0


In [None]:
from google.colab import drive

drive.mount('/content/gdrive', force_remount = True)

Mounted at /content/gdrive


In [None]:
workdir = '/content/gdrive/MyDrive/Unicamp/DL_applied_to_IR/Notebooks'

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Thu Apr 27 00:44:51 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-SXM...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P0    45W / 400W |      0MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [None]:
import pandas as pd
import numpy as np
import torch.nn.functional as F
import os
import json

from transformers import (
    DistilBertTokenizer, 
    DistilBertForMaskedLM, 
    BertTokenizer, 
    BertForMaskedLM,
    BatchEncoding
)

from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
from collections import defaultdict
from evaluate import load

## Get data

In [None]:
trec_covid_queries = load_dataset("BeIR/trec-covid", 'queries')
trec_covid_corpus = load_dataset("BeIR/trec-covid", 'corpus')



  0%|          | 0/1 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
trec_covid_queries

DatasetDict({
    queries: Dataset({
        features: ['_id', 'title', 'text'],
        num_rows: 50
    })
})

In [None]:
trec_covid_queries['queries'][:10]

{'_id': ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10'],
 'title': ['', '', '', '', '', '', '', '', '', ''],
 'text': ['what is the origin of COVID-19',
  'how does the coronavirus respond to changes in the weather',
  'will SARS-CoV2 infected people develop immunity? Is cross protection possible?',
  'what causes death from Covid-19?',
  'what drugs have been active against SARS-CoV or SARS-CoV-2 in animal studies?',
  'what types of rapid testing for Covid-19 have been developed?',
  'are there serological tests that detect antibodies to coronavirus?',
  'how has lack of testing availability led to underreporting of true incidence of Covid-19?',
  'how has COVID-19 affected Canada',
  'has social distancing had an impact on slowing the spread of COVID-19?']}

In [None]:
trec_covid_corpus['corpus'][:10]

{'_id': ['ug7v899j',
  '02tnwd4m',
  'ejv2xln0',
  '2b73a28n',
  '9785vg6d',
  'zjufx4fo',
  '5yhe786e',
  '8zchiykl',
  '8qnrcgnk',
  'jg13scgo'],
 'title': ['Clinical features of culture-proven Mycoplasma pneumoniae infections at King Abdulaziz University Hospital, Jeddah, Saudi Arabia',
  'Nitric oxide: a pro-inflammatory mediator in lung disease?',
  'Surfactant protein-D and pulmonary host defense',
  'Role of endothelin-1 in lung disease',
  'Gene expression in epithelial cells in response to pneumovirus infection',
  'Sequence requirements for RNA strand transfer during nidovirus discontinuous subgenomic RNA synthesis',
  'Debate: Transfusing to normal haemoglobin levels will not improve outcome',
  'The 21st International Symposium on Intensive Care and Emergency Medicine, Brussels, Belgium, 20-23 March 2001',
  'Heme oxygenase-1 and carbon monoxide in pulmonary medicine',
  'Technical Description of RODS: A Real-time Public Health Surveillance System'],
 'text': ['OBJECTIVE: T

## Get model

In [None]:
model_name = 'naver/splade_v2_distil'

tokenizer = DistilBertTokenizer.from_pretrained(model_name)
model = DistilBertForMaskedLM.from_pretrained(model_name).to(device)

## Prepare dataset

In [None]:
class SPLADEDataset(Dataset):
  def __init__(self, text, tokenizer, max_length):
    self.text = text
    self.tokenizer = tokenizer
    self.max_length = max_length
  
  def __len__(self):
    return len(self.text)
  
  def __getitem__(self, idx):
    tokenized_text = self.tokenizer(self.text[idx],
                                    padding = True,
                                    truncation = True,
                                    # Autores não removem CLS e SEP, como bem notado
                                    # por Leandro Carísio
                                    # https://github.com/naver/splade/blob/main/inference_splade.ipynb
                                    return_special_tokens_mask = True,
                                    max_length = self.max_length)
    return tokenized_text

In [None]:
max_length = 256
batch_size = 32

In [None]:
def collate_fn(batch):
  return BatchEncoding(tokenizer.pad(batch, return_tensors = 'pt'))

In [None]:
trec_covid_dataset = SPLADEDataset(trec_covid_corpus['corpus']['text'], 
                                   tokenizer, max_length)

trec_covid_dataloader = DataLoader(trec_covid_dataset, 
                                   batch_size = batch_size, 
                                   collate_fn = collate_fn)

## Sparse representation

In [None]:
def get_sparse_representation(model, tokenizer, text, agg='sum'):

    # tokenize and get token_ids
    tokenizer_output = tokenizer(
        text, 
        return_special_tokens_mask = True,
        truncation                 = True, 
        return_tensors             = 'pt',
        max_length                 = max_length
    )

    # propagates and get logits
    with torch.autocast(device_type=str(device), dtype=torch.float16, enabled=True):
        with torch.no_grad():
            model_output = model(
                input_ids      = tokenizer_output['input_ids'].to(device), 
                attention_mask = tokenizer_output['attention_mask'].to(device)
            )
    logits = model_output.logits[0, :]

    # Mask
    mask_valid_tokens = 1 - tokenizer_output['special_tokens_mask'].to(device)
    mask = mask_valid_tokens.squeeze().unsqueeze(-1).expand(logits.size())
    
    # Applying equation from paper to get sparse representation
    if agg == 'sum':
        wj = torch.sum(torch.log(1 + torch.relu(logits*mask)), dim=0)
    else:
        wj, _ = torch.max(torch.log(1 + torch.relu(logits*mask)), dim=0)

    return wj.to_sparse()

def get_tokens_and_values(wj):

    # Getting batck tokens from wj
    tokens = tokenizer.convert_ids_to_tokens(wj.indices()[0])
    return list(zip(tokens, wj.values()))

### Testing

In [None]:
text = "My name is Júlia. I really like dinosaurs."

In [None]:
wj = get_sparse_representation(model, tokenizer, text)
wj

tensor(indices=tensor([[ 1011,  1045,  2003,  2014,  2016,  2017,  2026,  2032,
                         2033,  2040,  2066,  2079,  2171,  2228,  2256,  2293,
                         2360,  2428,  2450,  2643,  2653,  2767,  2941,  3566,
                         3814,  4489,  5203,  5223,  6127,  6423,  7059,  7777,
                        12120, 15799, 18148, 19958]]),
       values=tensor([4.1577e-01, 7.5000e-01, 4.5752e-01, 6.4883e+00,
                      1.1951e-01, 9.1797e-01, 1.9023e+00, 4.0192e-02,
                      4.9976e-01, 3.5801e+00, 1.5938e+00, 2.8076e-01,
                      3.9297e+00, 3.2568e-01, 4.0192e-02, 1.4736e+00,
                      4.0161e-01, 1.2832e+00, 3.2202e-01, 1.1523e+00,
                      9.7609e-04, 6.3354e-02, 3.4448e-01, 7.1094e-01,
                      5.9375e-01, 1.2561e-01, 1.1523e+00, 6.2549e-01,
                      9.1162e-01, 2.8477e+00, 1.7324e+00, 2.6294e-01,
                      7.4756e-01, 2.2598e+00, 2.3691e+00, 5.7227e

In [None]:
tokens_and_values = get_tokens_and_values(wj)
tokens_and_values

[('-', tensor(0.4158, device='cuda:0', dtype=torch.float16)),
 ('i', tensor(0.7500, device='cuda:0', dtype=torch.float16)),
 ('is', tensor(0.4575, device='cuda:0', dtype=torch.float16)),
 ('her', tensor(6.4883, device='cuda:0', dtype=torch.float16)),
 ('she', tensor(0.1195, device='cuda:0', dtype=torch.float16)),
 ('you', tensor(0.9180, device='cuda:0', dtype=torch.float16)),
 ('my', tensor(1.9023, device='cuda:0', dtype=torch.float16)),
 ('him', tensor(0.0402, device='cuda:0', dtype=torch.float16)),
 ('me', tensor(0.4998, device='cuda:0', dtype=torch.float16)),
 ('who', tensor(3.5801, device='cuda:0', dtype=torch.float16)),
 ('like', tensor(1.5938, device='cuda:0', dtype=torch.float16)),
 ('do', tensor(0.2808, device='cuda:0', dtype=torch.float16)),
 ('name', tensor(3.9297, device='cuda:0', dtype=torch.float16)),
 ('think', tensor(0.3257, device='cuda:0', dtype=torch.float16)),
 ('our', tensor(0.0402, device='cuda:0', dtype=torch.float16)),
 ('love', tensor(1.4736, device='cuda:0', dt

### Batch sparse representation

In [None]:
def get_sparse_representation_batch(model, tokenizer, dataloader):

    wjs = None
    for batch in tqdm(dataloader):  
      with torch.autocast(device_type = str(device), dtype = torch.float16, enabled = True): 
        with torch.no_grad():
            model_output = model(
                input_ids      = batch['input_ids'].to(device), 
                attention_mask = batch['attention_mask'].to(device)
            )

      # Getting logits
      logits = model_output.logits

      # Getting mask fthat will be used to generate the sparse vector
      mask_valid_tokens = batch['attention_mask']
      mask = mask_valid_tokens.unsqueeze(-1).expand(logits.size()).to(device)

      # Getting wjs based on mask and logits
      wj, _ = torch.max(torch.log(1 + F.relu(logits*mask)), dim=1)
      wj_sparse = wj.to_sparse()

      # Concatenating into wjs to save
      wjs = wj_sparse if wjs is None else torch.cat([wjs, wj_sparse], dim=0)
      
    return wjs

In [None]:
if not os.path.exists(f'{workdir}/trec_covid_wjs_with_special_tokens.pt'):
    trec_covid_wjs = get_sparse_representation_batch(model, tokenizer, trec_covid_dataloader)
    torch.save(trec_covid_wjs, f'{workdir}/trec_covid_wjs_with_special_tokens.pt')
else:
    trec_covid_wjs = torch.load(f'{workdir}/trec_covid_wjs_with_special_tokens.pt').to(device)

In [None]:
trec_covid_wjs.size()

torch.Size([171332, 30522])

## Search

In [None]:
row_sums = torch.sparse.sum(trec_covid_wjs, dim=1).to_dense()
row_sums

tensor([222.8063, 179.4055, 235.0530,  ..., 191.2339, 223.8661,  11.9450],
       device='cuda:0')

In [None]:
ids_docs = trec_covid_corpus['corpus']['_id']

In [None]:
def get_scores_sparse_vectors(query_vec, 
                              wj_matrix_doc, 
                              normalize_doc = False, 
                              normalize_query = True,
                              hits = 1000):

    scores = torch.sparse.mm(wj_matrix_doc.to(torch.half), query_vec.unsqueeze(-1).to(torch.half)).to_dense()
    scores = scores.squeeze(1) 

    # Normalizations
    if normalize_doc:
        scores = torch.div(scores, row_sums)
    if normalize_query:
        scores = scores / query_vec.sum().item()

    # Getting ids and scores
    docs_score = list(zip(ids_docs, scores))
    docs_score = sorted(docs_score, key=lambda x: x[1].item(), reverse=True)

    return docs_score[:hits]

In [None]:
trec_queries_df = pd.DataFrame()
trec_queries_df['_id'] = trec_covid_queries['queries']['_id']
trec_queries_df['text'] = trec_covid_queries['queries']['text']

In [None]:
run = {
    'query_id':   [],
    'passage_id': [],
    'score':      [],
    'rank':       [],
    'Q0':         [],
    'run':        []
}

# não consegui descobrir como iterar corretamente em json
for i, row in trec_queries_df.iterrows():

    # Vectorizing query
    vec_query = get_sparse_representation(model, tokenizer, row['text'], agg = 'max')

    # getting score
    docs_score = get_scores_sparse_vectors(vec_query, trec_covid_wjs)

    passage_ids = [x[0]        for x in docs_score]
    scores      = [x[1].item() for x in docs_score]

    # appending in run
    run['query_id'].extend([row['_id']] * 1000)
    run['passage_id'].extend(passage_ids)
    run['score'].extend(scores)
    run['rank'].extend(list(range(1,1001)))
    run['Q0'].extend(['Q0'] * 1000)
    run['run'].extend(['vec_query_vs_wj_passages'] * 1000)

df_run = pd.DataFrame(run)
df_run = df_run[['query_id', 'Q0', 'passage_id', 'rank', 'score', 'run']]

print(df_run.shape)
df_run.head()

(50000, 6)


Unnamed: 0,query_id,Q0,passage_id,rank,score,run
0,1,Q0,1mjaycee,1,1.297852,vec_query_vs_wj_passages
1,1,Q0,zrycqlvs,2,1.290039,vec_query_vs_wj_passages
2,1,Q0,0wm6u10a,3,1.287109,vec_query_vs_wj_passages
3,1,Q0,cgvj10r2,4,1.287109,vec_query_vs_wj_passages
4,1,Q0,uv4gbhbb,5,1.287109,vec_query_vs_wj_passages


In [None]:
df_run.to_csv(f'{workdir}/run.sparse_rank.txt', sep='\t', header=None, index=None)

In [None]:
# Qrels with all relevances
!wget https://huggingface.co/datasets/BeIR/trec-covid-qrels/raw/main/test.tsv

--2023-04-27 00:55:53--  https://huggingface.co/datasets/BeIR/trec-covid-qrels/raw/main/test.tsv
Resolving huggingface.co (huggingface.co)... 13.227.219.41, 13.227.219.63, 13.227.219.125, ...
Connecting to huggingface.co (huggingface.co)|13.227.219.41|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 980831 (958K) [text/plain]
Saving to: ‘test.tsv’


2023-04-27 00:55:53 (5.50 MB/s) - ‘test.tsv’ saved [980831/980831]



In [None]:
with open('test.tsv', 'r') as fin:
  data = fin.read().splitlines(True)
with open('qrels_format.tsv', 'w') as fout:
  for linha in data[1:]:
    campos = linha.split()
    fout.write(f'{campos[0]}\t0\t{campos[1]}\t{campos[2]}\n')


In [None]:
!python -m pyserini.eval.trec_eval -c -m ndcg_cut.10 qrels_format.tsv {workdir}/run.sparse_rank.txt #type: ign

Downloading https://search.maven.org/remotecontent?filepath=uk/ac/gla/dcs/terrierteam/jtreceval/0.0.5/jtreceval-0.0.5-jar-with-dependencies.jar to /root/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar...
/root/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar already exists!
Skipping download.
Running command: ['java', '-jar', '/root/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar', '-c', '-m', 'ndcg_cut.10', 'qrels_format.tsv', '/content/gdrive/MyDrive/Unicamp/DL_applied_to_IR/Notebooks/run.sparse_rank.txt']
Results:
ndcg_cut_10           	all	0.6067
