# Tutorial 3 â€” Reranking (Two-Stage Retrieval)

This tutorial keeps semantic chunking and dense retrieval, then adds a second-stage reranker.

```mermaid
flowchart LR
    A[Query] --> B[Dense Retriever Top-10]
    B --> C[Cross-Encoder Reranker]
    C --> D[Reordered Top-5]
    D --> E[LLM Answer]
```

## Learning checkpoint: what reranking solves and what remains

**What works better in Tutorial 3**
- Candidate chunks are re-ordered with stronger query awareness.
- More relevant context should move closer to rank 1.
- Answer quality typically improves when top context is cleaner.

**Challenges you should observe**
- Latency increases due to second-stage scoring.
- Reranking cannot recover chunks never retrieved in first-pass.
- Exact lexical matches can still be missed if dense retrieval under-recalls them.

**Why move to Tutorial 4**
- We now need stronger recall for exact terms and identifiers.
- Next, we combine dense retrieval with keyword retrieval (hybrid) for coverage + precision.

In [None]:
# 1-5) Setup, load handbook text, chunk, embed, index

import importlib
import os
from pathlib import Path
import shutil
import subprocess
import sys

import pandas as pd
from dotenv import load_dotenv

if shutil.which("uv") is None:
    print("uv not found. Installing with pip...")
    subprocess.run([sys.executable, "-m", "pip", "install", "uv"], check=True)

cwd = Path.cwd().resolve()
repo_root = next(
    (path for path in [cwd, *cwd.parents] if (path / "pyproject.toml").exists() and (path / "src").exists()),
    cwd,
)
os.chdir(repo_root)
src_path = repo_root / "src"
if str(src_path) not in sys.path:
    sys.path.insert(0, str(src_path))

REQUIRED_PACKAGES = ["openai", "chromadb", "numpy", "pandas", "rank_bm25", "sentence_transformers", "dotenv"]
PIP_NAME_MAP = {"rank_bm25": "rank-bm25", "sentence_transformers": "sentence-transformers", "dotenv": "python-dotenv"}

def find_missing(packages: list[str]) -> list[str]:
    importlib.invalidate_caches()
    return [pkg for pkg in packages if importlib.util.find_spec(pkg) is None]

missing = find_missing(REQUIRED_PACKAGES)
if missing:
    print("Missing packages:", missing)
    print("Running: uv sync")
    subprocess.run(["uv", "sync"], check=True)

missing_after_sync = find_missing(REQUIRED_PACKAGES)
if missing_after_sync:
    pip_targets = [PIP_NAME_MAP.get(pkg, pkg) for pkg in missing_after_sync]
    print("Installing into current kernel with pip:", pip_targets)
    subprocess.run([sys.executable, "-m", "pip", "install", *pip_targets], check=True)

final_missing = find_missing(REQUIRED_PACKAGES)
if final_missing:
    raise ImportError(f"Dependencies still missing in current kernel: {final_missing}")

from rag_tutorials.io_utils import load_handbook_documents, load_queries
from rag_tutorials.chunking import semantic_chunk_documents
from rag_tutorials.pipeline import build_dense_retriever
from rag_tutorials.reranking import LocalCrossEncoderReranker
from rag_tutorials.qa import answer_with_context
from rag_tutorials.evaluation import evaluate_single, summarize

load_dotenv()
if not os.getenv("OPENAI_API_KEY"):
    raise EnvironmentError("OPENAI_API_KEY is required")

embedding_model = os.getenv("OPENAI_EMBEDDING_MODEL", "text-embedding-3-small")
chat_model = os.getenv("OPENAI_CHAT_MODEL", "gpt-4.1-mini")

handbook_path = Path("data/handbook_manual.txt")
queries_path = Path("data/queries.jsonl")
if not handbook_path.exists() or not queries_path.exists():
    raise FileNotFoundError("Run: uv run python scripts/generate_data.py")

documents = load_handbook_documents(handbook_path)
queries = load_queries(queries_path)
chunks = semantic_chunk_documents(documents)

dense_retriever, _ = build_dense_retriever(
    chunks=chunks,
    collection_name="tutorial3_dense_semantic",
    embedding_model=embedding_model,
)
reranker = LocalCrossEncoderReranker()

In [None]:
# Chunk boundary visualization (same source text, different split strategies)

from rag_tutorials.chunking import fixed_chunk_documents

section_doc = next(doc for doc in documents if doc.section == "International Work")
fixed_view = [c.text for c in fixed_chunk_documents([section_doc], chunk_size=120)]
semantic_view = [c.text for c in semantic_chunk_documents([section_doc])]

print("Section:", section_doc.section)
print("\nFixed chunks:")
for idx, chunk_text in enumerate(fixed_view, start=1):
    print(f"[{idx}] {chunk_text}")

print("\nSemantic chunks:")
for idx, chunk_text in enumerate(semantic_view, start=1):
    print(f"[{idx}] {chunk_text}")

In [None]:
# 6) Retriever + reranker logic and novice score inspection

def retrieve_with_rerank(question: str, first_stage_k: int = 10, final_k: int = 5):
    first_pass = dense_retriever(question, top_k=first_stage_k)
    reranked = reranker.rerank(question, first_pass, top_k=final_k)
    return first_pass, reranked

probe = "What is the policy for working from another country?"
first_pass, reranked = retrieve_with_rerank(probe)

before_df = pd.DataFrame([
    {"rank": i + 1, "chunk_id": r.chunk_id, "dense_score": r.score, "preview": r.text[:90]}
    for i, r in enumerate(first_pass)
])
after_df = pd.DataFrame([
    {"rank": i + 1, "chunk_id": r.chunk_id, "rerank_score": r.score, "preview": r.text[:90]}
    for i, r in enumerate(reranked)
])

print("Before reranking")
display(before_df.head(10))
print("After reranking")
display(after_df.head(5))

In [None]:
# 7-8) Prompt + end-to-end RAG query

def rag_answer_reranked(question: str, top_k: int = 5):
    _, ranked = retrieve_with_rerank(question, first_stage_k=10, final_k=top_k)
    context = [r.text for r in ranked]
    answer = answer_with_context(question, context, model=chat_model)
    return answer, ranked

answer, ranked = rag_answer_reranked(probe)
print(answer)

In [None]:
# 9-10) Evaluation queries and debug output

def retrieval_fn(question: str):
    _, ranked = retrieve_with_rerank(question, first_stage_k=10, final_k=5)
    return ranked

rows = [
    evaluate_single(
        query=q,
        retrieval_fn=retrieval_fn,
        answer_fn=lambda question, context: answer_with_context(question, context, model=chat_model),
        top_k=5,
    )
    for q in queries[:20]
]

print("Tutorial 3 metrics:", summarize(rows))