In [3]:
from datasets import load_dataset
from typing import List
import random
from dataclasses import dataclass
import time

@dataclass
class Result:
    document_id: str
    document_text: str
    score: float

class RandomRetriever:
    def __init__(self, corpus):
        self.corpus = corpus

    def search(self, query: str, k: int) -> List[Result]:
        results = random.sample(list(self.corpus.items()), k)
        return [Result(doc_id, text, random.random()) for doc_id, text in results]

def evaluate_retriever(retriever, queries, corpus, k=1, num_queries=1000):
    correct = 0
    total_time = 0

    for query, relevant_docs in queries[:num_queries]:
        start_time = time.time()
        results = retriever.search(query, k)
        end_time = time.time()

        total_time += end_time - start_time

        if any(result.document_id in relevant_docs for result in results):
            correct += 1

    recall_at_k = correct / num_queries
    avg_time = total_time / num_queries

    return recall_at_k, avg_time

  from .autonotebook import tqdm as notebook_tqdm


In [12]:
# load datasets
corpus = load_dataset("mteb/msmarco-v2", "corpus")
queries = load_dataset("mteb/msmarco-v2", "queries")
default = load_dataset("mteb/msmarco-v2", "default")

Downloading data: 100%|██████████| 18.7M/18.7M [00:01<00:00, 10.9MB/s]
Downloading data: 100%|██████████| 265k/265k [00:00<00:00, 1.46MB/s]
Downloading data: 100%|██████████| 292k/292k [00:00<00:00, 1.04MB/s]
Generating train split: 100%|██████████| 284212/284212 [00:00<00:00, 6649958.32 examples/s]
Generating dev split: 100%|██████████| 4009/4009 [00:00<00:00, 1416474.16 examples/s]
Generating dev2 split: 100%|██████████| 4411/4411 [00:00<00:00, 1694547.99 examples/s]


In [13]:
default

DatasetDict({
    train: Dataset({
        features: ['query-id', 'corpus-id', 'score'],
        num_rows: 284212
    })
    dev: Dataset({
        features: ['query-id', 'corpus-id', 'score'],
        num_rows: 4009
    })
    dev2: Dataset({
        features: ['query-id', 'corpus-id', 'score'],
        num_rows: 4411
    })
})

In [10]:
corpus

DatasetDict({
    corpus: Dataset({
        features: ['_id', 'title', 'text'],
        num_rows: 138364198
    })
})

In [11]:
queries

DatasetDict({
    queries: Dataset({
        features: ['_id', 'text'],
        num_rows: 285328
    })
})

In [14]:
import json

# Get dev set from default dataset
dev_set = default['dev']

# Create a dictionary of query_id to query_text
query_dict = {q['_id']: q['text'] for q in queries['queries']}

# Create a list to store our evaluation queries
eval_queries = []

# Use set for faster lookup
used_query_ids = set()

# Randomly sample from dev set until we have 1000 unique queries
while len(eval_queries) < 1000:
    idx = random.randint(0, len(dev_set) - 1)
    query_id = dev_set[idx]['query-id']
    
    if query_id not in used_query_ids and query_id in query_dict:
        used_query_ids.add(query_id)
        eval_queries.append({
            'query_id': query_id,
            'query_text': query_dict[query_id],
            'relevant_doc_id': dev_set[idx]['corpus-id']
        })

# Save the evaluation queries to a JSON file
with open('eval_queries.json', 'w') as f:
    json.dump(eval_queries, f)

print(f"Generated and saved {len(eval_queries)} evaluation queries to eval_queries.json")

Generated and saved 1000 evaluation queries to eval_queries.json


## Simple bi-encoder

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertTokenizerFast, BertModel
import torch.nn.functional as F
from datasets import load_dataset
import random
import json

# Set device and random seed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)

# Constants
BATCH_SIZE = 10
TRAIN_FRACTION = 0.01
ACCURACY_INTERVAL = 10

class BiEncoder(nn.Module):
    def __init__(self):
        super(BiEncoder, self).__init__()
        self.bert = BertModel.from_pretrained('prajjwal1/bert-tiny')

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        return F.normalize(outputs.pooler_output, p=2, dim=1)

def load_and_prepare_data():
    corpus = load_dataset("mteb/msmarco-v2", "corpus")
    queries = load_dataset("mteb/msmarco-v2", "queries")
    default = load_dataset("mteb/msmarco-v2", "default")

    # Create a mapping of query_id to query_text
    query_dict = {q['_id']: q['text'] for q in queries['queries']}
    
    # Create a mapping of document_id to document_text
    doc_dict = {d['_id']: d['text'] for d in corpus['corpus']}

    # Prepare training data
    train_data = []
    for item in default['train']:
        if item['query-id'] in query_dict and item['corpus-id'] in doc_dict:
            train_data.append({
                'query': query_dict[item['query-id']],
                'positive': doc_dict[item['corpus-id']],
                'negatives': [doc_dict[random.choice(list(doc_dict.keys()))] for _ in range(3)]  # 3 random negatives
            })

    # Use only 1% of the data
    train_data = random.sample(train_data, int(len(train_data) * TRAIN_FRACTION))

    return train_data

def tokenize_batch(batch, tokenizer):
    queries = tokenizer([item['query'] for item in batch], padding=True, truncation=True, return_tensors='pt')
    positives = tokenizer([item['positive'] for item in batch], padding=True, truncation=True, return_tensors='pt')
    negatives = tokenizer([neg for item in batch for neg in item['negatives']], padding=True, truncation=True, return_tensors='pt')

    return {
        'queries': {k: v.to(device) for k, v in queries.items()},
        'positives': {k: v.to(device) for k, v in positives.items()},
        'negatives': {k: v.reshape(len(batch), -1, v.shape[-1]).to(device) for k, v in negatives.items()}
    }

def cross_entropy_loss(similarities):
    return nn.CrossEntropyLoss()(similarities, torch.zeros(similarities.shape[0], dtype=torch.long, device=device))

def recall_at_1(similarities):
    return (similarities.argmax(dim=1) == 0).float().mean().item()

def train_and_evaluate():
    model = BiEncoder().to(device)
    tokenizer = BertTokenizerFast.from_pretrained('prajjwal1/bert-tiny')
    optimizer = optim.Adam(model.parameters(), lr=2e-4)

    train_data = load_and_prepare_data()
    
    losses = []
    recalls = []

    for epoch in range(3):  # 3 epochs
        for i in range(0, len(train_data), BATCH_SIZE):
            batch = train_data[i:i+BATCH_SIZE]
            tokenized_batch = tokenize_batch(batch, tokenizer)

            optimizer.zero_grad()

            query_embeddings = model(**tokenized_batch['queries'])
            positive_embeddings = model(**tokenized_batch['positives'])
            negative_embeddings = model(**tokenized_batch['negatives'])

            doc_embeddings = torch.cat([positive_embeddings.unsqueeze(1), negative_embeddings], dim=1)
            similarities = torch.bmm(query_embeddings.unsqueeze(1), doc_embeddings.transpose(1, 2)).squeeze(1)

            loss = cross_entropy_loss(similarities)
            loss.backward()
            optimizer.step()

            losses.append(loss.item())
            recalls.append(recall_at_1(similarities))

            if i % (ACCURACY_INTERVAL * BATCH_SIZE) == 0:
                avg_loss = sum(losses[-ACCURACY_INTERVAL:]) / min(ACCURACY_INTERVAL, len(losses))
                avg_recall = sum(recalls[-ACCURACY_INTERVAL:]) / min(ACCURACY_INTERVAL, len(recalls))
                print(f"Epoch {epoch+1}, Batch {i//BATCH_SIZE}, Loss: {avg_loss:.4f}, Recall@1: {avg_recall:.4f}")

    print(f"Final Loss: {sum(losses[-100:]) / 100:.4f}, Final Recall@1: {sum(recalls[-100:]) / 100:.4f}")

    return losses, recalls

if __name__ == "__main__":
    losses, recalls = train_and_evaluate()

    # You can add plotting code here if needed
    # plot(losses, recalls)