In [1]:
# Consolidate all installations needed for Task 2 (Embeddings and FAISS)
!pip install sentence-transformers langchain-huggingface faiss-cpu langchain-community



In [2]:
# Complete clean installation with compatible versions
!pip uninstall -y langchain langchain-community langchain-core langchain-huggingface
!pip install --upgrade pip
!pip install langchain==0.1.20 langchain-community==0.0.38 langchain-core==0.1.52 langchain-huggingface transformers accelerate bitsandbytes sentence-transformers faiss-cpu

# RESTART RUNTIME after installation

Found existing installation: langchain 0.1.20
Uninstalling langchain-0.1.20:
  Successfully uninstalled langchain-0.1.20
Found existing installation: langchain-community 0.0.38
Uninstalling langchain-community-0.0.38:
  Successfully uninstalled langchain-community-0.0.38
Found existing installation: langchain-core 0.1.52
Uninstalling langchain-core-0.1.52:
  Successfully uninstalled langchain-core-0.1.52
Found existing installation: langchain-huggingface 0.0.3
Uninstalling langchain-huggingface-0.0.3:
  Successfully uninstalled langchain-huggingface-0.0.3
Collecting langchain==0.1.20
  Using cached langchain-0.1.20-py3-none-any.whl.metadata (13 kB)
Collecting langchain-community==0.0.38
  Using cached langchain_community-0.0.38-py3-none-any.whl.metadata (8.7 kB)
Collecting langchain-core==0.1.52
  Using cached langchain_core-0.1.52-py3-none-any.whl.metadata (5.9 kB)
Collecting langchain-huggingface
  Using cached langchain_huggingface-1.1.0-py3-none-any.whl.metadata (2.8 kB)
INFO: pip 

In [None]:
# Complete clean installation with compatible versions
!pip uninstall -y langchain langchain-community langchain-core langchain-huggingface
!pip install --upgrade pip
!pip install langchain==0.1.20 langchain-community==0.0.38 langchain-core==0.1.52 langchain-huggingface transformers accelerate bitsandbytes sentence-transformers faiss-cpu

# RESTART RUNTIME after installation

# Data Exploration


In [3]:
# ==========================================================
#   FULL EXTRACTION + VALIDATION PIPELINE (ONE CELL)
# ==========================================================

import os
import zipfile
import rarfile

# ----------------------------------------------------------
# STEP 1 — Extract Main ZIP File
# ----------------------------------------------------------

MAIN_ZIP_FILE_PATH = '/content/mimic-iv-ext-direct-1.0.0.zip'
EXTRACT_ROOT_DIR = './MIMIC_RAG_Project_Data/'
os.makedirs(EXTRACT_ROOT_DIR, exist_ok=True)

print(f"Using ZIP file: {MAIN_ZIP_FILE_PATH}")
print(f"Starting extraction of {MAIN_ZIP_FILE_PATH.split('/')[-1]}...")

try:
    with zipfile.ZipFile(MAIN_ZIP_FILE_PATH, 'r') as zip_ref:
        zip_ref.extractall(EXTRACT_ROOT_DIR)
    print("Main ZIP extraction successful.\n")

except FileNotFoundError:
    print(f"ERROR: File not found at {MAIN_ZIP_FILE_PATH}")
except Exception as e:
    print(f"Unexpected error during ZIP extraction: {e}")


# ----------------------------------------------------------
# STEP 2 — Check Inner Folder Contents
# ----------------------------------------------------------

INNER_FOLDER_PATH = './MIMIC_RAG_Project_Data/mimic-iv-ext-direct-1.0.0'
print(f"--- Contents inside the inner folder ({INNER_FOLDER_PATH}) ---")

try:
    inner_contents = os.listdir(INNER_FOLDER_PATH)

    for item in inner_contents:
        full_path = os.path.join(INNER_FOLDER_PATH, item)
        if os.path.isdir(full_path):
            print(f"FOLDER: {item}")
        elif os.path.isfile(full_path):
            if item.endswith('.rar') or item.endswith('.zip'):
                print(f"COMPRESSED FILE: {item}")
            else:
                print(f"FILE: {item}")
        else:
            print(f"OTHER: {item}")

    key_files_found = {
        'Finished_Folder': 'Finished' in inner_contents,
        'samples.rar_file': 'samples.rar' in inner_contents,
        'diagnostic_kg.rar_file': 'diagnostic_kg.rar' in inner_contents
    }

    if all(key_files_found.values()):
        print("\nConfirmed: All required components are present.")
    else:
        print("\nWarning: Missing one or more required files.")

except FileNotFoundError:
    print(f"ERROR: Inner folder not found at {INNER_FOLDER_PATH}")
except Exception as e:
    print(f"Error while checking inner folder: {e}")


# ----------------------------------------------------------
# STEP 3 — Extract RAR Files (Samples + Knowledge Graph)
# ----------------------------------------------------------

SAMPLES_RAR_PATH = os.path.join(INNER_FOLDER_PATH, 'samples.rar')
KG_RAR_PATH = os.path.join(INNER_FOLDER_PATH, 'diagnostic_kg.rar')

SAMPLES_EXTRACT_DIR = './DiReCT_Notes_Sample/'
KG_EXTRACT_DIR = './DiReCT_Knowledge_Graph/'

os.makedirs(SAMPLES_EXTRACT_DIR, exist_ok=True)
os.makedirs(KG_EXTRACT_DIR, exist_ok=True)


def extract_rar(rar_path, extract_dir):
    file_name = os.path.basename(rar_path)
    print(f"\nExtracting {file_name}...")

    try:
        with rarfile.RarFile(rar_path) as rf:
            rf.extractall(extract_dir)
        print(f"Extraction complete: {file_name}")

        extracted_contents = os.listdir(extract_dir)
        print(f"Items extracted to {extract_dir} (first 5): {extracted_contents[:5]}")

    except FileNotFoundError:
        print(f"ERROR: RAR file not found at {rar_path}")
    except rarfile.RarCannotExec:
        print("ERROR: 'unrar' tool missing. Install using: !apt-get install unrar")
    except Exception as e:
        print(f"Error during extraction of {file_name}: {e}")


extract_rar(SAMPLES_RAR_PATH, SAMPLES_EXTRACT_DIR)
extract_rar(KG_RAR_PATH, KG_EXTRACT_DIR)

print("\nAll compressed files extracted successfully. Ready for data loading.")


Using ZIP file: /content/mimic-iv-ext-direct-1.0.0.zip
Starting extraction of mimic-iv-ext-direct-1.0.0.zip...
Main ZIP extraction successful.

--- Contents inside the inner folder (./MIMIC_RAG_Project_Data/mimic-iv-ext-direct-1.0.0) ---
FILE: .DS_Store
COMPRESSED FILE: diagnostic_kg.rar
COMPRESSED FILE: samples.rar
FOLDER: Finished
FILE: LICENSE.txt
FILE: README.md
FILE: SHA256SUMS.txt

Confirmed: All required components are present.

Extracting samples.rar...
Extraction complete: samples.rar
Items extracted to ./DiReCT_Notes_Sample/ (first 5): ['Finished']

Extracting diagnostic_kg.rar...
Extraction complete: diagnostic_kg.rar
Items extracted to ./DiReCT_Knowledge_Graph/ (first 5): ['Diagnosis_flowchart']

All compressed files extracted successfully. Ready for data loading.


In [None]:
# Install the command-line utility for handling .rar files
!apt-get install unrar

# Install the Python wrapper library
!pip install rarfile

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
unrar is already the newest version (1:6.1.5-1ubuntu0.1).
0 upgraded, 0 newly installed, 0 to remove and 41 not upgraded.


## Loading and Parsing Clinical Notes

In [4]:
import os
import json
import pandas as pd

# Correct shared path
CORRECTED_ROOT_DIR = './DiReCT_Notes_Sample/Finished'


# ===============================================================
#  PART 1 — LOADER VERSION 1
#  (Subjective / Objective / Assessment / Plan)
# ===============================================================

def load_clinical_notes_v1(root_dir):
    """
    Loads JSON clinical notes and extracts SOAP-style narrative fields.
    """
    all_notes = []

    for root, _, files in os.walk(root_dir):
        for file in files:
            if file.endswith('.json'):
                file_path = os.path.join(root, file)
                try:
                    with open(file_path, 'r', encoding='utf-8') as f:
                        note_data = json.load(f)

                    subjective = note_data.get('Subjective', '')
                    objective = note_data.get('Objective', '')
                    assessment = note_data.get('Assessment', '')
                    plan = note_data.get('Plan', '')

                    full_text = (
                        f"Subjective: {subjective}\n"
                        f"Objective: {objective}\n"
                        f"Assessment: {assessment}\n"
                        f"Plan: {plan}"
                    )

                    all_notes.append({
                        'note_id': note_data.get('ID', os.path.basename(file)),
                        'full_text': full_text,
                        'diagnosis': note_data.get('Diagnosis', 'N/A'),
                        'abstract': note_data.get('Abstract', ''),
                        'reasoning_steps': note_data.get('Reasoning_steps', [])
                    })

                except Exception:
                    continue

    return pd.DataFrame(all_notes)


print(f"Loading Version 1 Notes from: {CORRECTED_ROOT_DIR}")
df_notes_v1 = load_clinical_notes_v1(CORRECTED_ROOT_DIR)
print("Version 1 Loaded:", len(df_notes_v1), "notes")



# ===============================================================
#  PART 2 — JSON KEY INSPECTION
#  (Finds keys from one sample file)
# ===============================================================

def find_note_body_key(root_dir):
    """
    Prints all top-level keys in the first JSON file found.
    """
    for root, _, files in os.walk(root_dir):
        for file in files:
            if file.endswith('.json'):
                fpath = os.path.join(root, file)
                try:
                    with open(fpath, 'r', encoding='utf-8') as f:
                        sample = json.load(f)

                    print(f"\n--- Keys in sample file: {file} ---")
                    for key, value in sample.items():
                        snippet = str(value)[:150].replace("\n", " ")
                        print(f"Key: {key} | Type: {type(value).__name__} | Snippet: {snippet}...")
                    return

                except Exception:
                    continue
    print("No JSON files found.")


print("\nInspecting JSON Keys...")
find_note_body_key(CORRECTED_ROOT_DIR)



# ===============================================================
#  PART 3 — LOADER VERSION 2
#  (input1–input6 structured format + diagnosis from key)
# ===============================================================

def load_clinical_notes_v2(root_dir):
    """
    Loads JSON notes using 'input1' to 'input6'
    and extracts diagnosis from complex keys.
    """
    all_notes = []

    for root, _, files in os.walk(root_dir):
        for file in files:
            if file.endswith('.json'):
                fpath = os.path.join(root, file)
                try:
                    with open(fpath, 'r', encoding='utf-8') as f:
                        note_data = json.load(f)

                    complaint = note_data.get('input1', '')
                    hpi = note_data.get('input2', '')
                    pmh = note_data.get('input3', '')
                    fh = note_data.get('input4', '')
                    pe = note_data.get('input5', '')
                    labs = note_data.get('input6', '')

                    full_text = (
                        f"CHIEF COMPLAINT: {complaint}\n"
                        f"HISTORY OF PRESENT ILLNESS: {hpi}\n"
                        f"PAST MEDICAL HISTORY: {pmh}\n"
                        f"FAMILY HISTORY: {fh}\n"
                        f"PHYSICAL EXAMINATION: {pe}\n"
                        f"LABS/IMAGING: {labs}"
                    )

                    # Extract diagnosis from the first non-input key
                    diag_key = [k for k in note_data if not k.startswith("input")][0]
                    diagnosis = diag_key.split("$")[0]

                    all_notes.append({
                        'note_id': os.path.basename(fpath),
                        'full_text': full_text,
                        'diagnosis': diagnosis
                    })

                except Exception as e:
                    print(f"Error processing {fpath}. Skipping. Error: {e}")
                    continue

    return pd.DataFrame(all_notes)


print("\nLoading Version 2 Notes...")
df_notes_v2 = load_clinical_notes_v2(CORRECTED_ROOT_DIR)
print("Version 2 Loaded:", len(df_notes_v2), "notes")


Loading Version 1 Notes from: ./DiReCT_Notes_Sample/Finished
Version 1 Loaded: 511 notes

Inspecting JSON Keys...

--- Keys in sample file: 18427803-DS-5.json ---
Key: Migraine With Aura$Intermedia_3 | Type: dict | Snippet: {'Difficulty expressing language may be associated with migraine, especially when migraine is accompanied by neurological symptoms$Cause_1': {'Difficu...
Key: input1 | Type: str | Snippet: Difficulty producing speech  ...
Key: input2 | Type: str | Snippet: Patient woke up at about 0900 AM. Patient felt that she had a migraine. Patient took acetaminophen and went back to bed. Patient woke back up one hour...
Key: input3 | Type: str | Snippet: Hypothyroidism  HLD  History of breast CA  Left paramedian pontine stroke  ...
Key: input4 | Type: str | Snippet: Mother had heart disease in elderly age. Paternal grandmother had stroke in ___.  ...
Key: input5 | Type: str | Snippet: PHYSICAL EXAMINATION:  Presentation vitals:  Temperature: 98.5  Heart rate: 113  Blood pressure

## Corrected Data Loading and Parsing

In [5]:
# =============================================================================
# KNOWLEDGE GRAPH DATA EXPLORATION (CLEAN VERSION, NO EMOJIS, NO REDUNDANCY)
# =============================================================================

import os
import json
import pandas as pd

KG_ROOT_DIR = './DiReCT_Knowledge_Graph'

print("=" * 70)
print("EXPLORING KNOWLEDGE GRAPH STRUCTURE")
print("=" * 70)

# List directory structure
print(f"\nExploring directory: {KG_ROOT_DIR}")
print("-" * 70)

for root, dirs, files in os.walk(KG_ROOT_DIR):
    level = root.replace(KG_ROOT_DIR, '').count(os.sep)
    indent = ' ' * 2 * level
    print(f"{indent}{os.path.basename(root)}/")

    sub_indent = ' ' * 2 * (level + 1)
    for file in files:
        file_size = os.path.getsize(os.path.join(root, file))
        print(f"{sub_indent}{file} ({file_size} bytes)")

print("\n" + "=" * 70)
print("INSPECTING SAMPLE KNOWLEDGE GRAPH FILE")
print("=" * 70)

# Find the first JSON file
sample_kg_file = None
for root, _, files in os.walk(KG_ROOT_DIR):
    for file in files:
        if file.endswith('.json'):
            sample_kg_file = os.path.join(root, file)
            break
    if sample_kg_file:
        break

if sample_kg_file:
    print(f"\nSample file: {os.path.basename(sample_kg_file)}")
    print("-" * 70)

    try:
        with open(sample_kg_file, 'r', encoding='utf-8') as f:
            sample_kg_data = json.load(f)

        print("File loaded successfully\n")
        print("Top-level structure:\n")

        def inspect_structure(data, depth=0, max_depth=3):
            """Recursively inspect JSON structure"""
            indent = "  " * depth
            if depth > max_depth:
                return

            if isinstance(data, dict):
                for i, (key, value) in enumerate(data.items()):
                    if i >= 5:
                        print(f"{indent}... ({len(data)-5} more keys)")
                        break

                    clean_key = key.split('$')[0] if '$' in key else key
                    if isinstance(value, dict):
                        print(f"{indent}{clean_key} (dict with {len(value)} keys)")
                        inspect_structure(value, depth + 1)
                    elif isinstance(value, list):
                        print(f"{indent}{clean_key} (list with {len(value)} items)")
                        if value:
                            inspect_structure(value[0], depth + 1)
                    elif isinstance(value, str):
                        print(f"{indent}{clean_key}: {value[:80].replace(chr(10), ' ')}...")
                    else:
                        print(f"{indent}{clean_key}: {value}")

            elif isinstance(data, list) and data:
                print(f"{indent}List with {len(data)} items")
                inspect_structure(data[0], depth + 1)

        inspect_structure(sample_kg_data)

        print("\nRaw JSON preview (first 500 chars):")
        print("-" * 70)
        print(json.dumps(sample_kg_data, indent=2)[:500] + "...")
        print("-" * 70)

    except Exception as e:
        print(f"Error loading sample file: {e}")

else:
    print("No JSON files found in Knowledge Graph directory")

# =============================================================================
# LOAD ALL KNOWLEDGE GRAPH FILES
# =============================================================================

print("\n" + "=" * 70)
print("LOADING ALL KNOWLEDGE GRAPH FILES")
print("=" * 70)

def load_knowledge_graph_files(root_dir):
    """Load and flatten all knowledge graph JSON files."""
    kg_data = []

    for root, _, files in os.walk(root_dir):
        for file in files:
            if file.endswith('.json'):
                file_path = os.path.join(root, file)
                try:
                    with open(file_path, 'r', encoding='utf-8') as f:
                        data = json.load(f)

                    diagnosis_name = file.replace('.json', '')
                    kg_text = extract_kg_text(data)

                    kg_data.append({
                        'source_file': file,
                        'diagnosis': diagnosis_name,
                        'raw_data': data,
                        'text_content': kg_text,
                        'num_keys': count_keys(data)
                    })

                except Exception as e:
                    print(f"Error loading {file}: {e}")
                    continue

    return pd.DataFrame(kg_data)

def count_keys(data):
    """Count keys in nested dict/list."""
    if isinstance(data, dict):
        return len(data) + sum(count_keys(v) for v in data.values())
    if isinstance(data, list):
        return sum(count_keys(i) for i in data)
    return 0

def extract_kg_text(data):
    """Flatten nested knowledge graph into readable text."""
    parts = []

    def traverse(obj, depth=0):
        indent = "  " * depth
        if isinstance(obj, dict):
            for key, value in obj.items():
                clean_key = key.split('$')[0]
                parts.append(f"{indent}• {clean_key}")
                traverse(value, depth + 1)

        elif isinstance(obj, list):
            for item in obj:
                traverse(item, depth)

        elif isinstance(obj, str) and obj.strip():
            parts.append(f"{indent}- {obj.strip()}")

    traverse(data)
    return "\n".join(parts)

print(f"\nLoading knowledge graph files from: {KG_ROOT_DIR}")
df_kg = load_knowledge_graph_files(KG_ROOT_DIR)

# =============================================================================
# KNOWLEDGE GRAPH SUMMARY
# =============================================================================

print("\n" + "=" * 70)
print("KNOWLEDGE GRAPH SUMMARY")
print("=" * 70)

if not df_kg.empty:
    print(f"\nTotal knowledge graph files loaded: {len(df_kg)}")

    print("\nLoaded Diagnoses:")
    for idx, row in df_kg.iterrows():
        print(f"{idx+1}. {row['diagnosis']} ({row['num_keys']} keys, {len(row['text_content'])} chars)")

    print(f"\nTotal files: {len(df_kg)}")
    print(f"Total keys: {df_kg['num_keys'].sum()}")
    print(f"Avg keys per file: {df_kg['num_keys'].mean():.1f}")
    print(f"Total text length: {df_kg['text_content'].str.len().sum():,}")

    print("\nSAMPLE KNOWLEDGE GRAPH ENTRY")
    sample = df_kg.iloc[0]
    print(f"\nDiagnosis: {sample['diagnosis']}")
    print(f"Source: {sample['source_file']}")
    print(f"Number of keys: {sample['num_keys']}")
    print("\nContent Preview:")
    print("-" * 70)
    print(sample['text_content'][:800])
    print("\n... (content continues)")
else:
    print("No knowledge graph files were loaded")

# =============================================================================
# DEPENDENCY: df_notes MUST EXIST (FROM EARLIER CLINICAL NOTE LOADER)
# =============================================================================

print("\n" + "=" * 70)
print("DATASET COMPARISON SUMMARY")
print("=" * 70)

if 'df_notes' in globals() and not df_notes.empty and not df_kg.empty:
    print(f"\nClinical Notes:    {len(df_notes)} documents")
    print(f"Knowledge Graph:   {len(df_kg)} documents")

    kg_diagnoses = set(df_kg['diagnosis'])
    note_diagnoses = set(df_notes['diagnosis'])

    overlap = kg_diagnoses & note_diagnoses

    print(f"\nUnique diagnoses in KG:    {len(kg_diagnoses)}")
    print(f"Unique diagnoses in Notes: {len(note_diagnoses)}")
    print(f"Overlapping diagnoses:     {len(overlap)}")

    if overlap:
        print("\nCommon Diagnoses (sample):")
        for diag in list(overlap)[:5]:
            print(f"{diag}")

    print("\nDATA TYPE COMPARISON")
    print("-" * 70)

    sample_note = df_notes.iloc[0]
    print("\nCLINICAL NOTE SAMPLE:")
    print(f"Note ID: {sample_note['note_id']}")
    print(f"Diagnosis: {sample_note['diagnosis']}")
    print(sample_note['full_text'][:300], "...")

    sample_kg = df_kg.iloc[0]
    print("\nKNOWLEDGE GRAPH SAMPLE:")
    print(f"Diagnosis: {sample_kg['diagnosis']}")
    print(sample_kg['text_content'][:300], "...")

    print("\nKEY DIFFERENCES")
    print("- Clinical Notes: patient-specific, narrative, single case")
    print("- Knowledge Graph: generalized diagnostic rules, structured")
else:
    print("df_notes missing or empty. Cannot compare.")

print("\nDATA LOADING COMPLETE")
print("=" * 70)


EXPLORING KNOWLEDGE GRAPH STRUCTURE

Exploring directory: ./DiReCT_Knowledge_Graph
----------------------------------------------------------------------
DiReCT_Knowledge_Graph/
  Diagnosis_flowchart/
    Upper Gastrointestinal Bleeding.json (1067 bytes)
    Acute Coronary Syndrome.json (1020 bytes)
    Heart Failure.json (1726 bytes)
    Multiple Sclerosis.json (3675 bytes)
    Stroke.json (2833 bytes)
    Adrenal Insufficiency.json (2583 bytes)
    Epilepsy.json (1992 bytes)
    Hyperlipidemia.json (1059 bytes)
    Cardiomyopathy.json (2493 bytes)
    Hypertension.json (658 bytes)
    Tuberculosis.json (2633 bytes)
    Aortic Dissection.json (824 bytes)
    Peptic Ulcer Disease.json (1560 bytes)
    COPD.json (1038 bytes)
    Pneumonia.json (1961 bytes)
    Migraine.json (1512 bytes)
    Pulmonary Embolism.json (3324 bytes)
    Gastro-oesophageal Reflux Disease.json (1906 bytes)
    Diabetes.json (3107 bytes)
    Alzheimer.json (2522 bytes)
    Asthma.json (2502 bytes)
    Thyroid Di

New


In [6]:
# =============================================================================
# STANDARDIZE DATAFRAME
# =============================================================================

print("\n" + "="*70)
print("STANDARDIZING DATA")
print("="*70)

# Use df_notes_v2 (input1-6 format) as the standard
df_notes = df_notes_v2.copy()
print(f"✓ Standardized on df_notes_v2")
print(f"✓ Total clinical notes: {len(df_notes)}")
print(f"✓ Total knowledge graph entries: {len(df_kg)}")

# Display sample
print("\nSample clinical note:")
print(f"  Note ID: {df_notes.iloc[0]['note_id']}")
print(f"  Diagnosis: {df_notes.iloc[0]['diagnosis']}")
print(f"  Text preview: {df_notes.iloc[0]['full_text'][:100]}...")


STANDARDIZING DATA
✓ Standardized on df_notes_v2
✓ Total clinical notes: 511
✓ Total knowledge graph entries: 24

Sample clinical note:
  Note ID: 18427803-DS-5.json
  Diagnosis: Migraine With Aura
  Text preview: CHIEF COMPLAINT: Difficulty producing speech 

HISTORY OF PRESENT ILLNESS: Patient woke up at about ...


In [7]:
# =============================================================================
# CREATE TEXT CHUNKS FROM CLINICAL NOTES AND KNOWLEDGE GRAPH
# =============================================================================

from langchain.text_splitter import RecursiveCharacterTextSplitter

print("\n" + "="*70)
print("CREATING TEXT CHUNKS")
print("="*70)

# Initialize text splitter
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=500,
    chunk_overlap=50,
    separators=["\n\n", "\n", ". ", " "],
    length_function=len
)

print("✓ Text splitter initialized")
print(f"  Chunk size: 500")
print(f"  Chunk overlap: 50")


# ---------------------------------------------
# Create chunks from clinical notes
# ---------------------------------------------
def create_chunks_from_notes(df_notes):
    """Create chunks from clinical notes with enhanced metadata"""
    all_chunks = []

    print("\nProcessing clinical notes...")
    for idx, row in df_notes.iterrows():
        note_id = row['note_id']
        diagnosis = row.get('diagnosis', 'Unknown')
        full_text = row['full_text']

        # Create enhanced text with metadata header
        enhanced_text = f"NOTE_ID: {note_id}\nDIAGNOSIS: {diagnosis}\n\n{full_text}"

        # Split into chunks
        chunks = text_splitter.create_documents(
            texts=[enhanced_text],
            metadatas=[{
                'note_id': note_id,
                'diagnosis': diagnosis,
                'type': 'clinical_note'
            }]
        )

        # Add chunk indices
        for i, chunk in enumerate(chunks):
            chunk.metadata['chunk_index'] = i

        all_chunks.extend(chunks)

        if (idx + 1) % 100 == 0:
            print(f"  Processed {idx + 1}/{len(df_notes)} notes...")

    print(f"✓ Created {len(all_chunks)} chunks from clinical notes")
    return all_chunks


# ---------------------------------------------
# Create chunks from knowledge graph
# ---------------------------------------------
def create_chunks_from_kg(df_kg):
    """Create chunks from knowledge graph data"""
    kg_chunks = []

    print("\nProcessing knowledge graph...")
    for idx, row in df_kg.iterrows():
        diagnosis = row['diagnosis']
        text_content = row['text_content']

        # Add header
        enhanced_text = f"KNOWLEDGE: {diagnosis}\n\n{text_content}"

        chunks = text_splitter.create_documents(
            texts=[enhanced_text],
            metadatas=[{
                'diagnosis': diagnosis,
                'type': 'knowledge_graph',
                'source_file': row['source_file']
            }]
        )

        # Add chunk indices
        for i, chunk in enumerate(chunks):
            chunk.metadata['chunk_index'] = i

        kg_chunks.extend(chunks)

    print(f"✓ Created {len(kg_chunks)} chunks from knowledge graph")
    return kg_chunks


# Create all chunks
clinical_chunks = create_chunks_from_notes(df_notes)
kg_chunks = create_chunks_from_kg(df_kg)

# Combine all chunks
all_combined_chunks = clinical_chunks + kg_chunks

print("\n" + "-"*70)
print("CHUNK CREATION SUMMARY")
print("-"*70)
print(f"Clinical note chunks:     {len(clinical_chunks)}")
print(f"Knowledge graph chunks:   {len(kg_chunks)}")
print(f"Total combined chunks:    {len(all_combined_chunks)}")

# Display sample chunks
print("\n" + "-"*70)
print("SAMPLE CHUNKS")
print("-"*70)

print("\nClinical Note Chunk:")
print(f"Content: {clinical_chunks[0].page_content[:200]}...")
print(f"Metadata: {clinical_chunks[0].metadata}")

if kg_chunks:
    print("\nKnowledge Graph Chunk:")
    print(f"Content: {kg_chunks[0].page_content[:200]}...")
    print(f"Metadata: {kg_chunks[0].metadata}")

print("\n✓ Chunk creation complete")


CREATING TEXT CHUNKS
✓ Text splitter initialized
  Chunk size: 500
  Chunk overlap: 50

Processing clinical notes...
  Processed 100/511 notes...
  Processed 200/511 notes...
  Processed 300/511 notes...
  Processed 400/511 notes...
  Processed 500/511 notes...
✓ Created 5582 chunks from clinical notes

Processing knowledge graph...
✓ Created 159 chunks from knowledge graph

----------------------------------------------------------------------
CHUNK CREATION SUMMARY
----------------------------------------------------------------------
Clinical note chunks:     5582
Knowledge graph chunks:   159
Total combined chunks:    5741

----------------------------------------------------------------------
SAMPLE CHUNKS
----------------------------------------------------------------------

Clinical Note Chunk:
Content: NOTE_ID: 18427803-DS-5.json
DIAGNOSIS: Migraine With Aura

CHIEF COMPLAINT: Difficulty producing speech 

HISTORY OF PRESENT ILLNESS: Patient woke up at about 0900 AM. Patient 

In [8]:
# =============================================================================
# INITIALIZE EMBEDDINGS AND CREATE FAISS VECTOR STORE
# =============================================================================

from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS

print("\n" + "="*70)
print("INITIALIZING EMBEDDINGS AND VECTOR STORE")
print("="*70)

# ---------------------------------------------
# Initialize embedding model
# ---------------------------------------------
print("\nLoading embedding model: sentence-transformers/all-MiniLM-L6-v2")

embeddings = HuggingFaceEmbeddings(
    model_name="sentence-transformers/all-MiniLM-L6-v2",
    model_kwargs={'device': 'cpu'},
    encode_kwargs={'normalize_embeddings': True}
)

print("✓ Embedding model loaded successfully")
print("  Model: sentence-transformers/all-MiniLM-L6-v2")
print("  Embedding dimension: 384")

# Test embeddings
test_text = "Patient has migraine with aura"
test_embedding = embeddings.embed_query(test_text)
print(f"  Test embedding length: {len(test_embedding)}")


# ---------------------------------------------
# Create FAISS vector store
# ---------------------------------------------
print("\n" + "-"*70)
print("Creating FAISS vector store...")
print("This may take 2-5 minutes depending on data size...")
print("-"*70)

import time
start_time = time.time()

# Create vector store from all chunks
vectorstore = FAISS.from_documents(
    documents=all_combined_chunks,
    embedding=embeddings
)

elapsed_time = time.time() - start_time

print(f"\n✓ Vector store created successfully")
print(f"  Time taken: {elapsed_time:.2f} seconds")
print(f"  Total vectors: {vectorstore.index.ntotal}")
print(f"  Clinical note vectors: ~{len(clinical_chunks)}")
print(f"  Knowledge graph vectors: ~{len(kg_chunks)}")


# ---------------------------------------------
# Test vector store with sample queries
# ---------------------------------------------
print("\n" + "-"*70)
print("TESTING VECTOR STORE")
print("-"*70)

test_queries = [
    "migraine symptoms",
    "patient 18427803-DS-5",
    "headache treatment"
]

for query in test_queries:
    print(f"\nTest query: '{query}'")
    results = vectorstore.similarity_search(query, k=3)
    print(f"  Retrieved {len(results)} documents:")
    for i, doc in enumerate(results):
        doc_type = doc.metadata.get('type', 'unknown')
        diagnosis = doc.metadata.get('diagnosis', 'N/A')
        note_id = doc.metadata.get('note_id', 'N/A')
        preview = doc.page_content[:80].replace('\n', ' ')

        if doc_type == 'clinical_note':
            print(f"    {i+1}. [Clinical] {note_id} | {diagnosis}")
        else:
            print(f"    {i+1}. [Knowledge] {diagnosis}")
        print(f"       {preview}...")

print("\n✓ Vector store is working correctly")


INITIALIZING EMBEDDINGS AND VECTOR STORE

Loading embedding model: sentence-transformers/all-MiniLM-L6-v2


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


✓ Embedding model loaded successfully
  Model: sentence-transformers/all-MiniLM-L6-v2
  Embedding dimension: 384
  Test embedding length: 384

----------------------------------------------------------------------
Creating FAISS vector store...
This may take 2-5 minutes depending on data size...
----------------------------------------------------------------------

✓ Vector store created successfully
  Time taken: 284.11 seconds
  Total vectors: 5741
  Clinical note vectors: ~5582
  Knowledge graph vectors: ~159

----------------------------------------------------------------------
TESTING VECTOR STORE
----------------------------------------------------------------------

Test query: 'migraine symptoms'
  Retrieved 3 documents:
    1. [Knowledge] Migraine
       • Migraine Without Aura     - Headache attacks often begin suddenly, with no app...
    2. [Knowledge] Migraine
       KNOWLEDGE: Migraine...
    3. [Clinical] 18805216-DS-21.json | Migraine With Aura
       NOTE_ID: 1880521

In [9]:
# =============================================================================
# INITIALIZE LANGUAGE MODEL (LLM)
# =============================================================================

from langchain_community.llms import HuggingFacePipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch

print("\n" + "="*70)
print("INITIALIZING LANGUAGE MODEL")
print("="*70)

# Model configuration
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

print(f"\nLoading model: {MODEL_NAME}")
print("This may take 3-5 minutes on first run...")
print("(Model will be cached for future runs)")

# ---------------------------------------------
# Load tokenizer
# ---------------------------------------------
print("\nStep 1: Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Set pad token if not present
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print("✓ Tokenizer loaded")


# ---------------------------------------------
# Load model
# ---------------------------------------------
print("\nStep 2: Loading model...")

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto",
    low_cpu_mem_usage=True
)

print("✓ Model loaded")
print(f"  Device: {'GPU' if torch.cuda.is_available() else 'CPU'}")
print(f"  Dtype: {model.dtype}")


# ---------------------------------------------
# Create text generation pipeline
# ---------------------------------------------
print("\nStep 3: Creating generation pipeline...")

pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=256,
    temperature=0.7,
    top_p=0.95,
    repetition_penalty=1.15,
    do_sample=True
)

print("✓ Pipeline created")


# ---------------------------------------------
# Wrap in LangChain
# ---------------------------------------------
print("\nStep 4: Wrapping in LangChain...")

llm = HuggingFacePipeline(pipeline=pipe)

print("✓ LLM ready for use")


# ---------------------------------------------
# Test LLM
# ---------------------------------------------
print("\n" + "-"*70)
print("TESTING LLM")
print("-"*70)

test_prompt = "What are the symptoms of migraine? Answer briefly:"
print(f"\nTest prompt: '{test_prompt}'")
print("Generating response...\n")

test_output = llm.invoke(test_prompt)

print(f"Response: {test_output[:200]}...")

print("\n✓ LLM is working correctly")

print("\n" + "="*70)
print("ALL COMPONENTS INITIALIZED")
print("="*70)
print(f"✓ Embeddings:    Ready ({embeddings.model_name})")
print(f"✓ Vector Store:  Ready ({vectorstore.index.ntotal} vectors)")
print(f"✓ LLM:           Ready ({MODEL_NAME})")
print(f"✓ Chunks:        Ready ({len(all_combined_chunks)} chunks)")
print("\nSystem ready for RAG pipeline creation!")


INITIALIZING LANGUAGE MODEL

Loading model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
This may take 3-5 minutes on first run...
(Model will be cached for future runs)

Step 1: Loading tokenizer...
✓ Tokenizer loaded

Step 2: Loading model...


`torch_dtype` is deprecated! Use `dtype` instead!
Device set to use cuda:0


✓ Model loaded
  Device: GPU
  Dtype: torch.float16

Step 3: Creating generation pipeline...
✓ Pipeline created

Step 4: Wrapping in LangChain...
✓ LLM ready for use

----------------------------------------------------------------------
TESTING LLM
----------------------------------------------------------------------

Test prompt: 'What are the symptoms of migraine? Answer briefly:'
Generating response...

Response: What are the symptoms of migraine? Answer briefly:
- Migraine headaches can cause severe pain and discomfort on one side of the head or face, accompanied by sensitivity to light, sound, smells, tastes...

✓ LLM is working correctly

ALL COMPONENTS INITIALIZED
✓ Embeddings:    Ready (sentence-transformers/all-MiniLM-L6-v2)
✓ Vector Store:  Ready (5741 vectors)
✓ LLM:           Ready (TinyLlama/TinyLlama-1.1B-Chat-v1.0)
✓ Chunks:        Ready (5741 chunks)

System ready for RAG pipeline creation!


In [10]:
# =============================================================================
# VERIFY ALL COMPONENTS BEFORE CREATING RETRIEVER
# =============================================================================

print("\n" + "="*70)
print("SYSTEM COMPONENT VERIFICATION")
print("="*70)

components_status = {
    'df_notes': 'df_notes' in globals() and not df_notes.empty,
    'df_kg': 'df_kg' in globals() and not df_kg.empty,
    'all_combined_chunks': 'all_combined_chunks' in globals() and len(all_combined_chunks) > 0,
    'embeddings': 'embeddings' in globals(),
    'vectorstore': 'vectorstore' in globals() and hasattr(vectorstore, 'index'),
    'llm': 'llm' in globals()
}

print("\nComponent Status:")
for component, status in components_status.items():
    status_icon = "✓" if status else "✗"
    print(f"  {status_icon} {component}: {'Ready' if status else 'MISSING'}")

if all(components_status.values()):
    print("\n" + "="*70)
    print("✓ ALL COMPONENTS READY - PROCEEDING TO RETRIEVER CREATION")
    print("="*70)

    print("\nDetailed Statistics:")
    print(f"  Clinical Notes:        {len(df_notes)}")
    print(f"  Knowledge Graph:       {len(df_kg)}")
    print(f"  Total Chunks:          {len(all_combined_chunks)}")
    print(f"  Vector Store Size:     {vectorstore.index.ntotal}")
    print(f"  Embedding Model:       {embeddings.model_name}")
    print(f"  LLM Model:             TinyLlama-1.1B-Chat")
else:
    print("\n" + "="*70)
    print("✗ ERROR: MISSING COMPONENTS")
    print("="*70)
    print("\nPlease run the previous cells to initialize:")
    for component, status in components_status.items():
        if not status:
            print(f"  - {component}")

    raise RuntimeError("Cannot proceed without all components initialized")



SYSTEM COMPONENT VERIFICATION

Component Status:
  ✓ df_notes: Ready
  ✓ df_kg: Ready
  ✓ all_combined_chunks: Ready
  ✓ embeddings: Ready
  ✓ vectorstore: Ready
  ✓ llm: Ready

✓ ALL COMPONENTS READY - PROCEEDING TO RETRIEVER CREATION

Detailed Statistics:
  Clinical Notes:        511
  Knowledge Graph:       24
  Total Chunks:          5741
  Vector Store Size:     5741
  Embedding Model:       sentence-transformers/all-MiniLM-L6-v2
  LLM Model:             TinyLlama-1.1B-Chat


before

In [11]:
# =============================================================================
# FINAL DEFINITIVE FIX - Direct Metadata Filtering (CLEAN VERSION)
# =============================================================================

print("\n" + "="*70)
print("FINAL FIX: DIRECT SEARCH BY METADATA")
print("="*70)

# =============================================================================
# STEP 1: Inspect exact metadata format
# =============================================================================

print("\nInspecting metadata format...")

# Get actual chunks for our target patient
target_chunks = [
    c for c in all_combined_chunks
    if '18427803-DS-5' in c.metadata.get('note_id', '')
]

if target_chunks:
    sample = target_chunks[0]
    print("\nTarget patient chunks found.")
    print(f"   Total chunks: {len(target_chunks)}")
    print("   Metadata format:")
    print(f"     note_id: '{sample.metadata.get('note_id')}'")
    print(f"     type: '{sample.metadata.get('type')}'")
    print(f"     diagnosis: '{sample.metadata.get('diagnosis')}'")
    print("\n   Content preview:")
    print(f"     {sample.page_content[:150]}")

    EXACT_NOTE_ID = sample.metadata.get('note_id')
    print(f"\n   Exact note_id format: '{EXACT_NOTE_ID}'")
else:
    print("ERROR: Target chunks not found in all_combined_chunks")
    EXACT_NOTE_ID = None


# =============================================================================
# STEP 2: Create Working Retriever with Direct Filtering
# =============================================================================

from langchain.schema import BaseRetriever, Document
from typing import List
import re

class WorkingRetriever(BaseRetriever):
    """
    Retriever that uses direct list filtering for guaranteed deterministic lookup.
    """
    vectorstore: object
    all_chunks: List[Document]
    k: int = 5

    class Config:
        arbitrary_types_allowed = True

    def __init__(self, vectorstore, all_chunks, k=5):
        super().__init__(
            vectorstore=vectorstore,
            all_chunks=all_chunks,
            k=k
        )

    def extract_note_id(self, query):
        """Extract note_id pattern from query."""
        pattern = r'\b(\d{8}-[A-Z]{2}-\d+)\b'
        match = re.search(pattern, query)
        return match.group(1) if match else None

    def _get_relevant_documents(self, query: str) -> List[Document]:
        """Deterministic filtering, fallback to semantic search."""
        note_id_pattern = self.extract_note_id(query)

        if note_id_pattern:
            print(f"Searching for note_id pattern: {note_id_pattern}")

            matching_chunks = [
                chunk for chunk in self.all_chunks
                if note_id_pattern in chunk.metadata.get('note_id', '')
            ]

            if matching_chunks:
                print(f"   Found {len(matching_chunks)} chunks for this patient.")

                if len(matching_chunks) > self.k:
                    ranked = self.vectorstore.similarity_search(query, k=50)
                    matching_note_ids = set([c.metadata.get('note_id') for c in matching_chunks])

                    ranked_matches = [
                        doc for doc in ranked
                        if doc.metadata.get('note_id') in matching_note_ids
                    ]

                    if ranked_matches:
                        return ranked_matches[:self.k]

                sorted_matches = sorted(
                    matching_chunks,
                    key=lambda x: x.metadata.get('chunk_index', 0)
                )
                return sorted_matches[:self.k]

            else:
                print("   No chunks found matching that patient ID.")

        # Fallback mode
        print("Using semantic search instead.")
        return self.vectorstore.similarity_search(query, k=self.k)

    async def _aget_relevant_documents(self, query: str) -> List[Document]:
        return self._get_relevant_documents(query)


# Instantiate retriever
print("\nCreating working retriever...")
working_retriever = WorkingRetriever(
    vectorstore=vectorstore,
    all_chunks=all_combined_chunks,
    k=5
)
print("Working retriever created.")


# =============================================================================
# STEP 3: Test Direct Filtering
# =============================================================================

print("\n" + "="*70)
print("TESTING DIRECT FILTERING")
print("="*70)

test_query = "What is the chief complaint for patient 18427803-DS-5?"
print(f"\nQuery: {test_query}")

retrieved_docs = working_retriever._get_relevant_documents(test_query)

print(f"\nRetrieved {len(retrieved_docs)} documents:")
for i, doc in enumerate(retrieved_docs):
    note_id = doc.metadata.get('note_id')
    chunk_idx = doc.metadata.get('chunk_index')
    preview = doc.page_content[:100].replace('\n', ' ')
    print(f"   {i+1}. {note_id} [chunk {chunk_idx}]")
    print(f"      {preview}...")

if any('18427803-DS-5' in doc.metadata.get('note_id', '') for doc in retrieved_docs):
    print("\nSUCCESS: Target patient retrieved.")
else:
    print("\nFAILURE: Target patient not retrieved.")
    for doc in retrieved_docs:
        print(f"   Retrieved: {doc.metadata.get('note_id')}")


# =============================================================================
# STEP 4: Rebuild QA Chain with Working Retriever
# =============================================================================

print("\n" + "="*70)
print("REBUILDING QA CHAIN WITH WORKING RETRIEVER")
print("="*70)

from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate

CLEAN_PROMPT = """You are a clinical assistant. Use ONLY the context below to answer.

CONTEXT:
{context}

QUESTION: {question}

ANSWER:"""

prompt = PromptTemplate(
    template=CLEAN_PROMPT,
    input_variables=["context", "question"]
)

final_qa_chain = RetrievalQA.from_chain_type(
    llm=llm,
    chain_type="stuff",
    retriever=working_retriever,
    chain_type_kwargs={"prompt": prompt},
    return_source_documents=True
)

print("Final QA chain ready.")


# =============================================================================
# STEP 5: Final Comprehensive Testing
# =============================================================================

print("\n" + "="*70)
print("FINAL COMPREHENSIVE TEST")
print("="*70)

def final_test(query, expected_pattern=None):
    print("\n" + "-"*70)
    print(f"Query: {query}")
    print("-"*70)

    result = final_qa_chain.invoke({"query": query})
    retrieved_notes = [doc.metadata.get('note_id') for doc in result['source_documents']]

    if expected_pattern:
        found = any(expected_pattern in note for note in retrieved_notes)
        status = "SUCCESS" if found else "FAILED"
        print(f"\n{status}: Expected '{expected_pattern}'")
        print(f"Retrieved: {retrieved_notes[:3]}")

    answer = result["result"].strip()
    if "ANSWER:" in answer:
        answer = answer.split("ANSWER:")[-1].strip()

    print("\nGenerated Answer:")
    print(answer[:300])

    print("\nSource Documents:")
    for i, doc in enumerate(result['source_documents'][:3]):
        note_id = doc.metadata.get('note_id')
        preview = doc.page_content[:80].replace('\n', ' ')
        print(f"   {i+1}. {note_id}")
        print(f"      {preview}...")

    return result

print("\n" + "="*70)
print("TEST 1: Target Patient (18427803-DS-5)")
final_test("What is the chief complaint for patient 18427803-DS-5?", expected_pattern="18427803-DS-5")

print("\n" + "="*70)
print("TEST 2: Detailed Patient Query")
final_test("What are all the clinical findings for patient 18427803-DS-5?", expected_pattern="18427803-DS-5")

print("\n" + "="*70)
print("TEST 3: General Diagnostic")
final_test("What are the key features of migraine with aura?")


# =============================================================================
# SYSTEM SUMMARY
# =============================================================================

print("\n" + "="*70)
print("FINAL SYSTEM STATUS")
print("="*70)

target_test = working_retriever._get_relevant_documents("patient 18427803-DS-5")
success = any('18427803-DS-5' in doc.metadata.get('note_id', '') for doc in target_test)

status = f"""
SYSTEM STATUS:
   • Vector Store: {vectorstore.index.ntotal} vectors
   • All Chunks: {len(all_combined_chunks)} available
   • Direct Filtering: {'Working' if success else 'Needs Debugging'}
   • Target Patient Retrievable: {'YES' if success else 'NO'}

{'System ready for use.' if success else 'Check chunk metadata format.'}
"""

print(status)



FINAL FIX: DIRECT SEARCH BY METADATA

Inspecting metadata format...

Target patient chunks found.
   Total chunks: 14
   Metadata format:
     note_id: '18427803-DS-5.json'
     type: 'clinical_note'
     diagnosis: 'Migraine With Aura'

   Content preview:
     NOTE_ID: 18427803-DS-5.json
DIAGNOSIS: Migraine With Aura

CHIEF COMPLAINT: Difficulty producing speech 

HISTORY OF PRESENT ILLNESS: Patient woke up 

   Exact note_id format: '18427803-DS-5.json'

Creating working retriever...
Working retriever created.

TESTING DIRECT FILTERING

Query: What is the chief complaint for patient 18427803-DS-5?
Searching for note_id pattern: 18427803-DS-5
   Found 14 chunks for this patient.

Retrieved 5 documents:
   1. 18427803-DS-5.json [chunk 0]
      NOTE_ID: 18427803-DS-5.json DIAGNOSIS: Migraine With Aura  CHIEF COMPLAINT: Difficulty producing spe...
   2. 18427803-DS-5.json [chunk 1]
      Patient's husband called ___ and talked to nurse in primary...
   3. 18427803-DS-5.json [chunk 2]
 

# Integrating the Generative Model

In [12]:
# =============================================================================
# FIXED EVALUATION METRICS FOR RAG SYSTEM
# =============================================================================

from sklearn.metrics import precision_score, recall_score
import numpy as np
import re

print("\n" + "="*70)
print("RAG SYSTEM EVALUATION - FIXED VERSION")
print("="*70)

# ---------------------------------------------
# IMPROVED Test Cases with Better Ground Truth
# ---------------------------------------------
test_cases = [
    {
        "query": "What is the chief complaint for patient 18427803-DS-5?",
        "expected": ["difficulty producing speech", "difficulty", "speech"],
        "note_id": "18427803-DS-5"
    },
    {
        "query": "What is the diagnosis for patient 18427803-DS-5?",
        "expected": ["migraine", "migraine with aura"],
        "note_id": "18427803-DS-5"
    },
    {
        "query": "What are the symptoms of migraine with aura?",
        "expected": ["aura", "headache", "visual", "neurological"],
        "note_id": None
    }
]

print("\nLoaded Test Cases:")
for i, case in enumerate(test_cases):
    print(f"   {i+1}. {case['query']}")
    print(f"      Expected: {case['expected']}")


# ---------------------------------------------
# Helper: Improved Answer Cleaning
# ---------------------------------------------
def clean_llm_response(raw_answer):
    """Remove prompt artifacts and extract clean answer"""
    answer = raw_answer.strip()

    # Remove common prompt artifacts
    remove_phrases = [
        "You are a clinical assistant",
        "Use ONLY the context below to answer",
        "CONTEXT:",
        "QUESTION:",
        "ANSWER:",
        "Direct Answer:",
        "NOTE_ID:"
    ]

    for phrase in remove_phrases:
        if phrase in answer:
            # Take content after the phrase
            parts = answer.split(phrase)
            if len(parts) > 1:
                answer = parts[-1].strip()

    # Remove repeated context snippets
    if "CHIEF COMPLAINT:" in answer and answer.index("CHIEF COMPLAINT:") > 100:
        answer = answer.split("CHIEF COMPLAINT:")[0].strip()

    # Extract just the final answer after all metadata
    lines = answer.split('\n')
    clean_lines = [line for line in lines if not line.strip().startswith('NOTE_ID:')]
    answer = '\n'.join(clean_lines).strip()

    # If answer is still too long, extract first meaningful sentence
    sentences = answer.split('.')
    if len(sentences) > 0 and len(answer) > 500:
        answer = sentences[0].strip() + '.'

    return answer


def contains_expected(pred, expected_list):
    """Check if any expected term is in prediction"""
    pred = pred.lower()
    if isinstance(expected_list, str):
        expected_list = [expected_list]

    for expected in expected_list:
        if expected.lower() in pred:
            return 1
    return 0


# ---------------------------------------------
# Improved Grounding Check
# ---------------------------------------------
def check_grounding_improved(answer, sources, threshold=0.4):
    """Check if answer is grounded in source documents"""
    # Combine all source content
    combined_source = " ".join(doc.page_content.lower() for doc in sources)

    # Extract key terms from answer (ignore common words)
    answer_lower = answer.lower()
    stop_words = {'the', 'a', 'an', 'is', 'are', 'was', 'were', 'in', 'on', 'at', 'to', 'for', 'of', 'and', 'or'}
    answer_terms = set(answer_lower.split()) - stop_words

    if not answer_terms:
        return 0

    # Check how many answer terms appear in sources
    source_terms = set(combined_source.split())
    overlap = len(answer_terms & source_terms) / len(answer_terms)

    return 1 if overlap >= threshold else 0


# ---------------------------------------------
# Retrieval Evaluation (Same as before)
# ---------------------------------------------
def evaluate_retrieval(retriever, test_cases, k=5):
    print("\n" + "="*70)
    print("EVALUATING RETRIEVAL")
    print("="*70)

    precision_vals, recall_vals, mrr_vals = [], [], []
    patient_accuracy = []

    for i, case in enumerate(test_cases):
        query = case["query"]
        expected = case["expected"]
        expected_note = case["note_id"]

        print(f"\n{i+1}. Query: {query}")

        retrieved_docs = retriever._get_relevant_documents(query)

        # Check if correct patient note found
        if expected_note:
            note_ids = [doc.metadata.get('note_id', '') for doc in retrieved_docs]
            found = any(expected_note in n for n in note_ids)
            patient_accuracy.append(int(found))
            print(f"   Note Match: {found}")

        # Check relevance
        retrieved_texts = [doc.page_content.lower() for doc in retrieved_docs]
        relevant_vector = [contains_expected(txt, expected) for txt in retrieved_texts]

        precision_vals.append(sum(relevant_vector) / len(relevant_vector))
        recall_vals.append(1 if sum(relevant_vector) > 0 else 0)

        # Reciprocal Rank
        rr = 0
        for idx, val in enumerate(relevant_vector):
            if val == 1:
                rr = 1 / (idx + 1)
                break
        mrr_vals.append(rr)

        print(f"   Precision: {precision_vals[-1]:.2f}")
        print(f"   Recall:    {recall_vals[-1]:.2f}")
        print(f"   MRR:       {rr:.2f}")

    return {
        "precision": np.mean(precision_vals),
        "recall": np.mean(recall_vals),
        "mrr": np.mean(mrr_vals),
        "patient_accuracy": np.mean(patient_accuracy) if patient_accuracy else None
    }


# ---------------------------------------------
# FIXED Generation Evaluation
# ---------------------------------------------
def evaluate_generation_fixed(qa_chain, test_cases):
    print("\n" + "="*70)
    print("EVALUATING GENERATION (FIXED)")
    print("="*70)

    accuracy_list, hallucination_list, coherence_list, citation_list = [], [], [], []

    for i, case in enumerate(test_cases):
        query = case["query"]
        expected = case["expected"]

        print(f"\n{i+1}. Query: {query}")

        result = qa_chain.invoke({"query": query})
        raw_answer = result["result"]
        sources = result["source_documents"]

        # Clean the answer
        answer = clean_llm_response(raw_answer)

        print(f"   Expected: {expected}")
        print(f"   Raw Answer (first 100 chars): {raw_answer[:100]}...")
        print(f"   Cleaned Answer: {answer[:150]}...")

        # Accuracy
        acc = contains_expected(answer, expected)
        accuracy_list.append(acc)
        print(f"   Accuracy: {acc}")

        # Improved grounding check
        grounded = check_grounding_improved(answer, sources)
        hallucination_list.append(0 if grounded else 1)
        print(f"   Grounded: {grounded}")

        # Coherence (answer should be substantial)
        coherent = len(answer.split()) > 3 and len(answer) < 1000
        coherence_list.append(int(coherent))
        print(f"   Coherent: {coherent}")

        # Citation check
        contains_expected_source = any(contains_expected(doc.page_content, expected) for doc in sources)
        citation_list.append(int(contains_expected_source))
        print(f"   Relevant Sources: {contains_expected_source}")

    return {
        "accuracy": np.mean(accuracy_list),
        "hallucination_rate": np.mean(hallucination_list),
        "coherence": np.mean(coherence_list),
        "citation_quality": np.mean(citation_list)
    }


# =============================================================================
# RUN THE FIXED EVALUATION
# =============================================================================

print("\n" + "="*70)
print("RUNNING FIXED EVALUATION")
print("="*70)

retrieval_metrics = evaluate_retrieval(working_retriever, test_cases)
generation_metrics = evaluate_generation_fixed(final_qa_chain, test_cases)


# =============================================================================
# FINAL SUMMARY
# =============================================================================

print("\n" + "="*70)
print("FIXED RESULTS SUMMARY")
print("="*70)

print("\nRetrieval:")
print(f"   Precision:        {retrieval_metrics['precision']:.2%}")
print(f"   Recall:           {retrieval_metrics['recall']:.2%}")
print(f"   MRR:              {retrieval_metrics['mrr']:.2%}")
print(f"   Patient Accuracy: {retrieval_metrics['patient_accuracy']:.2%}")

print("\nGeneration:")
print(f"   Accuracy:           {generation_metrics['accuracy']:.2%}")
print(f"   Hallucination Rate: {generation_metrics['hallucination_rate']:.2%}")
print(f"   Coherence:          {generation_metrics['coherence']:.2%}")
print(f"   Citation Quality:   {generation_metrics['citation_quality']:.2%}")

print("\n" + "="*70)
print("EVALUATION COMPLETE - FIXED VERSION")
print("="*70)


RAG SYSTEM EVALUATION - FIXED VERSION

Loaded Test Cases:
   1. What is the chief complaint for patient 18427803-DS-5?
      Expected: ['difficulty producing speech', 'difficulty', 'speech']
   2. What is the diagnosis for patient 18427803-DS-5?
      Expected: ['migraine', 'migraine with aura']
   3. What are the symptoms of migraine with aura?
      Expected: ['aura', 'headache', 'visual', 'neurological']

RUNNING FIXED EVALUATION

EVALUATING RETRIEVAL

1. Query: What is the chief complaint for patient 18427803-DS-5?
Searching for note_id pattern: 18427803-DS-5
   Found 14 chunks for this patient.
   Note Match: True
   Precision: 0.20
   Recall:    1.00
   MRR:       1.00

2. Query: What is the diagnosis for patient 18427803-DS-5?
Searching for note_id pattern: 18427803-DS-5
   Found 14 chunks for this patient.
   Note Match: True
   Precision: 0.20
   Recall:    1.00
   MRR:       1.00

3. Query: What are the symptoms of migraine with aura?
Using semantic search instead.
   Precis

# Gradio


Certainly! It sounds like you're looking for a clearer picture of how all these components work together to perform RAG (Retrieval-Augmented Generation). Let me break it down for you:

Understanding RAG: Retrieval-Augmented Generation
RAG is a powerful technique that combines the strengths of information retrieval systems with the generative capabilities of large language models (LLMs). It helps LLMs provide more accurate, up-to-date, and attributable answers by grounding their responses in specific, verifiable sources.

Here's how the RAG process works in this notebook:

Retrieval: When you ask a question, the system first looks up relevant information from a predefined knowledge base. In this case, our knowledge base is composed of clinical notes.

How it works: Your question is converted into a numerical representation (an "embedding"). This embedding is then used to quickly search through a database of pre-computed embeddings of our clinical note chunks. The goal is to find the chunks that are most semantically similar to your question.
Augmentation: The pieces of information (the most relevant clinical note chunks) found during the retrieval step are then added to your original question.

How it works: These retrieved chunks become the "context" for the language model. Instead of just answering based on its general training data, the LLM is explicitly given the specific clinical notes that might contain the answer.
Generation: Finally, the augmented prompt (your question + the retrieved context) is fed to the Large Language Model.

How it works: The LLM (TinyLlama in this case) uses this provided context, along with its inherent understanding of language and medical concepts, to formulate a concise and relevant answer. It's instructed to only use the provided context for its answer, and to state if the information isn't available.
Key Models and Libraries Explained:
TinyLlama/TinyLlama-1.1B-Chat-v1.0 (LLM): This is our Generative Model. It's a smaller, efficient Large Language Model that takes the question and the retrieved context as input and generates the human-readable answer. Think of it as the brain that synthesizes information.

sentence-transformers/all-MiniLM-L6-v2 (Embedding Model): This is our Embedding Model. Its job is to convert text (both our clinical note chunks and your questions) into dense numerical vectors. These vectors capture the semantic meaning of the text, allowing us to find similar pieces of information very quickly. If two pieces of text have similar meanings, their embeddings will be numerically close.

FAISS (Vector Store): This is our Vector Store (or vector database). After the clinical note chunks are converted into embeddings, FAISS is used to store and efficiently search these embeddings. When a query embedding comes in, FAISS can rapidly find the k (e.g., 3) most similar document embeddings in its index.

LangChain: This is an Orchestration Framework. It doesn't perform RAG itself, but it provides a structured way to connect all the different components (LLMs, embedding models, vector stores, prompt templates) into a cohesive "chain" or "pipeline" that performs the RAG process automatically. It simplifies building complex LLM applications.

End-to-End RAG Flow (Data Preprocessing to Output):
Data Ingestion & Parsing: Raw clinical notes (JSON files from the unRARred archives) are loaded. We specifically extracted input1 through input6 fields to construct the full_text for each note.

Text Splitting (Chunking): Long full_text clinical notes are broken down into smaller, overlapping chunks (using RecursiveCharacterTextSplitter). Each chunk is easier for the embedding model to process and more granular for retrieval. Metadata like note_id and diagnosis are attached to each chunk.

Embedding Generation: Each chunk is fed into the all-MiniLM-L6-v2 embedding model, which converts it into a numerical vector (embedding).

Vector Store Indexing: These embeddings, along with their original text content and metadata, are stored in a FAISS index. This index is optimized for fast similarity searches.

User Query: You type a question into the Gradio chat interface.

Query Embedding: Your question is also converted into a numerical embedding using the same all-MiniLM-L6-v2 embedding model.

Retrieval: The query's embedding is used to search the FAISS vector store for the k (e.g., 3) most relevant clinical note chunks.

Context Augmentation: The retrieved chunks are formatted and inserted into a PromptTemplate as the CONTEXT for the LLM.

LLM Generation: The complete prompt (including the question and context) is sent to the TinyLlama LLM. The LLM generates an answer, instructed to use only the provided context.

Output: The LLM's generated answer, along with the source documents it used, is displayed in the Gradio chat interface.

This entire process ensures that the TinyLlama model provides answers that are not only coherent but also directly supported by the provided clinical data, minimizing hallucinations and improving trustworthiness. The Gradio interface further allows you to dynamically add more notes, which immediately become part of this searchable knowledge base.

To make this explanation a permanent part of your notebook, I can add a new markdown cell outlining this process. Would you like me to do that?

Add RAG Explanation to Notebook: Add a new markdown cell in the notebook that explains the RAG process, the role of each model and library (LLM, Embeddings, FAISS, LangChain), and the end-to-end data preprocessing to output flow.