In [1]:
import json

def load_hits_from_qrels_queries_corpus(qrels_file, queries_file, corpus_file=None):
    print(f"Loading qids from '{queries_file}'")
    queries = load_qids_to_queries(queries_file)

    print(f"Loading corpus from '{corpus_file}'")
    corpus = load_pids_to_passages(corpus_file) if corpus_file is not None else None

    # Step 3: Load qrels and combine all data
    results = {}
    with open(qrels_file, 'r') as f:
        for line in f:
            # Skip if the first line is the header
            if line.startswith("query-id"):
                continue

            qid, docid, score = line.strip().split('\t')
            score = float(score)

            # Initialize query entry if not already present
            if qid not in results:
                results[qid] = {'query': queries[qid], 'hits': []}

            # Create a hit entry
            hit = {
                'qid': qid,
                'docid': docid,
                'score': score,
                'content': corpus[docid] if corpus_file is not None else None
            }

            results[qid]['hits'].append(hit)

    # Step 4: Sort the queries by numeric qid and their hits by score
    rank_results = []
    for qid in sorted(results.keys(), key=lambda x: int(x.replace("test", "").replace("train", "").replace("dev", ""))):  # Sort by numeric qid
        sorted_hits = sorted(
            results[qid]['hits'], 
            key=lambda x: -x['score']  # Sort hits by score in descending order
        )
        rank_results.append({
            'query': results[qid]['query'],
            'hits': sorted_hits
        })

    return rank_results

def load_qids_to_queries(queries_file):
    queries = {}
    with open(queries_file, 'r') as f:
        for line in f:
            line = json.loads(line)
            qid, query = line["_id"], line["text"]
            queries[qid] = query
    return queries

def load_pids_to_passages(corpus_file):
    corpus = {}
    with open(corpus_file, 'r') as f:
        for line in f:
            data = json.loads(line)
            pid = data["_id"]
            
            # Extract title and text, combining them if the title exists
            title = data.get("title", "")
            text = data["text"]
            passage = title + "\n" + text if title and title.strip() else text
            
            corpus[pid] = passage
    return corpus

def load_qid_to_pid_to_score(qrels_file):
    qid_to_pid_to_score = {}
    with open(qrels_file, 'r') as f:
        for line in f:
            if line.startswith("query-id"):
                continue

            qid, pid, score = line.strip().split('\t')
            score = float(score)
            
            if qid not in qid_to_pid_to_score:
                qid_to_pid_to_score[qid] = {}
            qid_to_pid_to_score[qid][pid] = score
    return qid_to_pid_to_score


In [4]:
import torch
from torch.utils.data import Dataset
# from data_utils import load_qid_to_pid_to_score, load_pids_to_passages, load_hits_from_qrels_queries_corpus, strip_prefixes
import random

class PositiveNegativeDataset(Dataset):
    def __init__(self, queries_path, corpus_path, negative_rank_results_path, positive_rank_results_path, tokenizer, max_seq_len=None, num_neg_per_pos=8, seed=43):
        self.tokenizer = tokenizer
        self.positive_rank_results = load_qid_to_pid_to_score(positive_rank_results_path)
        self.corpus = load_pids_to_passages(corpus_path)
        negative_rank_results = load_hits_from_qrels_queries_corpus(negative_rank_results_path, queries_path, corpus_path)
        self.max_seq_len = max_seq_len
        self.truncation = max_seq_len is not None
        self.num_neg_per_pos = num_neg_per_pos  # Number of negatives to sample per positive
        self.seed = seed  # Global seed for reproducibility
        
        local_rng = random.Random(seed)
        self.negative_rank_results_with_positives = []
        for rank_result in negative_rank_results:
            hits = rank_result['hits']
            qid = hits[0]['qid']
            if qid in self.positive_rank_results:
                for positive_id in self.positive_rank_results[qid]:
                    positive_score = self.positive_rank_results[qid][positive_id]
                    
                    # Shuffle hits once for each query before creating the dataset
                    local_rng.shuffle(hits)
                    
                    self.negative_rank_results_with_positives.append({
                        "query_id": qid,
                        "query": rank_result['query'],
                        "positive_id": positive_id,
                        "positive_score": positive_score,
                        "hits": hits  # All hits for negative sampling
                    })

        # Create index mapping: [(query_idx, neg_group_idx)]
        self.index_mapping = []
        for query_idx, rank_result in enumerate(self.negative_rank_results_with_positives):
            num_hits = len([hit for hit in rank_result['hits'] if hit['docid'] != rank_result['positive_id']])
            num_groups = num_hits // self.num_neg_per_pos
            self.index_mapping.extend([(query_idx, group_idx) for group_idx in range(num_groups)])

    def __len__(self):
        return len(self.index_mapping)

    def __getitem__(self, idx):
        query_idx, group_idx = self.index_mapping[idx]
        rank_result = self.negative_rank_results_with_positives[query_idx]
        query = rank_result['query']

        # Positive passage
        positive_id = rank_result['positive_id']
        positive_passage = self.corpus[positive_id]

        # Determine negative samples for the current group
        start_idx = group_idx * self.num_neg_per_pos
        end_idx = start_idx + self.num_neg_per_pos
        negative_candidates = [hit for hit in rank_result['hits'] if hit['docid'] != positive_id]
        hard_negatives = negative_candidates[start_idx:end_idx]

        return {
            "query_id": rank_result['query_id'],
            "query": query,
            "positive_id": positive_id,
            "positive": positive_passage,
            "positive_label": rank_result['positive_score'],
            "negative_ids": [neg['docid'] for neg in hard_negatives],
            "negatives": [self.corpus[neg['docid']] for neg in hard_negatives],
            "negative_labels": [neg['score'] for neg in hard_negatives]
        }

    def collate_fn(self, batch):
        queries = [item['query'] for item in batch]
        positive_passages = [item['positive'] for item in batch]
        positive_labels = [item['positive_label'] for item in batch]
        negatives_flattened = [neg for item in batch for neg in item['negatives']]
        negative_labels_flattened = [label for item in batch for label in item['negative_labels']]
        
        # Tokenize positives and negatives
        tokenized_positives = self.tokenizer(queries, positive_passages, padding=True, truncation=self.truncation, return_tensors="pt", max_length=self.max_seq_len)
        repeated_queries = [query for query in queries for _ in range(self.num_neg_per_pos)]
        tokenized_negatives = self.tokenizer(repeated_queries, negatives_flattened, padding=True, truncation=self.truncation, return_tensors="pt", max_length=self.max_seq_len)
        
        return {
            "positives": tokenized_positives,
            "positive_labels": torch.tensor(positive_labels),
            "negatives": tokenized_negatives,
            "negative_labels": torch.tensor(negative_labels_flattened)
        }


In [30]:
from transformers import AutoTokenizer

queries_path = "../data/nq-train/queries_sampled_10000.jsonl"
corpus_path = "../data/nq/corpus.jsonl"
negative_rank_results_path = "../data/nq-train/nv_rerank_negatives_top100_sampled_10000_filtered.tsv"
positive_rank_results_path = "../data/nq-train/nv_rerank_positives_train_sampled_10000.tsv"
tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-large")
dataset = PositiveNegativeDataset(queries_path, corpus_path, negative_rank_results_path, positive_rank_results_path, tokenizer, num_neg_per_pos=2)

Loading qids from '../data/nq-train/queries_sampled_10000.jsonl'
Loading corpus from '../data/nq/corpus.jsonl'


In [27]:
len(dataset)

966458

In [29]:
dataset[966350]

{'query_id': 'train128285',
 'query': 'rome was sacked in 410 by the goths who were led by',
 'positive_id': 'doc1638137',
 'positive': 'Sack of Rome (410)\nThe Sack of Rome occurred on August 24, 410. The city was attacked by the Visigoths led by King Alaric. At that time, Rome was no longer the capital of the Western Roman Empire, having been replaced in that position first by Mediolanum in 286 and then by Ravenna in 402. Nevertheless, the city of Rome retained a paramount position as "the eternal city" and a spiritual center of the Empire. The sack was a major shock to contemporaries, friends and foes of the Empire alike.',
 'positive_label': 27.171875,
 'negative_ids': ['doc1638161'],
 'negatives': ["Sack of Rome (410)\nInfuriated, Alaric broke off negotiations, and Jovius returned to Ravenna to strengthen his relationship with the Emperor. Honorius was now firmly committed to war, and Jovius swore on the Emperor's head to never to make peace with Alaric. Alaric himself soon change

In [33]:
dataloaders = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=dataset.collate_fn)

In [34]:
for batch in dataloaders:
    print(batch)
    break

{'positives': {'input_ids': tensor([[     1,    328,    490,  24936,    452,  15968,   9586,    267,    262,
            550,      2,  59801,  17841,  59801,    269,   7939,    311,   1066,
            335,    373,    269,    411,   6294,    270,    266,  13504,    272,
            373,    269,    264,   9586,  77212,    272,    406,    260,  97681,
          75544,  59801,    263,   3836,    277,   3910,    265,    342,   1503,
            264,    527,    342,    557,    482,    262,   4271,    264,    800,
            839,  17444,    260,  59801,  25021,  77212,    263,  13936,    264,
          10979,    283,    313,   4461,    264,    552,    315,  25845,    441,
            342,    261,    299,    539,   2855,   1969,    267,  82756,    260,
            344,    930,    261,  97681,   8466,  13634,   1310,    725,    268,
            264,   1727,    283,    266,  13725,    324,    272,  77212,    295,
          25845,    315,   9147,    267,   9239,    265,    315,    782,  24772,


In [27]:
batch["positive_labels"]

tensor([20.3906,  8.1562])

In [26]:
batch["positives"].keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])