In [3]:
! pip install transformers datasets



In [4]:
from datasets import load_dataset

# Load HotpotQA dataset
hotpot_dataset = load_dataset("hotpot_qa", "distractor", trust_remote_code=True)

# Display HotpotQA dataset structure
print(hotpot_dataset)

Downloading builder script:   0%|          | 0.00/6.42k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/9.19k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/566M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/46.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/90447 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/7405 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'question', 'answer', 'type', 'level', 'supporting_facts', 'context'],
        num_rows: 90447
    })
    validation: Dataset({
        features: ['id', 'question', 'answer', 'type', 'level', 'supporting_facts', 'context'],
        num_rows: 7405
    })
})


In [5]:
# Load WikiNQ dataset
wiki_dataset = load_dataset("Tevatron/wikipedia-nq")

# Display WikiNQ dataset structure
print(wiki_dataset)

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/97.6M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/242M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/241k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/58622 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/6489 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3610 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['query_id', 'query', 'answers', 'positive_passages', 'negative_passages'],
        num_rows: 58622
    })
    dev: Dataset({
        features: ['query_id', 'query', 'answers', 'positive_passages', 'negative_passages'],
        num_rows: 6489
    })
    test: Dataset({
        features: ['query_id', 'query', 'answers', 'positive_passages', 'negative_passages'],
        num_rows: 3610
    })
})


In [6]:
hotpot_train_set = hotpot_dataset['train']
hotpot_test_set = hotpot_dataset['validation']

wiki_train_set = wiki_dataset['train']
wiki_test_set = wiki_dataset['test']

In [7]:
small_hotpot_train_dataset = hotpot_train_set.shuffle(seed=42).select(range(2000))
small_hotpot_test_dataset = hotpot_test_set.shuffle(seed=42).select(range(1000))

small_wiki_train_dataset = wiki_train_set.shuffle(seed=43).select(range(2000))
small_wiki_test_dataset = wiki_test_set.shuffle(seed=43).select(range(1000))

In [8]:
import random

def get_random_negatives(positive_queries, corpus, num_negatives):
    """Randomly sample negatives excluding golden_docs."""
    available_queries = [query for query in corpus if query not in positive_queries]
    return random.sample(available_queries, min(len(available_queries), num_negatives))

In [9]:
def create_document_data(data, dataset):
    document_data = {}
    
    if dataset == 'hotpot':
        for entry in data:
            query = entry['question']
            golden_docs = entry['supporting_facts']
            for golden_doc_title in golden_docs['title']:
                if golden_doc_title not in document_data:
                    document_data[golden_doc_title] = {}
                    golden_doc_sentences = entry['context']['sentences'][entry['context']['title'].index(golden_doc_title)]
                    document_data[golden_doc_title]['doc'] = ''.join(golden_doc_sentences)
                    document_data[golden_doc_title]['positive_queries'] = []
                document_data[golden_doc_title]['positive_queries'].append(query)
    elif dataset == 'wiki':
        for entry in data:
            query = entry['query']
            golden_docs = entry['positive_passages']
            for golden_doc in golden_docs:
                if golden_doc['title'] not in document_data:
                    document_data[golden_doc['title']] = {}
                    document_data[golden_doc['title']]['doc'] = golden_doc['text']
                    document_data[golden_doc['title']]['positive_queries'] = []
                document_data[golden_doc['title']]['positive_queries'].append(query)
    else:
        assert(0)
    
    return document_data
    

In [14]:
def create_corpus(data, document_data, document_keys, dataset, sample_method, batch_size):
    if sample_method == 'global':
        if dataset == 'hotpot':
            global_corpus = [entry['question'] for entry in data]
        else:
            global_corpus = [entry['query'] for entry in data]
        return global_corpus
    elif sample_method == 'inbatch':
        batches = [document_keys[i:i + batch_size] for i in range(0, len(document_keys), batch_size)]
        batch_corpuses = []
        for batch in batches:
            batch_corpus = set()
            for doc_title in batch:
                batch_corpus.update(document_data[doc_title]['positive_queries'])
            batch_corpuses.append(list(batch_corpus))
        return batch_corpuses
    else:
        assert(0)

In [15]:
# Preprocessing function
def preprocess_data_train(data, num_negatives, dataset='hotpot', sample_method='global', batch_size=32):
    """Preprocess data to create datapoints with negatives."""
    document_data = create_document_data(data, dataset)
    
    document_keys = list(document_data.keys())
    random.shuffle(document_keys)
    corpuses = create_corpus(data, document_data, document_keys, dataset, sample_method, batch_size)
  
    processed_data = []
    
    for idx, document_title in enumerate(document_keys):
        document_info = document_data[document_title]
        doc = document_info['doc']
        positive_queries = document_info['positive_queries']
        
        if sample_method == 'global':
            corpus = corpuses
        else:
            corpus = corpuses[idx // batch_size]
        
        for positive_query in positive_queries:
            negatives = get_random_negatives(positive_queries, corpus, num_negatives)
            
            datapoint = {
                'doc': doc,
                'positive_query': positive_query,
                'negatives': negatives,
            }
            processed_data.append(datapoint)

    return processed_data

In [22]:
# Preprocessing function
def preprocess_data_test(data, num_negatives, dataset='hotpot', sample_method='global', batch_size=32):
    """Preprocess data to create datapoints with negatives."""
    document_data = create_document_data(data, dataset)

    document_keys = list(document_data.keys())
    random.shuffle(document_keys)
    corpuses = create_corpus(data, document_data, document_keys, dataset, sample_method, batch_size)

    processed_data = []
    
    for idx, document_title in enumerate(document_keys):
        document_info = document_data[document_title]
        doc = document_info['doc']        
        if sample_method == 'global':
            corpus = corpuses
        else:
            corpus = corpuses[idx // batch_size]

        positive_queries = document_info['positive_queries']
        negatives = get_random_negatives(positive_queries, corpus, num_negatives * len(positive_queries))
        datapoint = {
            'doc': doc,
            'positives': positive_queries,
            'negatives': negatives,
        }
        processed_data.append(datapoint)

    return processed_data

In [23]:
hotpot_train_data_random = preprocess_data_train(small_hotpot_train_dataset, 7, 'hotpot', 'global')[:2500]
hotpot_train_data_inbatch = preprocess_data_train(small_hotpot_train_dataset, 7, 'hotpot', 'inbatch', 32)[:2500]

hotpot_test_data_random = preprocess_data_test(small_hotpot_test_dataset, 15, 'hotpot', 'global')[:2500]
hotpot_test_data_inbatch = preprocess_data_test(small_hotpot_test_dataset, 15, 'hotpot', 'inbatch', 32)[:2500]

In [24]:
wiki_train_data_random = preprocess_data_train(small_wiki_train_dataset, 7, 'wiki', 'global')[:2500]
wiki_train_data_inbatch = preprocess_data_train(small_wiki_train_dataset, 7, 'wiki', 'inbatch', 32)[:2500]

wiki_test_data_random = preprocess_data_test(small_wiki_test_dataset, 15, 'wiki', 'global')[:2500]
wiki_test_data_inbatch = preprocess_data_test(small_wiki_test_dataset, 15, 'wiki', 'inbatch', 32)[:2500]

In [26]:
import torch
from transformers import BertTokenizer, BertForPreTraining

tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

def tokenize_input(document, query):
    return tokenizer.encode_plus(
        document,
        query,
        add_special_tokens=True,
        max_length=512,
        padding='max_length',
        truncation=True,
        return_tensors="pt"
    )

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]



In [28]:
from torch.nn import functional as F

def compute_query_likelihood(model, encoding):
    # Forward pass to get the logits
    outputs = model(input_ids=encoding['input_ids'], attention_mask=encoding['attention_mask'])
    logits = outputs.prediction_logits  # Shape: (batch_size, seq_len, hidden_size)

    # Get the query logits (after the [SEP] token)
    sep_token_start_idx = (encoding['input_ids'] == tokenizer.sep_token_id).nonzero(as_tuple=True)[1][0].item()  # Index of first SEP
    sep_token_end_idx = (encoding['input_ids'] == tokenizer.sep_token_id).nonzero(as_tuple=True)[1][1].item()  # Index of second SEP

    query_logits = logits[:, sep_token_start_idx + 1: sep_token_end_idx, :]  # Query logits after the second SEP

    # Normalize logits
    normalized_logits = F.log_softmax(query_logits, dim=-1)

    query_encodings = encoding['input_ids'][0, sep_token_start_idx + 1: sep_token_end_idx]

    # Sum of log likelihoods for query tokens
    query_likelihood = torch.sum(normalized_logits[0][range(len(query_encodings)), query_encodings])

    return query_likelihood

In [29]:
from torch.nn import BCEWithLogitsLoss
from tqdm import tqdm

CrossEntropyLoss = BCEWithLogitsLoss()

def train_query_likelihood_model(train_data, model, optimizer, epochs=3, device='cuda'):
    model.to(device)
    model.train()

    for epoch in range(epochs):
        total_loss = 0

        for i, data in tqdm(enumerate(train_data), desc=f"Epoch {epoch+1}", total=len(train_data)):
            doc = data['doc']
            positive_query = data['positive_query']
            negatives = data['negatives']

            # Move inputs to device (e.g., GPU if available)
            encoding_pos = tokenize_input(doc, positive_query).to(device)
            encoding_negs = [tokenize_input(doc, neg_query).to(device) for neg_query in negatives]
            
            # Compute likelihoods for the document and negatives
            sim_pos = compute_query_likelihood(model, encoding_pos)
            sim_negs = torch.stack([compute_query_likelihood(model, encoding_neg) for encoding_neg in encoding_negs])

            all_similarities = torch.cat([sim_pos.unsqueeze(0), sim_negs], dim=0)
            target = torch.cat([torch.ones(1), torch.zeros(len(sim_negs))], dim=0).to(device)
            
            loss = CrossEntropyLoss(all_similarities, target)
            
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_data)
        print(f'Epoch {epoch+1}, Average Loss: {avg_loss}')

In [30]:
def evaluate_query_likelihood_model(eval_data, model, device='cuda'):
    model.to(device)
    model.eval()

    precision_at_1 = 0
    precision_at_10 = 0
    mrr_total = 0
    map_total = 0
    total_documents = len(eval_data)
    
    with torch.no_grad():
        for data in tqdm(eval_data, desc="Evaluating", total=total_documents):
            doc = data['doc']
            positives = data['positives']
            negatives = data['negatives']

            # Move inputs to device (e.g., GPU if available)
            encoding_poses = [tokenize_input(doc, pos_query).to(device) for pos_query in positives]
            encoding_negs = [tokenize_input(doc, neg_query).to(device) for neg_query in negatives]

            # Compute likelihoods for the query and negatives
            sim_poses = torch.stack([compute_query_likelihood(model, encoding_pos) for encoding_pos in encoding_poses])
            sim_negs = torch.stack([compute_query_likelihood(model, encoding_neg) for encoding_neg in encoding_negs])

            # Concatenate positive and negative likelihoods
            all_similarities = torch.cat([sim_poses, sim_negs], dim=0)

            # Get the ranking of the positive query (rank is based on the similarity score)
            rankings = torch.argsort(torch.argsort(all_similarities, descending=True))
            rank_of_positives = rankings[:len(sim_poses)]

            if 0 in rank_of_positives:
                precision_at_1 += 1
            precision_at_10 += len(list(filter(lambda x: x < 10, rank_of_positives))) / min(len(all_similarities), 10)

            # Compute Reciprocal Rank for MRR
            rank_of_first_positive = min(rank_of_positives) + 1
            mrr_total += 1 / rank_of_first_positive

            # Compute Average Precision for MAP
            ap = 0.0
            for i, rank in enumerate(sorted(rank_of_positives)):
                ap += (i + 1) / (rank + 1)  # Precision at each rank where a positive query appears
            map_total += ap / len(rank_of_positives)  # Average precision for this document

    # Average the metrics over all documents
    precision_at_1 /= total_documents
    precision_at_10 /= total_documents
    mrr_total /= total_documents
    map_total /= total_documents
    
    print(f'Precision@1: {precision_at_1}')
    print(f'Precision@10: {precision_at_10}')
    print(f'MRR: {mrr_total}')
    print(f'MAP: {map_total}')

In [31]:
import warnings
import logging
warnings.filterwarnings('ignore')
logging.getLogger("transformers").setLevel(logging.ERROR)

In [33]:
# Train with Random/Global Negatives Hotpot
model_hotpot_random = BertForPreTraining.from_pretrained("bert-base-cased")
# Optimizer
optimizer_hotpot_random = torch.optim.AdamW(model_hotpot_random.parameters(), lr=5e-5)

train_query_likelihood_model(hotpot_train_data_random, model_hotpot_random, optimizer_hotpot_random)

Epoch 1: 100%|██████████| 100/100 [01:38<00:00,  1.02it/s]


Epoch 1, Average Loss: 0.3747033605724573


Epoch 2: 100%|██████████| 100/100 [01:44<00:00,  1.05s/it]


Epoch 2, Average Loss: 0.17464602433145046


Epoch 3: 100%|██████████| 100/100 [01:45<00:00,  1.05s/it]

Epoch 3, Average Loss: 0.17955359034240245





In [34]:
evaluate_query_likelihood_model(hotpot_test_data_random, model_hotpot_random)

Evaluating: 100%|██████████| 100/100 [01:27<00:00,  1.14it/s]

Precision@1: 0.47
Precision@10: 0.09799999999999986
MRR: 0.5962307453155518
MAP: 0.5974721908569336





In [None]:
# Train with Inbatch Negatives Hotpot
model_hotpot_inbatch = BertForPreTraining.from_pretrained("bert-base-cased")
# Optimizer
optimizer_hotpot_inbatch = torch.optim.AdamW(model_hotpot_inbatch.parameters(), lr=5e-5)

train_query_likelihood_model(hotpot_train_data_inbatch, model_hotpot_inbatch, optimizer_hotpot_inbatch)

In [None]:
evaluate_query_likelihood_model(hotpot_test_data_inbatch, model_hotpot_inbatch)

In [None]:
# Train with Random/Global Negatives WikiNQ
model_wiki_random = BertForPreTraining.from_pretrained("bert-base-cased")
# Optimizer
optimizer_wiki_random = torch.optim.AdamW(model_wiki_random.parameters(), lr=5e-5)

train_query_likelihood_model(wiki_train_data_random, model_wiki_random, optimizer_wiki_random)

In [None]:
evaluate_query_likelihood_model(wiki_test_data_random, model_wiki_random)

In [None]:
# Train with Inbatch Negatives WikiNQ
model_wiki_inbatch = BertForPreTraining.from_pretrained("bert-base-cased")
# Optimizer
optimizer_wiki_inbatch = torch.optim.AdamW(model_wiki_inbatch.parameters(), lr=5e-5)

train_query_likelihood_model(wiki_train_data_inbatch, model_wiki_inbatch, optimizer_wiki_inbatch)

In [None]:
evaluate_query_likelihood_model(wiki_test_data_inbatch, model_wiki_inbatch)