In [25]:
import json

def load_hits_from_qrels_queries_corpus(qrels_file, queries_file, corpus_file=None):
    print(f"Loading qids from '{queries_file}'")
    queries = load_qids_to_queries(queries_file)

    print(f"Loading corpus from '{corpus_file}'")
    corpus = load_pids_to_passages(corpus_file) if corpus_file is not None else None

    # Step 3: Load qrels and combine all data
    results = {}
    with open(qrels_file, 'r') as f:
        for line in f:
            # Skip if the first line is the header
            if line.startswith("query-id"):
                continue

            qid, docid, score = line.strip().split('\t')
            score = float(score)

            # Initialize query entry if not already present
            if qid not in results:
                results[qid] = {'query': queries[qid], 'hits': []}

            # Create a hit entry
            hit = {
                'qid': qid,
                'docid': docid,
                'score': score,
                'content': corpus[docid] if corpus_file is not None else None
            }

            results[qid]['hits'].append(hit)

    # Step 4: Sort the queries by numeric qid and their hits by score
    rank_results = []
    for qid in sorted(results.keys(), key=lambda x: int(x.replace("test", "").replace("train", "").replace("dev", ""))):  # Sort by numeric qid
        sorted_hits = sorted(
            results[qid]['hits'], 
            key=lambda x: -x['score']  # Sort hits by score in descending order
        )
        rank_results.append({
            'query': results[qid]['query'],
            'hits': sorted_hits
        })

    return rank_results

def load_qids_to_queries(queries_file):
    queries = {}
    with open(queries_file, 'r') as f:
        for line in f:
            line = json.loads(line)
            qid, query = line["_id"], line["text"]
            queries[qid] = query
    return queries

def load_pids_to_passages(corpus_file):
    corpus = {}
    with open(corpus_file, 'r') as f:
        for line in f:
            data = json.loads(line)
            pid = data["_id"]
            
            # Extract title and text, combining them if the title exists
            title = data.get("title", "")
            text = data["text"]
            passage = title + "\n" + text if title and title.strip() else text
            
            corpus[pid] = passage
    return corpus

def load_qid_to_pid_to_score(qrels_file):
    qid_to_pid_to_score = {}
    with open(qrels_file, 'r') as f:
        for line in f:
            if line.startswith("query-id"):
                continue

            qid, pid, score = line.strip().split('\t')
            score = float(score)
            
            if qid not in qid_to_pid_to_score:
                qid_to_pid_to_score[qid] = {}
            qid_to_pid_to_score[qid][pid] = score
    return qid_to_pid_to_score


In [28]:

# load ground truth positives
# load top 30-50 reranked (negative) hits
# treat dataset as one long list of hits
#  - expand each query into its set of top k hits
#  - len(dataset) = sum(len(hits) for hits in dataset)
#  - getitem(idx) returns (idx // num queries) + (idx % num queries)

# load qid --> queries
# load pid --> passages
# load rank results as qid --> pid --> score
# load ground truth as qid --> pid --> score

# get top 1000 embedding rank results (DONE)
# send all embedding rank results to reranker (DONE)
# remove false negatives (within 95% of lowest ground truth rerank score) (DONE)

# need script that loads all query<>positive scores and sends to reranker for score (TODO)
# pass positive qid->pid->score to dataset so it can put positive score onto rank result (TODO)

# NOTE: we don't remove false negatives from rerank stage because we may still want to observe their behavior when sent through reranker or measure scoring

from torch.utils.data import Dataset
# from data_utils import load_qid_to_pid_to_score, load_pids_to_passages, load_hits_from_qrels_queries_corpus

class TeacherTriplesDataset(Dataset):
    def __init__(self, queries_path, corpus_path, negative_rank_results_path, ground_truth_path):
        self.ground_truth = load_qid_to_pid_to_score(ground_truth_path)
        self.corpus = load_pids_to_passages(corpus_path)
        rank_results = load_hits_from_qrels_queries_corpus(negative_rank_results_path, queries_path, corpus_path)

        self.negative_rank_results_with_positives = []
        for rank_results in rank_results:
            hits = rank_results['hits']
            qid = hits[0]['qid']
            if qid in self.ground_truth:
                for positive_id in self.ground_truth[qid]:
                    self.negative_rank_results_with_positives.append({
                        "query_id": qid,
                        "query": rank_results['query'],
                        "positive_id": positive_id,
                        # "positive_score": "TODO",
                        "hits": hits
                    })

        # Create index mapping: [(query_idx, hit_idx)]
        self.index_mapping = []
        for query_idx, rank_result in enumerate(self.negative_rank_results_with_positives):
            num_hits = len(rank_result['hits'])
            self.index_mapping.extend([(query_idx, hit_idx) for hit_idx in range(num_hits)])

    def __len__(self):
        return len(self.index_mapping)

    def __getitem__(self, idx):
        query_idx, hit_idx = self.index_mapping[idx]
        rank_result = self.negative_rank_results_with_positives[query_idx]
        query = rank_result['query']
        hit = rank_result['hits'][hit_idx]
        return {
            "query": query,
            "positive_id": rank_result['positive_id'],
            "positive": self.corpus[rank_result['positive_id']],
            # TODO: positive_score
            "negative_id": hit['docid'],
            "negative": hit['content'],
            "negative_score": hit['score']
        }


In [30]:
queries_path = "../data/nq/queries.jsonl"
corpus_path = "../data/nq/corpus.jsonl"
negative_rank_results_path = "../data/nq/bge_en_icl_qrels_1000_ip.tsv"
ground_truth_path = "../data/nq/qrels/test.tsv"
dataset = TeacherTriplesDataset(queries_path, corpus_path, negative_rank_results_path, ground_truth_path)

Loading qids from '../data/nq/queries.jsonl'
Loading corpus from '../data/nq/corpus.jsonl'


In [31]:
len(dataset)

4201000

In [32]:
dataset[4088653]

{'query': 'where was the tv show high chaparral filmed',
 'positive_id': 'doc114945',
 'positive': 'The High Chaparral\nAll the exterior filming was done at Old Tucson Studios in Arizona and in the nearby Saguaro National Park, although in a few later episodes there was some filming in California and (in season 3) in the Coronado National Forest south of Tucson. The interiors were generally filmed at the NBC television studios in Burbank, Los Angeles.[1]',
 'negative_id': 'doc1855892',
 'negative': "Hill Valley (Back to the Future)\nFor Back to the Future Part III, Hill Valley 1885 was filmed in Sonora, California. The producers were able to use the land rent-free as long as they left the buildings there. They agreed to leave everything except the Clock Tower. Interestingly, on August 10, 1996, a lightning bolt struck the town and it burned down[citation needed]. An arson fire on the Universal Studios Hollywood backlot on November 6, 1990, had previously destroyed much of Courthouse Sq