### Load Dataset & Train Index

In [1]:
import torch
import evaluate
import datasets
from pathlib import Path
from ettcl.searching import ColBERTSearcher, SearcherConfig
from ettcl.encoding import ColBERTEncoder
from ettcl.modeling import ColBERTModel, ColBERTTokenizer

dataset = "imdb"
checkpoint = Path("../training/imdb/bert-base-uncased/checkpoint-12500")
index_path = checkpoint / "index"
label_column = 'label'

In [2]:
model = ColBERTModel.from_pretrained(checkpoint)
tokenizer = ColBERTTokenizer(checkpoint)
encoder = ColBERTEncoder(model, tokenizer)

In [3]:
train_dataset = datasets.load_dataset(dataset, split="train")
train_dataset.set_format("pt", columns=[label_column])
test_dataset = datasets.load_dataset(dataset, split="test")
test_dataset.set_format("pt", columns=[label_column])
len(train_dataset), len(test_dataset)

Found cached dataset imdb (/home/IAIS/hiser/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)
Found cached dataset imdb (/home/IAIS/hiser/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)


(25000, 25000)

#### Evaluation

In [4]:
k = 1
searcher_config = SearcherConfig()
searcher = ColBERTSearcher(index_path, encoder, searcher_config)
accuracy_metric = evaluate.load("accuracy")

In [5]:
match_indices, match_scores = searcher.search(test_dataset["text"], k=k, return_tensors="pt", progress_bar=True)
match_indices[1].shape, match_scores[1].shape

USING DEVICE 0
[Jun 19, 02:32:27] #> Loading codec...
[Jun 19, 02:32:27] Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
[Jun 19, 02:32:27] Loading packbits_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
[Jun 19, 02:32:29] #> Loading IVF...
[Jun 19, 02:32:29] #> Loading doclens...


100%|██████████| 2/2 [00:00<00:00, 70.18it/s]

[Jun 19, 02:32:29] #> Loading codes and residuals...



100%|██████████| 2/2 [00:00<00:00,  2.60it/s]


Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Encoding: 98it [05:02,  3.09s/it]                        
Searching (device 0): 100%|██████████| 25000/25000 [14:36<00:00, 28.51it/s]


(torch.Size([1]), torch.Size([1]))

In [18]:
# y_pred = [torch.mode(train_dataset[label_column][indices]).values for indices in match_indices]
y_pred = train_dataset[label_column][torch.cat(match_indices)]
accuracy_metric.compute(predictions=y_pred, references=test_dataset[label_column])

{'accuracy': 0.49956}

In [1]:
train_dataset[match_indices[0]]

NameError: name 'train_dataset' is not defined