### Load Dataset & Train Index

In [4]:
import torch
import evaluate
import datasets
from ettcl.searching import ColBERTSearcher
from ettcl.encoding import ColBERTEncoder, EncoderFactory

dataset = "trec"
label_column = 'coarse_label'
checkpoint = "../models/colbertv2.0"
index_path = "../indexes/colbertv2.0.2bits"
encoder_factory = EncoderFactory(ColBERTEncoder, checkpoint=checkpoint)

In [5]:
train_dataset = datasets.load_dataset(dataset, split="train")

train_dataset.set_format("pt", columns=[label_column])
test_dataset = datasets.load_dataset(dataset, split="test")
test_dataset.set_format("pt", columns=[label_column])

Found cached dataset trec (/home/hiser/.cache/huggingface/datasets/trec/default/2.0.0/f2469cab1b5fceec7249fda55360dfdbd92a7a5b545e91ea0f78ad108ffac1c2)
Found cached dataset trec (/home/hiser/.cache/huggingface/datasets/trec/default/2.0.0/f2469cab1b5fceec7249fda55360dfdbd92a7a5b545e91ea0f78ad108ffac1c2)


#### Evaluation

In [6]:
k = 5
searcher = ColBERTSearcher(index_path, encoder_factory)
accuracy_metric = evaluate.load("accuracy")

indices, scores = searcher.search(test_dataset["text"], k=k)
indices[0].shape, scores[0].shape

[Jun 07, 21:14:52] Loading segmented_maxsim_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
[Jun 07, 21:14:52] #> Loading codec...
[Jun 07, 21:14:52] #> Loading IVF...
[Jun 07, 21:14:52] Loading segmented_lookup_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...




[Jun 07, 21:14:52] #> Loading doclens...


100%|██████████| 2/2 [00:00<00:00, 10512.04it/s]

[Jun 07, 21:14:53] #> Loading codes and residuals...



100%|██████████| 2/2 [00:00<00:00, 1134.21it/s]

[Jun 07, 21:14:53] Loading filter_pids_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...





[Jun 07, 21:14:53] Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...


100%|██████████| 4/4 [00:12<00:00,  3.12s/it]
100%|██████████| 500/500 [00:00<00:00, 1088.29it/s]


(torch.Size([5]), torch.Size([5]))

In [7]:
voting = train_dataset[label_column][torch.cat(indices)].view(-1, k)
y_pred = torch.mode(voting, dim=1, keepdim=False)[0]

accuracy_metric.compute(predictions=y_pred, references=test_dataset[label_column])

{'accuracy': 0.76}