# Open-Domain QA: Train & Inference Walkthrough


This notebook shows the minimal commands to prepare data, install dependencies, train DPR (optional but recommended for hybrid retrieval), train the MRC model, and run ODQA inference to produce `predictions.json`.


> Run cells top-to-bottom. Adjust hyperparameters/paths as needed.

In [None]:
# Setup environment (run this once at the start)
import os
import sys

print(f"Working directory: {os.getcwd()}")
print(f"Using Python: {sys.executable}")
print(f"Python version: {sys.version}")

In [None]:
# Environment setup (install dependencies)
!pip install -r requirements.txt

In [None]:
# Verify data layout (expects ../data with train/test + wikipedia_documents.json)
!ls -l ../data || (cd .. && tar -xzf data.tar.gz && ls -l data)

In [None]:
!python retrieval/DPR_train.py \
  --model_name_or_path klue/bert-base \
  --output_dir ./models/dpr \
  --device cuda \
  --learning_rate 2e-5 \
  --num_epochs 10 \
  --batch_size 32 \
  --gradient_accumulation_steps 2 \
  --max_q_length 64 \
  --max_p_length 256 \
  --use_hard_negatives True \
  --num_neg 2 \
  --warmup_steps 500 \
  --save_steps 500 \
  --eval_steps 500 \
  --overwrite_output_dir

In [None]:
# Train MRC model (with retrieval-enabled eval)
!python train.py \
  --output_dir ./models/train_dataset \
  --do_train --do_eval \
  --overwrite_output_dir

## Test Retrieval Performance

Test and compare different retrieval methods (Sparse BM25, Dense DPR, Hybrid) on the validation set to see which performs best.

In [None]:
# Test retrieval performance on validation set
from datasets import load_from_disk
from transformers import AutoTokenizer
from retrieval.Sparse_retrieval import SparseRetrieval
from retrieval.Dense_retrieval import DenseRetrieval
from retrieval.retrieval import Retrieval
import pandas as pd

# Load validation dataset
print("Loading validation dataset...")
datasets = load_from_disk("../data/train_dataset")
val_dataset = datasets["validation"]
print(f"Validation set size: {len(val_dataset)}")

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("klue/bert-base", use_fast=True)

In [None]:
# Test 1: Sparse Retrieval (BM25) only
print("\n" + "="*60)
print("Testing Sparse Retrieval (BM25)")
print("="*60)

sparse_retriever = SparseRetrieval(
    tokenize_fn=tokenizer.tokenize,
    data_path="../data",
    context_path="wikipedia_documents.json",
)

# Retrieve top-k passages for each question
sparse_results = sparse_retriever.retrieve(val_dataset, topk=10)

# Calculate accuracy (how many times the correct context was retrieved)
if "original_context" in sparse_results.columns and "context" in sparse_results.columns:
    # Check if original context is anywhere in the retrieved contexts
    sparse_results["correct"] = sparse_results.apply(
        lambda row: row["original_context"] in row["context"], axis=1
    )
    sparse_accuracy = sparse_results["correct"].sum() / len(sparse_results)
    print(f"Sparse Retrieval Accuracy: {sparse_accuracy:.2%}")
    print(f"Correctly retrieved: {sparse_results['correct'].sum()}/{len(sparse_results)}")
else:
    print("Cannot calculate accuracy - missing ground truth context")

In [None]:
# Test 2: Dense Retrieval (DPR) only - requires trained encoders
print("\n" + "="*60)
print("Testing Dense Retrieval (DPR)")
print("="*60)

try:
    dense_retriever = DenseRetrieval(
        model_name_or_path="klue/bert-base",
        data_path="../data",
        context_path="wikipedia_documents.json",
        q_encoder_path="./models/dpr/q_encoder",
        p_encoder_path="./models/dpr/p_encoder",
    )
    
    # Build dense embeddings
    print("Building dense embeddings...")
    dense_retriever.get_dense_embedding()
    print(f"Dense embeddings shape: {dense_retriever.passage_embeddings.shape if hasattr(dense_retriever, 'passage_embeddings') else 'Not computed'}")
    
    # Retrieve using dense method
    print("Retrieving passages...")
    dense_results = dense_retriever.retrieve(val_dataset, topk=10)
    
    # Calculate accuracy
    if "original_context" in dense_results.columns and "context" in dense_results.columns:
        dense_results["correct"] = dense_results.apply(
            lambda row: row["original_context"] in row["context"], axis=1
        )
        dense_accuracy = dense_results["correct"].sum() / len(dense_results)
        print(f"Dense Retrieval Accuracy: {dense_accuracy:.2%}")
        print(f"Correctly retrieved: {dense_results['correct'].sum()}/{len(dense_results)}")
        
        # Show sample results
        print("\nSample DPR Retrieval Results (first 3 examples):")
        for idx in range(min(3, len(dense_results))):
            print(f"\nExample {idx + 1}:")
            print(f"  Question: {dense_results.iloc[idx]['question'][:80]}...")
            print(f"  Original context found: {dense_results.iloc[idx]['correct']}")
            print(f"  Retrieved passage preview: {dense_results.iloc[idx]['context'][:100]}...")
    else:
        print("Cannot calculate accuracy - missing ground truth context")
        
except Exception as e:
    print(f"Dense retrieval failed: {e}")
    print("Make sure DPR encoders are trained (run DPR training cell first)")
    import traceback
    traceback.print_exc()

In [None]:
# Test 3: Hybrid Retrieval (BM25 + DPR)
print("\n" + "="*60)
print("Testing Hybrid Retrieval (BM25 + DPR)")
print("="*60)

try:
    hybrid_retriever = Retrieval(
        tokenize_fn=tokenizer.tokenize,
        data_path="../data",
        context_path="wikipedia_documents.json",
        use_sparse=True,
        use_dense=True,
        dense_model_path="klue/bert-base",
        q_encoder_path="./models/dpr/q_encoder",
        p_encoder_path="./models/dpr/p_encoder",
        sparse_weight=0.5,
        dense_weight=0.5,
    )
    
    # Retrieve using hybrid method
    hybrid_results = hybrid_retriever.retrieve(val_dataset, topk=10)
    
    # Calculate accuracy
    if "original_context" in hybrid_results.columns and "context" in hybrid_results.columns:
        hybrid_results["correct"] = hybrid_results.apply(
            lambda row: row["original_context"] in row["context"], axis=1
        )
        hybrid_accuracy = hybrid_results["correct"].sum() / len(hybrid_results)
        print(f"Hybrid Retrieval Accuracy: {hybrid_accuracy:.2%}")
        print(f"Correctly retrieved: {hybrid_results['correct'].sum()}/{len(hybrid_results)}")
    else:
        print("Cannot calculate accuracy - missing ground truth context")
        
except Exception as e:
    print(f"Hybrid retrieval failed: {e}")
    print("Make sure DPR encoders are trained (run DPR training cell first)")

In [None]:
# Compare results
print("\n" + "="*60)
print("RETRIEVAL PERFORMANCE COMPARISON")
print("="*60)

comparison_data = []

# Add results if available
if 'sparse_accuracy' in locals():
    comparison_data.append({
        'Method': 'Sparse (BM25)',
        'Accuracy': f"{sparse_accuracy:.2%}",
        'Correct': f"{sparse_results['correct'].sum()}/{len(sparse_results)}"
    })

if 'dense_accuracy' in locals():
    comparison_data.append({
        'Method': 'Dense (DPR)',
        'Accuracy': f"{dense_accuracy:.2%}",
        'Correct': f"{dense_results['correct'].sum()}/{len(dense_results)}"
    })

if 'hybrid_accuracy' in locals():
    comparison_data.append({
        'Method': 'Hybrid (BM25+DPR)',
        'Accuracy': f"{hybrid_accuracy:.2%}",
        'Correct': f"{hybrid_results['correct'].sum()}/{len(hybrid_results)}"
    })

if comparison_data:
    comparison_df = pd.DataFrame(comparison_data)
    print(comparison_df.to_string(index=False))
    print("\nâœ“ Higher accuracy means the retrieval method finds the correct context more often")
else:
    print("No results to compare. Run the test cells above first.")

In [None]:
# Run ODQA inference on test set (produces predictions.json in output_dir)
!python inference.py \
  --output_dir ./outputs/test_dataset \
  --dataset_name ../data/test_dataset \
  --model_name_or_path ./models/train_dataset \
  --do_predict \
  --eval_retrieval True \
  --top_k_retrieval 20 \
  --overwrite_output_dir

## Notes

- Ensure GPU is available for DPR/MRC; set `--device cpu` if needed (slower).

- `DPR_train.py` uses BM25 hard negatives; requires `rank_bm25` installed (in requirements).

- If you change retrieval code, delete cached `sparse_embedding.bin` / `tfidfv.bin` and rerun.

- `predictions.json` will be saved to `./outputs/test_dataset` in the inference step.