# Dense Retriever

Author: Monique Monteiro - moniquelouise@gmail.com

In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

Mounted at /content/gdrive


## Libraries Installation

In [None]:
!pip install transformers -q

In [None]:
!pip install jsonlines

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
!pip install trectools

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
!pip install evaluate

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting evaluate
  Downloading evaluate-0.4.0-py3-none-any.whl (81 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.4/81.4 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets>=2.0.0
  Downloading datasets-2.11.0-py3-none-any.whl (468 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m468.7/468.7 kB[0m [31m21.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess
  Downloading multiprocess-0.70.14-py39-none-any.whl (132 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m132.9/132.9 kB[0m [31m19.5 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash
  Downloading xxhash-3.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.2/212.2 kB[0m [31m29.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting responses<0.19
  Downloading respo

In [None]:
!pip install hnswlib

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting hnswlib
  Downloading hnswlib-0.7.0.tar.gz (33 kB)
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: hnswlib
  Building wheel for hnswlib (pyproject.toml) ... [?25l[?25hdone
  Created wheel for hnswlib: filename=hnswlib-0.7.0-cp39-cp39-linux_x86_64.whl size=2119213 sha256=0eb2375383f8d7ece64f153ad6c27a68cff662fab9ad41f5cb1f25a55078266b
  Stored in directory: /root/.cache/pip/wheels/ba/26/61/fface6c407f56418b3140cd7645917f20ba6b27d4e32b2bd20
Successfully built hnswlib
Installing collected packages: hnswlib
Successfully installed hnswlib-0.7.0


In [None]:
import random
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import BatchEncoding
from torch.utils import data
from torch import nn
from torch import optim
from tqdm.auto import tqdm
from transformers import get_linear_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup
from statistics import mean
import jsonlines
import pickle
from transformers import AutoTokenizer, AutoModel
from collections import defaultdict

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Definition of seed for randomness, for the purpose of reproducing training.

In [None]:
random.seed(123)
np.random.seed(123)
torch.manual_seed(123)

<torch._C.Generator at 0x7fb7d029d330>

## Pretrained model and tokenizer

In [None]:
model_name = 'microsoft/MiniLM-L12-H384-uncased' 
tokenizer = AutoTokenizer.from_pretrained(model_name)

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [None]:
max_length = 256

In [None]:
main_dir = "/content/gdrive/MyDrive/Unicamp-aula-7"
models_dir = f"{main_dir}/models_cls_pooling"

## Dataset definition

In [None]:
from torch.utils import data

class Dataset(data.Dataset):
    def __init__(self, tokenized_texts):
        self.tokenized_texts = tokenized_texts
    
    def __len__(self):
        return len(self.tokenized_texts['input_ids'])
    
    def __getitem__(self, idx):
        #return self.tokenized_texts[idx]
        return {
            'input_ids': self.tokenized_texts[idx].ids,
            'attention_mask': self.tokenized_texts[idx].attention_mask
        }

In [None]:
from transformers import BatchEncoding

def collate_fn(batch):
  return BatchEncoding(tokenizer.pad(batch, return_tensors='pt'))

## Dataset download and processing

In [None]:
!wget https://storage.googleapis.com/unicamp-dl/ia368dd_2023s1/msmarco/msmarco_triples.train.tiny.tsv

--2023-04-19 23:13:06--  https://storage.googleapis.com/unicamp-dl/ia368dd_2023s1/msmarco/msmarco_triples.train.tiny.tsv
Resolving storage.googleapis.com (storage.googleapis.com)... 64.233.170.128, 74.125.200.128, 74.125.68.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|64.233.170.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 8076179 (7.7M) [text/tab-separated-values]
Saving to: ‘msmarco_triples.train.tiny.tsv’


2023-04-19 23:13:08 (5.32 MB/s) - ‘msmarco_triples.train.tiny.tsv’ saved [8076179/8076179]



In [None]:
df = pd.read_csv('msmarco_triples.train.tiny.tsv', sep='\t', header=None, names=["query", "relevant", "non_relevant"])

train_df, val_df = train_test_split(df, test_size=0.1, random_state=42)

queries_train = list(train_df['query'])
docs_train = list(train_df['relevant'])
queries_val = list(val_df["query"])
docs_val = list(val_df["relevant"])

In [None]:
query_model = AutoModel.from_pretrained(model_name)
doc_model = AutoModel.from_pretrained(model_name)

Downloading pytorch_model.bin:   0%|          | 0.00/133M [00:00<?, ?B/s]

In [None]:
train_queries_tokenized = tokenizer(queries_train, max_length=max_length, truncation=True, padding=True)
train_passages_tokenized = tokenizer(docs_train, max_length=max_length, truncation=True, padding=True)
val_queries_tokenized = tokenizer(queries_val, max_length=max_length, truncation=True, padding=True)
val_passages_tokenized = tokenizer(docs_val, max_length=max_length, truncation=True, padding=True)

In [None]:
dataset_queries_train = Dataset(train_queries_tokenized)
dataset_docs_train = Dataset(train_passages_tokenized)

dataset_queries_val = Dataset(val_queries_tokenized)
dataset_docs_val = Dataset(val_passages_tokenized)

In [None]:
len(dataset_queries_train), len(dataset_docs_train), len(dataset_queries_val), len(dataset_docs_val)

(9900, 9900, 1100, 1100)

In [None]:
batch_size=32

In [None]:
dataloader_queries_train = data.DataLoader(dataset_queries_train, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
dataloader_docs_train = data.DataLoader(dataset_docs_train, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

dataloader_queries_val = data.DataLoader(dataset_queries_val, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
dataloader_docs_val = data.DataLoader(dataset_docs_val, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

## Training loop

In [None]:
def compute_loss(query_outputs, doc_outputs):
  #last_hidden_state: shape (batch_size, sequence_length, hidden_size)
  query_cls = query_outputs.last_hidden_state[:, 0, :]
  doc_cls = doc_outputs.last_hidden_state[:, 0, :]

  # full similarity matrix
  similarity = torch.matmul(query_cls, torch.transpose(doc_cls, 0, 1))
  exp_similarity = torch.exp(similarity)
  denominator = torch.sum(exp_similarity, dim=1)
  loss = -torch.log((torch.diag(exp_similarity)/denominator))

  return torch.mean(loss)


In [None]:
def compute_loss_with_mean_pooling(query_outputs, doc_outputs):
  #last_hidden_state: shape (batch_size, sequence_length, hidden_size)
  query_mean_vector = query_outputs.last_hidden_state.mean(dim=1)
  doc_mean_vector = doc_outputs.last_hidden_state.mean(dim=1)

  # full similarity matrix
  similarity = torch.matmul(query_mean_vector, 
                            torch.transpose(doc_mean_vector, 0, 1))
  exp_similarity = torch.exp(similarity)
  denominator = torch.sum(exp_similarity, dim=1)
  loss = -torch.log((torch.diag(exp_similarity)/denominator))

  return torch.mean(loss)

In [None]:
def evaluate(query_model, doc_model, dataloader_queries, dataloader_docs, 
             set_name='Valid', loss_fn=compute_loss):
  query_model.eval()
  doc_model.eval()

  losses = []
  with torch.no_grad():
    for query_batch, doc_batch in tqdm(list(zip(dataloader_queries, 
                                                dataloader_docs)), 
                                      mininterval=0.5, desc=set_name, 
                                      disable=False):
      query_outputs = query_model(**query_batch.to(device))
      doc_outputs = doc_model(**doc_batch.to(device))
      loss = loss_fn(query_outputs, doc_outputs)
      losses.append(loss.cpu().item())
      
  return losses

Same hyperparameters for batch size, number of epochs, learning rate, learning rate scheduler and maximum input length used in https://colab.research.google.com/drive/1fJ9Xx4v8eiF0wrbMBw8tGs5JZhX86Fkz?usp=sharing (by Lenadro Fernandes)

Other variations for those hypeparameters led to worse results.

In [None]:
def train(query_model, doc_model, epochs = 20, lr=2e-5, loss_fn=compute_loss):
  query_model = query_model.to(device)
  doc_model = doc_model.to(device)

  num_training_steps = epochs * len(dataloader_queries_train)
  num_warmup_steps = int(num_training_steps * 0.1)

  query_optimizer = optim.AdamW(query_model.parameters(), lr)
  query_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
      query_optimizer, num_warmup_steps, num_training_steps)
  doc_optimizer = optim.AdamW(doc_model.parameters(), lr)
  doc_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
      doc_optimizer, num_warmup_steps, num_training_steps)

  # Training loop
  for epoch in tqdm(range(epochs), desc='Epochs'):
      query_model.train()
      doc_model.train()
      train_losses = []
      for query_batch, doc_batch in tqdm(list(zip(dataloader_queries_train, 
                                                  dataloader_docs_train)), 
                                         mininterval=0.5, desc='Train', 
                                         disable=False):
          query_optimizer.zero_grad()
          doc_optimizer.zero_grad()

          query_outputs = query_model(**query_batch.to(device))
          doc_outputs = doc_model(**doc_batch.to(device))

          loss = loss_fn(query_outputs, doc_outputs)
          loss.backward()
          
          query_optimizer.step()
          query_scheduler.step()
          
          doc_optimizer.step()
          doc_scheduler.step()
          
          train_losses.append(loss.cpu().item())

      query_model.save_pretrained(f"{models_dir}/query_model_{epoch+1}")
      doc_model.save_pretrained(f"{models_dir}/doc_model_{epoch+1}")

      print(f'Epoch: {epoch + 1} Training loss: {mean(train_losses)}')
      val_losses = evaluate(query_model, doc_model, dataloader_queries_val, 
               dataloader_docs_val, set_name='Valid', loss_fn=loss_fn)
      print(f'Epoch: {epoch + 1} Valid loss: {mean(val_losses)}')
  
  return query_model, doc_model

#### Using [CLS] token pooling

In [None]:
query_model, doc_model = train(query_model, doc_model)

Epochs:   0%|          | 0/20 [00:00<?, ?it/s]

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 1 Training loss: 2.2128054949545093


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 1 Valid loss: 0.3924016376025975


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 2 Training loss: 0.3520540765456615


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 2 Valid loss: 0.09862756029968815


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 3 Training loss: 0.14031915396992717


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 3 Valid loss: 0.09686023268316474


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 4 Training loss: 0.08428099621988593


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 4 Valid loss: 0.0774912363167719


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 5 Training loss: 0.05307565355286633


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 5 Valid loss: 0.09290167569249336


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 6 Training loss: 0.039610639750777235


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 6 Valid loss: 0.07958409792211439


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 7 Training loss: 0.03453548175673331


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 7 Valid loss: 0.050535159098217264


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 8 Training loss: 0.02310355017224567


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 8 Valid loss: 0.06486951914079588


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 9 Training loss: 0.01911619869266771


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 9 Valid loss: 0.05258942633206191


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 10 Training loss: 0.013168442254643578


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 10 Valid loss: 0.04764369627747718


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 11 Training loss: 0.011637378661332824


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 11 Valid loss: 0.04244550009960741


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 12 Training loss: 0.009707040713152721


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 12 Valid loss: 0.05449711303018765


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 13 Training loss: 0.009789569806232447


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 13 Valid loss: 0.04895339830171516


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 14 Training loss: 0.005562897400253156


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 14 Valid loss: 0.04338003609851252


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 15 Training loss: 0.005323349779486776


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 15 Valid loss: 0.04134480717704199


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 16 Training loss: 0.0050752832900534195


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 16 Valid loss: 0.040785993432343404


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 17 Training loss: 0.003783591696687396


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 17 Valid loss: 0.03934763063630921


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 18 Training loss: 0.004035086338309714


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 18 Valid loss: 0.037259927356821466


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 19 Training loss: 0.005408151376333638


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 19 Valid loss: 0.037728191715411544


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 20 Training loss: 0.0025071259981653002


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 20 Valid loss: 0.03772947734635506


#### Using mean pooling

In [None]:
query_model, doc_model = train(query_model, doc_model, loss_fn=compute_loss_with_mean_pooling)

Epochs:   0%|          | 0/20 [00:00<?, ?it/s]

Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 1 Training loss: 2.971965960725661


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 1 Valid loss: 0.42906245408313615


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 2 Training loss: 0.2924361018403884


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 2 Valid loss: 0.12487358653119632


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 3 Training loss: 0.10497448202372799


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 3 Valid loss: 0.09627644350818758


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 4 Training loss: 0.04523182156901326


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 4 Valid loss: 0.09394651864256177


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 5 Training loss: 0.0244456467641172


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 5 Valid loss: 0.0768546123564842


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 6 Training loss: 0.01989409980925566


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 6 Valid loss: 0.06747974996861948


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 7 Training loss: 0.017408944299865153


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 7 Valid loss: 0.08061064978412885


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 8 Training loss: 0.012827451013717725


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 8 Valid loss: 0.057656416805860186


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 9 Training loss: 0.012605927476374013


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 9 Valid loss: 0.07042944801043112


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 10 Training loss: 0.00582514199176653


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 10 Valid loss: 0.0476159046360408


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 11 Training loss: 0.00474605352772473


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 11 Valid loss: 0.038084980488877464


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 12 Training loss: 0.004501709471761358


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 12 Valid loss: 0.04067563498101663


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 13 Training loss: 0.0037174494841066565


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 13 Valid loss: 0.03795276731727881


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 14 Training loss: 0.002169210042792392


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 14 Valid loss: 0.03527747765508918


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 15 Training loss: 0.0027781045557129535


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 15 Valid loss: 0.046814531799230154


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 16 Training loss: 0.001842613017690324


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 16 Valid loss: 0.03550767882532974


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 17 Training loss: 0.001123973956600821


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 17 Valid loss: 0.03162502061781132


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 18 Training loss: 0.0012077248278875496


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 18 Valid loss: 0.031332056983425054


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 19 Training loss: 0.0011621363368307827


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 19 Valid loss: 0.03215569139700522


Train:   0%|          | 0/310 [00:00<?, ?it/s]

Epoch: 20 Training loss: 0.0007838695662529414


Valid:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch: 20 Valid loss: 0.03212840379316698


## Evaluation on TREC-COVID

### Passages download

In [None]:
!wget https://huggingface.co/datasets/BeIR/trec-covid/resolve/main/corpus.jsonl.gz

--2023-04-17 04:14:00--  https://huggingface.co/datasets/BeIR/trec-covid/resolve/main/corpus.jsonl.gz
Resolving huggingface.co (huggingface.co)... 13.224.249.63, 13.224.249.89, 13.224.249.86, ...
Connecting to huggingface.co (huggingface.co)|13.224.249.63|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/a8/10/a810e88b0e7b233be82b89c1fa6ec2d75efc6d55784c2ada9dcac8434a634f3a/e9e97686e3138eaff989f67c04cd32e8f8f4c0d4857187e3f180275b23e24e85?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27corpus.jsonl.gz%3B+filename%3D%22corpus.jsonl.gz%22%3B&response-content-type=application%2Fgzip&Expires=1681964041&Policy=eyJTdGF0ZW1lbnQiOlt7IlJlc291cmNlIjoiaHR0cHM6Ly9jZG4tbGZzLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2E4LzEwL2E4MTBlODhiMGU3YjIzM2JlODJiODljMWZhNmVjMmQ3NWVmYzZkNTU3ODRjMmFkYTlkY2FjODQzNGE2MzRmM2EvZTllOTc2ODZlMzEzOGVhZmY5ODlmNjdjMDRjZDMyZThmOGY0YzBkNDg1NzE4N2UzZjE4MDI3NWIyM2UyNGU4NT9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9u

In [None]:
!mv corpus.jsonl.gz {main_dir}/trec-covid

In [None]:
!gunzip {main_dir}/trec-covid/corpus.jsonl.gz

gzip: /content/gdrive/MyDrive/Unicamp-aula-7/trec-covid/corpus.jsonl already exists; do you wish to overwrite (y or n)? ^C


### Queries download

In [None]:
!wget https://huggingface.co/datasets/BeIR/trec-covid/resolve/main/queries.jsonl.gz

--2023-04-17 04:16:22--  https://huggingface.co/datasets/BeIR/trec-covid/resolve/main/queries.jsonl.gz
Resolving huggingface.co (huggingface.co)... 13.224.249.89, 13.224.249.63, 13.224.249.86, ...
Connecting to huggingface.co (huggingface.co)|13.224.249.89|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/a8/10/a810e88b0e7b233be82b89c1fa6ec2d75efc6d55784c2ada9dcac8434a634f3a/9eadcc2cdf140addc9dae83648bb2c6611f5e4b66eaed7475fa5a0ca48eda371?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27queries.jsonl.gz%3B+filename%3D%22queries.jsonl.gz%22%3B&response-content-type=application%2Fgzip&Expires=1681964182&Policy=eyJTdGF0ZW1lbnQiOlt7IlJlc291cmNlIjoiaHR0cHM6Ly9jZG4tbGZzLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2E4LzEwL2E4MTBlODhiMGU3YjIzM2JlODJiODljMWZhNmVjMmQ3NWVmYzZkNTU3ODRjMmFkYTlkY2FjODQzNGE2MzRmM2EvOWVhZGNjMmNkZjE0MGFkZGM5ZGFlODM2NDhiYjJjNjYxMWY1ZTRiNjZlYWVkNzQ3NWZhNWEwY2E0OGVkYTM3MT9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0a

In [None]:
!mv queries.jsonl.gz {main_dir}/trec-covid

In [None]:
!gunzip {main_dir}/trec-covid/queries.jsonl.gz

gzip: /content/gdrive/MyDrive/Unicamp-aula-7/trec-covid/queries.jsonl already exists; do you wish to overwrite (y or n)? ^C


### Queries vs passages download

In [None]:
!wget https://huggingface.co/datasets/BeIR/trec-covid-qrels/raw/main/test.tsv

--2023-04-17 04:16:29--  https://huggingface.co/datasets/BeIR/trec-covid-qrels/raw/main/test.tsv
Resolving huggingface.co (huggingface.co)... 13.224.249.89, 13.224.249.86, 13.224.249.35, ...
Connecting to huggingface.co (huggingface.co)|13.224.249.89|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 980831 (958K) [text/plain]
Saving to: ‘test.tsv’


2023-04-17 04:16:30 (1.11 MB/s) - ‘test.tsv’ saved [980831/980831]



In [None]:
!mv test.tsv {main_dir}/trec-covid

### Queries and passages vectorization

In [None]:
query_model = AutoModel.from_pretrained(f"{models_dir}/query_model_18").to(device)
doc_model = AutoModel.from_pretrained(f"{models_dir}/doc_model_18").to(device)

In [None]:
batch_size = 256

In [None]:

query_ids = []
query_texts = []

with jsonlines.open(f"{main_dir}/trec-covid/queries.jsonl") as reader:
  for item in reader:
    id = item["_id"]
    query_ids.append(id)
    text = item["text"]
    query_texts.append(text)

test_queries_tokenized = tokenizer(query_texts, max_length=max_length, 
                                   truncation=True, padding=True)
dataset_queries_test = Dataset(test_queries_tokenized)
dataloader_queries_test = data.DataLoader(dataset_queries_test, 
                                          batch_size=batch_size, shuffle=False, 
                                          collate_fn=collate_fn)



In [None]:
passage_ids = []
passage_texts = []

with jsonlines.open(f"{main_dir}/trec-covid/corpus.jsonl") as reader:
  for item in reader:
    id = item["_id"]
    passage_ids.append(id)
    text = item["title"] + ' ' + item["text"]
    passage_texts.append(text)

test_passages_tokenized = tokenizer(passage_texts, max_length=max_length, 
                                   truncation=True, padding=True)
dataset_passages_test = Dataset(test_passages_tokenized)
dataloader_passages_test = data.DataLoader(dataset_passages_test, 
                                          batch_size=batch_size, shuffle=False, 
                                          collate_fn=collate_fn)

In [None]:
len(dataset_passages_test)

171332

Queries vectorization

In [None]:
cls_pooling = True

In [None]:
query_model.eval()

queries_matrix = None

with torch.no_grad():
  for query_batch in tqdm(dataloader_queries_test, mininterval=0.5, desc='Test', 
                                          disable=False):
    query_outputs = query_model(**query_batch.to(device))

    if cls_pooling:
      query_vector = query_outputs.last_hidden_state[:, 0, :]
    else:
      query_vector = query_outputs.last_hidden_state.mean(dim=1)

    if queries_matrix is None:
      queries_matrix = query_vector
    else:
      queries_matrix = torch.cat((queries_matrix, query_vector), dim=0)

Test:   0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
queries_matrix.shape

torch.Size([50, 384])

Passsages vectorization

In [None]:
doc_model.eval()

passages_matrix = None

with torch.no_grad():
  for doc_batch in tqdm(dataloader_passages_test, mininterval=0.5, desc='Test', 
                                          disable=False):
    doc_outputs = doc_model(**doc_batch.to(device))

    if cls_pooling:
      doc_vector = doc_outputs.last_hidden_state[:, 0, :]
    else:
      doc_vector = doc_outputs.last_hidden_state.mean(dim=1)

    if passages_matrix is None:
      passages_matrix = doc_vector
    else:
      passages_matrix = torch.cat((passages_matrix, doc_vector), dim=0)

Test:   0%|          | 0/670 [00:00<?, ?it/s]

In [None]:
with open(f"{main_dir}/trec-covid/passages_matrix_mean.pickle", "wb") as f:
  pickle.dump(passages_matrix, f)

In [None]:
passages_matrix = None

#mean? NDCG@10 = 0.1968
with open(f"{main_dir}/trec-covid/passages_matrix.pickle", "rb") as f:
  passages_matrix = pickle.load(f)

In [None]:
passages_matrix.shape

torch.Size([171332, 384])

In [None]:
#Check memory availability
%time similarity = torch.matmul(queries_matrix, torch.transpose(passages_matrix, 0, 1))

CPU times: user 381 µs, sys: 886 µs, total: 1.27 ms
Wall time: 726 µs


In [None]:
similarity.shape

torch.Size([50, 171332])

In [None]:
type(similarity)

torch.Tensor

Get the top k documents for each query

In [None]:
k=1000

In [None]:
run = defaultdict(list)

for i in range(similarity.shape[0]):
  doc_scores = similarity[i,:]
  doc_scores, indices = torch.sort(doc_scores, dim=-1, descending=True)
  doc_scores = doc_scores[:k]
  indices = indices[:k]
  query_id = query_ids[i]
  doc_ids = [passage_ids[indices[j].item()]for j in range(indices.shape[0])]
  run["query"] += [query_id] * k
  run["docid"] += doc_ids
  run["score"] += doc_scores.tolist()
  run["q0"] += ["q0"] * k
  run["rank"] += list(range(1,k+1))
  run["system"] += ["dense_ret"] * k


### Evaluation

In [None]:
import pandas as pd

qrel = pd.read_csv(f"{main_dir}/trec-covid/test.tsv", sep="\t", header=None, 
                   skiprows=1, names=["query", "docid", "rel"])
qrel["q0"] = "q0"
qrel = qrel.to_dict(orient="list")

In [None]:
from evaluate import load

def eval_ndcg10(run):
  trec_eval = load("trec_eval")
  results = trec_eval.compute(predictions=[run], references=[qrel])
  return results['NDCG@10']    

In [None]:
eval_ndcg10(run)

0.4186288143932398

### Approximate Nearest Neighbor

#### Option 1: hnswlib

In [None]:
import os
import csv
import pickle
import time
import hnswlib

In [None]:
embedding_size = query_model.config.hidden_size
embedding_size

384

In [None]:
%time
#Defining our hnswlib index
index_path = f"{main_dir}/hnswlib.index"
index = hnswlib.Index(space = 'ip', dim = embedding_size)

if os.path.exists(index_path):
    print("Loading index...")
    index.load_index(index_path)
else:
    ### Create the HNSWLIB index
    print("Start creating HNSWLIB index")
    index.init_index(max_elements = passages_matrix.shape[0])

    corpus_embeddings = passages_matrix.cpu().numpy()

    # Then we train the index to find a suitable clustering
    index.add_items(corpus_embeddings, list(range(len(corpus_embeddings))))

    print("Saving index to:", index_path)
    index.save_index(index_path)

# Controlling the recall by setting ef:
index.set_ef(k+1)  # ef should always be > top_k_hits

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 5.48 µs
Start creating HNSWLIB index
Saving index to: /content/gdrive/MyDrive/Unicamp-aula-7/hnswlib.index


In [None]:
run_hnswlib_knn = defaultdict(list)

for i in range(queries_matrix.shape[0]):
  question_embedding = queries_matrix[i]
  question_embedding = question_embedding.cpu().numpy()

  #We use hnswlib knn_query method to find the top_k_hits
  corpus_ids, distances = index.knn_query(question_embedding, k=k)

  # We extract corpus ids and scores for the first query
  hits = [{'corpus_id': passage_ids[id], 'score': 1-score} for id, score in zip(corpus_ids[0], distances[0])]
  hits = sorted(hits, key=lambda x: x['score'], reverse=True)

  query_id = query_ids[i]
  run_hnswlib_knn["query"] += [query_id] * k
  run_hnswlib_knn["docid"] += [h["corpus_id"] for h in hits]
  run_hnswlib_knn["score"] += [h["score"] for h in hits]
  run_hnswlib_knn["q0"] += ["q0"] * k
  run_hnswlib_knn["rank"] += list(range(1,k+1))
  run_hnswlib_knn["system"] += ["dense_ret"] * k

In [None]:
eval_ndcg10(run_hnswlib_knn)

0.4186288143932398

#### Option 2: Manual implementation

Empirically, it has been found that a good rule of thumb for choosing the number of clusters is to set k to be proportional to the square root of the size of the dataset, typically between 5 to 50.  

Source: ChatGPT

So, let's test [5, 50, 400]

In [None]:
from sklearn.cluster import KMeans
import numpy as np

def k_means(n_clusters):
  X = passages_matrix.cpu().numpy()

  # Create a KMeans instance with 2 clusters
  kmeans = KMeans(n_clusters=n_clusters, random_state=0)

  # Fit the model to the data
  kmeans.fit(X)

  # Predict the clusters for each data point
  labels = kmeans.predict(X)

  # Get the centroids of each cluster
  centroids = kmeans.cluster_centers_

  return labels, centroids

In [None]:
index_k_means_5 = k_means(5)
index_k_means_5



(array([0, 2, 2, ..., 2, 3, 4], dtype=int32),
 array([[-0.1691598 , -0.00462118, -0.15312356, ..., -0.02624302,
          0.25497615,  0.34593326],
        [-0.20944507,  0.07483591, -0.18408433, ...,  0.07129455,
          0.32713333,  0.49148273],
        [-0.43535185,  0.08170702, -0.45992637, ...,  0.04341308,
          0.43524653,  0.37533662],
        [-0.0979168 ,  0.17215057, -0.18739302, ...,  0.07751448,
          0.17153639,  0.30724558],
        [-0.17082076,  0.17823392, -0.21806201, ...,  0.18289155,
          0.33964255,  0.34775198]], dtype=float32))

In [None]:
index_k_means_50 = k_means(50)
index_k_means_50



(array([23, 23, 33, ...,  9,  0, 10], dtype=int32),
 array([[-0.12976333,  0.12458751, -0.20559192, ...,  0.11754234,
          0.24010107,  0.39828628],
        [-0.402933  , -0.08727767, -0.4558951 , ...,  0.14338899,
          0.29990456,  0.3930119 ],
        [-0.07299861,  0.210586  , -0.1557284 , ...,  0.15109593,
          0.24581847,  0.49769914],
        ...,
        [-0.1641145 ,  0.1724554 , -0.23842067, ...,  0.16965581,
          0.18852165,  0.32002014],
        [-0.15046954, -0.01348849, -0.09646183, ..., -0.1025316 ,
          0.12200475,  0.3579614 ],
        [-0.09824128,  0.08079532, -0.30863985, ...,  0.03976616,
          0.4228579 ,  0.40720406]], dtype=float32))

In [None]:
index_k_means_400 = k_means(400)
index_k_means_400

KeyboardInterrupt: ignored

Exception ignored in: 'sklearn.cluster._k_means_common._relocate_empty_clusters_dense'
Traceback (most recent call last):
  File "<__array_function__ internals>", line 177, in where
KeyboardInterrupt: 


(array([158, 305, 174, ..., 135,  23, 290], dtype=int32),
 array([[-0.22248112, -0.00641324, -0.3365814 , ...,  0.18084395,
          0.39857948,  0.3823406 ],
        [-0.1904184 ,  0.14054233, -0.25812137, ...,  0.13551325,
          0.3880657 ,  0.528619  ],
        [-0.27297834,  0.1167797 ,  0.03306125, ..., -0.02051023,
          0.29838032,  0.3724155 ],
        ...,
        [ 0.19364676,  0.19551428, -0.13123266, ..., -0.04318923,
          0.07917188,  0.18082348],
        [ 0.15300828,  0.11337031,  0.03025441, ..., -0.1815958 ,
          0.10925046,  0.21533307],
        [-0.47812337, -0.22355929, -0.27350357, ...,  0.2160744 ,
          0.34264266,  0.42233694]], dtype=float32))

In [None]:
index_k_means_400[0].shape

(171332,)

In [None]:
index_k_means_400[1].shape

(50, 384)

In [None]:
def search(query_embedding, index, passages_matrix):
  #Inmer product between query and centroids
  similarities = torch.matmul(query_embedding, 
                              torch.transpose(torch.tensor(index[1]), 0, 1).to(
                                  device))
  #Finds the best cluster
  cluster = torch.argmax(similarities).item()

  #Finds all the passages in the cluster
  doc_indices = torch.where(torch.tensor(index[0] == cluster))[0]
  filtered_passages = torch.zeros_like(passages_matrix)
  filtered_passages[doc_indices] = passages_matrix[doc_indices]

  similarity = torch.matmul(query_embedding, 
                            torch.transpose(filtered_passages, 0, 1))

  doc_scores, indices = torch.sort(similarity, dim=-1, descending=True)
  doc_scores = doc_scores[:k]
  indices = indices[:k]
  doc_ids = [passage_ids[indices[j].item()]for j in range(indices.shape[0])]
  scores = doc_scores.tolist()
  
  return doc_ids, scores

In [None]:
from tqdm import tqdm

def search_with_app_knn(index):
  run_app_knn = defaultdict(list)

  for i in tqdm(range(queries_matrix.shape[0])):
    question_embedding = queries_matrix[i].to(device)
    doc_ids, scores = search(question_embedding, index, passages_matrix)
    query_id = query_ids[i]
    run_app_knn["query"] += [query_id] * k
    run_app_knn["docid"] += doc_ids
    run_app_knn["score"] += scores
    run_app_knn["q0"] += ["q0"] * k
    run_app_knn["rank"] += list(range(1,k+1))
    run_app_knn["system"] += ["dense_ret"] * k

  return run_app_knn

In [None]:
run_app_knn = search_with_app_knn(index_k_means_5)
eval_ndcg10(run_app_knn)

0.4030904689670826

In [None]:
run_app_knn = search_with_app_knn(index_k_means_50)
eval_ndcg10(run_app_knn)

100%|██████████| 50/50 [00:08<00:00,  5.70it/s]


Downloading builder script:   0%|          | 0.00/5.51k [00:00<?, ?B/s]

0.3337930897698208

In [None]:
run_app_knn = search_with_app_knn(index_k_means_400)
eval_ndcg10(run_app_knn)