In [7]:
import torch
from ettcl.encoding import ColBERTEncoder
from ettcl.utils.utils import split_into_sentences
from ettcl.modeling.tokenization_sentence_colbert import SentenceTokenizer
from ettcl.modeling.modeling_sentence_colbert import SentenceColBERTForReranking
from datasets import load_dataset

In [5]:
train_dataset = load_dataset("imdb", split="train")
train_dataset.set_format("torch")

test_dataset = load_dataset("imdb", split="test")
test_dataset.set_format("torch")

train_dataset

Found cached dataset imdb (/home/IAIS/hiser/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)
Found cached dataset imdb (/home/IAIS/hiser/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)


Dataset({
    features: ['text', 'label'],
    num_rows: 25000
})

In [6]:
train_dataset = train_dataset.map(
    lambda text: {"sents": split_into_sentences(text)},
    input_columns="text",
)

test_dataset = test_dataset.map(
    lambda text: {"sents": split_into_sentences(text)},
    input_columns="text",
)

train_dataset

Loading cached processed dataset at /home/IAIS/hiser/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-ffa864a77e2b118e.arrow


Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Dataset({
    features: ['text', 'label', 'sents'],
    num_rows: 25000
})

In [4]:
model = SentenceColBERTForReranking.from_pretrained("bert-base-uncased")
tokenizer = SentenceTokenizer.from_pretrained("bert-base-uncased", doc_maxlen=32, query_maxlen=32)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing SentenceColBERTForReranking: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing SentenceColBERTForReranking from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SentenceColBERTForReranking from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### Check triple processing

In [8]:
from ettcl.core.triple_sampling import DataCollatorForTriples

In [9]:
train_dataset = train_dataset.map(
    lambda sents: tokenizer(sents, truncation=True),
    input_columns="sents",
    batched=True,
)

train_dataset

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Dataset({
    features: ['text', 'label', 'sents', 'input_ids', 'attention_mask'],
    num_rows: 25000
})

In [10]:
triples = [
    [
        train_dataset.select_columns(["input_ids", "attention_mask"])[j]
        for j in range(i, i+4)
    ]
    for i in range(0, 400, 4)
]
len(triples)

100

In [11]:
data_collator = DataCollatorForTriples(tokenizer)
dl = torch.utils.data.DataLoader(
    triples,
    batch_size=8,
    shuffle=True,
    collate_fn=data_collator,
)
batch = next(iter(dl))

print("batch_size:", len(batch["input_ids"]))
print("nway:", len(batch["input_ids"][0]))
print("num_sentences:", len(batch["input_ids"][0][0]))
print("sentence_length:", len(batch["input_ids"][0][0][0]))
print(batch["input_ids"].shape)

model(**batch).loss

batch_size: 8
nway: 4
num_sentences: 16
sentence_length: 32
torch.Size([8, 4, 16, 32])


tensor([1.1147], grad_fn=<UnsqueezeBackward0>)

### Check Indexing & Searching

In [14]:
from ettcl.indexing import ColBERTIndexer
from ettcl.searching import ColBERTSearcher
from ettcl.utils.multiprocessing import run_multiprocessed

index_path = "models/imdb/sentence_colbert/index"

In [12]:
encoder = ColBERTEncoder(model.colbert, tokenizer)
indexer = ColBERTIndexer(encoder)

In [None]:
indexer.index(index_path, train_dataset["sents"])

In [16]:
searcher = ColBERTSearcher(index_path, encoder)

test_dataset = test_dataset.map(
    run_multiprocessed(searcher.search),
    input_columns="sents",
    fn_kwargs={"k": 50},
    batched=True,
    num_proc=2,
    with_rank=True
)

test_dataset

Setting TOKENIZERS_PARALLELISM=false for forked processes.


Map (num_proc=2):   0%|          | 0/25000 [00:00<?, ? examples/s]

Dataset({
    features: ['text', 'label', 'sents', 'match_pids', 'match_scores'],
    num_rows: 25000
})

### Evaluate Metrics

In [17]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [18]:
prefix = ""
ks = [1,3,5,10]

match_pids = test_dataset["match_pids"]
if isinstance(match_pids, list):
    print("WARNING")
    match_pids = torch.nn.utils.rnn.pad_sequence(match_pids, batch_first=True, padding_value=-1)

match_labels = train_dataset["label"][match_pids.tolist()]

metrics = {}
for k in ks:
    knn = match_labels[:, :k]
    y_pred = torch.mode(knn)[0]
    assert -1 not in y_pred, "Not enough matches"

    metrics[f"{prefix}/accuracy/{k}"] = accuracy_score(y_pred=y_pred, y_true=test_dataset["label"])
    metrics[f"{prefix}/precision/micro/{k}"] = precision_score(
        y_pred=y_pred, y_true=test_dataset["label"], average="micro"
    )
    metrics[f"{prefix}/precision/macro/{k}"] = precision_score(
        y_pred=y_pred, y_true=test_dataset["label"], average="macro"
    )
    metrics[f"{prefix}/recall/micro/{k}"] = recall_score(
        y_pred=y_pred, y_true=test_dataset["label"], average="micro"
    )
    metrics[f"{prefix}/recall/macro/{k}"] = recall_score(
        y_pred=y_pred, y_true=test_dataset["label"], average="macro"
    )
    metrics[f"{prefix}/f1/micro/{k}"] = f1_score(
        y_pred=y_pred, y_true=test_dataset["label"], average="micro"
    )
    metrics[f"{prefix}/f1/macro/{k}"] = f1_score(
        y_pred=y_pred, y_true=test_dataset["label"], average="macro"
    )

metrics

{'/accuracy/1': 0.70936,
 '/precision/micro/1': 0.70936,
 '/precision/macro/1': 0.7096982066359365,
 '/recall/micro/1': 0.70936,
 '/recall/macro/1': 0.70936,
 '/f1/micro/1': 0.70936,
 '/f1/macro/1': 0.7092427648219298,
 '/accuracy/3': 0.74564,
 '/precision/micro/3': 0.74564,
 '/precision/macro/3': 0.74645996797576,
 '/recall/micro/3': 0.74564,
 '/recall/macro/3': 0.74564,
 '/f1/micro/3': 0.7456399999999999,
 '/f1/macro/3': 0.7454282610762658,
 '/accuracy/5': 0.764,
 '/precision/micro/5': 0.764,
 '/precision/macro/5': 0.7656060180232452,
 '/recall/micro/5': 0.764,
 '/recall/macro/5': 0.764,
 '/f1/micro/5': 0.764,
 '/f1/macro/5': 0.7636427094617858,
 '/accuracy/10': 0.77416,
 '/precision/micro/10': 0.77416,
 '/precision/macro/10': 0.7855452910860196,
 '/recall/micro/10': 0.77416,
 '/recall/macro/10': 0.77416,
 '/f1/micro/10': 0.77416,
 '/f1/macro/10': 0.771886155356883}