In [1]:
! pip install transformers datasets



In [2]:
from datasets import load_dataset

# Load HotpotQA dataset
hotpot_dataset = load_dataset("hotpot_qa", "distractor", trust_remote_code=True)

# Display HotpotQA dataset structure
print(hotpot_dataset)

Downloading builder script:   0%|          | 0.00/6.42k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/9.19k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/566M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/46.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/90447 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/7405 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'question', 'answer', 'type', 'level', 'supporting_facts', 'context'],
        num_rows: 90447
    })
    validation: Dataset({
        features: ['id', 'question', 'answer', 'type', 'level', 'supporting_facts', 'context'],
        num_rows: 7405
    })
})


In [3]:
# Load WikiNQ dataset
wiki_dataset = load_dataset("Tevatron/wikipedia-nq")

# Display WikiNQ dataset structure
print(wiki_dataset)

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/97.6M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/242M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/241k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/58622 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/6489 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3610 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['query_id', 'query', 'answers', 'positive_passages', 'negative_passages'],
        num_rows: 58622
    })
    dev: Dataset({
        features: ['query_id', 'query', 'answers', 'positive_passages', 'negative_passages'],
        num_rows: 6489
    })
    test: Dataset({
        features: ['query_id', 'query', 'answers', 'positive_passages', 'negative_passages'],
        num_rows: 3610
    })
})


In [4]:
hotpot_train_set = hotpot_dataset['train']
hotpot_test_set = hotpot_dataset['validation']

wiki_train_set = wiki_dataset['train']
wiki_test_set = wiki_dataset['dev']

In [5]:
small_hotpot_train_dataset = hotpot_train_set.shuffle(seed=42).select(range(9000))
small_hotpot_test_dataset = hotpot_test_set.shuffle(seed=42)

small_wiki_train_dataset = wiki_train_set.shuffle(seed=43).select(range(9000))
small_wiki_test_dataset = wiki_test_set.shuffle(seed=43)

In [6]:
import random

def get_random_negatives(positive_queries, corpus, num_negatives):
    """Randomly sample negatives excluding golden_docs."""
    available_queries = [query for query in corpus if query not in positive_queries]
    return random.sample(available_queries, min(len(available_queries), num_negatives))

In [7]:
def create_document_data(data, dataset):
    document_data = {}
    
    if dataset == 'hotpot':
        for entry in data:
            query = entry['question']
            golden_docs = entry['supporting_facts']
            for golden_doc_title in golden_docs['title']:
                if golden_doc_title not in document_data:
                    document_data[golden_doc_title] = {}
                    golden_doc_sentences = entry['context']['sentences'][entry['context']['title'].index(golden_doc_title)]
                    document_data[golden_doc_title]['doc'] = ''.join(golden_doc_sentences)
                    document_data[golden_doc_title]['positive_queries'] = []
                document_data[golden_doc_title]['positive_queries'].append(query)
    elif dataset == 'wiki':
        for entry in data:
            query = entry['query']
            golden_docs = entry['positive_passages']
            for golden_doc in golden_docs:
                if golden_doc['title'] not in document_data:
                    document_data[golden_doc['title']] = {}
                    document_data[golden_doc['title']]['doc'] = golden_doc['text']
                    document_data[golden_doc['title']]['positive_queries'] = []
                document_data[golden_doc['title']]['positive_queries'].append(query)
    else:
        assert(0)
    
    return document_data
    

In [8]:
def create_corpus(data, document_data, document_keys, dataset, sample_method, batch_size):
    if sample_method == 'global':
        if dataset == 'hotpot':
            global_corpus = [entry['question'] for entry in data]
        else:
            global_corpus = [entry['query'] for entry in data]
        return global_corpus
    elif sample_method == 'inbatch':
        batches = [document_keys[i:i + batch_size] for i in range(0, len(document_keys), batch_size)]
        batch_corpuses = []
        for batch in batches:
            batch_corpus = set()
            for doc_title in batch:
                batch_corpus.update(document_data[doc_title]['positive_queries'])
            batch_corpuses.append(list(batch_corpus))
        return batch_corpuses
    else:
        assert(0)

In [9]:
from tqdm import tqdm

# Preprocessing function
def preprocess_data_train(data, num_negatives, dataset='hotpot', sample_method='global', batch_size=32):
    """Preprocess data to create datapoints with negatives."""
    document_data = create_document_data(data, dataset)
    
    document_keys = list(document_data.keys())
    random.shuffle(document_keys)
    corpuses = create_corpus(data, document_data, document_keys, dataset, sample_method, batch_size)
  
    processed_data = []
    
    for idx, document_title in tqdm(enumerate(document_keys), total=len(document_keys)):
        document_info = document_data[document_title]
        doc = document_info['doc']
        positive_queries = document_info['positive_queries']
        
        if sample_method == 'global':
            corpus = corpuses
        else:
            corpus = corpuses[idx // batch_size]
        
        for positive_query in positive_queries:
            negatives = get_random_negatives(positive_queries, corpus, num_negatives)
            
            datapoint = {
                'doc': doc,
                'positives': [positive_query],
                'negatives': negatives,
            }
            processed_data.append(datapoint)

    return processed_data

In [10]:
# Preprocessing function
def preprocess_data_test(data, num_negatives, dataset='hotpot', sample_method='global', batch_size=32):
    """Preprocess data to create datapoints with negatives."""
    document_data = create_document_data(data, dataset)

    document_keys = list(document_data.keys())
    random.shuffle(document_keys)
    corpuses = create_corpus(data, document_data, document_keys, dataset, sample_method, batch_size)

    processed_data = []
    
    for idx, document_title in tqdm(enumerate(document_keys), total=len(document_keys)):
        document_info = document_data[document_title]
        doc = document_info['doc']        
        if sample_method == 'global':
            corpus = corpuses
        else:
            corpus = corpuses[idx // batch_size]

        positive_queries = document_info['positive_queries']
        negatives = get_random_negatives(positive_queries, corpus, num_negatives * len(positive_queries))
        datapoint = {
            'doc': doc,
            'positives': positive_queries,
            'negatives': negatives,
        }
        processed_data.append(datapoint)

    return processed_data

In [11]:
hotpot_train_data_random = preprocess_data_train(small_hotpot_train_dataset, 7, 'hotpot', 'global')[:9000]
hotpot_train_data_inbatch = preprocess_data_train(small_hotpot_train_dataset, 7, 'hotpot', 'inbatch')[:9000]

hotpot_test_data_random = preprocess_data_test(small_hotpot_test_dataset, 15, 'hotpot', 'global')[:1200]
hotpot_test_data_inbatch = preprocess_data_test(small_hotpot_test_dataset, 15, 'hotpot', 'inbatch')[:1200]

100%|██████████| 16573/16573 [00:19<00:00, 846.57it/s]
100%|██████████| 16573/16573 [00:00<00:00, 35706.81it/s]
100%|██████████| 13783/13783 [00:09<00:00, 1484.51it/s]
100%|██████████| 13783/13783 [00:00<00:00, 34888.62it/s]


In [12]:
wiki_train_data_random = preprocess_data_train(small_wiki_train_dataset, 7, 'wiki', 'global')[:9000]
wiki_train_data_inbatch = preprocess_data_train(small_wiki_train_dataset, 7, 'wiki', 'inbatch')[:9000]

wiki_test_data_random = preprocess_data_test(small_wiki_test_dataset, 15, 'wiki', 'global')[:1200]
wiki_test_data_inbatch = preprocess_data_test(small_wiki_test_dataset, 15, 'wiki', 'inbatch')[:1200]

100%|██████████| 40491/40491 [01:18<00:00, 515.17it/s]
100%|██████████| 40491/40491 [00:01<00:00, 23701.46it/s]
100%|██████████| 30365/30365 [00:19<00:00, 1523.41it/s]
100%|██████████| 30365/30365 [00:01<00:00, 27127.41it/s]


In [13]:
import warnings
import logging
warnings.filterwarnings('ignore')
logging.getLogger("transformers").setLevel(logging.ERROR)

In [14]:
import torch
from transformers import BertTokenizer, BertModel

tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

def tokenize_input(document, query):
    return tokenizer.encode_plus(
        document,
        query,
        add_special_tokens=True,
        max_length=512,
        padding='max_length',
        truncation=True,
        return_tensors="pt"
    )

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [15]:
from torch.utils.data import Dataset, DataLoader

def collate_fn(batch):
    return batch[0]

class TokenizedDataset(Dataset):
    def __init__(self, data_list, tokenizer_function, device):
        self.data = []
        for data in tqdm(data_list, total=len(data_list)):
            doc = data['doc']
            positives = data['positives']
            negatives = data['negatives']

            encoding_poses = [tokenize_input(doc, pos_query).to(device) for pos_query in positives]
            encoding_negs = [tokenize_input(doc, neg_query).to(device) for neg_query in negatives]

            all_input_ids = torch.cat([encoding_pos['input_ids'] for encoding_pos in encoding_poses] + [encoding_neg['input_ids'] for encoding_neg in encoding_negs], dim=0)
            all_attention_masks = torch.cat([encoding_pos['attention_mask'] for encoding_pos in encoding_poses] + [encoding_neg['attention_mask'] for encoding_neg in encoding_negs], dim=0)
            
            self.data.append((all_input_ids, all_attention_masks, len(positives), len(negatives)))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [16]:
hotpot_train_data_random_loader = DataLoader(TokenizedDataset(hotpot_train_data_random, tokenize_input, 'cuda'), collate_fn=collate_fn)
hotpot_train_data_inbatch_loader = DataLoader(TokenizedDataset(hotpot_train_data_inbatch, tokenize_input, 'cuda'), collate_fn=collate_fn)

hotpot_test_data_random_loader = DataLoader(TokenizedDataset(hotpot_test_data_random, tokenize_input, 'cuda'), collate_fn=collate_fn)
hotpot_test_data_inbatch_loader = DataLoader(TokenizedDataset(hotpot_test_data_inbatch, tokenize_input, 'cuda'), collate_fn=collate_fn)

100%|██████████| 9000/9000 [04:39<00:00, 32.15it/s]
100%|██████████| 9000/9000 [04:38<00:00, 32.34it/s]
100%|██████████| 1200/1200 [01:37<00:00, 12.28it/s]
100%|██████████| 1200/1200 [01:38<00:00, 12.24it/s]


In [17]:
wiki_train_data_random_loader = DataLoader(TokenizedDataset(wiki_train_data_random, tokenize_input, 'cuda'), collate_fn=collate_fn)
wiki_train_data_inbatch_loader = DataLoader(TokenizedDataset(wiki_train_data_inbatch, tokenize_input, 'cuda'), collate_fn=collate_fn)

wiki_test_data_random_loader = DataLoader(TokenizedDataset(wiki_test_data_random, tokenize_input, 'cuda'), collate_fn=collate_fn)
wiki_test_data_inbatch_loader = DataLoader(TokenizedDataset(wiki_test_data_inbatch, tokenize_input, 'cuda'), collate_fn=collate_fn)

100%|██████████| 9000/9000 [05:59<00:00, 25.01it/s]
100%|██████████| 9000/9000 [05:58<00:00, 25.13it/s]
100%|██████████| 1200/1200 [02:47<00:00,  7.17it/s]
100%|██████████| 1200/1200 [02:13<00:00,  8.99it/s]


In [18]:
import torch.nn as nn

class QueryLikelihoodModel(nn.Module):
    def __init__(self, bert_model_name='bert-base-cased', hidden_size=768, ffl_depth=1, dropout_prob=0.3):
        super(QueryLikelihoodModel, self).__init__()

        # Load pre-trained BERT model
        self.bert = BertModel.from_pretrained(bert_model_name)

        # Define a feedforward layer stack (FFL) over the CLS token
        layers = []
        input_size = hidden_size
        for _ in range(ffl_depth):
            layers.append(nn.Linear(input_size, hidden_size))
            layers.append(nn.BatchNorm1d(hidden_size))
            layers.append(nn.ReLU())                     # Non-linearity
            layers.append(nn.Dropout(dropout_prob))      # Dropout
            input_size = hidden_size

        self.ffl = nn.Sequential(*layers)

        # Final layer to output a single score
        self.output_layer = nn.Linear(hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        # Pass inputs through BERT
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)

        # Get the [CLS] token representation
        cls_output = outputs.last_hidden_state[:, 0, :]  # CLS token is the first one in the sequence

        # Pass CLS representation through the FFL
        ffl_output = self.ffl(cls_output)

        # Final linear layer to output a score
        score = self.output_layer(ffl_output).view(-1)
        return score

In [19]:
def compute_query_likelihood(model, input_ids, attention_masks):
    return model(input_ids=input_ids, attention_mask=attention_masks)

In [20]:
from torch.nn import CrossEntropyLoss

CELoss = CrossEntropyLoss()

def train_query_likelihood_model(train_data_loader, model, optimizer, epochs=3, device='cuda'):
    model.to(device)
    model.train()
    total_documents = len(train_data_loader)

    for epoch in range(epochs):
        total_loss = 0

        for all_input_ids, all_attention_masks, num_positive, num_negative in tqdm(train_data_loader, desc=f"Epoch {epoch+1}", total=total_documents):

            all_similarities = compute_query_likelihood(model, all_input_ids, all_attention_masks)
            target = torch.cat([torch.ones(num_positive), torch.zeros(num_negative)], dim=0).to(device)

            loss = CELoss(all_similarities, target)
            
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / total_documents
        print(f'Epoch {epoch+1}, Average Loss: {avg_loss}')

In [21]:
def evaluate_query_likelihood_model(eval_data_loader, model, device='cuda'):
    model.to(device)
    model.eval()

    precision_at_1 = 0
    precision_at_10 = 0
    mrr_total = 0
    map_total = 0
    total_documents = len(eval_data_loader)
    
    with torch.no_grad():
        for all_input_ids, all_attention_masks, num_positive, num_negative in tqdm(eval_data_loader, desc="Evaluating", total=total_documents):
            split_input_ids = [all_input_ids[i:i+8] for i in range(0, all_input_ids.size(0), 8)]
            split_attention_masks = [all_attention_masks[i:i+8] for i in range(0, all_attention_masks.size(0), 8)]
            all_similarities = torch.cat([compute_query_likelihood(model, input_ids, attention_masks) for input_ids, attention_masks in zip(split_input_ids, split_attention_masks)], dim=0)

            # Get the ranking of the positive query (rank is based on the similarity score)
            rankings = torch.argsort(torch.argsort(all_similarities, descending=True))
            rank_of_positives = rankings[:num_positive]

            if 0 in rank_of_positives:
                precision_at_1 += 1
            precision_at_10 += len(list(filter(lambda x: x < 10, rank_of_positives))) / min(len(all_similarities), 10)

            # Compute Reciprocal Rank for MRR
            rank_of_first_positive = min(rank_of_positives) + 1
            mrr_total += 1 / rank_of_first_positive

            # Compute Average Precision for MAP
            ap = 0.0
            for i, rank in enumerate(sorted(rank_of_positives)):
                ap += (i + 1) / (rank + 1)  # Precision at each rank where a positive query appears
            map_total += ap / len(rank_of_positives)  # Average precision for this document

    # Average the metrics over all documents
    precision_at_1 /= total_documents
    precision_at_10 /= total_documents
    mrr_total /= total_documents
    map_total /= total_documents
    
    print(f'Precision@1: {precision_at_1}')
    print(f'Precision@10: {precision_at_10}')
    print(f'MRR: {mrr_total}')
    print(f'MAP: {map_total}')

In [22]:
def freeze_bert_layers(model, num_layers_to_freeze=10):
    # Freeze the embedding layer
    for param in model.bert.embeddings.parameters():
        param.requires_grad = False
    
    # Freeze the first `num_layers_to_freeze` encoder layers
    for layer in model.bert.encoder.layer[:num_layers_to_freeze]:
        for param in layer.parameters():
            param.requires_grad = False

In [23]:
# Train with Random/Global Negatives Hotpot
model_hotpot_random = QueryLikelihoodModel(bert_model_name='bert-base-cased', ffl_depth=1, dropout_prob=0.3)
freeze_bert_layers(model_hotpot_random, num_layers_to_freeze=9)
# Optimizer
optimizer_hotpot_random = torch.optim.AdamW(model_hotpot_random.parameters(), lr=5e-5)

train_query_likelihood_model(hotpot_train_data_random_loader, model_hotpot_random, optimizer_hotpot_random)

model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

Epoch 1: 100%|██████████| 9000/9000 [30:04<00:00,  4.99it/s]


Epoch 1, Average Loss: 0.09442817672426031


Epoch 2: 100%|██████████| 9000/9000 [30:06<00:00,  4.98it/s]


Epoch 2, Average Loss: 0.04296438797212121


Epoch 3: 100%|██████████| 9000/9000 [30:08<00:00,  4.98it/s]

Epoch 3, Average Loss: 0.028652233966728165





In [24]:
evaluate_query_likelihood_model(hotpot_test_data_random_loader, model_hotpot_random)

Evaluating: 100%|██████████| 1200/1200 [06:00<00:00,  3.33it/s]

Precision@1: 0.9775
Precision@10: 0.12916666666666463
MRR: 0.9868611097335815
MAP: 0.9870397448539734





In [25]:
# Train with Inbatch Negatives Hotpot
model_hotpot_inbatch = QueryLikelihoodModel(bert_model_name='bert-base-cased', ffl_depth=1, dropout_prob=0.3)
freeze_bert_layers(model_hotpot_inbatch, num_layers_to_freeze=9)
# Optimizer
optimizer_hotpot_inbatch = torch.optim.AdamW(model_hotpot_inbatch.parameters(), lr=5e-5)

train_query_likelihood_model(hotpot_train_data_inbatch_loader, model_hotpot_inbatch, optimizer_hotpot_inbatch)

Epoch 1: 100%|██████████| 9000/9000 [30:09<00:00,  4.97it/s]


Epoch 1, Average Loss: 0.1092655900080757


Epoch 2: 100%|██████████| 9000/9000 [30:08<00:00,  4.98it/s]


Epoch 2, Average Loss: 0.04616112476184431


Epoch 3: 100%|██████████| 9000/9000 [30:08<00:00,  4.98it/s]

Epoch 3, Average Loss: 0.029670300069398765





In [26]:
evaluate_query_likelihood_model(hotpot_test_data_inbatch_loader, model_hotpot_inbatch)

Evaluating: 100%|██████████| 1200/1200 [05:48<00:00,  3.44it/s]

Precision@1: 0.9775
Precision@10: 0.12966666666666427
MRR: 0.9875277876853943
MAP: 0.9872578978538513





In [27]:
# Train with Random/Global Negatives WikiNQ
model_wiki_random = QueryLikelihoodModel(bert_model_name='bert-base-cased', ffl_depth=1, dropout_prob=0.3)
freeze_bert_layers(model_wiki_random, num_layers_to_freeze=9)
# Optimizer
optimizer_wiki_random = torch.optim.AdamW(model_wiki_random.parameters(), lr=5e-5)

train_query_likelihood_model(wiki_train_data_random_loader, model_wiki_random, optimizer_wiki_random)

Epoch 1: 100%|██████████| 9000/9000 [30:09<00:00,  4.97it/s]


Epoch 1, Average Loss: 0.33243923589527336


Epoch 2: 100%|██████████| 9000/9000 [30:08<00:00,  4.98it/s]


Epoch 2, Average Loss: 0.177686263608698


Epoch 3: 100%|██████████| 9000/9000 [30:06<00:00,  4.98it/s]

Epoch 3, Average Loss: 0.11829557911606892





In [28]:
evaluate_query_likelihood_model(wiki_test_data_random_loader, model_wiki_random)

Evaluating: 100%|██████████| 1200/1200 [08:16<00:00,  2.42it/s]

Precision@1: 0.9016666666666666
Precision@10: 0.17033333333333084
MRR: 0.9412730932235718
MAP: 0.9366087913513184





In [29]:
# Train with Inbatch Negatives WikiNQ
model_wiki_inbatch = QueryLikelihoodModel(bert_model_name='bert-base-cased', ffl_depth=1, dropout_prob=0.3)
freeze_bert_layers(model_wiki_inbatch, num_layers_to_freeze=9)
# Optimizer
optimizer_wiki_inbatch = torch.optim.AdamW(model_wiki_inbatch.parameters(), lr=5e-5)

train_query_likelihood_model(wiki_train_data_inbatch_loader, model_wiki_inbatch, optimizer_wiki_inbatch)

Epoch 1: 100%|██████████| 9000/9000 [31:15<00:00,  4.80it/s]


Epoch 1, Average Loss: 0.3586725292428075


Epoch 2: 100%|██████████| 9000/9000 [32:52<00:00,  4.56it/s]


Epoch 2, Average Loss: 0.17495657060183267


Epoch 3: 100%|██████████| 9000/9000 [32:52<00:00,  4.56it/s]

Epoch 3, Average Loss: 0.11952758331382422





In [30]:
evaluate_query_likelihood_model(wiki_test_data_inbatch_loader, model_wiki_inbatch)

Evaluating: 100%|██████████| 1200/1200 [07:35<00:00,  2.64it/s]

Precision@1: 0.8958333333333334
Precision@10: 0.16641666666666413
MRR: 0.934960126876831
MAP: 0.9318392872810364



