In [1]:
! pip install transformers datasets



In [2]:
from datasets import load_dataset

# Load HotpotQA dataset
hotpot_dataset = load_dataset("hotpot_qa", "distractor", trust_remote_code=True)

# Display HotpotQA dataset structure
print(hotpot_dataset)

Downloading builder script:   0%|          | 0.00/6.42k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/9.19k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/566M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/46.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/90447 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/7405 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'question', 'answer', 'type', 'level', 'supporting_facts', 'context'],
        num_rows: 90447
    })
    validation: Dataset({
        features: ['id', 'question', 'answer', 'type', 'level', 'supporting_facts', 'context'],
        num_rows: 7405
    })
})


In [3]:
# Load WikiNQ dataset
wiki_dataset = load_dataset("Tevatron/wikipedia-nq")

# Display WikiNQ dataset structure
print(wiki_dataset)

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/97.6M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/242M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/241k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/58622 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/6489 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3610 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['query_id', 'query', 'answers', 'positive_passages', 'negative_passages'],
        num_rows: 58622
    })
    dev: Dataset({
        features: ['query_id', 'query', 'answers', 'positive_passages', 'negative_passages'],
        num_rows: 6489
    })
    test: Dataset({
        features: ['query_id', 'query', 'answers', 'positive_passages', 'negative_passages'],
        num_rows: 3610
    })
})


In [4]:
hotpot_train_set = hotpot_dataset['train']
hotpot_test_set = hotpot_dataset['validation']

wiki_train_set = wiki_dataset['train']
wiki_test_set = wiki_dataset['dev']

In [5]:
small_hotpot_train_dataset = hotpot_train_set.shuffle(seed=42).select(range(9000))
small_hotpot_test_dataset = hotpot_test_set.shuffle(seed=42)

small_wiki_train_dataset = wiki_train_set.shuffle(seed=43).select(range(9000))
small_wiki_test_dataset = wiki_test_set.shuffle(seed=43)

In [6]:
import random

def get_random_negatives(golden_doc_titles, corpus, num_negatives):
    """Randomly sample negatives excluding golden_docs."""
    available_doc_titles = [doc_title for doc_title in corpus if doc_title not in golden_doc_titles]
    sampled_titles = random.sample(available_doc_titles, min(len(available_doc_titles), num_negatives))
    negatives = [''.join(corpus[doc_title]) for doc_title in sampled_titles]
    return negatives

In [7]:
def create_corpus_hotpot(data, sample_method, batch_size):
    """Preprocess data to create datapoints with negatives."""
    if sample_method == 'global':
        global_corpus = {}
        for entry in data:
            for title, sentences in zip(entry['context']['title'], entry['context']['sentences']):
                if title not in global_corpus:
                    global_corpus[title] = ''.join(sentences)
        return global_corpus
    elif sample_method == 'inbatch':
        data_list = list(data)
        batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
        batch_corpuses = []
        for batch in batches:
            batch_corpus = {}
            for entry in batch:
                for title, sentences in zip(entry['context']['title'], entry['context']['sentences']):
                    if title not in batch_corpus:
                        batch_corpus[title] = ''.join(sentences)
            batch_corpuses.append(batch_corpus)
        return batch_corpuses
    else:
        assert(0)


In [8]:
from tqdm import tqdm

# Preprocessing function Hotpot
def preprocess_data_train_hotpot(data, num_negatives, sample_method='global', batch_size=32):
    """Preprocess data to create datapoints with negatives."""
    corpuses = create_corpus_hotpot(data, sample_method, batch_size)

    processed_data = []

    for idx, entry in tqdm(enumerate(data), total=len(data)):
        query = entry['question']
        golden_doc_titles = entry['supporting_facts']['title']
        context = entry['context']
        if sample_method == 'global':
            corpus = corpuses
        else:
            corpus = corpuses[idx // batch_size]

        for golden_doc_title in golden_doc_titles:
            golden_doc = ''.join(context['sentences'][context['title'].index(golden_doc_title)])
            negatives = get_random_negatives(golden_doc_titles, corpus, num_negatives)

            datapoint = {
                'query': query,
                'positives': [golden_doc],
                'negatives': negatives,
            }
            processed_data.append(datapoint)

    return processed_data

In [9]:
# Preprocessing function Hotpot
def preprocess_data_test_hotpot(data, num_negatives, sample_method='global', batch_size=32):
    """Preprocess data to create datapoints with negatives."""
    corpuses = create_corpus_hotpot(data, sample_method, batch_size)

    processed_data = []

    for idx, entry in tqdm(enumerate(data), total=len(data)):
        query = entry['question']
        golden_doc_titles = entry['supporting_facts']['title']
        context = entry['context']
        if sample_method == 'global':
            corpus = corpuses
        else:
            corpus = corpuses[idx // batch_size]

        positives = [''.join(context['sentences'][context['title'].index(golden_doc_title)]) for golden_doc_title in golden_doc_titles]
        negatives = get_random_negatives(golden_doc_titles, corpus, num_negatives * len(positives))

        datapoint = {
            'query': query,
            'positives': positives,
            'negatives': negatives,
        }
        processed_data.append(datapoint)

    return processed_data

In [10]:
# Preprocessing function WikiNQ
def preprocess_data_train_wiki(data, num_negatives):
    processed_data = []

    for idx, entry in tqdm(enumerate(data), total=len(data)):
        query = entry['query']
        golden_docs = entry['positive_passages']
        negative_docs = entry['negative_passages']

        for golden_doc in golden_docs:
            negatives = random.sample(negative_docs, min(len(negative_docs), num_negatives))
            negatives = [neg['text'] for neg in negatives]

            datapoint = {
                'query': query,
                'positives': [golden_doc['text']],
                'negatives': negatives,
            }
            processed_data.append(datapoint)

    return processed_data

In [11]:
# Preprocessing function WikiNQ
def preprocess_data_test_wiki(data, num_negatives):
    processed_data = []

    for idx, entry in tqdm(enumerate(data), total=len(data)):
        query = entry['query']
        golden_docs = entry['positive_passages']
        negative_docs = entry['negative_passages']

        positives = [golden_doc['text'] for golden_doc in golden_docs]
        negatives = random.sample(negative_docs, min(len(negative_docs), num_negatives * len(positives)))
        negatives = [neg['text'] for neg in negatives]

        datapoint = {
            'query': query,
            'positives': positives,
            'negatives': negatives,
        }
        processed_data.append(datapoint)

    return processed_data

In [12]:
hotpot_train_data_random = preprocess_data_train_hotpot(small_hotpot_train_dataset, 7, 'global')[:9000]
hotpot_train_data_inbatch = preprocess_data_train_hotpot(small_hotpot_train_dataset, 7, 'inbatch')[:9000]

hotpot_test_data_random = preprocess_data_test_hotpot(small_hotpot_test_dataset, 15, 'global')[:1200]
hotpot_test_data_inbatch = preprocess_data_test_hotpot(small_hotpot_test_dataset, 15, 'inbatch')[:1200]

100%|██████████| 9000/9000 [03:00<00:00, 49.98it/s]
100%|██████████| 9000/9000 [00:06<00:00, 1330.57it/s]
100%|██████████| 7405/7405 [00:59<00:00, 123.67it/s]
100%|██████████| 7405/7405 [00:08<00:00, 888.04it/s]


In [13]:
wiki_train_data = preprocess_data_train_wiki(small_wiki_train_dataset, 7)[:9000]
wiki_test_data = preprocess_data_test_wiki(small_wiki_test_dataset, 15)[:1200]

100%|██████████| 9000/9000 [00:13<00:00, 652.42it/s]
100%|██████████| 6489/6489 [00:09<00:00, 692.35it/s]


In [14]:
import warnings
import logging
warnings.filterwarnings('ignore')
logging.getLogger("transformers").setLevel(logging.ERROR)

In [15]:
import torch
from transformers import BertTokenizer, BertModel

tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

def tokenize_input(query, document):
    return tokenizer.encode_plus(
        query,
        document,
        add_special_tokens=True,
        max_length=512,
        padding='max_length',
        truncation=True,
        return_tensors="pt"
    )

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [16]:
from torch.utils.data import Dataset, DataLoader

def collate_fn(batch):
    return batch[0]

class TokenizedDataset(Dataset):
    def __init__(self, data_list, tokenizer_function, device):
        self.data = []
        for data in tqdm(data_list, total=len(data_list)):
            query = data['query']
            positives = data['positives']
            negatives = data['negatives']

            encoding_poses = [tokenizer_function(query, pos).to(device) for pos in positives]
            encoding_negs = [tokenizer_function(query, neg).to(device) for neg in negatives]
            
            all_input_ids = torch.cat([encoding_pos['input_ids'] for encoding_pos in encoding_poses] + [encoding_neg['input_ids'] for encoding_neg in encoding_negs], dim=0)
            all_attention_masks = torch.cat([encoding_pos['attention_mask'] for encoding_pos in encoding_poses] + [encoding_neg['attention_mask'] for encoding_neg in encoding_negs], dim=0)
            
            self.data.append((all_input_ids, all_attention_masks, len(positives), len(negatives)))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [17]:
hotpot_train_data_random_loader = DataLoader(TokenizedDataset(hotpot_train_data_random, tokenize_input, 'cuda'), collate_fn=collate_fn)
hotpot_train_data_inbatch_loader = DataLoader(TokenizedDataset(hotpot_train_data_inbatch, tokenize_input, 'cuda'), collate_fn=collate_fn)

hotpot_test_data_random_loader = DataLoader(TokenizedDataset(hotpot_test_data_random, tokenize_input, 'cuda'), collate_fn=collate_fn)
hotpot_test_data_inbatch_loader = DataLoader(TokenizedDataset(hotpot_test_data_inbatch, tokenize_input, 'cuda'), collate_fn=collate_fn)

100%|██████████| 9000/9000 [04:53<00:00, 30.63it/s]
100%|██████████| 9000/9000 [04:58<00:00, 30.19it/s]
100%|██████████| 1200/1200 [03:09<00:00,  6.34it/s]
100%|██████████| 1200/1200 [03:12<00:00,  6.23it/s]


In [18]:
wiki_train_data_loader = DataLoader(TokenizedDataset(wiki_train_data, tokenize_input, 'cuda'), collate_fn=collate_fn)

wiki_test_data_loader = DataLoader(TokenizedDataset(wiki_test_data, tokenize_input, 'cuda'), collate_fn=collate_fn)

100%|██████████| 9000/9000 [05:10<00:00, 29.01it/s]
100%|██████████| 1200/1200 [05:19<00:00,  3.76it/s]


In [19]:
import torch.nn as nn

class DocumentLikelihoodModel(nn.Module):
    def __init__(self, bert_model_name='bert-base-cased', hidden_size=768, ffl_depth=1, dropout_prob=0.3):
        super(DocumentLikelihoodModel, self).__init__()

        # Load pre-trained BERT model
        self.bert = BertModel.from_pretrained(bert_model_name)

        # Define a feedforward layer stack (FFL) over the CLS token
        layers = []
        input_size = hidden_size
        for _ in range(ffl_depth):
            layers.append(nn.Linear(input_size, hidden_size))
            layers.append(nn.BatchNorm1d(hidden_size))
            layers.append(nn.ReLU())                     # Non-linearity
            layers.append(nn.Dropout(dropout_prob))      # Dropout
            input_size = hidden_size

        self.ffl = nn.Sequential(*layers)

        # Final layer to output a single score
        self.output_layer = nn.Linear(hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        # Pass inputs through BERT
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)

        # Get the [CLS] token representation
        cls_output = outputs.last_hidden_state[:, 0, :]  # CLS token is the first one in the sequence

        # Pass CLS representation through the FFL
        ffl_output = self.ffl(cls_output)

        # Final linear layer to output a score
        score = self.output_layer(ffl_output).view(-1)
        return score

In [20]:
from torch.nn import functional as F

def compute_document_likelihood(model, input_ids, attention_masks):
    # Forward pass to get the logits
    return model(input_ids=input_ids, attention_mask=attention_masks)

In [21]:
from torch.nn import CrossEntropyLoss

CELoss = CrossEntropyLoss()

def train_document_likelihood_model(train_data_loader, model, optimizer, epochs=3, device='cuda'):
    model.to(device)
    model.train()
    total_queries = len(train_data_loader)

    for epoch in range(epochs):
        total_loss = 0

        for all_input_ids, all_attention_masks, num_positive, num_negative in tqdm(train_data_loader, desc=f"Epoch {epoch+1}", total=total_queries):
            
            all_similarities = compute_document_likelihood(model, all_input_ids, all_attention_masks)
            target = torch.cat([torch.ones(num_positive), torch.zeros(num_negative)], dim=0).to(device)
            
            loss = CELoss(all_similarities, target)
 
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / total_queries
        print(f'Epoch {epoch+1}, Average Loss: {avg_loss}')

In [22]:
def evaluate_document_likelihood_model(eval_data_loader, model, device='cuda'):
    model.to(device)
    model.eval()

    precision_at_1 = 0
    precision_at_10 = 0
    mrr_total = 0
    map_total = 0
    total_queries = len(eval_data_loader)
    
    with torch.no_grad():
        for all_input_ids, all_attention_masks, num_positive, num_negative in tqdm(eval_data_loader, desc="Evaluating", total=total_queries):
            split_input_ids = [all_input_ids[i:i+8] for i in range(0, all_input_ids.size(0), 8)]
            split_attention_masks = [all_attention_masks[i:i+8] for i in range(0, all_attention_masks.size(0), 8)]
            all_similarities = torch.cat([compute_document_likelihood(model, input_ids, attention_masks) for input_ids, attention_masks in zip(split_input_ids, split_attention_masks)], dim=0)

            # Get the ranking of the positive document (rank is based on the similarity score)
            rankings = torch.argsort(torch.argsort(all_similarities, descending=True))
            rank_of_positives = rankings[:num_positive]

            if 0 in rank_of_positives:
                precision_at_1 += 1
            precision_at_10 += len(list(filter(lambda x: x < 10, rank_of_positives))) / min(len(all_similarities), 10)

            # Compute Reciprocal Rank for MRR
            rank_of_first_positive = min(rank_of_positives) + 1
            mrr_total += 1 / rank_of_first_positive

            # Compute Average Precision for MAP
            ap = 0.0
            for i, rank in enumerate(sorted(rank_of_positives)):
                ap += (i + 1) / (rank + 1)  # Precision at each rank where a positive document appears
            map_total += ap / len(rank_of_positives)  # Average precision for this query

    # Average the metrics over all queries
    precision_at_1 /= total_queries
    precision_at_10 /= total_queries
    mrr_total /= total_queries
    map_total /= total_queries
    
    print(f'Precision@1: {precision_at_1}')
    print(f'Precision@10: {precision_at_10}')
    print(f'MRR: {mrr_total}')
    print(f'MAP: {map_total}')

In [23]:
def freeze_bert_layers(model, num_layers_to_freeze=10):
    # Freeze the embedding layer
    for param in model.bert.embeddings.parameters():
        param.requires_grad = False
    
    # Freeze the first `num_layers_to_freeze` encoder layers
    for layer in model.bert.encoder.layer[:num_layers_to_freeze]:
        for param in layer.parameters():
            param.requires_grad = False

In [24]:
# Train with Random/Global Negatives Hotpot
model_hotpot_random = DocumentLikelihoodModel(bert_model_name='bert-base-cased', ffl_depth=1, dropout_prob=0.3)
freeze_bert_layers(model_hotpot_random, num_layers_to_freeze=9)
# Optimizer
optimizer_hotpot_random = torch.optim.AdamW(model_hotpot_random.parameters(), lr=5e-5)

train_document_likelihood_model(hotpot_train_data_random_loader, model_hotpot_random, optimizer_hotpot_random)

model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

Epoch 1: 100%|██████████| 9000/9000 [29:58<00:00,  5.01it/s]


Epoch 1, Average Loss: 0.15396682263218306


Epoch 2: 100%|██████████| 9000/9000 [29:56<00:00,  5.01it/s]


Epoch 2, Average Loss: 0.05573871461482793


Epoch 3: 100%|██████████| 9000/9000 [30:00<00:00,  5.00it/s]

Epoch 3, Average Loss: 0.0392910169060394





In [25]:
evaluate_document_likelihood_model(hotpot_test_data_random_loader, model_hotpot_random)

Evaluating: 100%|██████████| 1200/1200 [11:11<00:00,  1.79it/s]

Precision@1: 0.9975
Precision@10: 0.24158333333333076
MRR: 0.9983974695205688
MAP: 0.9856060147285461





In [26]:
# Train with Inbatch Negatives Hotpot
model_hotpot_inbatch = DocumentLikelihoodModel(bert_model_name='bert-base-cased', ffl_depth=1, dropout_prob=0.3)
freeze_bert_layers(model_hotpot_inbatch, num_layers_to_freeze=9)
# Optimizer
optimizer_hotpot_inbatch = torch.optim.AdamW(model_hotpot_inbatch.parameters(), lr=5e-5)

train_document_likelihood_model(hotpot_train_data_inbatch_loader, model_hotpot_inbatch, optimizer_hotpot_inbatch)

Epoch 1: 100%|██████████| 9000/9000 [30:00<00:00,  5.00it/s]


Epoch 1, Average Loss: 0.209549385939978


Epoch 2: 100%|██████████| 9000/9000 [30:00<00:00,  5.00it/s]


Epoch 2, Average Loss: 0.10758508360001139


Epoch 3: 100%|██████████| 9000/9000 [30:01<00:00,  5.00it/s]

Epoch 3, Average Loss: 0.07743959869013275





In [27]:
evaluate_document_likelihood_model(hotpot_test_data_inbatch_loader, model_hotpot_inbatch)

Evaluating: 100%|██████████| 1200/1200 [11:11<00:00,  1.79it/s]

Precision@1: 0.9708333333333333
Precision@10: 0.241416666666664
MRR: 0.9851388931274414
MAP: 0.9590916037559509





In [28]:
# Train with WikiNQ
model_wiki = DocumentLikelihoodModel(bert_model_name='bert-base-cased', ffl_depth=1, dropout_prob=0.3)
freeze_bert_layers(model_wiki, num_layers_to_freeze=9)
# Optimizer
optimizer_wiki = torch.optim.AdamW(model_wiki.parameters(), lr=5e-5)

train_document_likelihood_model(wiki_train_data_loader, model_wiki, optimizer_wiki)

Epoch 1: 100%|██████████| 9000/9000 [30:04<00:00,  4.99it/s]


Epoch 1, Average Loss: 1.7166420945514822


Epoch 2: 100%|██████████| 9000/9000 [30:02<00:00,  4.99it/s]


Epoch 2, Average Loss: 1.3940939492808087


Epoch 3: 100%|██████████| 9000/9000 [30:01<00:00,  5.00it/s]

Epoch 3, Average Loss: 0.9949655522181031





In [29]:
evaluate_document_likelihood_model(wiki_test_data_loader, model_wiki)

Evaluating: 100%|██████████| 1200/1200 [18:17<00:00,  1.09it/s]

Precision@1: 0.2891666666666667
Precision@10: 0.2232499999999985
MRR: 0.45968538522720337
MAP: 0.38898196816444397



