In [8]:
import torch
from transformers import AutoModelForSequenceClassification
import re
import os
import gzip
import json
import tqdm
import multiprocessing

def get_document(doc_id, base_path="/root/data/msmarco_v2.1_doc_segmented/"):
    match = re.match(r'msmarco_v2\.1_doc_(\d+)_(\d+)#(\d+)_(\d+)', doc_id)
    if not match:
        raise ValueError(f"Invalid doc_id format: {doc_id}")
    
    shard_number = int(match.group(1))
    byte_offset = int(match.group(4))
    file_path = os.path.join(base_path, f"msmarco_v2.1_doc_segmented_{shard_number:02d}.json.gz")

    with gzip.open(file_path, 'rb') as f:
        f.seek(byte_offset)
        line = f.readline().decode('utf-8')
        
        try:
            document = json.loads(line)
            if document['docid'] == doc_id:
                return document
            else:
                raise ValueError(f"Document at offset does not match requested doc_id: {doc_id}")
        except json.JSONDecodeError:
            raise ValueError(f"Invalid JSON at offset {byte_offset} in file {file_path}")

def get_document_wrapper(doc_id):
    try:
        document = get_document(doc_id)
        return {'docid': doc_id, 'content': document['segment']}
    except Exception as e:
        print(f"Error retrieving document {doc_id}: {e}")
        return None

def rerank_results(query, results, batch_size=32):
    with multiprocessing.Pool(processes=16) as pool:
        documents = list(tqdm.tqdm(
            pool.imap(get_document_wrapper, [result['doc_id'] for result in results]),
            total=len(results),
            desc="Gathering documents"
        ))
    
    documents = [doc for doc in documents if doc is not None]
    
    all_scores = []
    for i in range(0, len(documents), batch_size):
        batch = documents[i:i+batch_size]
        sentence_pairs = [[query, doc['content']] for doc in batch]
        
        with torch.inference_mode():
            scores = model.compute_score(sentence_pairs, max_length=1024)
        
        all_scores.extend(scores.tolist())
    
    # Combine scores with original results and sort
    for result, score in zip(results[:len(all_scores)], all_scores):
        result['new_score'] = score
    
    reranked_results = sorted(results[:len(all_scores)], key=lambda x: x['new_score'], reverse=True)
    return reranked_results
def parse_results(results_string):
    parsed_results = []
    for line in results_string.strip().split('\n'):
        parts = line.split()
        parsed_results.append({
            'query_id': parts[0],
            'q0': parts[1],
            'doc_id': parts[2],
            'rank': int(parts[3]),
            'score': float(parts[4]),
            'run_id': parts[5]
        })
    return parsed_results

# Main reranking function
def rerank_and_save(input_file, output_file):
    with open(input_file, 'r') as f:
        input_data = f.read()
    
    results = parse_results(input_data)
    
    # Group results by query_id
    grouped_results = {}
    for result in results:
        query_id = result['query_id']
        if query_id not in grouped_results:
            grouped_results[query_id] = []
        grouped_results[query_id].append(result)
    
    # Rerank each group
    reranked_groups = {}
    for query_id, group in tqdm.tqdm(grouped_results.items()):
        query = f"Query for {query_id}"  # Replace with actual query if available
        reranked_groups[query_id] = rerank_results(query, group)
    
    # Write reranked results
    with open(output_file, 'w') as f:
        for query_id, reranked_results in reranked_groups.items():
            for new_rank, result in enumerate(reranked_results, start=1):
                f.write(f"{query_id} Q0 {result['doc_id']} {new_rank} {result['new_score']:.8f} cross-encoder\n")


In [14]:
model = AutoModelForSequenceClassification.from_pretrained(
    'jinaai/jina-reranker-v2-base-multilingual',
    torch_dtype="auto",
    trust_remote_code=True,
)

model.to('cuda')  # Use 'cpu' if no GPU is available
model.eval()

flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.
flash_attn is not installed. Using PyTorch native attention implementation.


XLMRobertaForSequenceClassification(
  (roberta): XLMRobertaModel(
    (embeddings): XLMRobertaEmbeddings(
      (word_embeddings): Embedding(250002, 768, padding_idx=1)
      (position_embeddings): Embedding(1026, 768)
      (token_type_embeddings): Embedding(1, 768)
    )
    (emb_drop): Dropout(p=0.1, inplace=False)
    (emb_ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (encoder): XLMRobertaEncoder(
      (layers): ModuleList(
        (0-11): 12 x Block(
          (mixer): MHA(
            (Wqkv): LinearResidual(in_features=768, out_features=2304, bias=True)
            (inner_attn): SelfAttention(
              (drop): Dropout(p=0.1, inplace=False)
            )
            (inner_cross_attn): CrossAttention(
              (drop): Dropout(p=0.1, inplace=False)
            )
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (dropout1): Dropout(p=0.1, inplace=False)
          (drop_path1): StochasticDepth(p=0.0, mode=r

In [15]:
model = torch.compile(model)

In [9]:
input_file = "test_results.txt"
output_file = "reranked_results.txt"
rerank_and_save(input_file, output_file)

  0%|          | 0/301 [00:00<?, ?it/s]

Gathering documents: 100%|██████████| 125/125 [01:04<00:00,  1.94it/s]
  0%|          | 0/301 [01:07<?, ?it/s]


AttributeError: 'list' object has no attribute 'tolist'

In [16]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
def rerank_results(query, results, batch_size=32):
    print(f"Reranking for query: {query}")
    print(f"Number of results: {len(results)}")
    
    with multiprocessing.Pool(processes=32) as pool:
        documents = list(tqdm.tqdm(
            pool.imap(get_document_wrapper, [result['doc_id'] for result in results]),
            total=len(results),
            desc="Gathering documents"
        ))
    
    documents = [doc for doc in documents if doc is not None]
    print(f"Number of valid documents: {len(documents)}")
    
    all_scores = []
    for i in range(0, len(documents), batch_size):
        batch = documents[i:i+batch_size]
        sentence_pairs = [[query, doc['content']] for doc in batch]
        
        print(f"Processing batch {i//batch_size + 1}, size: {len(batch)}")
        with torch.no_grad():
            scores = model.compute_score(sentence_pairs, max_length=1024)
        
        print(f"Scores type: {type(scores)}")
        if isinstance(scores, list):
            print(f"Scores length: {len(scores)}")
            all_scores.extend(scores)
        else:
            print(f"Scores shape: {scores.shape}")
            all_scores.extend(scores.squeeze().tolist())
    
    print(f"Total scores: {len(all_scores)}")
    
    # Combine scores with original results and sort
    for result, score in zip(results[:len(all_scores)], all_scores):
        result['new_score'] = score
    
    reranked_results = sorted(results[:len(all_scores)], key=lambda x: x['new_score'], reverse=True)
    return reranked_results
# ... (keep the existing parse_results function)
def parse_results(results_string):
    parsed_results = []
    for line in results_string.strip().split('\n'):
        parts = line.split()
        parsed_results.append({
            'query_id': parts[0],
            'q0': parts[1],
            'doc_id': parts[2],
            'rank': int(parts[3]),
            'score': float(parts[4]),
            'run_id': parts[5]
        })
    return parsed_results


def rerank_and_save(input_file, output_file, num_queries=3):
    with open(input_file, 'r') as f:
        input_data = f.read()
    
    results = parse_results(input_data)
    
    # Group results by query_id
    grouped_results = {}
    for result in results:
        query_id = result['query_id']
        if query_id not in grouped_results:
            grouped_results[query_id] = []
        grouped_results[query_id].append(result)
    
    # Rerank a subset of queries
    reranked_groups = {}
    for query_id, group in list(grouped_results.items())[:num_queries]:
        print(f"\nProcessing query_id: {query_id}")
        query = f"Query for {query_id}"  # Replace with actual query if available
        reranked_groups[query_id] = rerank_results(query, group)
    
    # Write reranked results
    with open(output_file, 'w') as f:
        for query_id, reranked_results in reranked_groups.items():
            for new_rank, result in enumerate(reranked_results, start=1):
                f.write(f"{query_id} Q0 {result['doc_id']} {new_rank} {result['new_score']:.8f} cross-encoder\n")


In [17]:
input_file = "test_results.txt"
output_file = "reranked_results.txt"
rerank_and_save(input_file, output_file, num_queries=3)


Processing query_id: 2024-145979
Reranking for query: Query for 2024-145979
Number of results: 125


Gathering documents: 100%|██████████| 125/125 [00:46<00:00,  2.66it/s]


Number of valid documents: 125
Processing batch 1, size: 32
Scores type: <class 'list'>
Scores length: 32
Processing batch 2, size: 32
Scores type: <class 'list'>
Scores length: 32
Processing batch 3, size: 32
Scores type: <class 'list'>
Scores length: 32
Processing batch 4, size: 29
Scores type: <class 'list'>
Scores length: 29
Total scores: 125

Processing query_id: 2024-36935
Reranking for query: Query for 2024-36935
Number of results: 125


Gathering documents: 100%|██████████| 125/125 [00:36<00:00,  3.46it/s]


Number of valid documents: 125
Processing batch 1, size: 32
Scores type: <class 'list'>
Scores length: 32
Processing batch 2, size: 32
Scores type: <class 'list'>
Scores length: 32
Processing batch 3, size: 32
Scores type: <class 'list'>
Scores length: 32
Processing batch 4, size: 29
Scores type: <class 'list'>
Scores length: 29
Total scores: 125

Processing query_id: 2024-216592
Reranking for query: Query for 2024-216592
Number of results: 125


Gathering documents: 100%|██████████| 125/125 [00:37<00:00,  3.31it/s]


Number of valid documents: 125
Processing batch 1, size: 32
Scores type: <class 'list'>
Scores length: 32
Processing batch 2, size: 32
Scores type: <class 'list'>
Scores length: 32
Processing batch 3, size: 32
Scores type: <class 'list'>
Scores length: 32
Processing batch 4, size: 29
Scores type: <class 'list'>
Scores length: 29
Total scores: 125
