# RAG Retrieval Pipeline

This notebook replicates the functionality of `retrieval.ps1` for executing retrieval tasks.
It performs document retrieval using BM25 or BGE-M3 retrievers and evaluates the results.

## 1. Import Packages

In [None]:
import os
import json
import datetime
import random
from tqdm import tqdm
from loguru import logger
import pandas as pd
import numpy as np

# Import custom modules
from src.datasets.dataset import get_task_datasets
from src.llms import Mock
from src.tasks.retrieval import RetrievalTask
from src.retrievers import CustomBM25Retriever, CustomBGEM3Retriever
from src.embeddings.base import HuggingfaceEmbeddings

## 2. Configuration Parameters

In [None]:
# Configuration parameters (matching retrieval.ps1)
config = {
    'ocr_type': 'gt',  # OCR type: 'gt', 'paddleocr', etc.
    'retriever_type': 'bm25',  # Retriever type: 'bm25' or 'bge-m3'
    'model_name': 'mock',
    'retrieve_top_k': 2,
    'data_path': 'data/qas_v2.json',
    'docs_path': None,  # Will be set based on ocr_type
    'task': 'Retrieval',
    'evaluation_stage': 'retrieval',
    'num_threads': 1,
    'show_progress_bar': True,
    'output_path': './output',
    'chunk_size': 1024,
    'chunk_overlap': 0
}

# Set docs_path based on ocr_type
config['docs_path'] = f"data/retrieval_base/{config['ocr_type']}"

print("Configuration:")
for key, value in config.items():
    print(f"  {key}: {value}")

## 3. Set Random Seed

In [None]:
def setup_seed(seed=0):
    """Set random seed for reproducibility."""
    import torch
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

setup_seed(0)
print("Random seed set to 0")

## 4. Load Dataset

In [None]:
# Load the dataset
print(f"Loading dataset from {config['data_path']}...")
datasets = get_task_datasets(config['data_path'], config['task'])
dataset = datasets[0]
print(f"Loaded {len(dataset)} data points")

# Display a sample data point
if len(dataset) > 0:
    print("\nSample data point:")
    sample = dataset[0]
    for key, value in sample.items():
        if isinstance(value, str) and len(value) > 100:
            print(f"  {key}: {value[:100]}...")
        else:
            print(f"  {key}: {value}")

## 5. Initialize Model and Retriever

In [None]:
# Initialize the mock LLM (not used in retrieval stage but required by the pipeline)
llm = Mock()
print("Initialized Mock LLM")

# Initialize the retriever based on configuration
print(f"\nInitializing {config['retriever_type']} retriever...")
if config['retriever_type'] == "bge-m3":
    embed_model = HuggingfaceEmbeddings(model_name="BAAI/bge-m3")
    retriever = CustomBGEM3Retriever(
        config['docs_path'], 
        embed_model=embed_model, 
        embed_dim=1024,
        chunk_size=config['chunk_size'], 
        chunk_overlap=config['chunk_overlap'], 
        similarity_top_k=config['retrieve_top_k']
    )
elif config['retriever_type'] == "bm25":
    retriever = CustomBM25Retriever(
        config['docs_path'], 
        chunk_size=config['chunk_size'], 
        chunk_overlap=config['chunk_overlap'], 
        similarity_top_k=config['retrieve_top_k']
    )
else:
    raise ValueError(f"Unsupported retriever type: {config['retriever_type']}")

print(f"Retriever initialized successfully")

## 6. Initialize Retrieval Task

In [None]:
# Initialize the retrieval task
output_dir = os.path.join(config['output_path'], config['evaluation_stage'], config['ocr_type'])
task = RetrievalTask(output_dir=output_dir)
task.set_model(llm, retriever)
print(f"Retrieval task initialized with output directory: {output_dir}")

## 7. Execute Retrieval Pipeline

In [None]:
# Process each data point
results = []

print(f"\nProcessing {len(dataset)} data points...")
for data_point in tqdm(dataset, desc="Retrieving", disable=not config['show_progress_bar']):
    try:
        # Perform retrieval
        retrieval_results = task.retrieve_docs(data_point)
        data_point["retrieval_results"] = retrieval_results
        
        # Score the retrieval
        result = {'id': data_point['ID'], **task.scoring(data_point)}
        results.append(result)
        
    except Exception as e:
        logger.warning(f"Error processing data point {data_point.get('ID', 'unknown')}: {e}")
        data_point["retrieval_results"] = []
        result = {'id': data_point['ID'], **task.scoring(data_point)}
        results.append(result)

print(f"\nProcessed {len(results)} data points")

## 8. Compute Overall Metrics

In [None]:
# Filter valid results
valid_results = [result for result in results if result['valid']]
print(f"Valid results: {len(valid_results)} out of {len(results)}")

# Compute overall metrics
if len(valid_results) > 0:
    overall = task.compute_overall(valid_results)
    print("\nOverall Metrics:")
    for key, value in overall.items():
        print(f"  {key}: {value:.4f}" if isinstance(value, float) else f"  {key}: {value}")
else:
    overall = {}
    print("No valid results to compute metrics")

## 9. Save Results

In [None]:
# Prepare output
info = {
    'task': task.__class__.__name__, 
    'retriever': retriever.__class__.__name__,
    'ocr_type': config['ocr_type'],
    'retrieve_top_k': config['retrieve_top_k'],
    'chunk_size': config['chunk_size'],
    'chunk_overlap': config['chunk_overlap']
}

output = {
    'info': info,
    'overall': overall,
    'results': results
}

# Define output path
ret_name = {
    "CustomBM25Retriever": "bm25",
    "CustomBGEM3Retriever": "bge-m3"
}[retriever.__class__.__name__]

os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, f'all_{ret_name}_top{config["retrieve_top_k"]}.json')

# Save to JSON file
with open(output_path, 'w', encoding='utf-8') as f:
    json.dump(output, f, ensure_ascii=False, indent=4)

print(f"\nResults saved to: {output_path}")

## 10. Display Results Summary

In [None]:
# Create a DataFrame for visualization
results_df = []

# Load QA data for additional context
with open(config['data_path'], 'r', encoding='utf-8') as f:
    qa_dict = {item['ID']: item for item in json.load(f)}

for result in results:
    if result['id'] in qa_dict:
        qa_item = qa_dict[result['id']]
        results_df.append({
            'id': result['id'],
            'ocr_type': config['ocr_type'],
            'retriever': ret_name,
            'domain': qa_item.get('doc_type', ''),
            'doc_name': qa_item.get('doc_name', '').split('/')[-1],
            'evidence_source': qa_item.get('evidence_source', ''),
            'answer_form': qa_item.get('answer_form', ''),
            'lcs': result['metrics']['lcs'],
            'valid': result['valid']
        })

df = pd.DataFrame(results_df)
print(f"\nResults DataFrame shape: {df.shape}")
display(df.head(10))

## 11. Results Analysis by Evidence Source

In [None]:
# Analyze results by evidence source
if len(df) > 0:
    df['lcs_percent'] = df['lcs'] * 100
    
    # Group by evidence source
    evidence_summary = df.groupby('evidence_source').agg({
        'lcs_percent': ['mean', 'count'],
        'valid': 'sum'
    }).round(2)
    
    print("\nResults by Evidence Source:")
    display(evidence_summary)
    
    # Overall average
    overall_avg = df['lcs_percent'].mean()
    print(f"\nOverall Average LCS: {overall_avg:.2f}%")
    print(f"Total Valid Results: {df['valid'].sum()} / {len(df)}")

## 12. Results Analysis by Domain

In [None]:
# Analyze results by domain
if len(df) > 0:
    domain_summary = df.groupby('domain').agg({
        'lcs_percent': ['mean', 'count'],
        'valid': 'sum'
    }).round(2)
    
    print("\nResults by Domain:")
    display(domain_summary)

## 13. Sample Retrieval Results

In [None]:
# Display sample retrieval results
print("\nSample Retrieval Results (Top 3):")
for i, result in enumerate(results[:3]):
    print(f"\n{'='*80}")
    print(f"Result {i+1} - ID: {result['id']}")
    print(f"Valid: {result['valid']}")
    print(f"LCS Score: {result['metrics']['lcs']:.4f}")
    
    log = result['log']
    print(f"\nQuestion: {log['quest']}")
    print(f"\nEvidence Source: {log['evidence_source']}")
    
    if 'retrieval_context' in log and len(log['retrieval_context']) > 0:
        print(f"\nRetrieved {len(log['retrieval_context'])} documents:")
        for j, doc in enumerate(log['retrieval_context']):
            print(f"\n  Document {j+1}:")
            print(f"    File: {doc.get('file_name', 'N/A')}")
            print(f"    Page: {doc.get('page_idx', 'N/A')}")
            text = doc.get('text', '')
            print(f"    Text: {text[:200]}..." if len(text) > 200 else f"    Text: {text}")
    else:
        print("\nNo documents retrieved")
    
    gt_context = log['ground_truth_context']
    print(f"\nGround Truth Context: {gt_context[:200]}..." if len(gt_context) > 200 else f"\nGround Truth Context: {gt_context}")