In [None]:
! pip install transformers datasets

In [None]:
from datasets import load_dataset

# Load HotpotQA dataset
hotpot_dataset = load_dataset("hotpot_qa", "distractor", trust_remote_code=True)

# Display HotpotQA dataset structure
print(hotpot_dataset)

In [None]:
# Load WikiNQ dataset
wiki_dataset = load_dataset("Tevatron/wikipedia-nq")

# Display WikiNQ dataset structure
print(wiki_dataset)

In [None]:
hotpot_train_set = hotpot_dataset['train']
hotpot_test_set = hotpot_dataset['validation']

wiki_train_set = wiki_dataset['train']
wiki_test_set = wiki_dataset['test']

In [None]:
small_hotpot_train_dataset = hotpot_train_set.shuffle(seed=42).select(range(2000))
small_hotpot_test_dataset = hotpot_test_set.shuffle(seed=42).select(range(1000))

small_wiki_train_dataset = wiki_train_set.shuffle(seed=43).select(range(2000))
small_wiki_test_dataset = wiki_test_set.shuffle(seed=43).select(range(1000))

In [None]:
import random

def get_random_negatives(golden_doc_titles, corpus, num_negatives):
    """Randomly sample negatives excluding golden_docs."""
    available_doc_titles = [doc_title for doc_title in corpus if doc_title not in golden_doc_titles]
    sampled_titles = random.sample(available_doc_titles, min(len(available_doc_titles), num_negatives))
    negatives = [''.join(corpus[doc_title]) for doc_title in sampled_titles]
    return negatives

In [None]:
def create_corpus_hotpot(data, sample_method, batch_size):
    """Preprocess data to create datapoints with negatives."""
    if sample_method == 'global':
        global_corpus = {}
        for entry in data:
            for title, sentences in zip(entry['context']['title'], entry['context']['sentences']):
                if title not in global_corpus:
                    global_corpus[title] = ''.join(sentences)
        return global_corpus
    elif sample_method == 'inbatch':
        data_list = list(data)
        batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
        batch_corpuses = []
        for batch in batches:
            batch_corpus = {}
            for entry in batch:
                for title, sentences in zip(entry['context']['title'], entry['context']['sentences']):
                    if title not in batch_corpus:
                        batch_corpus[title] = ''.join(sentences)
            batch_corpuses.append(batch_corpus)
        return batch_corpuses
    else:
        assert(0)
    

In [None]:
# Preprocessing function Hotpot
def preprocess_data_train_hotpot(data, num_negatives, sample_method='global', batch_size=32):
    """Preprocess data to create datapoints with negatives."""
    corpuses = create_corpus_hotpot(data, sample_method, batch_size)

    processed_data = []

    for idx, entry in enumerate(data):
        query = entry['question']
        golden_doc_titles = entry['supporting_facts']['title']
        context = entry['context']
        if sample_method == 'global':
            corpus = corpuses
        else:
            corpus = corpuses[idx // batch_size]

        for golden_doc_title in golden_doc_titles:
            golden_doc = ''.join(context['sentences'][context['title'].index(golden_doc_title)])
            negatives = get_random_negatives(golden_doc_titles, corpus, num_negatives)

            datapoint = {
                'query': query,
                'doc': golden_doc,
                'negatives': negatives,
            }
            processed_data.append(datapoint)

    return processed_data

In [None]:
# Preprocessing function Hotpot
def preprocess_data_test_hotpot(data, num_negatives, sample_method='global', batch_size=32):
    """Preprocess data to create datapoints with negatives."""
    corpuses = create_corpus_hotpot(data, sample_method, batch_size)

    processed_data = []

    for idx, entry in enumerate(data):
        query = entry['question']
        golden_doc_titles = entry['supporting_facts']['title']
        context = entry['context']
        if sample_method == 'global':
            corpus = corpuses
        else:
            corpus = corpuses[idx // batch_size]

        positives = [''.join(context['sentences'][context['title'].index(golden_doc_title)]) for golden_doc_title in golden_doc_titles]
        negatives = get_random_negatives(golden_doc_titles, corpus, num_negatives * len(positives))
        
        datapoint = {
            'query': query,
            'positives': positives,
            'negatives': negatives,
        }
        processed_data.append(datapoint)

    return processed_data

In [None]:
# Preprocessing function WikiNQ
def preprocess_data_train_wiki(data, num_negatives):
    processed_data = []
    
    for idx, entry in enumerate(data):
        query = entry['query']
        golden_docs = entry['positive_passages']
        negative_docs = entry['negative_passages']
        
        for golden_doc in golden_docs:
            negatives = random.sample(negative_docs, min(len(negative_docs), num_negatives))
            negatives = [neg['text'] for neg in negatives]
            
            datapoint = {
                'query': query,
                'doc': golden_doc['text'],
                'negatives': negatives,
            }
            processed_data.append(datapoint)
            
    return processed_data

In [None]:
# Preprocessing function WikiNQ
def preprocess_data_test_wiki(data, num_negatives):
    processed_data = []
    
    for idx, entry in enumerate(data):
        query = entry['query']
        golden_docs = entry['positive_passages']
        negative_docs = entry['negative_passages']
        
        positives = [golden_doc['text'] for golden_doc in golden_docs]
        negatives = random.sample(negative_docs, min(len(negative_docs), num_negatives * len(positives)))
        negatives = [neg['text'] for neg in negatives]   
            
        datapoint = {
            'query': query,
            'positives': positives,
            'negatives': negatives,
        }
        processed_data.append(datapoint)
            
    return processed_data

In [None]:
hotpot_train_data_random = preprocess_data_train_hotpot(small_hotpot_train_dataset, 7, 'global')[:2500]
hotpot_train_data_inbatch = preprocess_data_train_hotpot(small_hotpot_train_dataset, 7, 'inbatch', 16)[:2500]

hotpot_test_data_random = preprocess_data_test_hotpot(small_hotpot_test_dataset, 15, 'global')[:2500]
hotpot_test_data_inbatch = preprocess_data_test_hotpot(small_hotpot_test_dataset, 15, 'inbatch', 16)[:2500]

In [None]:
wiki_train_data = preprocess_data_train_wiki(small_wiki_train_dataset, 7)[:2500]

wiki_test_data = preprocess_data_test_wiki(small_wiki_test_dataset, 15)[:2500]

In [None]:
import torch
from transformers import BertTokenizer, BertForPreTraining

tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

def tokenize_input(query, document):
    return tokenizer.encode_plus(
        query,
        document,
        add_special_tokens=True,
        max_length=512,
        padding='max_length',
        truncation=True,
        return_tensors="pt"
    )

In [None]:
from torch.nn import functional as F

def compute_document_likelihood(model, encoding):
    # Forward pass to get the logits
    outputs = model(input_ids=encoding['input_ids'], attention_mask=encoding['attention_mask'])
    logits = outputs.prediction_logits  # Shape: (batch_size, seq_len, hidden_size)

    # Get the query logits (after the [SEP] token)
    sep_token_start_idx = (encoding['input_ids'] == tokenizer.sep_token_id).nonzero(as_tuple=True)[1][0].item()  # Index of first SEP
    sep_token_end_idx = (encoding['input_ids'] == tokenizer.sep_token_id).nonzero(as_tuple=True)[1][1].item()  # Index of second SEP

    document_logits = logits[:, sep_token_start_idx + 1: sep_token_end_idx, :]  # Query logits after the second SEP

    # Normalize logits
    normalized_logits = F.log_softmax(document_logits, dim=-1)

    document_encodings = encoding['input_ids'][0, sep_token_start_idx + 1: sep_token_end_idx]

    # Sum of log likelihoods for query tokens
    document_likelihood = torch.sum(normalized_logits[0][range(len(document_encodings)), document_encodings])

    return document_likelihood

In [None]:
from torch.nn import BCEWithLogitsLoss
from tqdm import tqdm

CrossEntropyLoss = BCEWithLogitsLoss()

def train_document_likelihood_model(train_data, model, optimizer, epochs=3, device='cuda'):
    model.to(device)
    model.train()

    for epoch in range(epochs):
        total_loss = 0

        for data in tqdm(train_data, desc=f"Epoch {epoch+1}", total=len(train_data)):
            query = data['query']
            positive_doc = data['doc']
            negatives = data['negatives']

            # Move inputs to device (e.g., GPU if available)
            encoding_pos = tokenize_input(query, positive_doc).to(device)
            encoding_negs = [tokenize_input(query, neg_doc).to(device) for neg_doc in negatives]
            
            # Compute likelihoods for the document and negatives
            sim_pos = compute_document_likelihood(model, encoding_pos)
            sim_negs = torch.stack([compute_document_likelihood(model, encoding_neg) for encoding_neg in encoding_negs])

            all_similarities = torch.cat([sim_pos.unsqueeze(0), sim_negs], dim=0)
            target = torch.cat([torch.ones(1), torch.zeros(len(sim_negs))], dim=0).to(device)
            
            loss = CrossEntropyLoss(all_similarities, target)
 
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_data)
        print(f'Epoch {epoch+1}, Average Loss: {avg_loss}')

In [None]:
def evaluate_document_likelihood_model(eval_data, model, device='cuda'):
    model.to(device)
    model.eval()

    precision_at_1 = 0
    precision_at_10 = 0
    mrr_total = 0
    map_total = 0
    total_queries = len(eval_data)
    
    with torch.no_grad():
        for data in tqdm(eval_data, desc="Evaluating", total=total_queries):
            query = data['query']
            positives = data['positives']
            negatives = data['negatives']

            # Move inputs to device (e.g., GPU if available)
            encoding_poses = [tokenize_input(query, pos_doc).to(device) for pos_doc in positives]
            encoding_negs = [tokenize_input(query, neg_doc).to(device) for neg_doc in negatives]

            # Compute likelihoods for the document and negatives
            sim_poses = torch.stack([compute_document_likelihood(model, encoding_pos) for encoding_pos in encoding_poses])
            sim_negs = torch.stack([compute_document_likelihood(model, encoding_neg) for encoding_neg in encoding_negs])

            # Concatenate positive and negative likelihoods
            all_similarities = torch.cat([sim_poses, sim_negs], dim=0)

            # Get the ranking of the positive document (rank is based on the similarity score)
            rankings = torch.argsort(torch.argsort(all_similarities, descending=True))
            rank_of_positives = rankings[:len(sim_poses)]

            if 0 in rank_of_positives:
                precision_at_1 += 1
            precision_at_10 += len(list(filter(lambda x: x < 10, rank_of_positives))) / min(len(all_similarities), 10)

            # Compute Reciprocal Rank for MRR
            rank_of_first_positive = min(rank_of_positives) + 1
            mrr_total += 1 / rank_of_first_positive

            # Compute Average Precision for MAP
            ap = 0.0
            for i, rank in enumerate(sorted(rank_of_positives)):
                ap += (i + 1) / (rank + 1)  # Precision at each rank where a positive document appears
            map_total += ap / len(rank_of_positives)  # Average precision for this query

    # Average the metrics over all queries
    precision_at_1 /= total_queries
    precision_at_10 /= total_queries
    mrr_total /= total_queries
    map_total /= total_queries
    
    print(f'Precision@1: {precision_at_1}')
    print(f'Precision@10: {precision_at_10}')
    print(f'MRR: {mrr_total}')
    print(f'MAP: {map_total}')

In [None]:
import warnings
import logging
warnings.filterwarnings('ignore')
logging.getLogger("transformers").setLevel(logging.ERROR)

In [None]:
# Train with Random/Global Negatives Hotpot
model_hotpot_random = BertForPreTraining.from_pretrained("bert-base-cased")
# Optimizer
optimizer_hotpot_random = torch.optim.AdamW(model_hotpot_random.parameters(), lr=5e-5)

train_document_likelihood_model(hotpot_train_data_random, model_hotpot_random, optimizer_hotpot_random)

In [None]:
evaluate_document_likelihood_model(hotpot_test_data_random, model_hotpot_random)

In [None]:
# Train with Inbatch Negatives Hotpot
model_hotpot_inbatch = BertForPreTraining.from_pretrained("bert-base-cased")
# Optimizer
optimizer_hotpot_inbatch = torch.optim.AdamW(model_hotpot_inbatch.parameters(), lr=5e-5)

train_document_likelihood_model(hotpot_train_data_inbatch, model_hotpot_inbatch, optimizer_hotpot_inbatch)

In [None]:
evaluate_document_likelihood_model(hotpot_test_data_inbatch, model_hotpot_inbatch)

In [None]:
# Train with WikiNQ
model_wiki = BertForPreTraining.from_pretrained("bert-base-cased")
# Optimizer
optimizer_wiki = torch.optim.AdamW(model_wiki.parameters(), lr=5e-5)

train_document_likelihood_model(wiki_train_data, model_wiki, optimizer_wiki)

In [None]:
evaluate_document_likelihood_model(wiki_test_data, model_wiki)