In [1]:
! pip install transformers datasets



In [2]:
from datasets import load_dataset

# Load HotpotQA dataset
hotpot_dataset = load_dataset("hotpot_qa", "distractor", trust_remote_code=True)

# Display HotpotQA dataset structure
print(hotpot_dataset)

Downloading builder script:   0%|          | 0.00/6.42k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/9.19k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/566M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/46.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/90447 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/7405 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'question', 'answer', 'type', 'level', 'supporting_facts', 'context'],
        num_rows: 90447
    })
    validation: Dataset({
        features: ['id', 'question', 'answer', 'type', 'level', 'supporting_facts', 'context'],
        num_rows: 7405
    })
})


In [3]:
# Load WikiNQ dataset
wiki_dataset = load_dataset("Tevatron/wikipedia-nq")

# Display WikiNQ dataset structure
print(wiki_dataset)

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/298M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/97.6M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/242M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/241k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/58622 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/6489 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3610 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['query_id', 'query', 'answers', 'positive_passages', 'negative_passages'],
        num_rows: 58622
    })
    dev: Dataset({
        features: ['query_id', 'query', 'answers', 'positive_passages', 'negative_passages'],
        num_rows: 6489
    })
    test: Dataset({
        features: ['query_id', 'query', 'answers', 'positive_passages', 'negative_passages'],
        num_rows: 3610
    })
})


In [4]:
hotpot_train_set = hotpot_dataset['train']
hotpot_test_set = hotpot_dataset['validation']

wiki_train_set = wiki_dataset['train']
wiki_test_set = wiki_dataset['dev']

In [5]:
small_hotpot_train_dataset = hotpot_train_set.shuffle(seed=42).select(range(6500))
small_hotpot_test_dataset = hotpot_test_set.shuffle(seed=42)

small_wiki_train_dataset = wiki_train_set.shuffle(seed=43).select(range(6500))
small_wiki_test_dataset = wiki_test_set.shuffle(seed=43)

In [6]:
import random

def get_random_negatives(positive_queries, corpus, num_negatives):
    """Randomly sample negatives excluding golden_docs."""
    available_queries = [query for query in corpus if query not in positive_queries]
    return random.sample(available_queries, min(len(available_queries), num_negatives))

In [7]:
def create_document_data(data, dataset):
    document_data = {}
    
    if dataset == 'hotpot':
        for entry in data:
            query = entry['question']
            golden_docs = entry['supporting_facts']
            for golden_doc_title in golden_docs['title']:
                if golden_doc_title not in document_data:
                    document_data[golden_doc_title] = {}
                    golden_doc_sentences = entry['context']['sentences'][entry['context']['title'].index(golden_doc_title)]
                    document_data[golden_doc_title]['doc'] = ''.join(golden_doc_sentences)
                    document_data[golden_doc_title]['positive_queries'] = []
                document_data[golden_doc_title]['positive_queries'].append(query)
    elif dataset == 'wiki':
        for entry in data:
            query = entry['query']
            golden_docs = entry['positive_passages']
            for golden_doc in golden_docs:
                if golden_doc['title'] not in document_data:
                    document_data[golden_doc['title']] = {}
                    document_data[golden_doc['title']]['doc'] = golden_doc['text']
                    document_data[golden_doc['title']]['positive_queries'] = []
                document_data[golden_doc['title']]['positive_queries'].append(query)
    else:
        assert(0)
    
    return document_data
    

In [8]:
def create_corpus(data, document_data, document_keys, dataset, sample_method, batch_size):
    if sample_method == 'global':
        if dataset == 'hotpot':
            global_corpus = [entry['question'] for entry in data]
        else:
            global_corpus = [entry['query'] for entry in data]
        return global_corpus
    elif sample_method == 'inbatch':
        batches = [document_keys[i:i + batch_size] for i in range(0, len(document_keys), batch_size)]
        batch_corpuses = []
        for batch in batches:
            batch_corpus = set()
            for doc_title in batch:
                batch_corpus.update(document_data[doc_title]['positive_queries'])
            batch_corpuses.append(list(batch_corpus))
        return batch_corpuses
    else:
        assert(0)

In [9]:
from tqdm import tqdm

# Preprocessing function
def preprocess_data_train(data, num_negatives, dataset='hotpot', sample_method='global', batch_size=32):
    """Preprocess data to create datapoints with negatives."""
    document_data = create_document_data(data, dataset)
    
    document_keys = list(document_data.keys())
    random.shuffle(document_keys)
    corpuses = create_corpus(data, document_data, document_keys, dataset, sample_method, batch_size)
  
    processed_data = []
    
    for idx, document_title in tqdm(enumerate(document_keys), total=len(document_keys)):
        document_info = document_data[document_title]
        doc = document_info['doc']
        positive_queries = document_info['positive_queries']
        
        if sample_method == 'global':
            corpus = corpuses
        else:
            corpus = corpuses[idx // batch_size]
        
        for positive_query in positive_queries:
            negatives = get_random_negatives(positive_queries, corpus, num_negatives)
            
            datapoint = {
                'doc': doc,
                'positives': [positive_query],
                'negatives': negatives,
            }
            processed_data.append(datapoint)

    return processed_data

In [10]:
def get_random_negatives_test(golden_doc_titles, corpus, num_negatives):
    """Randomly sample negatives excluding golden_docs."""
    available_doc_titles = [doc_title for doc_title in corpus if doc_title not in golden_doc_titles]
    sampled_titles = random.sample(available_doc_titles, min(len(available_doc_titles), num_negatives))
    negatives = [''.join(corpus[doc_title]) for doc_title in sampled_titles]
    return negatives

In [11]:
def create_corpus_hotpot_test(data, sample_method, batch_size):
    """Preprocess data to create datapoints with negatives."""
    if sample_method == 'global':
        global_corpus = {}
        for entry in data:
            for title, sentences in zip(entry['context']['title'], entry['context']['sentences']):
                if title not in global_corpus:
                    global_corpus[title] = ''.join(sentences)
        return global_corpus
    elif sample_method == 'inbatch':
        data_list = list(data)
        batches = [data_list[i:i + batch_size] for i in range(0, len(data_list), batch_size)]
        batch_corpuses = []
        for batch in batches:
            batch_corpus = {}
            for entry in batch:
                for title, sentences in zip(entry['context']['title'], entry['context']['sentences']):
                    if title not in batch_corpus:
                        batch_corpus[title] = ''.join(sentences)
            batch_corpuses.append(batch_corpus)
        return batch_corpuses
    else:
        assert(0)


In [12]:
# Preprocessing function Hotpot
def preprocess_data_test_hotpot(data, num_negatives, sample_method='global', batch_size=32):
    """Preprocess data to create datapoints with negatives."""
    corpuses = create_corpus_hotpot_test(data, sample_method, batch_size)

    processed_data = []

    for idx, entry in tqdm(enumerate(data), total=len(data)):
        query = entry['question']
        golden_doc_titles = entry['supporting_facts']['title']
        context = entry['context']
        if sample_method == 'global':
            corpus = corpuses
        else:
            corpus = corpuses[idx // batch_size]

        positives = [''.join(context['sentences'][context['title'].index(golden_doc_title)]) for golden_doc_title in golden_doc_titles]
        negatives = get_random_negatives_test(golden_doc_titles, corpus, num_negatives * len(positives))

        datapoint = {
            'query': query,
            'positives': positives,
            'negatives': negatives,
        }
        processed_data.append(datapoint)

    return processed_data

In [13]:
# Preprocessing function WikiNQ
def preprocess_data_test_wiki(data, num_negatives):
    processed_data = []

    for idx, entry in tqdm(enumerate(data), total=len(data)):
        query = entry['query']
        golden_docs = entry['positive_passages']
        negative_docs = entry['negative_passages']

        positives = [golden_doc['text'] for golden_doc in golden_docs]
        negatives = random.sample(negative_docs, min(len(negative_docs), num_negatives * len(positives)))
        negatives = [neg['text'] for neg in negatives]

        datapoint = {
            'query': query,
            'positives': positives,
            'negatives': negatives,
        }
        processed_data.append(datapoint)

    return processed_data

In [14]:
hotpot_train_data_random = preprocess_data_train(small_hotpot_train_dataset, 7, 'hotpot', 'global')[:6500]
hotpot_train_data_inbatch = preprocess_data_train(small_hotpot_train_dataset, 7, 'hotpot', 'inbatch')[:6500]

hotpot_test_data_random = preprocess_data_test_hotpot(small_hotpot_test_dataset, 15, 'global')[:850]
hotpot_test_data_inbatch = preprocess_data_test_hotpot(small_hotpot_test_dataset, 15, 'inbatch')[:850]

100%|██████████| 12247/12247 [00:10<00:00, 1178.36it/s]
100%|██████████| 12247/12247 [00:00<00:00, 42406.61it/s]
100%|██████████| 7405/7405 [01:08<00:00, 107.48it/s]
100%|██████████| 7405/7405 [00:08<00:00, 827.06it/s]


In [15]:
wiki_train_data_random = preprocess_data_train(small_wiki_train_dataset, 7, 'wiki', 'global')[:6500]
wiki_train_data_inbatch = preprocess_data_train(small_wiki_train_dataset, 7, 'wiki', 'inbatch')[:6500]

wiki_test_data = preprocess_data_test_wiki(small_wiki_test_dataset, 15)[:850]

100%|██████████| 30162/30162 [00:40<00:00, 737.91it/s]
100%|██████████| 30162/30162 [00:01<00:00, 25011.92it/s]
100%|██████████| 6489/6489 [00:10<00:00, 632.96it/s]


In [16]:
import warnings
import logging
warnings.filterwarnings('ignore')
logging.getLogger("transformers").setLevel(logging.ERROR)

In [17]:
import torch
from transformers import BertTokenizer, BertModel

tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

def tokenize_input(document, query):
    return tokenizer.encode_plus(
        document,
        query,
        add_special_tokens=True,
        max_length=512,
        padding='max_length',
        truncation=True,
        return_tensors="pt"
    )

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [18]:
from torch.utils.data import Dataset, DataLoader

def collate_fn(batch):
    return batch[0]

class TokenizedDataset(Dataset):
    def __init__(self, data_list, tokenizer_function, device):
        self.data = []
        for data in tqdm(data_list, total=len(data_list)):
            doc = data['doc']
            positives = data['positives']
            negatives = data['negatives']

            encoding_poses = [tokenize_input(doc, pos_query).to(device) for pos_query in positives]
            encoding_negs = [tokenize_input(doc, neg_query).to(device) for neg_query in negatives]

            all_input_ids = torch.cat([encoding_pos['input_ids'] for encoding_pos in encoding_poses] + [encoding_neg['input_ids'] for encoding_neg in encoding_negs], dim=0)
            all_attention_masks = torch.cat([encoding_pos['attention_mask'] for encoding_pos in encoding_poses] + [encoding_neg['attention_mask'] for encoding_neg in encoding_negs], dim=0)
            
            self.data.append((all_input_ids, all_attention_masks, len(positives), len(negatives)))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    

class TokenizedDatasetTest(Dataset):
    def __init__(self, data_list, tokenizer_function, device):
        self.data = []
        for data in tqdm(data_list, total=len(data_list)):
            query = data['query']
            positives = data['positives']
            negatives = data['negatives']

            encoding_poses = [tokenize_input(pos_doc, query).to(device) for pos_doc in positives]
            encoding_negs = [tokenize_input(neg_doc, query).to(device) for neg_doc in negatives]

            all_input_ids = torch.cat([encoding_pos['input_ids'] for encoding_pos in encoding_poses] + [encoding_neg['input_ids'] for encoding_neg in encoding_negs], dim=0)
            all_attention_masks = torch.cat([encoding_pos['attention_mask'] for encoding_pos in encoding_poses] + [encoding_neg['attention_mask'] for encoding_neg in encoding_negs], dim=0)
            
            self.data.append((all_input_ids, all_attention_masks, len(positives), len(negatives)))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [19]:
hotpot_train_data_random_loader = DataLoader(TokenizedDataset(hotpot_train_data_random, tokenize_input, 'cuda'), collate_fn=collate_fn)
hotpot_train_data_inbatch_loader = DataLoader(TokenizedDataset(hotpot_train_data_inbatch, tokenize_input, 'cuda'), collate_fn=collate_fn)

hotpot_test_data_random_loader = DataLoader(TokenizedDatasetTest(hotpot_test_data_random, tokenize_input, 'cuda'), collate_fn=collate_fn)
hotpot_test_data_inbatch_loader = DataLoader(TokenizedDatasetTest(hotpot_test_data_inbatch, tokenize_input, 'cuda'), collate_fn=collate_fn)

100%|██████████| 6500/6500 [03:33<00:00, 30.44it/s]
 41%|████      | 2673/6500 [01:28<02:06, 30.36it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
 75%|███████▍  | 4874/6500 [02:39<00:50, 32.43it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list wi

In [20]:
wiki_train_data_random_loader = DataLoader(TokenizedDataset(wiki_train_data_random, tokenize_input, 'cuda'), collate_fn=collate_fn)
wiki_train_data_inbatch_loader = DataLoader(TokenizedDataset(wiki_train_data_inbatch, tokenize_input, 'cuda'), collate_fn=collate_fn)

wiki_test_data_loader = DataLoader(TokenizedDatasetTest(wiki_test_data, tokenize_input, 'cuda'), collate_fn=collate_fn)

100%|██████████| 6500/6500 [04:13<00:00, 25.67it/s]
100%|██████████| 6500/6500 [04:13<00:00, 25.64it/s]
  2%|▏         | 14/850 [00:04<04:17,  3.25it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
 71%|███████   | 603/850 [03:04<01:32,  2.68it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
100%|██████████| 850/850 [04:14<00:00,  3.34it/s]


In [21]:
import torch.nn as nn

class QueryLikelihoodModel(nn.Module):
    def __init__(self, bert_model_name='bert-base-cased', hidden_size=768, ffl_depth=1, dropout_prob=0.3):
        super(QueryLikelihoodModel, self).__init__()

        # Load pre-trained BERT model
        self.bert = BertModel.from_pretrained(bert_model_name)

        # Define a feedforward layer stack (FFL) over the CLS token
        layers = []
        input_size = hidden_size
        for _ in range(ffl_depth):
            layers.append(nn.Linear(input_size, hidden_size))
            layers.append(nn.BatchNorm1d(hidden_size))
            layers.append(nn.ReLU())                     # Non-linearity
            layers.append(nn.Dropout(dropout_prob))      # Dropout
            input_size = hidden_size

        self.ffl = nn.Sequential(*layers)

        # Final layer to output a single score
        self.output_layer = nn.Linear(hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        # Pass inputs through BERT
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)

        # Get the [CLS] token representation
        cls_output = outputs.last_hidden_state[:, 0, :]  # CLS token is the first one in the sequence

        # Pass CLS representation through the FFL
        ffl_output = self.ffl(cls_output)

        # Final linear layer to output a score
        score = self.output_layer(ffl_output).view(-1)
        return score

In [22]:
def compute_query_likelihood(model, input_ids, attention_masks):
    return model(input_ids=input_ids, attention_mask=attention_masks)

In [23]:
from torch.nn import CrossEntropyLoss

CELoss = CrossEntropyLoss()

def train_query_likelihood_model(train_data_loader, model, optimizer, epochs=3, device='cuda'):
    model.to(device)
    model.train()
    total_documents = len(train_data_loader)

    for epoch in range(epochs):
        total_loss = 0

        for all_input_ids, all_attention_masks, num_positive, num_negative in tqdm(train_data_loader, desc=f"Epoch {epoch+1}", total=total_documents):

            all_similarities = compute_query_likelihood(model, all_input_ids, all_attention_masks)
            target = torch.cat([torch.ones(num_positive), torch.zeros(num_negative)], dim=0).to(device)

            loss = CELoss(all_similarities, target)
            
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / total_documents
        print(f'Epoch {epoch+1}, Average Loss: {avg_loss}')

In [24]:
def evaluate_query_likelihood_model(eval_data_loader, model, device='cuda'):
    model.to(device)
    model.eval()

    precision_at_1 = 0
    precision_at_10 = 0
    mrr_total = 0
    map_total = 0
    total_documents = len(eval_data_loader)
    
    with torch.no_grad():
        for all_input_ids, all_attention_masks, num_positive, num_negative in tqdm(eval_data_loader, desc="Evaluating", total=total_documents):
            split_input_ids = [all_input_ids[i:i+8] for i in range(0, all_input_ids.size(0), 8)]
            split_attention_masks = [all_attention_masks[i:i+8] for i in range(0, all_attention_masks.size(0), 8)]
            all_similarities = torch.cat([compute_query_likelihood(model, input_ids, attention_masks) for input_ids, attention_masks in zip(split_input_ids, split_attention_masks)], dim=0)

            # Get the ranking of the positive query (rank is based on the similarity score)
            rankings = torch.argsort(torch.argsort(all_similarities, descending=True))
            rank_of_positives = rankings[:num_positive]

            if 0 in rank_of_positives:
                precision_at_1 += 1
            precision_at_10 += len(list(filter(lambda x: x < 10, rank_of_positives))) / min(len(all_similarities), 10)

            # Compute Reciprocal Rank for MRR
            rank_of_first_positive = min(rank_of_positives) + 1
            mrr_total += 1 / rank_of_first_positive

            # Compute Average Precision for MAP
            ap = 0.0
            for i, rank in enumerate(sorted(rank_of_positives)):
                ap += (i + 1) / (rank + 1)  # Precision at each rank where a positive query appears
            map_total += ap / len(rank_of_positives)  # Average precision for this document

    # Average the metrics over all documents
    precision_at_1 /= total_documents
    precision_at_10 /= total_documents
    mrr_total /= total_documents
    map_total /= total_documents
    
    print(f'Precision@1: {precision_at_1}')
    print(f'Precision@10: {precision_at_10}')
    print(f'MRR: {mrr_total}')
    print(f'MAP: {map_total}')

In [25]:
def freeze_bert_layers(model, num_layers_to_freeze=10):
    # Freeze the embedding layer
    for param in model.bert.embeddings.parameters():
        param.requires_grad = False
    
    # Freeze the first `num_layers_to_freeze` encoder layers
    for layer in model.bert.encoder.layer[:num_layers_to_freeze]:
        for param in layer.parameters():
            param.requires_grad = False

In [26]:
# Train with Random/Global Negatives Hotpot
model_hotpot_random = QueryLikelihoodModel(bert_model_name='bert-base-cased', ffl_depth=1, dropout_prob=0.3)
freeze_bert_layers(model_hotpot_random, num_layers_to_freeze=9)
# Optimizer
optimizer_hotpot_random = torch.optim.AdamW(model_hotpot_random.parameters(), lr=5e-5)

train_query_likelihood_model(hotpot_train_data_random_loader, model_hotpot_random, optimizer_hotpot_random)

model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

Epoch 1: 100%|██████████| 6500/6500 [21:53<00:00,  4.95it/s]


Epoch 1, Average Loss: 0.11328912297901433


Epoch 2: 100%|██████████| 6500/6500 [21:54<00:00,  4.95it/s]


Epoch 2, Average Loss: 0.04486722542699367


Epoch 3: 100%|██████████| 6500/6500 [21:55<00:00,  4.94it/s]

Epoch 3, Average Loss: 0.03437688052639605





In [27]:
evaluate_query_likelihood_model(hotpot_test_data_random_loader, model_hotpot_random)

Evaluating: 100%|██████████| 850/850 [08:01<00:00,  1.77it/s]

Precision@1: 0.98117182480116852
Precision@10: 0.24104647058823368
MRR: 0.9823273205757141
MAP: 0.9728959465026855





In [28]:
# Train with Inbatch Negatives Hotpot
model_hotpot_inbatch = QueryLikelihoodModel(bert_model_name='bert-base-cased', ffl_depth=1, dropout_prob=0.3)
freeze_bert_layers(model_hotpot_inbatch, num_layers_to_freeze=9)
# Optimizer
optimizer_hotpot_inbatch = torch.optim.AdamW(model_hotpot_inbatch.parameters(), lr=5e-5)

train_query_likelihood_model(hotpot_train_data_inbatch_loader, model_hotpot_inbatch, optimizer_hotpot_inbatch)

Epoch 1: 100%|██████████| 6500/6500 [21:56<00:00,  4.94it/s]


Epoch 1, Average Loss: 0.13144763162061157


Epoch 2: 100%|██████████| 6500/6500 [21:56<00:00,  4.94it/s]


Epoch 2, Average Loss: 0.048327771609159305


Epoch 3: 100%|██████████| 6500/6500 [21:56<00:00,  4.94it/s]

Epoch 3, Average Loss: 0.03278356602511896





In [29]:
evaluate_query_likelihood_model(hotpot_test_data_inbatch_loader, model_hotpot_inbatch)

Evaluating: 100%|██████████| 850/850 [07:59<00:00,  1.77it/s]

Precision@1: 0.938235294117647
Precision@10: 0.2302352941176454
MRR: 0.968541886806488
MAP: 0.920303111076355





In [30]:
# Train with Random/Global Negatives WikiNQ
model_wiki_random = QueryLikelihoodModel(bert_model_name='bert-base-cased', ffl_depth=1, dropout_prob=0.3)
freeze_bert_layers(model_wiki_random, num_layers_to_freeze=9)
# Optimizer
optimizer_wiki_random = torch.optim.AdamW(model_wiki_random.parameters(), lr=5e-5)

train_query_likelihood_model(wiki_train_data_random_loader, model_wiki_random, optimizer_wiki_random)

Epoch 1: 100%|██████████| 6500/6500 [21:57<00:00,  4.93it/s]


Epoch 1, Average Loss: 0.3502412207183986


Epoch 2: 100%|██████████| 6500/6500 [21:57<00:00,  4.93it/s]


Epoch 2, Average Loss: 0.1584025215466554


Epoch 3: 100%|██████████| 6500/6500 [21:58<00:00,  4.93it/s]

Epoch 3, Average Loss: 0.08991518079438457





In [31]:
evaluate_query_likelihood_model(wiki_test_data_loader, model_wiki_random)

Evaluating: 100%|██████████| 850/850 [13:07<00:00,  1.08it/s]

Precision@1: 0.1952941176470588
Precision@10: 0.16952941176470532
MRR: 0.36243546962738037
MAP: 0.2948113179206848





In [32]:
# Train with Inbatch Negatives WikiNQ
model_wiki_inbatch = QueryLikelihoodModel(bert_model_name='bert-base-cased', ffl_depth=1, dropout_prob=0.3)
freeze_bert_layers(model_wiki_inbatch, num_layers_to_freeze=9)
# Optimizer
optimizer_wiki_inbatch = torch.optim.AdamW(model_wiki_inbatch.parameters(), lr=5e-5)

train_query_likelihood_model(wiki_train_data_inbatch_loader, model_wiki_inbatch, optimizer_wiki_inbatch)

Epoch 1: 100%|██████████| 6500/6500 [21:57<00:00,  4.93it/s]


Epoch 1, Average Loss: 0.36544184492183784


Epoch 2: 100%|██████████| 6500/6500 [21:56<00:00,  4.94it/s]


Epoch 2, Average Loss: 0.16433344044881507


Epoch 3: 100%|██████████| 6500/6500 [21:56<00:00,  4.94it/s]

Epoch 3, Average Loss: 0.12027742949814962





In [33]:
evaluate_query_likelihood_model(wiki_test_data_loader, model_wiki_inbatch)

Evaluating: 100%|██████████| 850/850 [13:07<00:00,  1.08it/s]

Precision@1: 0.24705882352941176
Precision@10: 0.1789411764705875
MRR: 0.4355828046798706
MAP: 0.36232834815979004



