# RAG Retrieval Pipeline

This notebook replicates the functionality of `retrieval.ps1` for executing retrieval tasks.
It performs document retrieval using BM25 or BGE-M3 retrievers and evaluates the results.

## 1. Import Packages

In [1]:
import os
import json
import datetime
import random
from tqdm import tqdm
from loguru import logger
import pandas as pd
import numpy as np

# Import custom modules
from src.datasets.dataset import get_task_datasets
from src.llms import Mock
from src.tasks.retrieval import RetrievalTask
from src.retrievers import CustomBM25Retriever, CustomBGEM3Retriever
from src.embeddings.base import HuggingfaceEmbeddings

  from .autonotebook import tqdm as notebook_tqdm


## 2. Configuration Parameters

In [2]:
# Configuration parameters (matching retrieval.ps1)
config = {
    'ocr_type': 'gt',  # OCR type: 'gt', 'paddleocr', etc.
    'retriever_type': 'bm25',  # Retriever type: 'bm25' or 'bge-m3'
    'model_name': 'mock',
    'retrieve_top_k': 2,
    'data_path': 'data/qas_v2_clean.json',  # Using cleaned data - run 'python clean_data.py' first
    'docs_path': None,  # Will be set based on ocr_type
    'task': 'Retrieval',
    'evaluation_stage': 'retrieval',
    'num_threads': 1,
    'show_progress_bar': True,
    'output_path': './output',
    'chunk_size': 1024,
    'chunk_overlap': 0
}

# Set docs_path based on ocr_type
config['docs_path'] = f"data/retrieval_base/{config['ocr_type']}"

print("Configuration:")
for key, value in config.items():
    print(f"  {key}: {value}")

Configuration:
  ocr_type: gt
  retriever_type: bm25
  model_name: mock
  retrieve_top_k: 2
  data_path: data/qas_v2_clean.json
  docs_path: data/retrieval_base/gt
  task: Retrieval
  evaluation_stage: retrieval
  num_threads: 1
  show_progress_bar: True
  output_path: ./output
  chunk_size: 1024
  chunk_overlap: 0


## 3. Set Random Seed

In [3]:
def setup_seed(seed=0):
    """Set random seed for reproducibility."""
    import torch
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

setup_seed(0)
print("Random seed set to 0")

Random seed set to 0


## 4. Load Dataset

In [4]:
# Load the dataset
print(f"Loading dataset from {config['data_path']}...")
datasets = get_task_datasets(config['data_path'], config['task'])
dataset = datasets[0]
print(f"Loaded {len(dataset)} data points")

# Display a sample data point
if len(dataset) > 0:
    print("\nSample data point:")
    sample = dataset[0]
    for key, value in sample.items():
        if isinstance(value, str) and len(value) > 100:
            print(f"  {key}: {value[:100]}...")
        else:
            print(f"  {key}: {value}")

Loading dataset from data/qas_v2_clean.json...
Loaded 7481 data points

Sample data point:
  doc_name: finance/JPMORGAN_2021Q1_10Q
  ID: 00073cc2-c801-467c-9039-fca63c78c6a9
  questions: What was the total amount of nonaccrual loans retained as of March 31, 2021?
  answers: 842
  doc_type: finance
  answer_form: Numeric
  evidence_source: table
  evidence_context: Nonaccrual loans retained $^{(\mathrm{a})}$ & \$ & 842 & \$ & 689 & $22 \%$
  evidence_page_no: 24


## 5. Initialize Model and Retriever

In [5]:
# Initialize the mock LLM (not used in retrieval stage but required by the pipeline)
llm = Mock()
print("Initialized Mock LLM")

# Initialize the retriever based on configuration
print(f"\nInitializing {config['retriever_type']} retriever...")
if config['retriever_type'] == "bge-m3":
    embed_model = HuggingfaceEmbeddings(model_name="BAAI/bge-m3")
    retriever = CustomBGEM3Retriever(
        config['docs_path'], 
        embed_model=embed_model, 
        embed_dim=1024,
        chunk_size=config['chunk_size'], 
        chunk_overlap=config['chunk_overlap'], 
        similarity_top_k=config['retrieve_top_k']
    )
elif config['retriever_type'] == "bm25":
    retriever = CustomBM25Retriever(
        config['docs_path'], 
        chunk_size=config['chunk_size'], 
        chunk_overlap=config['chunk_overlap'], 
        similarity_top_k=config['retrieve_top_k']
    )
else:
    raise ValueError(f"Unsupported retriever type: {config['retriever_type']}")

print(f"Retriever initialized successfully")

Initialized Mock LLM

Initializing bm25 retriever...


Parsing nodes: 100%|██████████| 1011/1011 [00:01<00:00, 588.33it/s]
Parsing nodes: 100%|██████████| 1341/1341 [00:06<00:00, 216.05it/s]
Parsing nodes: 100%|██████████| 2133/2133 [00:05<00:00, 408.20it/s]
Parsing nodes: 100%|██████████| 1187/1187 [00:01<00:00, 857.59it/s]
Parsing nodes: 100%|██████████| 1724/1724 [00:01<00:00, 1493.76it/s]
Parsing nodes: 100%|██████████| 487/487 [00:02<00:00, 205.14it/s]
Parsing nodes: 100%|██████████| 288/288 [00:00<00:00, 1892.83it/s]
Parsing nodes: 100%|██████████| 204/204 [00:00<00:00, 304.32it/s]
Parsing nodes: 100%|██████████| 679/679 [00:00<00:00, 994.32it/s] 


Indexing finished for all directories!
Retriever initialized successfully


## 6. Initialize Retrieval Task

In [6]:
# Initialize the retrieval task
output_dir = os.path.join(config['output_path'], config['evaluation_stage'], config['ocr_type'])
task = RetrievalTask(output_dir=output_dir)
task.set_model(llm, retriever)
print(f"Retrieval task initialized with output directory: {output_dir}")

Retrieval task initialized with output directory: ./output\retrieval\gt


## 7. Execute Retrieval Pipeline

In [7]:
# Process each data point
results = []

print(f"\nProcessing {len(dataset)} data points...")
for data_point in tqdm(dataset, desc="Retrieving", disable=not config['show_progress_bar']):
    try:
        # Perform retrieval
        retrieval_results = task.retrieve_docs(data_point)
        data_point["retrieval_results"] = retrieval_results
        
        # Score the retrieval
        result = {'id': data_point['ID'], **task.scoring(data_point)}
        results.append(result)
        
    except Exception as e:
        logger.warning(f"Error processing data point {data_point.get('ID', 'unknown')}: {e}")
        data_point["retrieval_results"] = []
        result = {'id': data_point['ID'], **task.scoring(data_point)}
        results.append(result)

print(f"\nProcessed {len(results)} data points")


Processing 7481 data points...


Retrieving: 100%|██████████| 7481/7481 [05:22<00:00, 23.21it/s] 


Processed 7481 data points





## 8. Compute Overall Metrics

In [8]:
# Filter valid results
valid_results = [result for result in results if result['valid']]
print(f"Valid results: {len(valid_results)} out of {len(results)}")

# Compute overall metrics
if len(valid_results) > 0:
    overall = task.compute_overall(valid_results)
    print("\nOverall Metrics:")
    for key, value in overall.items():
        print(f"  {key}: {value:.4f}" if isinstance(value, float) else f"  {key}: {value}")
else:
    overall = {}
    print("No valid results to compute metrics")

Valid results: 7481 out of 7481

Overall Metrics:
  avg. lcs: 0.8031
  num: 7481


## 9. Save Results

In [9]:
# Prepare output
info = {
    'task': task.__class__.__name__, 
    'retriever': retriever.__class__.__name__,
    'ocr_type': config['ocr_type'],
    'retrieve_top_k': config['retrieve_top_k'],
    'chunk_size': config['chunk_size'],
    'chunk_overlap': config['chunk_overlap']
}

output = {
    'info': info,
    'overall': overall,
    'results': results
}

# Define output path
ret_name = {
    "CustomBM25Retriever": "bm25",
    "CustomBGEM3Retriever": "bge-m3"
}[retriever.__class__.__name__]

os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, f'all_{ret_name}_top{config["retrieve_top_k"]}.json')

# Save to JSON file
with open(output_path, 'w', encoding='utf-8') as f:
    json.dump(output, f, ensure_ascii=False, indent=4)

print(f"\nResults saved to: {output_path}")


Results saved to: ./output\retrieval\gt\all_bm25_top2.json


## 10. Display Results Summary

In [10]:
# Create a DataFrame for visualization
results_df = []

# Load QA data for additional context
with open(config['data_path'], 'r', encoding='utf-8') as f:
    qa_dict = {item['ID']: item for item in json.load(f)}

for result in results:
    if result['id'] in qa_dict:
        qa_item = qa_dict[result['id']]
        results_df.append({
            'id': result['id'],
            'ocr_type': config['ocr_type'],
            'retriever': ret_name,
            'domain': qa_item.get('doc_type', ''),
            'doc_name': qa_item.get('doc_name', '').split('/')[-1],
            'evidence_source': qa_item.get('evidence_source', ''),
            'answer_form': qa_item.get('answer_form', ''),
            'lcs': result['metrics']['lcs'],
            'valid': result['valid']
        })

df = pd.DataFrame(results_df)
print(f"\nResults DataFrame shape: {df.shape}")
display(df.head(10))


Results DataFrame shape: (7481, 9)


Unnamed: 0,id,ocr_type,retriever,domain,doc_name,evidence_source,answer_form,lcs,valid
0,00073cc2-c801-467c-9039-fca63c78c6a9,gt,bm25,finance,JPMORGAN_2021Q1_10Q,table,Numeric,0.0,True
1,000b6710-f8b4-4dd4-9913-90c7d424fccf,gt,bm25,finance,JPMORGAN_2021Q1_10Q,table,Numeric,0.0,True
2,00183cfe-ceb0-4220-b984-f33f61c61ae4,gt,bm25,finance,JPMORGAN_2021Q1_10Q,table,Numeric,0.0,True
3,002f9cc4-096b-4aff-b5b7-751f497e28aa,gt,bm25,finance,JPMORGAN_2021Q1_10Q,table,Numeric,1.0,True
4,003c6ab8-2d19-4cf0-8d43-8259815f9e34,gt,bm25,finance,JPMORGAN_2021Q1_10Q,table,Numeric,0.125,True
5,0042d740-0c34-439f-ad44-e0f06a9e72f8,gt,bm25,finance,JPMORGAN_2021Q1_10Q,table,Numeric,1.0,True
6,004cdaf0-0ed9-4a32-8f0f-a9db4b6a3fea,gt,bm25,finance,JPMORGAN_2021Q1_10Q,text,Numeric,1.0,True
7,0068eeac-7cfb-49f6-8de0-4a849afe5363,gt,bm25,finance,DUDE_026e416e05d6efc5f061a2165fd827c3,text,String,0.925926,True
8,006baa01-fdbc-46e7-8734-baefc2e4866f,gt,bm25,finance,JPMORGAN_2021Q1_10Q,table,Numeric,0.0,True
9,007b0a78-f278-4163-9312-8e5cbea3351d,gt,bm25,finance,JPMORGAN_2021Q1_10Q,table,Numeric,1.0,True


## 11. Results Analysis by Evidence Source

In [11]:
# Analyze results by evidence source
if len(df) > 0:
    df['lcs_percent'] = df['lcs'] * 100
    
    # Group by evidence source
    evidence_summary = df.groupby('evidence_source').agg({
        'lcs_percent': ['mean', 'count'],
        'valid': 'sum'
    }).round(2)
    
    print("\nResults by Evidence Source:")
    display(evidence_summary)
    
    # Overall average
    overall_avg = df['lcs_percent'].mean()
    print(f"\nOverall Average LCS: {overall_avg:.2f}%")
    print(f"Total Valid Results: {df['valid'].sum()} / {len(df)}")


Results by Evidence Source:


Unnamed: 0_level_0,lcs_percent,lcs_percent,valid
Unnamed: 0_level_1,mean,count,sum
evidence_source,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2
chart,71.1,747,747
formula,80.61,1142,1142
multi,66.42,126,126
reading_order,76.92,52,52
table,75.9,2053,2053
text,85.54,3361,3361



Overall Average LCS: 80.31%
Total Valid Results: 7481 / 7481


## 12. Results Analysis by Domain

In [12]:
# Analyze results by domain
if len(df) > 0:
    domain_summary = df.groupby('domain').agg({
        'lcs_percent': ['mean', 'count'],
        'valid': 'sum'
    }).round(2)
    
    print("\nResults by Domain:")
    display(domain_summary)


Results by Domain:


Unnamed: 0_level_0,lcs_percent,lcs_percent,valid
Unnamed: 0_level_1,mean,count,sum
domain,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2
academic,80.63,1150,1150
administration,84.8,1322,1322
finance,62.29,1365,1365
law,85.81,1142,1142
manual,84.54,1107,1107
news,87.85,546,546
textbook,84.13,849,849


## 13. Sample Retrieval Results

In [13]:
# Display sample retrieval results
print("\nSample Retrieval Results (Top 3):")
for i, result in enumerate(results[:3]):
    print(f"\n{'='*80}")
    print(f"Result {i+1} - ID: {result['id']}")
    print(f"Valid: {result['valid']}")
    print(f"LCS Score: {result['metrics']['lcs']:.4f}")
    
    log = result['log']
    print(f"\nQuestion: {log['quest']}")
    print(f"\nEvidence Source: {log['evidence_source']}")
    
    if 'retrieval_context' in log and len(log['retrieval_context']) > 0:
        print(f"\nRetrieved {len(log['retrieval_context'])} documents:")
        for j, doc in enumerate(log['retrieval_context']):
            print(f"\n  Document {j+1}:")
            print(f"    File: {doc.get('file_name', 'N/A')}")
            print(f"    Page: {doc.get('page_idx', 'N/A')}")
            text = doc.get('text', '')
            print(f"    Text: {text[:200]}..." if len(text) > 200 else f"    Text: {text}")
    else:
        print("\nNo documents retrieved")
    
    gt_context = log['ground_truth_context']
    print(f"\nGround Truth Context: {gt_context[:200]}..." if len(gt_context) > 200 else f"\nGround Truth Context: {gt_context}")


Sample Retrieval Results (Top 3):

Result 1 - ID: 00073cc2-c801-467c-9039-fca63c78c6a9
Valid: True
LCS Score: 0.0000

Question: What was the total amount of nonaccrual loans retained as of March 31, 2021?

Evidence Source: table

Retrieved 2 documents:

  Document 1:
    File: JPMORGAN_2021Q1_10Q
    Page: 28
    Text: Selected metrics (continued)
\begin{tabular}{|c|c|c|c|c|c|}
  \multirow[b]{2}{*}{(in millions, except ratios)} & \multicolumn{5}{|c|}{As of or for the three months ended March 31,} \\
  & \multicolumn...

  Document 2:
    File: JPMORGAN_2021Q1_10Q
    Page: 49
    Text: \begin{tabular}{|c|c|c|c|c|}
  (in millions) & & $$\operatorname{arch} 31, 2021$$ & \multicolumn{2}{|l|}{$$\text { December 31, } 2020$$} \\
  Retained loans ${ }^{(2)}$ & \$ & 14,943 & \$ & 15,406 \\...

Ground Truth Context: Nonaccrual loans retained $^{(\mathrm{a})}$ & \$ & 842 & \$ & 689 & $22 \%$

Result 2 - ID: 000b6710-f8b4-4dd4-9913-90c7d424fccf
Valid: True
LCS Score: 0.0000

Question: By what p