# ü©∫ Aura-Med: Pediatric Respiratory Triage

**Problem:** Pneumonia is the #1 infectious killer of children worldwide.  
**Solution:** Fast, offline, audible-to-actionable triage using **HeAR** and **MedGemma**.  

This notebook demonstrates the end-to-end clinical journey ‚Äî from audio input to actionable WHO IMCI treatment recommendations ‚Äî using **real model inference** on Google's Health AI Developer Foundations (HAI-DEF).

### Models Used
| Model | Role | Source |
|---|---|---|
| **HeAR** | Bioacoustic encoder (cough ‚Üí 512-dim embedding) | HuggingFace (`google/hear`) |
| **MedGemma 4B-IT** | Clinical reasoning LLM (embedding + vitals ‚Üí triage) | HuggingFace (`google/medgemma-4b-it`) |

In [None]:
# Cell 1: Setup & Dependencies
%pip install -q torch transformers>=4.50.0 librosa pydantic pandas psutil \
    tensorflow>=2.15.0 huggingface_hub bitsandbytes accelerate soundfile opendatasets

import sys
import os

# Ensure src/ is importable
REPO_ROOT = os.path.abspath(os.path.join(os.getcwd(), '..'))
if os.path.exists(os.path.join(REPO_ROOT, 'src')):
    sys.path.insert(0, REPO_ROOT)
elif os.path.exists(os.path.join(os.getcwd(), 'src')):
    sys.path.insert(0, os.getcwd())
    REPO_ROOT = os.getcwd()
else:
    print('‚ö†Ô∏è Could not find src/ directory. Please run from the repo root.')

os.chdir(REPO_ROOT)
print(f'Working directory: {os.getcwd()}')

In [None]:
# Cell 2: HuggingFace Authentication (for gated MedGemma model)
from huggingface_hub import login

# Option A: Set your token directly (for quick testing)
# login(token='hf_YOUR_TOKEN_HERE')

# Option B: Use Colab secrets (recommended for submission)
try:
    from google.colab import userdata
    hf_token = userdata.get('HF_TOKEN')
    login(token=hf_token)
    print('‚úÖ Authenticated with HuggingFace via Colab Secrets')
except Exception:
    print('‚ö†Ô∏è No Colab secret found. Trying cached credentials...')
    try:
        login()
        print('‚úÖ Using cached HuggingFace credentials')
    except Exception:
        print('‚ùå No HF credentials. MedGemma (gated model) may fail to download.')

In [None]:
# Cell 3: Hardware Check & Environment
import torch
import psutil

def check_hardware():
    print('--- Hardware Environment ---')
    if torch.cuda.is_available():
        vram = torch.cuda.get_device_properties(0).total_mem / 1e9
        print(f'‚úÖ GPU Available: {torch.cuda.get_device_name(0)} ({vram:.1f} GB VRAM)')
        print(f'   CUDA Version: {torch.version.cuda}')
    else:
        print('‚ö†Ô∏è GPU Not Available ‚Äî Demo will use mock inference')
    
    total_ram = psutil.virtual_memory().total / (1024**3)
    print(f'‚úÖ Total System RAM: {total_ram:.1f} GB')
    print(f'üì¶ PyTorch: {torch.__version__}')

check_hardware()

In [None]:
# Cell 4: Load Real Models
from src.agent.core import AuraMedAgent
from src.datatypes import PatientVitals, TriageResult, TriageStatus
from src.demo.scenarios import DemoScenarios
from src.visualization.renderer import NotebookRenderer
from src.utils.latency_tracker import LatencyTracker
from src.config import IS_DEMO_MODE, HEAR_EMBEDDING_DIM, MEDGEMMA_MODEL_PATH

print(f'Demo Mode: {IS_DEMO_MODE}')
print(f'HeAR Embedding Dim: {HEAR_EMBEDDING_DIM}')
print(f'MedGemma Model: {MEDGEMMA_MODEL_PATH}')
print()

print('üîÑ Initializing AuraMed Agent (loading models)...')
agent = AuraMedAgent()
renderer = NotebookRenderer()
tracker = LatencyTracker()
print()
print('‚úÖ AuraMed Agent ready for inference.')

---
## üè• Part A ‚Äî Pre-configured Clinical Journeys

Three clinical flows demonstrating core capabilities:

In [None]:
# Cell 5: Journey 1 ‚Äî Clinical Success (Pneumonia Detection)
print('‚ïê' * 60)
print('JOURNEY 1: Clinical Success ‚Äî Pneumonia Triage')
print('‚ïê' * 60)

audio_path_1, vitals_1, label_1 = DemoScenarios.get_journey_1_success()
print(f'Patient: {vitals_1.age_months} months old, RR: {vitals_1.respiratory_rate} bpm')
print(f'Audio: {audio_path_1}')
print()

result_1 = agent.predict(audio_path_1, vitals_1)
tracker.record(label_1, result_1)

print(f'Model Reasoning: {result_1.reasoning}')
display(renderer.render(result_1))

In [None]:
# Cell 6: Journey 2 ‚Äî Emergency Safety Override
print('‚ïê' * 60)
print('JOURNEY 2: Emergency Override ‚Äî Danger Signs Detected')
print('‚ïê' * 60)

audio_path_2, vitals_2, label_2 = DemoScenarios.get_journey_2_emergency()
print(f'Patient: {vitals_2.age_months} months old, Danger Signs: {vitals_2.danger_signs}')
print(f'‚ö° Safety guard should intercept BEFORE model inference')
print()

result_2 = agent.predict(audio_path_2, vitals_2)
tracker.record(label_2, result_2)

print(f'Result: {result_2.reasoning}')
display(renderer.render(result_2))

In [None]:
# Cell 7: Journey 3 ‚Äî Audio Quality Gate (Inconclusive)
print('‚ïê' * 60)
print('JOURNEY 3: Inconclusive ‚Äî Low Quality Audio')
print('‚ïê' * 60)

audio_path_3, vitals_3, label_3 = DemoScenarios.get_journey_3_inconclusive()
print(f'Patient: {vitals_3.age_months} months old, RR: {vitals_3.respiratory_rate} bpm')
print(f'Audio: {audio_path_3} (near-silent ‚Äî should trigger quality gate)')
print()

result_3 = agent.predict(audio_path_3, vitals_3)
tracker.record(label_3, result_3)

print(f'Result: {result_3.reasoning}')
display(renderer.render(result_3))

---
## üé§ Part B ‚Äî Interactive: Upload Your Own Recording

Upload a `.wav` audio recording (e.g., a cough sound) and enter patient vitals to see the full AI triage pipeline in action.

In [None]:
# Cell 8: Upload Audio File
from IPython.display import display, Audio, HTML
from google.colab import files
import os

print('üìÅ Upload a .wav audio recording (cough, breathing, etc.):')
print('   Recommended: 16kHz mono, 2-10 seconds long')
print()

uploaded = files.upload()

if uploaded:
    uploaded_filename = list(uploaded.keys())[0]
    uploaded_path = os.path.join(os.getcwd(), uploaded_filename)
    print(f'\n‚úÖ Uploaded: {uploaded_filename} ({len(uploaded[uploaded_filename])} bytes)')
    
    # Play the uploaded audio
    print('\nüîä Preview:')
    display(Audio(uploaded_path))
else:
    uploaded_path = None
    print('‚ö†Ô∏è No file uploaded. Please upload a .wav file and re-run this cell.')

In [None]:
# Cell 9: Enter Patient Vitals
print('üìã Enter patient vitals for triage assessment:')
print('‚îÄ' * 40)

#@markdown ### Patient Information
age_months = 7  #@param {type:"integer"}
respiratory_rate = 52  #@param {type:"integer"}
danger_signs = False  #@param {type:"boolean"}

custom_vitals = PatientVitals(
    age_months=age_months,
    respiratory_rate=respiratory_rate,
    danger_signs=danger_signs
)

# WHO IMCI respiratory rate thresholds
if age_months < 2:
    threshold = 60
elif age_months < 12:
    threshold = 50
else:
    threshold = 40

rr_status = '‚ö†Ô∏è FAST' if respiratory_rate >= threshold else '‚úÖ Normal'

print(f'  Age: {age_months} months')
print(f'  Respiratory Rate: {respiratory_rate} bpm ({rr_status}, threshold: {threshold})')
print(f'  Danger Signs: {"üî¥ YES" if danger_signs else "üü¢ No"}')

In [None]:
# Cell 10: Run Triage on Uploaded Audio
print('‚ïê' * 60)
print('INTERACTIVE JOURNEY: Your Uploaded Recording')
print('‚ïê' * 60)

if uploaded_path and os.path.exists(uploaded_path):
    print(f'Audio: {uploaded_filename}')
    print(f'Patient: {custom_vitals.age_months}mo, RR={custom_vitals.respiratory_rate}, '
          f'Danger Signs={custom_vitals.danger_signs}')
    print()
    
    try:
        result_custom = agent.predict(uploaded_path, custom_vitals)
        tracker.record('Interactive: ' + uploaded_filename, result_custom)
        
        print(f'Model Reasoning: {result_custom.reasoning}')
        display(renderer.render(result_custom))
    except Exception as e:
        print(f'‚ùå Error during triage: {e}')
        print('   Tip: Ensure the file is a valid .wav audio file (16kHz mono recommended).')
else:
    print('‚ö†Ô∏è No audio file found. Please run Cell 8 first to upload a recording.')

---
## üìä Part C ‚Äî Dataset Validation (ICBHI 2017)

Validate Aura-Med against the **ICBHI 2017 Respiratory Sound Database** ‚Äî a gold-standard medical dataset with doctor-confirmed diagnoses.

### Step 1: Download the dataset to Google Drive

In [None]:
# Cell 11: Mount Google Drive & Download ICBHI Dataset
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

# Create dataset directory in Drive
ICBHI_DIR = '/content/drive/MyDrive/aura-med/data/icbhi'
os.makedirs(ICBHI_DIR, exist_ok=True)

# Check if already downloaded
audio_dir = os.path.join(ICBHI_DIR, 'audio_and_txt_files')
diagnosis_file = os.path.join(ICBHI_DIR, 'patient_diagnosis.csv')

if os.path.exists(diagnosis_file) and os.path.exists(audio_dir):
    n_wav = len([f for f in os.listdir(audio_dir) if f.endswith('.wav')])
    print(f'‚úÖ ICBHI dataset already downloaded! ({n_wav} audio files found)')
    print(f'   Location: {ICBHI_DIR}')
else:
    print('‚¨áÔ∏è Downloading ICBHI 2017 dataset from Kaggle...')
    print()
    print('üìã INSTRUCTIONS:')
    print('   You need a Kaggle account to download. Two options:')
    print()
    print('   OPTION A ‚Äî Automatic (Kaggle API):')
    print('   1. Go to kaggle.com ‚Üí Your Profile ‚Üí Settings ‚Üí API ‚Üí Create New Token')
    print('   2. Upload the kaggle.json file when prompted below')
    print()
    
    # Try automatic download via opendatasets
    try:
        import opendatasets as od
        od.download(
            'https://www.kaggle.com/datasets/vbookshelf/respiratory-sound-database',
            data_dir='/content/drive/MyDrive/aura-med/data'
        )
        
        # opendatasets saves to a subfolder ‚Äî move contents to our expected path
        kaggle_dir = '/content/drive/MyDrive/aura-med/data/respiratory-sound-database'
        if os.path.exists(kaggle_dir):
            import shutil
            # Move contents into icbhi/
            for item in os.listdir(kaggle_dir):
                src = os.path.join(kaggle_dir, item)
                dst = os.path.join(ICBHI_DIR, item)
                if not os.path.exists(dst):
                    shutil.move(src, dst)
            print(f'\n‚úÖ Dataset downloaded and organized at: {ICBHI_DIR}')
        
    except Exception as e:
        print(f'\n‚ö†Ô∏è Automatic download failed: {e}')
        print()
        print('   OPTION B ‚Äî Manual Download:')
        print('   1. Go to: https://www.kaggle.com/datasets/vbookshelf/respiratory-sound-database')
        print('   2. Click "Download" (sign in if needed)')
        print('   3. Extract the ZIP file')
        print('   4. Upload the extracted folder to Google Drive at:')
        print(f'      {ICBHI_DIR}/')
        print('   5. Ensure this structure exists:')
        print(f'      {ICBHI_DIR}/audio_and_txt_files/*.wav')
        print(f'      {ICBHI_DIR}/patient_diagnosis.csv')

In [None]:
# Cell 12: Load Dataset & Show Distribution
from src.data.icbhi_loader import ICBHIDataset
import pandas as pd

ICBHI_DIR = '/content/drive/MyDrive/aura-med/data/icbhi'
dataset = ICBHIDataset(ICBHI_DIR)

# Print summary
print(dataset.summary())
print()

# Show as a table
counts = dataset.get_diagnosis_counts()
df_counts = pd.DataFrame([
    {'Diagnosis': diag, 'Audio Files': count, 
     'Expected Triage': dataset.samples_by_diagnosis and 
     ICBHIDataset.__module__ and 
     ('YELLOW' if diag in ['Pneumonia','LRTI','Bronchiolitis','COPD','Bronchiectasis','Asthma'] else 'GREEN')}
    for diag, count in counts.items()
])
display(df_counts)

### Step 2: Run Batch Validation

Select how many samples to test and which diagnosis to focus on.

In [None]:
# Cell 13: Batch Validation Configuration
#@markdown ### Validation Settings
num_samples = 5  #@param {type:"integer"}
target_diagnosis = "Pneumonia"  #@param ["Pneumonia", "LRTI", "COPD", "URTI", "Healthy", "Bronchiectasis", "Bronchiolitis", "Asthma", "All"] {allow-input: true}

if target_diagnosis == 'All':
    target_diagnosis = None

print(f'Validation config: {num_samples} samples'
      f'{" from " + target_diagnosis if target_diagnosis else " (all diagnoses)"}')

In [None]:
# Cell 14: Run Validation & Show Results
import pandas as pd
from IPython.display import display, HTML

print('‚ïê' * 60)
print('DATASET VALIDATION: ICBHI 2017')
print('‚ïê' * 60)

samples = dataset.get_samples(n=num_samples, diagnosis=target_diagnosis)
print(f'Running {len(samples)} samples through AuraMed pipeline...\n')

results = []
correct = 0
total = 0

for i, sample in enumerate(samples, 1):
    print(f'[{i}/{len(samples)}] Patient {sample.patient_id} ‚Äî {sample.diagnosis}')
    try:
        result = agent.predict(sample.audio_path, sample.vitals)
        
        match = '‚úÖ' if result.status == sample.expected_triage else '‚ùå'
        if result.status == sample.expected_triage:
            correct += 1
        total += 1
        
        results.append({
            'Patient': sample.patient_id,
            'Diagnosis': sample.diagnosis,
            'Expected': sample.expected_triage.value,
            'Predicted': result.status.value,
            'Confidence': f'{result.confidence:.2f}',
            'Match': match,
            'Reasoning': result.reasoning[:80] + '...' if len(result.reasoning) > 80 else result.reasoning
        })
        
        latency = result.usage_stats.get('latency_sec', 0) if result.usage_stats else 0
        print(f'   {match} Expected: {sample.expected_triage.value}, '
              f'Got: {result.status.value} (conf: {result.confidence:.2f}, {latency:.1f}s)')
        
    except Exception as e:
        print(f'   ‚ö†Ô∏è Error: {e}')
        total += 1
        results.append({
            'Patient': sample.patient_id,
            'Diagnosis': sample.diagnosis,
            'Expected': sample.expected_triage.value,
            'Predicted': 'ERROR',
            'Confidence': '-',
            'Match': '‚ö†Ô∏è',
            'Reasoning': str(e)[:80]
        })

# Summary table
print()
print('‚ïê' * 60)
accuracy = (correct / total * 100) if total > 0 else 0
print(f'ACCURACY: {correct}/{total} ({accuracy:.1f}%)')
print('‚ïê' * 60)

df_results = pd.DataFrame(results)
display(df_results)

---
## üìä Performance Telemetry

In [None]:
# Cell 15: Performance Telemetry Summary
print('‚ïê' * 60)
print('PERFORMANCE TELEMETRY')
print('‚ïê' * 60)

display(tracker.generate_summary_table())

total = tracker.get_total_runtime()
print(f'\nTotal pipeline runtime across all journeys: {total:.3f}s')

if torch.cuda.is_available():
    vram_used = torch.cuda.max_memory_allocated() / 1e9
    print(f'Peak GPU VRAM usage: {vram_used:.2f} GB')

## Summary

This demonstration shows Aura-Med's complete clinical pipeline:

- **HeAR** extracts bioacoustic embeddings from cough audio (512-dim)
- **MedGemma 4B** performs WHO IMCI-aligned clinical reasoning
- **Safety Guard** provides rule-based override for emergency danger signs
- **Quality Gate** rejects low-quality audio before it reaches the AI
- **ICBHI Validation** tests accuracy against real medical diagnoses

All inference runs within edge-deployment constraints (< 4 GB RAM, < 10s latency).