In [1]:
import copy
from tqdm import tqdm
import json
import re
import os
import gzip
from rank_gpt import run_retriever, sliding_windows, write_eval_file

def load_topics(topics_file):
    topics = {}
    with open(topics_file, 'r') as f:
        for line in f:
            qid, query = line.strip().split('\t')
            topics[qid] = query
    return topics

def get_document_wrapper(doc_id):
    try:
        document = get_document(doc_id)
        return {'docid': doc_id, 'content': document['segment']}
    except Exception as e:
        print(f"Error retrieving document {doc_id}: {e}")
        return None

def get_document(doc_id, base_path="/root/data/msmarco_v2.1_doc_segmented/"):
    match = re.match(r'msmarco_v2\.1_doc_(\d+)_(\d+)#(\d+)_(\d+)', doc_id)
    if not match:
        raise ValueError(f"Invalid doc_id format: {doc_id}")
    
    shard_number = int(match.group(1))
    byte_offset = int(match.group(4))
    file_path = os.path.join(base_path, f"msmarco_v2.1_doc_segmented_{shard_number:02d}.json.gz")

    with gzip.open(file_path, 'rb') as f:
        f.seek(byte_offset)
        line = f.readline().decode('utf-8')
        
        try:
            document = json.loads(line)
            if document['docid'] == doc_id:
                return document
            else:
                raise ValueError(f"Document at offset does not match requested doc_id: {doc_id}")
        except json.JSONDecodeError:
            raise ValueError(f"Invalid JSON at offset {byte_offset} in file {file_path}")

def load_results(results_file):
    results = {}
    with open(results_file, 'r') as f:
        for line in f:
            qid, _, docid, rank, score, _ = line.strip().split()
            if qid not in results:
                results[qid] = []
            results[qid].append({'docid': docid, 'rank': int(rank), 'score': float(score)})
    return results

# def prepare_rank_results(topics, results):
#     rank_results = []
#     for qid, query in topics.items():
#         if qid in results:
#             item = {"query": query, "hits": []}
#             for hit in results[qid]:
#                 doc = get_document_wrapper(hit['docid'])
#                 if doc:
#                     item["hits"].append({
#                         "content": doc['content'],
#                         "qid": qid,
#                         "docid": hit['docid'],
#                         "rank": hit['rank'],
#                         "score": hit['score']
#                     })
#             rank_results.append(item)
#     return rank_results


In [2]:
import multiprocessing
from functools import partial
def prepare_rank_results(topics, results, max_results=100, max_topics=2, num_processes=16):
    rank_results = []
    
    for qid, query in tqdm(list(topics.items())[:max_topics], desc="Processing topics"):
        if qid in results:
            hits = results[qid][:max_results]
            
            with multiprocessing.Pool(processes=num_processes) as pool:
                documents = list(tqdm(
                    pool.imap(get_document_wrapper, [hit['docid'] for hit in hits]),
                    total=len(hits),
                    desc=f"Retrieving documents for query {qid}"
                ))
            
            item = {"query": query, "hits": []}
            for hit, doc in zip(hits, documents):
                if doc:
                    item["hits"].append({
                        "content": doc['content'],
                        "qid": qid,
                        "docid": hit['docid'],
                        "rank": hit['rank'],
                        "score": hit['score']
                    })
            
            rank_results.append(item)
    
    return rank_results



In [3]:
import env
topics_file = "topics.rag24.raggy-dev.txt"
results_file = "raggy-dev_results.txt"
api_key = env.GOOGLE_API_KEY

topics = load_topics(topics_file)
results = load_results(results_file)
rank_results = prepare_rank_results(topics, results, max_results=100, max_topics=3)

new_results = []
for item in tqdm(rank_results):
    new_item = sliding_windows(
        item, 
        rank_start=0, 
        rank_end=100, 
        window_size=100, 
        # step=10, 
        model_name='gemini', 
        api_key=api_key
    )
    new_results.append(new_item)

# Write the reranked results to a file
output_file = "gemini_rerank_test.txt"
write_eval_file(new_results, output_file)

print(f"Reranked results have been written to {output_file}")

Processing topics:   0%|          | 0/3 [00:00<?, ?it/s]

Retrieving documents for query 2001010: 100%|██████████| 100/100 [00:41<00:00,  2.43it/s]
Retrieving documents for query 2001459: 100%|██████████| 100/100 [00:48<00:00,  2.05it/s]
Retrieving documents for query 2002075: 100%|██████████| 100/100 [00:46<00:00,  2.13it/s]
Processing topics: 100%|██████████| 3/3 [02:16<00:00, 45.60s/it]
  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 3/3 [00:14<00:00,  4.81s/it]

Reranked results have been written to gemini_rerank_test.txt



