In [None]:
! pip install transformers datasets

In [None]:
from datasets import load_dataset

# Load HotpotQA dataset
hotpot_dataset = load_dataset("hotpot_qa", "distractor", trust_remote_code=True)

# Display HotpotQA dataset structure
print(hotpot_dataset)

In [None]:
# Load WikiNQ dataset
wiki_dataset = load_dataset("Tevatron/wikipedia-nq")

# Display WikiNQ dataset structure
print(wiki_dataset)

In [None]:
hotpot_train_set = hotpot_dataset['train']
hotpot_test_set = hotpot_dataset['validation']

wiki_train_set = wiki_dataset['train']
wiki_test_set = wiki_dataset['dev']

In [None]:
small_hotpot_train_dataset = hotpot_train_set.shuffle(seed=42).select(range(9000))
small_hotpot_test_dataset = hotpot_test_set.shuffle(seed=42)

small_wiki_train_dataset = wiki_train_set.shuffle(seed=43).select(range(9000))
small_wiki_test_dataset = wiki_test_set.shuffle(seed=43)

In [None]:
import random

def get_random_negatives(golden_doc_titles, corpus, num_negatives):
    """Randomly sample negatives excluding golden_docs."""
    available_doc_titles = [doc_title for doc_title in corpus if doc_title not in golden_doc_titles]
    sampled_titles = random.sample(available_doc_titles, min(len(available_doc_titles), num_negatives))
    negatives = [''.join(corpus[doc_title]) for doc_title in sampled_titles]
    return negatives

In [None]:
def create_corpus_hotpot(data, sample_method, batch_size):
    """Preprocess data to create datapoints with negatives."""
    if sample_method == 'global':
        global_corpus = {}
        for entry in data:
            for title, sentences in zip(entry['context']['title'], entry['context']['sentences']):
                if title not in global_corpus:
                    global_corpus[title] = ''.join(sentences)
        return global_corpus
    elif sample_method == 'inbatch':
        data_list = list(data)
        batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
        batch_corpuses = []
        for batch in batches:
            batch_corpus = {}
            for entry in batch:
                for title, sentences in zip(entry['context']['title'], entry['context']['sentences']):
                    if title not in batch_corpus:
                        batch_corpus[title] = ''.join(sentences)
            batch_corpuses.append(batch_corpus)
        return batch_corpuses
    else:
        assert(0)
    

In [None]:
from tqdm import tqdm

# Preprocessing function Hotpot
def preprocess_data_train_hotpot(data, num_negatives, sample_method='global', batch_size=32):
    """Preprocess data to create datapoints with negatives."""
    corpuses = create_corpus_hotpot(data, sample_method, batch_size)

    processed_data = []

    for idx, entry in tqdm(enumerate(data), total=len(data)):
        query = entry['question']
        golden_doc_titles = entry['supporting_facts']['title']
        context = entry['context']
        if sample_method == 'global':
            corpus = corpuses
        else:
            corpus = corpuses[idx // batch_size]

        for golden_doc_title in golden_doc_titles:
            golden_doc = ''.join(context['sentences'][context['title'].index(golden_doc_title)])
            negatives = get_random_negatives(golden_doc_titles, corpus, num_negatives)

            datapoint = {
                'query': query,
                'positives': [golden_doc],
                'negatives': negatives,
            }
            processed_data.append(datapoint)

    return processed_data

In [None]:
# Preprocessing function Hotpot
def preprocess_data_test_hotpot(data, num_negatives, sample_method='global', batch_size=32):
    """Preprocess data to create datapoints with negatives."""
    corpuses = create_corpus_hotpot(data, sample_method, batch_size)

    processed_data = []

    for idx, entry in tqdm(enumerate(data), total=len(data)):
        query = entry['question']
        golden_doc_titles = entry['supporting_facts']['title']
        context = entry['context']
        if sample_method == 'global':
            corpus = corpuses
        else:
            corpus = corpuses[idx // batch_size]

        positives = [''.join(context['sentences'][context['title'].index(golden_doc_title)]) for golden_doc_title in golden_doc_titles]
        negatives = get_random_negatives(golden_doc_titles, corpus, num_negatives * len(positives))
        
        datapoint = {
            'query': query,
            'positives': positives,
            'negatives': negatives,
        }
        processed_data.append(datapoint)

    return processed_data

In [None]:
# Preprocessing function WikiNQ
def preprocess_data_train_wiki(data, num_negatives):
    processed_data = []
    
    for idx, entry in tqdm(enumerate(data), total=len(data)):
        query = entry['query']
        golden_docs = entry['positive_passages']
        negative_docs = entry['negative_passages']
        
        for golden_doc in golden_docs:
            negatives = random.sample(negative_docs, min(len(negative_docs), num_negatives))
            negatives = [neg['text'] for neg in negatives]
            
            datapoint = {
                'query': query,
                'positives': [golden_doc['text']],
                'negatives': negatives,
            }
            processed_data.append(datapoint)
            
    return processed_data

In [None]:
# Preprocessing function WikiNQ
def preprocess_data_test_wiki(data, num_negatives):
    processed_data = []
    
    for idx, entry in tqdm(enumerate(data), total=len(data)):
        query = entry['query']
        golden_docs = entry['positive_passages']
        negative_docs = entry['negative_passages']
        
        positives = [golden_doc['text'] for golden_doc in golden_docs]
        negatives = random.sample(negative_docs, min(len(negative_docs), num_negatives * len(positives)))
        negatives = [neg['text'] for neg in negatives]   
            
        datapoint = {
            'query': query,
            'positives': positives,
            'negatives': negatives,
        }
        processed_data.append(datapoint)
            
    return processed_data

In [None]:
hotpot_train_data_random = preprocess_data_train_hotpot(small_hotpot_train_dataset, 7, 'global')[:9000]
hotpot_train_data_inbatch = preprocess_data_train_hotpot(small_hotpot_train_dataset, 7, 'inbatch')[:9000]

hotpot_test_data_random = preprocess_data_test_hotpot(small_hotpot_test_dataset, 15, 'global')[:1200]
hotpot_test_data_inbatch = preprocess_data_test_hotpot(small_hotpot_test_dataset, 15, 'inbatch')[:1200]

In [None]:
wiki_train_data = preprocess_data_train_wiki(small_wiki_train_dataset, 7)[:9000]

wiki_test_data = preprocess_data_test_wiki(small_wiki_test_dataset, 15)[:1200]

In [None]:
import warnings
import logging
warnings.filterwarnings('ignore')
logging.getLogger("transformers").setLevel(logging.ERROR)

In [None]:
import torch
from transformers import BertTokenizer, BertForPreTraining

tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

def tokenize_input(query, document):
    return tokenizer.encode_plus(
        query,
        document,
        add_special_tokens=True,
        max_length=512,
        padding='max_length',
        truncation=True,
        return_tensors="pt"
    )

In [None]:
from torch.utils.data import Dataset, DataLoader

def collate_fn(batch):
    return batch[0]

class TokenizedDataset(Dataset):
    def __init__(self, data_list, tokenizer_function, device):
        self.data = []
        for data in tqdm(data_list, total=len(data_list)):
            query = data['query']
            positives = data['positives']
            negatives = data['negatives']

            encoding_poses = [tokenizer_function(query, pos).to(device) for pos in positives]
            encoding_negs = [tokenizer_function(query, neg).to(device) for neg in negatives]
            
            all_input_ids = torch.cat([encoding_pos['input_ids'] for encoding_pos in encoding_poses] + [encoding_neg['input_ids'] for encoding_neg in encoding_negs], dim=0)
            all_attention_masks = torch.cat([encoding_pos['attention_mask'] for encoding_pos in encoding_poses] + [encoding_neg['attention_mask'] for encoding_neg in encoding_negs], dim=0)
            
            self.data.append((all_input_ids, all_attention_masks, len(positives), len(negatives)))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [None]:
hotpot_train_data_random_loader = DataLoader(TokenizedDataset(hotpot_train_data_random, tokenize_input, 'cuda'), collate_fn=collate_fn)
hotpot_train_data_inbatch_loader = DataLoader(TokenizedDataset(hotpot_train_data_inbatch, tokenize_input, 'cuda'), collate_fn=collate_fn)

hotpot_test_data_random_loader = DataLoader(TokenizedDataset(hotpot_test_data_random, tokenize_input, 'cuda'), collate_fn=collate_fn)
hotpot_test_data_inbatch_loader = DataLoader(TokenizedDataset(hotpot_test_data_inbatch, tokenize_input, 'cuda'), collate_fn=collate_fn)

In [None]:
wiki_train_data_loader = DataLoader(TokenizedDataset(wiki_train_data, tokenize_input, 'cuda'), collate_fn=collate_fn)

wiki_test_data_loader = DataLoader(TokenizedDataset(wiki_test_data, tokenize_input, 'cuda'), collate_fn=collate_fn)

In [None]:
from torch.nn import functional as F

def compute_document_likelihood(model, input_ids, attention_masks):
    # Forward pass to get the logits
    outputs = model(input_ids=input_ids, attention_mask=attention_masks)
    logits = outputs.prediction_logits  # Shape: (batch_size, seq_len, hidden_size)

    # Normalize logits
    normalized_logits = F.log_softmax(logits, dim=-1)

    # Get the positions of the first and second [SEP] tokens for each sequence
    sep_token_positions = (input_ids == tokenizer.sep_token_id).nonzero(as_tuple=True)[1].reshape(-1, 2)  # Shape: (num_sep_tokens, 2)

    # Create a mask to zero out tokens outside of the document portion
    mask = torch.arange(input_ids.size(1), device=input_ids.device).unsqueeze(0)  # Shape: (1, seq_len)
    mask = (mask >= (sep_token_positions[:, 0] + 1).unsqueeze(1)) & (mask < sep_token_positions[:, 1].unsqueeze(1))  # Shape: (batch_size, seq_len)

    # Select the logits corresponding to the token IDs in input_ids
    token_logits = normalized_logits.gather(2, input_ids.unsqueeze(-1)).squeeze(-1)  # Shape: (batch_size, seq_len)
    
    # Apply the mask to only select the document logits
    masked_token_logits = token_logits * mask
    
    # Compute the log likelihood for the document tokens
    document_likelihood = torch.sum(masked_token_logits, dim=1)

    return document_likelihood

In [None]:
from torch.nn import CrossEntropyLoss

CELoss = CrossEntropyLoss()

def train_document_likelihood_model(train_data_loader, model, optimizer, epochs=3, device='cuda'):
    model.to(device)
    model.train()
    total_queries = len(train_data_loader)

    for epoch in range(epochs):
        total_loss = 0

        for all_input_ids, all_attention_masks, num_positive, num_negative in tqdm(train_data_loader, desc=f"Epoch {epoch+1}", total=total_queries):
            
            all_similarities = compute_document_likelihood(model, all_input_ids, all_attention_masks)
            target = torch.cat([torch.ones(num_positive), torch.zeros(num_negative)], dim=0).to(device)
            
            loss = CELoss(all_similarities, target)
 
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / total_queries
        print(f'Epoch {epoch+1}, Average Loss: {avg_loss}')

In [None]:
def evaluate_document_likelihood_model(eval_data_loader, model, device='cuda'):
    model.to(device)
    model.eval()

    precision_at_1 = 0
    precision_at_10 = 0
    mrr_total = 0
    map_total = 0
    total_queries = len(eval_data_loader)
    
    with torch.no_grad():
        for all_input_ids, all_attention_masks, num_positive, num_negative in tqdm(eval_data_loader, desc="Evaluating", total=total_queries):
            
            all_similarities = compute_document_likelihood(model, all_input_ids, all_attention_masks)

            # Get the ranking of the positive document (rank is based on the similarity score)
            rankings = torch.argsort(torch.argsort(all_similarities, descending=True))
            rank_of_positives = rankings[:num_positive]

            if 0 in rank_of_positives:
                precision_at_1 += 1
            precision_at_10 += len(list(filter(lambda x: x < 10, rank_of_positives))) / min(len(all_similarities), 10)

            # Compute Reciprocal Rank for MRR
            rank_of_first_positive = min(rank_of_positives) + 1
            mrr_total += 1 / rank_of_first_positive

            # Compute Average Precision for MAP
            ap = 0.0
            for i, rank in enumerate(sorted(rank_of_positives)):
                ap += (i + 1) / (rank + 1)  # Precision at each rank where a positive document appears
            map_total += ap / len(rank_of_positives)  # Average precision for this query

    # Average the metrics over all queries
    precision_at_1 /= total_queries
    precision_at_10 /= total_queries
    mrr_total /= total_queries
    map_total /= total_queries
    
    print(f'Precision@1: {precision_at_1}')
    print(f'Precision@10: {precision_at_10}')
    print(f'MRR: {mrr_total}')
    print(f'MAP: {map_total}')

In [None]:
def freeze_bert_layers(model, num_layers_to_freeze=10):
    # Freeze the embedding layer
    for param in model.bert.embeddings.parameters():
        param.requires_grad = False
    
    # Freeze the first `num_layers_to_freeze` encoder layers
    for layer in model.bert.encoder.layer[:num_layers_to_freeze]:
        for param in layer.parameters():
            param.requires_grad = False

In [None]:
# Train with Random/Global Negatives Hotpot
model_hotpot_random = BertForPreTraining.from_pretrained("bert-base-cased")
freeze_bert_layers(model_hotpot_random, num_layers_to_freeze=9)
# Optimizer
optimizer_hotpot_random = torch.optim.AdamW(model_hotpot_random.parameters(), lr=5e-5)

train_document_likelihood_model(hotpot_train_data_random_loader, model_hotpot_random, optimizer_hotpot_random)

In [None]:
evaluate_document_likelihood_model(hotpot_test_data_random_loader, model_hotpot_random)

In [None]:
# Train with Inbatch Negatives Hotpot
model_hotpot_inbatch = BertForPreTraining.from_pretrained("bert-base-cased")
freeze_bert_layers(model_hotpot_inbatch, num_layers_to_freeze=9)
# Optimizer
optimizer_hotpot_inbatch = torch.optim.AdamW(model_hotpot_inbatch.parameters(), lr=5e-5)

train_document_likelihood_model(hotpot_train_data_inbatch_loader, model_hotpot_inbatch, optimizer_hotpot_inbatch)

In [None]:
evaluate_document_likelihood_model(hotpot_test_data_inbatch_loader, model_hotpot_inbatch)

In [None]:
# Train with WikiNQ
model_wiki = BertForPreTraining.from_pretrained("bert-base-cased")
freeze_bert_layers(model_wiki, num_layers_to_freeze=9)
# Optimizer
optimizer_wiki = torch.optim.AdamW(model_wiki.parameters(), lr=5e-5)

train_document_likelihood_model(wiki_train_data_loader, model_wiki, optimizer_wiki)

In [None]:
evaluate_document_likelihood_model(wiki_test_data_loader, model_wiki)