# RAG Pipeline Testing & Fine-tuning Notebook

This notebook allows you to test the complete RAG pipeline:
1. PDF Extraction
2. Chunking (namespace-specific strategies)
3. Pinecone Ingestion
4. RAG Search & Retrieval
5. Agent Routing
6. End-to-End Chat with Citations

**Setup**: Run from `backend/` directory:
```bash
cd backend
jupyter notebook notebooks/test_rag_pipeline.ipynb
```

In [22]:
%pip install langchain_core
%pip install pinecone
%pip install dotenv
%pip install langchain_openai
%pip install supabase

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Collecting dotenv
  Downloading dotenv-0.9.9-py2.py3-none-any.whl.metadata (279 bytes)
Downloading dotenv-0.9.9-py2.py3-none-any.whl (1.9 kB)
Installing collected packages: dotenv
Successfully installed dotenv-0.9.9
Note: you may need to restart the kernel to use updated packages.


Note: you may need to restart the kernel to use updated packages.
Collecting supabase
  Using cached supabase-2.25.1-py3-none-any.whl.metadata (4.6 kB)
Collecting realtime==2.25.1 (from supabase)
  Using cached realtime-2.25.1-py3-none-any.whl.metadata (7.0 kB)
Collecting supabase-functions==2.25.1 (from supabase)
  Using cached supabase_functions-2.25.1-py3-none-any.whl.metadata (2.4 kB)
Collecting storage3==2.25.1 (from supabase)
  Using cached storage3-2.25.1-py3-none-any.whl.metadata (2.1 kB)
Collecting supabase-auth==2.25.1 (from supabase)
  Using cached supabase_auth-2.25.1-py3-none-any.whl.metadata (6.4 kB)
Collecting postgrest==2.25.1 (from supabase)
  Using cached postgrest-2.25.1-py3-none-any.whl.metadata (3.4 kB)
Collecting deprecation>=2.1.0 (from postgrest==2.25.1->supabase)
  Using cached deprecation-2.1.0-py2.py3-none-any.whl.metadata (4.6 kB)
Collecting strenum>=0.4.9 (from postgrest==2.25.1->supabase)
  Using cached StrEnum-0.4.15-py3-none-any.whl.metadata (5.3 kB)
Col

In [24]:
import numpy
import supabase
from supabase import create_client, Client

print(f"‚úÖ NumPy Version: {numpy.__version__} (Should be >= 1.23)")
print(f"‚úÖ Supabase Client found successfully")

ImportError: cannot import name 'create_client' from 'supabase' (unknown location)

## 1. Setup & Imports

In [12]:
# Add backend to path
import sys
from pathlib import Path

project_root = Path.cwd()
backend_dir = project_root / 'backend'

# Add to sys.path if it's not already there
if str(backend_dir) not in sys.path:
    sys.path.insert(0, str(backend_dir))

print(f"Added to path: {backend_dir}")

Added to path: /Users/dhairyapatel/Desktop/FYP/diabetes_fyp/backend


In [13]:
# Core imports
import os
import logging
from dotenv import load_dotenv
from pprint import pprint

# Load environment variables
load_dotenv(backend_dir / ".env")

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Verify environment variables
print("Environment Check:")
print(f"‚úì OPENAI_API_KEY: {'Set' if os.getenv('OPENAI_API_KEY') else '‚ùå Missing'}")
print(f"‚úì PINECONE_API_KEY: {'Set' if os.getenv('PINECONE_API_KEY') else '‚ùå Missing'}")
print(f"‚úì PINECONE_INDEX: {os.getenv('PINECONE_INDEX', 'Not set')}")

Environment Check:
‚úì OPENAI_API_KEY: Set
‚úì PINECONE_API_KEY: Set
‚úì PINECONE_INDEX: diabetes-medical-knowledge


In [20]:
# Import RAG services and agents
from app.services.rag_service import (
    get_rag_service,
    NAMESPACE_CLINICAL_SAFETY,
    NAMESPACE_CULTURAL_DIET,
    NAMESPACE_LIFESTYLE_PATTERNS,
)
from app.agents.router_agent import RouterState, route_intent
from app.agents.clinical_safety_agent import ClinicalSafetyState, check_clinical_safety
from app.agents.lifestyle_analyst_agent import LifestyleState, analyze_lifestyle
from app.schemas.patient_context import PatientContext
from app.schemas.enhanced_patient_context import EnhancedPatientContext

print("‚úì Imports successful")

RuntimeError: module was compiled against NumPy C-API version 0x10 (NumPy 1.23) but the running NumPy has C-API version 0xf. Check the section C-API incompatibility at the Troubleshooting ImportError section at https://numpy.org/devdocs/user/troubleshooting-importerror.html#c-api-incompatibility for indications on how to solve this problem.

ImportError: cannot import name 'Client' from 'supabase' (unknown location)

## 2. PDF Extraction Testing

In [None]:
# Import PDF extraction utilities
try:
    import pdfplumber
    PDFPLUMBER_AVAILABLE = True
    print("‚úì pdfplumber available")
except ImportError:
    PDFPLUMBER_AVAILABLE = False
    print("‚ùå pdfplumber not available")

try:
    import fitz  # PyMuPDF
    PYMUPDF_AVAILABLE = True
    print("‚úì PyMuPDF available")
except ImportError:
    PYMUPDF_AVAILABLE = False
    print("‚ùå PyMuPDF not available")

def extract_text_from_pdf(pdf_path: Path) -> str:
    """Extract text from PDF file."""
    text_content = []
    
    if PDFPLUMBER_AVAILABLE:
        try:
            with pdfplumber.open(pdf_path) as pdf:
                for page_num, page in enumerate(pdf.pages, 1):
                    text = page.extract_text()
                    if text:
                        text_content.append(f"Page {page_num}:\n{text}\n")
            print(f"‚úì Extracted {len(text_content)} pages using pdfplumber")
            return "\n".join(text_content)
        except Exception as e:
            print(f"pdfplumber failed: {e}")
    
    if PYMUPDF_AVAILABLE:
        try:
            doc = fitz.open(pdf_path)
            for page_num, page in enumerate(doc, 1):
                text = page.get_text()
                if text:
                    text_content.append(f"Page {page_num}:\n{text}\n")
            doc.close()
            print(f"‚úì Extracted {len(text_content)} pages using PyMuPDF")
            return "\n".join(text_content)
        except Exception as e:
            print(f"PyMuPDF failed: {e}")
    
    raise RuntimeError("Could not extract text. Install pdfplumber or pymupdf.")

In [None]:
# Test PDF extraction on a sample file
rag_docs_dir = backend_dir / "rag_docs"
clinical_docs_dir = rag_docs_dir / "clinical_safety_docs"

# List available PDFs
if clinical_docs_dir.exists():
    pdf_files = list(clinical_docs_dir.glob("*.pdf"))
    print(f"Found {len(pdf_files)} PDFs in clinical_safety_docs/")
    for pdf in pdf_files[:5]:  # Show first 5
        print(f"  - {pdf.name} ({pdf.stat().st_size / 1024:.0f} KB)")
    
    if pdf_files:
        # Test extraction on first small PDF
        test_pdf = min(pdf_files, key=lambda p: p.stat().st_size)  # Smallest file
        print(f"\nTesting extraction on: {test_pdf.name}")
        
        extracted_text = extract_text_from_pdf(test_pdf)
        print(f"\nExtracted {len(extracted_text)} characters")
        print("\nFirst 500 characters:")
        print("=" * 80)
        print(extracted_text[:500])
        print("=" * 80)
else:
    print("‚ùå rag_docs/clinical_safety_docs/ directory not found")

## 3. Chunking Strategy Testing

In [None]:
# Import chunking functions
import re
from typing import List, Dict, Any

def chunk_with_headers(text: str, source: str, min_chunk_size: int = 200, max_chunk_size: int = 1000) -> List[Dict[str, Any]]:
    """Chunk text preserving headers (for clinical safety documents)."""
    chunks = []
    lines = text.split('\n')
    current_header_stack = []
    current_chunk = []
    current_chunk_size = 0
    
    for line in lines:
        line_stripped = line.strip()
        if not line_stripped:
            continue
        
        # Detect headers
        is_header = (
            line_stripped.isupper() and len(line_stripped) > 3 and len(line_stripped) < 100
            or re.match(r'^\d+\.?\s+[A-Z]', line_stripped)
            or re.match(r'^[A-Z][A-Z\s]{3,}', line_stripped)
            or line_stripped.endswith(':') and len(line_stripped) < 80
        )
        
        if is_header:
            if current_chunk_size >= min_chunk_size:
                header_path = " > ".join(current_header_stack) if current_header_stack else "Introduction"
                chunk_text = "\n".join(current_chunk)
                chunks.append({
                    "content": f"Header: {header_path}\n{chunk_text}",
                    "metadata": {
                        "source": source,
                        "header_path": header_path,
                        "tags": [h.lower().replace(" ", "_") for h in current_header_stack],
                    }
                })
            
            if len(current_header_stack) < 3:
                current_header_stack.append(line_stripped)
            else:
                current_header_stack[-1] = line_stripped
            
            current_chunk = [line_stripped]
            current_chunk_size = len(line_stripped)
        else:
            current_chunk.append(line_stripped)
            current_chunk_size += len(line_stripped) + 1
            
            if current_chunk_size >= max_chunk_size:
                header_path = " > ".join(current_header_stack) if current_header_stack else "Introduction"
                chunk_text = "\n".join(current_chunk)
                chunks.append({
                    "content": f"Header: {header_path}\n{chunk_text}",
                    "metadata": {
                        "source": source,
                        "header_path": header_path,
                        "tags": [h.lower().replace(" ", "_") for h in current_header_stack],
                    }
                })
                current_chunk = []
                current_chunk_size = 0
    
    if current_chunk_size >= min_chunk_size:
        header_path = " > ".join(current_header_stack) if current_header_stack else "Introduction"
        chunk_text = "\n".join(current_chunk)
        chunks.append({
            "content": f"Header: {header_path}\n{chunk_text}",
            "metadata": {
                "source": source,
                "header_path": header_path,
                "tags": [h.lower().replace(" ", "_") for h in current_header_stack],
            }
        })
    
    return chunks

print("‚úì Chunking function loaded")

In [None]:
# Test chunking on extracted text
if 'extracted_text' in locals():
    chunks = chunk_with_headers(extracted_text, source=test_pdf.name)
    
    print(f"Created {len(chunks)} chunks from {len(extracted_text)} characters")
    print(f"Average chunk size: {sum(len(c['content']) for c in chunks) / len(chunks):.0f} chars")
    
    # Show first 3 chunks
    print("\nFirst 3 chunks:")
    for i, chunk in enumerate(chunks[:3], 1):
        print(f"\n{'='*80}")
        print(f"Chunk {i}:")
        print(f"Source: {chunk['metadata']['source']}")
        print(f"Header: {chunk['metadata']['header_path']}")
        print(f"Tags: {chunk['metadata'].get('tags', [])}")
        print(f"Length: {len(chunk['content'])} chars")
        print(f"\nContent preview (first 300 chars):")
        print(chunk['content'][:300])
else:
    print("No extracted text available. Run PDF extraction cell first.")

## 4. Pinecone Connection & Stats

In [None]:
# Initialize RAG service
rag = get_rag_service()

if rag.is_available():
    print("‚úì RAG service initialized successfully")
    print(f"  Index: {rag.index_name}")
    
    # Get index stats
    try:
        stats = rag.index.describe_index_stats()
        print(f"\nIndex Statistics:")
        print(f"  Total vectors: {stats.get('total_vector_count', 'N/A')}")
        print(f"  Dimensions: {stats.get('dimension', 'N/A')}")
        
        namespaces = stats.get('namespaces', {})
        if namespaces:
            print(f"\n  Namespaces:")
            for ns, ns_stats in namespaces.items():
                print(f"    - {ns}: {ns_stats.get('vector_count', 0)} vectors")
    except Exception as e:
        print(f"Could not fetch index stats: {e}")
else:
    print("‚ùå RAG service not available. Check PINECONE_API_KEY.")

## 5. RAG Search & Retrieval Testing

In [None]:
# Test RAG search on different namespaces
if rag.is_available():
    test_queries = [
        {
            "query": "metformin kidney disease contraindication",
            "namespace": NAMESPACE_CLINICAL_SAFETY,
            "top_k": 3
        },
        {
            "query": "glucose target Type 2 Diabetes elderly",
            "namespace": NAMESPACE_LIFESTYLE_PATTERNS,
            "top_k": 3
        },
        {
            "query": "Hainanese Chicken Rice nutritional carbs",
            "namespace": NAMESPACE_CULTURAL_DIET,
            "top_k": 2
        }
    ]
    
    for test in test_queries:
        print(f"\n{'='*80}")
        print(f"Query: {test['query']}")
        print(f"Namespace: {test['namespace']}")
        print(f"{'='*80}")
        
        results = rag.search(test['query'], namespace=test['namespace'], top_k=test['top_k'])
        
        if results:
            print(f"\nFound {len(results)} results:")
            for i, result in enumerate(results, 1):
                print(f"\nResult {i}:")
                print(f"  Score: {result.get('score', 'N/A')}")
                print(f"  Source: {result.get('metadata', {}).get('source', 'Unknown')}")
                print(f"  Tags: {result.get('metadata', {}).get('tags', [])}")
                print(f"  Content preview: {result.get('content', '')[:200]}...")
        else:
            print("‚ùå No results found")
else:
    print("RAG service not available")

In [None]:
# Test formatted context for LLM (with citations)
if rag.is_available():
    query = "MOH guidelines for metformin use in elderly patients"
    
    print(f"Query: {query}")
    print(f"Namespace: {NAMESPACE_CLINICAL_SAFETY}")
    print("\n" + "="*80)
    
    context = rag.get_context_for_llm(
        query=query,
        namespace=NAMESPACE_CLINICAL_SAFETY,
        top_k=3,
        include_citations=True
    )
    
    if context:
        print("\nFormatted Context (with mandatory citations):")
        print("="*80)
        print(context)
        print("="*80)
        
        # Check for citation markers
        source_count = context.count("Source:")
        citation_instruction_count = context.count("you MUST cite")
        print(f"\nCitation Analysis:")
        print(f"  Sources mentioned: {source_count}")
        print(f"  Citation instructions: {citation_instruction_count}")
    else:
        print("‚ùå No context returned")

## 6. Agent Routing Testing

In [None]:
# Create test patient context
test_patient = PatientContext(
    user_id="test-user-123",
    first_name="John",
    last_name="Doe",
    age=65,
    sex="Male",
    ethnicity="Chinese",
    height=170.0,
    activity_level="moderate",
    location="Singapore",
    conditions=["Type 2 Diabetes", "Hypertension"],
    medications=["Metformin 500mg", "Lisinopril 10mg"]
)

print("Test patient created:")
print(f"  Name: {test_patient.first_name} {test_patient.last_name}")
print(f"  Age: {test_patient.age}")
print(f"  Conditions: {test_patient.conditions}")
print(f"  Medications: {test_patient.medications}")

In [None]:
# Test router agent with different queries
test_routing_queries = [
    "What are the side effects of metformin?",  # Should route to clinical_safety
    "What do MOH guidelines say about diabetes management?",  # Should route to clinical_safety
    "What was my average glucose this week?",  # Should route to lifestyle_analyst
    "Did I take my medication today?",  # Should route to lifestyle_analyst
    "Hello, how are you?",  # Should route to unmatched
]

for query in test_routing_queries:
    print(f"\n{'='*80}")
    print(f"Query: {query}")
    print(f"{'='*80}")
    
    state = RouterState(patient=test_patient, user_message=query)
    decision = route_intent.invoke(input=state)
    
    print(f"\nRouting Decision:")
    print(f"  Target Agent: {decision['target_agent']}")
    print(f"  Intent: {decision['intent']}")
    print(f"  Rationale: {decision['rationale']}")

## 7. Clinical Safety Agent Testing (with RAG)

In [None]:
# Test Clinical Safety Agent with RAG integration
test_clinical_query = "What do MOH guidelines say about metformin use in elderly patients with kidney disease?"

print(f"Query: {test_clinical_query}")
print("="*80)

# Create agent state
clinical_state = ClinicalSafetyState(
    patient=test_patient,
    user_message=test_clinical_query,
    enhanced_context=None  # Can add enhanced context if needed
)

# Invoke agent
result = check_clinical_safety.invoke(input=clinical_state)

print("\nClinical Safety Analysis:")
print(f"  Is Safe: {result.get('is_safe', 'N/A')}")
print(f"  Warnings ({len(result.get('warnings', []))}):")
for warning in result.get('warnings', []):
    print(f"    - {warning}")
print(f"\n  Rationale: {result.get('rationale', 'N/A')}")

# Check RAG context
rag_context = result.get('rag_context', '')
rag_citations = result.get('rag_citations', [])

print(f"\n  RAG Context Length: {len(rag_context)} chars")
print(f"  RAG Citations ({len(rag_citations)}): {rag_citations}")

if rag_context:
    print("\n  RAG Context Preview (first 500 chars):")
    print("  " + "-"*78)
    print("  " + rag_context[:500])
    print("  " + "-"*78)

## 8. End-to-End Chat Simulation

In [None]:
# Simulate full chat flow: Router ‚Üí Agent ‚Üí LLM Response
from langchain_openai import ChatOpenAI
from langchain_core.messages import SystemMessage, HumanMessage
from app.core.system_prompt_builder import build_system_prompt

def test_full_chat_flow(user_query: str, patient: PatientContext):
    """Test complete chat flow with RAG citations."""
    print(f"\n{'='*80}")
    print(f"User Query: {user_query}")
    print(f"{'='*80}")
    
    # Step 1: Router
    router_state = RouterState(patient=patient, user_message=user_query)
    routing_decision = route_intent.invoke(input=router_state)
    print(f"\n[ROUTER] Target Agent: {routing_decision['target_agent']}")
    
    # Step 2: Agent (Clinical Safety)
    if routing_decision['target_agent'] == 'clinical_safety':
        clinical_state = ClinicalSafetyState(
            patient=patient,
            user_message=user_query,
            enhanced_context=None
        )
        agent_result = check_clinical_safety.invoke(input=clinical_state)
        
        rag_context = agent_result.get('rag_context', '')
        rag_citations = agent_result.get('rag_citations', [])
        agent_text = agent_result.get('rationale', '')
        
        print(f"[AGENT] RAG Citations: {rag_citations}")
        print(f"[AGENT] RAG Context Length: {len(rag_context)} chars")
        
        # Step 3: Build System Prompt
        patient_str = f"{patient.first_name} {patient.last_name}, {patient.age}yo {patient.sex}"
        system_prompt = build_system_prompt(
            patient_context_str=patient_str,
            enhanced_context=None,
            user_message=user_query,
            agent_text=agent_text,
            rag_context=rag_context
        )
        
        # Step 4: LLM Call
        llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.7)
        messages = [
            SystemMessage(content=system_prompt),
            HumanMessage(content=user_query)
        ]
        
        print(f"\n[LLM] Calling GPT-4o-mini...")
        response = llm.invoke(messages)
        
        print(f"\n{'='*80}")
        print("LLM RESPONSE:")
        print(f"{'='*80}")
        print(response.content)
        print(f"{'='*80}")
        
        # Step 5: Citation Validation
        print(f"\n[VALIDATION] Citation Check:")
        citations_found = []
        for citation in rag_citations:
            clean_citation = citation.replace('.pdf', '').replace('_', ' ')
            if clean_citation in response.content or citation in response.content:
                citations_found.append(citation)
                print(f"  ‚úì Found citation: {citation}")
            else:
                print(f"  ‚ùå Missing citation: {citation}")
        
        if citations_found:
            print(f"\n  ‚úì SUCCESS: {len(citations_found)}/{len(rag_citations)} citations included")
        else:
            print(f"\n  ‚ùå FAILURE: No citations found in response")
        
        return response.content, citations_found
    else:
        print(f"[INFO] Not a clinical safety query, skipping detailed flow")
        return None, []

In [None]:
# Test full chat flow with citation validation
test_queries = [
    "What do MOH guidelines recommend for metformin dosing in elderly patients?",
    "Are there any contraindications for metformin in patients with kidney disease?",
    "What are the ADA recommendations for HbA1c targets?"
]

for query in test_queries:
    response, citations = test_full_chat_flow(query, test_patient)
    print("\n" + "#"*80 + "\n")

## 9. Fine-tuning Parameters

In [None]:
# Experiment with different RAG parameters
print("RAG Parameter Tuning")
print("="*80)

if rag.is_available():
    query = "metformin kidney disease contraindication"
    
    # Test different top_k values
    for top_k in [1, 3, 5]:
        results = rag.search(query, namespace=NAMESPACE_CLINICAL_SAFETY, top_k=top_k)
        print(f"\ntop_k={top_k}: {len(results)} results")
        if results:
            avg_score = sum(r.get('score', 0) for r in results) / len(results)
            print(f"  Average score: {avg_score:.4f}")
            print(f"  Min score: {min(r.get('score', 0) for r in results):.4f}")
            print(f"  Max score: {max(r.get('score', 0) for r in results):.4f}")
else:
    print("RAG service not available")

## 10. Summary & Recommendations

In [None]:
print("="*80)
print("PIPELINE TESTING SUMMARY")
print("="*80)
print("\n‚úì Components Tested:")
print("  1. PDF Extraction (pdfplumber/PyMuPDF)")
print("  2. Chunking Strategy (header-based for clinical docs)")
print("  3. Pinecone Connection & Stats")
print("  4. RAG Search & Retrieval")
print("  5. Agent Routing (LLM-based)")
print("  6. Clinical Safety Agent with RAG")
print("  7. End-to-End Chat with Citation Validation")
print("\n‚ö†Ô∏è Known Issues:")
print("  - Citations may not appear consistently in LLM responses")
print("  - LLM relies on instructions, no programmatic enforcement")
print("  - rag_citations array collected but not fully utilized")
print("\nüí° Recommendations:")
print("  1. Add post-processing to validate citations in responses")
print("  2. Simplify citation instructions (fewer lines, clearer format)")
print("  3. Consider using JSON mode for structured output with citations")
print("  4. Monitor RAG retrieval scores (threshold: >0.7 for relevance)")
print("  5. Fine-tune chunk sizes based on retrieval quality")
print("="*80)