# TremorTrace — Complete ML Training Pipeline

**AI-powered Parkinson's disease screening using spiral and wave drawing analysis.**

This notebook trains 2 MobileNetV2 CNNs (Spiral + Wave) using two-phase transfer learning,
evaluates them comprehensively, and exports a production-ready inference package.

**Models:** Spiral CNN (50%) + Wave CNN (50%) → Weighted Ensemble → Probability Percentage (0–100%)

**Runtime:** Google Colab with T4 GPU

## Cell 1: Environment Setup

In [None]:
"""Environment setup: mount Drive, install deps, verify GPU, set seeds."""
import time
cell_start = time.time()

# Mount Google Drive for persistent storage
from google.colab import drive
drive.mount('/content/drive')

# Install requirements
!pip install -q opendatasets scikit-learn seaborn opencv-python-headless

# Core imports
import os
import sys
import random
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

# Set up project paths
PROJECT_DIR = '/content/drive/MyDrive/TremorTrace'
os.makedirs(PROJECT_DIR, exist_ok=True)

SAVE_DIR = os.path.join(PROJECT_DIR, 'outputs')
os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(os.path.join(SAVE_DIR, 'plots'), exist_ok=True)

# Clone or upload src/ to Colab — adjust this path to where your src/ lives
# Option A: If src/ is in Google Drive
SRC_DIR = os.path.join(PROJECT_DIR, 'src')
if os.path.isdir(SRC_DIR):
    sys.path.insert(0, os.path.dirname(SRC_DIR))
    print(f'\u2705 src/ found at {SRC_DIR}')
else:
    # Option B: Upload src/ files to /content/src/
    SRC_DIR = '/content/src'
    os.makedirs(SRC_DIR, exist_ok=True)
    sys.path.insert(0, '/content')
    print(f'\u26a0\ufe0f  src/ not in Drive. Upload src/ files to {SRC_DIR}')

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

# Verify GPU
gpus = tf.config.list_physical_devices('GPU')
print(f'\nTensorFlow version: {tf.__version__}')
if gpus:
    print(f'\u2705 GPU available: {gpus[0].name}')
    # Prevent TF from allocating all GPU memory at once
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
else:
    print('\u26a0\ufe0f  No GPU detected! Training will be slow.')
    print('   Go to Runtime > Change runtime type > Hardware accelerator > T4 GPU')

print(f'\n\u2705 Environment setup complete ({time.time() - cell_start:.1f}s)')

## Cell 2: Download Dataset

In [None]:
"""Download Parkinson's Drawings dataset from Kaggle."""
cell_start = time.time()

import opendatasets as od

# Download dataset (will prompt for Kaggle credentials on first run)
DATASET_URL = 'https://www.kaggle.com/datasets/kmader/parkinsons-drawings'
DOWNLOAD_DIR = '/content'
od.download(DATASET_URL, data_dir=DOWNLOAD_DIR)

# Set data paths
DATA_ROOT = os.path.join(DOWNLOAD_DIR, 'parkinsons-drawings')

# Update config
from src import config
config.DATA_ROOT = DATA_ROOT
config.PROJECT_DIR = PROJECT_DIR

# Verify dataset structure
assert os.path.isdir(DATA_ROOT), f'Dataset not found at {DATA_ROOT}'
print(f'\n\u2705 Dataset downloaded to: {DATA_ROOT}')
print(f'   Contents: {os.listdir(DATA_ROOT)}')
print(f'   ({time.time() - cell_start:.1f}s)')

## Cell 3: Explore Dataset

In [None]:
"""Explore dataset structure: count images per class per split."""

# Define paths for each drawing type
PATHS = {}
for drawing_type in ['spiral', 'wave']:
    PATHS[drawing_type] = {
        'train': os.path.join(DATA_ROOT, drawing_type, 'training'),
        'test': os.path.join(DATA_ROOT, drawing_type, 'testing'),
    }

# Count images per class per split
print('=' * 65)
print(f'{"Drawing":>10} | {"Split":>8} | {"Healthy":>8} | {"Parkinson":>10} | {"Total":>6}')
print('=' * 65)

total_images = 0
for drawing_type in ['spiral', 'wave']:
    for split_name, split_path in PATHS[drawing_type].items():
        healthy_dir = os.path.join(split_path, 'healthy')
        parkinson_dir = os.path.join(split_path, 'parkinson')

        n_healthy = len([f for f in os.listdir(healthy_dir)
                         if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
        n_parkinson = len([f for f in os.listdir(parkinson_dir)
                           if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
        n_total = n_healthy + n_parkinson
        total_images += n_total

        print(f'{drawing_type:>10} | {split_name:>8} | {n_healthy:>8} | {n_parkinson:>10} | {n_total:>6}')

print('=' * 65)
print(f'{"TOTAL":>10} | {"":>8} | {"":>8} | {"":>10} | {total_images:>6}')
print(f'\n\u2705 Dataset exploration complete — {total_images} total images')

## Cell 4: Visualize Sample Images

In [None]:
"""Visualize sample images: 4x5 grid of spiral/wave × healthy/parkinson."""
from PIL import Image as PILImage

categories = [
    ('spiral', 'healthy',   'Spiral — Healthy'),
    ('spiral', 'parkinson', 'Spiral — Parkinson'),
    ('wave',   'healthy',   'Wave — Healthy'),
    ('wave',   'parkinson', 'Wave — Parkinson'),
]

n_samples = 5
fig, axes = plt.subplots(4, n_samples, figsize=(20, 16))
fig.suptitle('Dataset Samples', fontsize=20, fontweight='bold')

for row, (drawing_type, class_name, row_label) in enumerate(categories):
    img_dir = os.path.join(PATHS[drawing_type]['train'], class_name)
    img_files = sorted([f for f in os.listdir(img_dir)
                        if f.lower().endswith(('.png', '.jpg', '.jpeg'))])[:n_samples]

    for col, img_file in enumerate(img_files):
        img = PILImage.open(os.path.join(img_dir, img_file)).convert('RGB')
        axes[row, col].imshow(img)
        axes[row, col].axis('off')
        if col == 0:
            axes[row, col].set_ylabel(row_label, fontsize=12, fontweight='bold',
                                       rotation=0, labelpad=120, va='center')

plt.tight_layout()
plt.savefig(os.path.join(SAVE_DIR, 'plots', 'dataset_samples.png'), dpi=150, bbox_inches='tight')
plt.show()
print('\u2705 Sample visualization complete')

## Cell 5: Create Data Generators

In [None]:
"""Create train/val/test generators for both spiral and wave."""
from src.data_pipeline import create_generators

print('Creating Spiral generators...')
spiral_train_gen, spiral_val_gen, spiral_test_gen = create_generators(
    'spiral',
    PATHS['spiral']['train'],
    PATHS['spiral']['test'],
)

print('\nCreating Wave generators...')
wave_train_gen, wave_val_gen, wave_test_gen = create_generators(
    'wave',
    PATHS['wave']['train'],
    PATHS['wave']['test'],
)

print('\n\u2705 All data generators created')

## Cell 6: Verify Augmentation

In [None]:
"""Visualize augmentation effects on spiral and wave images."""
from src.data_pipeline import get_augmentation_config
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array

n_augmented = 6
fig, axes = plt.subplots(2, n_augmented + 1, figsize=(20, 6))
fig.suptitle('Augmentation Verification', fontsize=18, fontweight='bold')

for row, drawing_type in enumerate(['spiral', 'wave']):
    # Pick the first training image
    class_dir = os.path.join(PATHS[drawing_type]['train'], 'healthy')
    img_file = sorted(os.listdir(class_dir))[0]
    img_path = os.path.join(class_dir, img_file)

    img = load_img(img_path, target_size=(224, 224))
    img_array = img_to_array(img) / 255.0

    # Show original
    axes[row, 0].imshow(img_array)
    axes[row, 0].set_title('Original', fontsize=10, fontweight='bold')
    axes[row, 0].set_ylabel(f'{drawing_type.capitalize()}', fontsize=12, fontweight='bold')
    axes[row, 0].axis('off')

    # Generate augmented versions
    aug_config = get_augmentation_config(drawing_type)
    aug_config_no_rescale = {k: v for k, v in aug_config.items() if k != 'rescale'}
    datagen = ImageDataGenerator(**aug_config_no_rescale)

    img_batch = img_array[np.newaxis, ...]  # Add batch dim
    aug_iter = datagen.flow(img_batch, batch_size=1, seed=None)

    for col in range(1, n_augmented + 1):
        aug_img = next(aug_iter)[0]
        aug_img = np.clip(aug_img, 0, 1)
        axes[row, col].imshow(aug_img)
        axes[row, col].set_title(f'Aug {col}', fontsize=10)
        axes[row, col].axis('off')

plt.tight_layout()
plt.savefig(os.path.join(SAVE_DIR, 'plots', 'augmentation_verification.png'), dpi=150, bbox_inches='tight')
plt.show()
print('\u2705 Augmentation verification complete')
print('   Spiral: 360° rotation (rotationally invariant)')
print('   Wave: 15° rotation + horizontal flip (orientation-sensitive)')

## Cell 7: Build Models

In [None]:
"""Build MobileNetV2 models for spiral and wave classification."""
from src.model_builder import build_model

print('Building Spiral CNN...')
spiral_model, spiral_base = build_model('spiral')

print('\nBuilding Wave CNN...')
wave_model, wave_base = build_model('wave')

# Print architecture summary (both models are identical in structure)
print('\n' + '=' * 60)
print('Model Architecture (identical for both):')
print('=' * 60)
spiral_model.summary()

print('\n\u2705 Both models built and compiled')

## Cell 8: Train Spiral CNN

In [None]:
"""Train Spiral CNN with two-phase transfer learning."""
from src.trainer import train_two_phase

spiral_history = train_two_phase(
    model=spiral_model,
    base_model=spiral_base,
    train_gen=spiral_train_gen,
    val_gen=spiral_val_gen,
    name='spiral',
    save_dir=SAVE_DIR,
)

## Cell 9: Train Wave CNN

In [None]:
"""Train Wave CNN with two-phase transfer learning."""

wave_history = train_two_phase(
    model=wave_model,
    base_model=wave_base,
    train_gen=wave_train_gen,
    val_gen=wave_val_gen,
    name='wave',
    save_dir=SAVE_DIR,
)

## Cell 10: Plot Training History

In [None]:
"""Visualize training curves for both models."""
from src.evaluator import plot_training_history

print('Spiral CNN Training History:')
plot_training_history(spiral_history, 'Spiral CNN', SAVE_DIR)

print('\nWave CNN Training History:')
plot_training_history(wave_history, 'Wave CNN', SAVE_DIR)

## Cell 11: Evaluate Both CNNs

In [None]:
"""Full evaluation suite for both models."""
from src.evaluator import evaluate_model, plot_prediction_distribution
from sklearn.metrics import precision_score, recall_score

# Evaluate Spiral CNN
print('\n' + '=' * 60)
print('EVALUATING SPIRAL CNN')
print('=' * 60)
spiral_pred, spiral_true, spiral_acc, spiral_auc = evaluate_model(
    spiral_model, spiral_test_gen, 'Spiral CNN', SAVE_DIR
)
plot_prediction_distribution(spiral_pred, spiral_true, 'Spiral CNN', SAVE_DIR)

# Evaluate Wave CNN
print('\n' + '=' * 60)
print('EVALUATING WAVE CNN')
print('=' * 60)
wave_pred, wave_true, wave_acc, wave_auc = evaluate_model(
    wave_model, wave_test_gen, 'Wave CNN', SAVE_DIR
)
plot_prediction_distribution(wave_pred, wave_true, 'Wave CNN', SAVE_DIR)

# Compute additional metrics for comparison
spiral_prec = precision_score(spiral_true, (spiral_pred > 0.5).astype(int), zero_division=0)
spiral_rec = recall_score(spiral_true, (spiral_pred > 0.5).astype(int), zero_division=0)
wave_prec = precision_score(wave_true, (wave_pred > 0.5).astype(int), zero_division=0)
wave_rec = recall_score(wave_true, (wave_pred > 0.5).astype(int), zero_division=0)

# Store metrics for later use
all_metrics = {
    'Spiral CNN': {
        'accuracy': spiral_acc,
        'auc': spiral_auc,
        'precision': spiral_prec,
        'recall': spiral_rec,
    },
    'Wave CNN': {
        'accuracy': wave_acc,
        'auc': wave_auc,
        'precision': wave_prec,
        'recall': wave_rec,
    },
}

print('\n\u2705 Both models evaluated')

## Cell 12: Sample Predictions with Percentages

In [None]:
"""Visual grid showing predictions as percentages with risk tiers."""
from src.evaluator import show_sample_predictions

print('Spiral CNN — Sample Predictions:')
show_sample_predictions(spiral_model, spiral_test_gen, 'Spiral CNN', SAVE_DIR)

print('\nWave CNN — Sample Predictions:')
show_sample_predictions(wave_model, wave_test_gen, 'Wave CNN', SAVE_DIR)

## Cell 13: Grad-CAM Visualization

In [None]:
"""Generate Grad-CAM heatmaps showing what each model focuses on."""
from src.gradcam import visualize_gradcam_grid

print('Spiral CNN — Grad-CAM:')
visualize_gradcam_grid(spiral_model, spiral_test_gen, 'Spiral CNN', SAVE_DIR, n_samples=8)

print('\nWave CNN — Grad-CAM:')
visualize_gradcam_grid(wave_model, wave_test_gen, 'Wave CNN', SAVE_DIR, n_samples=8)

## Cell 14: Test Ensemble

In [None]:
"""Test the 2-model ensemble on matched spiral/wave test samples."""
from src.ensemble import ensemble_predict
import json

# Get predictions from both models on their respective test sets
# Since spiral and wave test sets have the same number of samples in matching order,
# we can pair them for ensemble testing
n_ensemble_samples = min(len(spiral_pred), len(wave_pred))

print('=' * 80)
print(f'{"#":>3} | {"Spiral %":>9} | {"Wave %":>7} | {"Ensemble %":>10} | {"Risk Tier":>15} | {"Agree":>5}')
print('=' * 80)

ensemble_results = []
for i in range(n_ensemble_samples):
    result = ensemble_predict(
        spiral_cnn_prob=float(spiral_pred[i]),
        wave_cnn_prob=float(wave_pred[i]),
        input_mode='drawn',
    )
    ensemble_results.append(result)

    print(f'{i+1:>3} | {result["spiral_cnn_percent"]:>8.1f}% | {result["wave_cnn_percent"]:>6.1f}% | '
          f'{result["pd_probability_percent"]:>9.1f}% | {result["risk_tier"]:>15} | '
          f'{"Yes" if result["unanimous"] else "No":>5}')

# Show a few full result dictionaries
print('\n' + '=' * 80)
print('Sample Full Output (first 3):')
print('=' * 80)
for i in range(min(3, len(ensemble_results))):
    # Exclude disclaimer from print for brevity
    display_result = {k: v for k, v in ensemble_results[i].items() if k != 'disclaimer'}
    print(f'\nSample {i+1}:')
    print(json.dumps(display_result, indent=2))

print(f'\n\u2705 Ensemble tested on {n_ensemble_samples} paired samples')

## Cell 15: Risk Tier Distribution

In [None]:
"""Visualize how predictions distribute across risk tiers."""
from src.evaluator import plot_risk_tier_distribution

print('Spiral CNN — Risk Tier Distribution:')
plot_risk_tier_distribution(spiral_pred, spiral_true, 'Spiral CNN', SAVE_DIR)

print('\nWave CNN — Risk Tier Distribution:')
plot_risk_tier_distribution(wave_pred, wave_true, 'Wave CNN', SAVE_DIR)

# Ensemble risk tier distribution
print('\nEnsemble — Risk Tier Distribution:')
ensemble_probs = np.array([
    0.5 * spiral_pred[i] + 0.5 * wave_pred[i]
    for i in range(n_ensemble_samples)
])
# Use spiral_true as ground truth (same for both since test sets are aligned)
plot_risk_tier_distribution(ensemble_probs, spiral_true[:n_ensemble_samples], 'Ensemble', SAVE_DIR)

## Cell 16: Model Comparison

In [None]:
"""Side-by-side comparison of Spiral CNN vs Wave CNN."""
from src.evaluator import plot_model_comparison

print('Model Comparison — Spiral CNN vs Wave CNN:')
print(f'  Spiral CNN: Accuracy={spiral_acc:.1%}, AUC={spiral_auc:.4f}')
print(f'  Wave CNN:   Accuracy={wave_acc:.1%}, AUC={wave_auc:.4f}')

plot_model_comparison(all_metrics, SAVE_DIR)

print('\n\u2705 Model comparison complete')

## Cell 17: Test Input Handler End-to-End

In [None]:
"""End-to-end test of process_input() — the single backend entry point."""
import base64
from src.input_handler import process_input, preprocess_image_from_base64

# Pick one spiral and one wave test image, encode as base64
spiral_test_dir = os.path.join(PATHS['spiral']['test'], 'parkinson')
wave_test_dir = os.path.join(PATHS['wave']['test'], 'parkinson')

spiral_img_path = os.path.join(spiral_test_dir, sorted(os.listdir(spiral_test_dir))[0])
wave_img_path = os.path.join(wave_test_dir, sorted(os.listdir(wave_test_dir))[0])

# Encode images to base64 (simulating what the frontend sends)
with open(spiral_img_path, 'rb') as f:
    spiral_b64 = base64.b64encode(f.read()).decode('utf-8')
with open(wave_img_path, 'rb') as f:
    wave_b64 = base64.b64encode(f.read()).decode('utf-8')

# Test "drawn" mode
print('Testing process_input() with input_mode="drawn"...')
result_drawn = process_input(
    spiral_image_base64=spiral_b64,
    wave_image_base64=wave_b64,
    spiral_cnn_model=spiral_model,
    wave_cnn_model=wave_model,
    input_mode='drawn',
)

# Test "uploaded" mode (should produce identical predictions)
print('Testing process_input() with input_mode="uploaded"...')
result_uploaded = process_input(
    spiral_image_base64=spiral_b64,
    wave_image_base64=wave_b64,
    spiral_cnn_model=spiral_model,
    wave_cnn_model=wave_model,
    input_mode='uploaded',
)

# Display results (excluding base64 fields for readability)
def display_result(result, label):
    print(f'\n{"=" * 60}')
    print(f'{label}')
    print(f'{"=" * 60}')
    for key, value in result.items():
        if 'base64' in key:
            print(f'  {key}: [{"present" if value else "missing"}, '
                  f'{len(value) if value else 0} chars]')
        elif key == 'disclaimer':
            print(f'  {key}: "{value[:50]}..."')
        else:
            print(f'  {key}: {value}')

display_result(result_drawn, 'Result — Drawn Mode')
display_result(result_uploaded, 'Result — Uploaded Mode')

# Verify both modes produce the same predictions
assert result_drawn['pd_probability_percent'] == result_uploaded['pd_probability_percent'], \
    'ERROR: Drawn and uploaded modes produced different predictions!'
print(f'\n\u2705 Both modes produce identical predictions: '
      f'{result_drawn["pd_probability_percent"]}% ({result_drawn["risk_tier"]})')

# Display Grad-CAM overlays
if result_drawn.get('spiral_gradcam_base64') and result_drawn.get('wave_gradcam_base64'):
    from IPython.display import HTML, display
    display(HTML(f'''
    <div style="display: flex; gap: 20px;">
        <div><h4>Spiral Grad-CAM</h4>
            <img src="{result_drawn['spiral_gradcam_base64']}" width="224"/></div>
        <div><h4>Wave Grad-CAM</h4>
            <img src="{result_drawn['wave_gradcam_base64']}" width="224"/></div>
    </div>
    '''))

## Cell 18: Export Package

In [None]:
"""Export models and inference files for backend integration."""
import shutil
import json
from datetime import datetime

EXPORT_DIR = os.path.join(PROJECT_DIR, 'exports')
os.makedirs(EXPORT_DIR, exist_ok=True)

print('Exporting TremorTrace package...')
print('=' * 60)

# 1. Copy trained models
for model_name in ['spiral', 'wave']:
    src_path = os.path.join(SAVE_DIR, f'{model_name}_final.keras')
    dst_path = os.path.join(EXPORT_DIR, f'{model_name}_final.keras')
    if os.path.exists(src_path):
        shutil.copy2(src_path, dst_path)
        size_mb = os.path.getsize(dst_path) / (1024 * 1024)
        print(f'  \u2705 {model_name}_final.keras ({size_mb:.1f} MB)')
    else:
        print(f'  \u26a0\ufe0f  {model_name}_final.keras not found at {src_path}')

# 2. Copy backend inference files (self-contained)
backend_files = ['ensemble.py', 'gradcam.py', 'input_handler.py']
for filename in backend_files:
    src_path = os.path.join(SRC_DIR, filename)
    dst_path = os.path.join(EXPORT_DIR, filename)
    if os.path.exists(src_path):
        shutil.copy2(src_path, dst_path)
        print(f'  \u2705 {filename}')
    else:
        print(f'  \u26a0\ufe0f  {filename} not found at {src_path}')

# 3. Save metadata JSON
metadata = {
    'project': 'TremorTrace',
    'version': '1.0.0',
    'exported_at': datetime.now().isoformat(),
    'models': {
        'spiral_cnn': {
            'file': 'spiral_final.keras',
            'architecture': 'MobileNetV2 + classification head',
            'input_shape': [224, 224, 3],
            'accuracy': round(spiral_acc, 4),
            'auc': round(spiral_auc, 4),
        },
        'wave_cnn': {
            'file': 'wave_final.keras',
            'architecture': 'MobileNetV2 + classification head',
            'input_shape': [224, 224, 3],
            'accuracy': round(wave_acc, 4),
            'auc': round(wave_auc, 4),
        },
    },
    'ensemble': {
        'weights': {'spiral_cnn': 0.5, 'wave_cnn': 0.5},
        'output': 'probability_percentage_0_to_100',
    },
    'risk_tiers': [
        {'range': '0-25%', 'label': 'Low Risk', 'color': '#27AE60'},
        {'range': '25-45%', 'label': 'Mild Risk', 'color': '#F1C40F'},
        {'range': '45-65%', 'label': 'Moderate Risk', 'color': '#E67E22'},
        {'range': '65-85%', 'label': 'Elevated Risk', 'color': '#E74C3C'},
        {'range': '85-100%', 'label': 'High Risk', 'color': '#C0392B'},
    ],
    'entry_point': 'input_handler.process_input()',
    'backend_files': backend_files,
}

metadata_path = os.path.join(EXPORT_DIR, 'metadata.json')
with open(metadata_path, 'w') as f:
    json.dump(metadata, f, indent=2)
print(f'  \u2705 metadata.json')

print(f'\n\u2705 Export complete! Package at: {EXPORT_DIR}')
print(f'   Files: {os.listdir(EXPORT_DIR)}')

## Cell 19: Final Dashboard

In [None]:
"""Final summary dashboard with all key metrics and visualizations."""
from src.gradcam import generate_gradcam, find_target_layer
from src.ensemble import get_risk_tier

fig = plt.figure(figsize=(20, 12))
fig.suptitle('TremorTrace — Final Dashboard', fontsize=22, fontweight='bold')

# Grid: 2 rows x 3 columns
gs = fig.add_gridspec(2, 3, hspace=0.35, wspace=0.3)

# --- Panel 1: Model Metrics Table ---
ax1 = fig.add_subplot(gs[0, 0])
ax1.axis('off')
ax1.set_title('Model Performance', fontsize=14, fontweight='bold')

table_data = [
    ['Metric', 'Spiral CNN', 'Wave CNN'],
    ['Accuracy', f'{spiral_acc:.1%}', f'{wave_acc:.1%}'],
    ['AUC', f'{spiral_auc:.4f}', f'{wave_auc:.4f}'],
    ['Precision', f'{spiral_prec:.4f}', f'{wave_prec:.4f}'],
    ['Recall', f'{spiral_rec:.4f}', f'{wave_rec:.4f}'],
]
table = ax1.table(cellText=table_data[1:], colLabels=table_data[0],
                  cellLoc='center', loc='center')
table.auto_set_font_size(False)
table.set_fontsize(11)
table.scale(1, 1.5)
for i in range(3):
    table[0, i].set_facecolor('#3498DB')
    table[0, i].set_text_props(color='white', fontweight='bold')

# --- Panel 2: Spiral Grad-CAM sample ---
ax2 = fig.add_subplot(gs[0, 1])
spiral_test_gen.reset()
sample_batch, _ = next(spiral_test_gen)
sample_img = sample_batch[0]
_, spiral_overlay = generate_gradcam(spiral_model, sample_img)
ax2.imshow(spiral_overlay)
ax2.set_title('Spiral CNN — Grad-CAM', fontsize=14, fontweight='bold')
ax2.axis('off')

# --- Panel 3: Wave Grad-CAM sample ---
ax3 = fig.add_subplot(gs[0, 2])
wave_test_gen.reset()
sample_batch, _ = next(wave_test_gen)
sample_img = sample_batch[0]
_, wave_overlay = generate_gradcam(wave_model, sample_img)
ax3.imshow(wave_overlay)
ax3.set_title('Wave CNN — Grad-CAM', fontsize=14, fontweight='bold')
ax3.axis('off')

# --- Panel 4: Ensemble Agreement Stats ---
ax4 = fig.add_subplot(gs[1, 0])
ax4.axis('off')
ax4.set_title('Ensemble Statistics', fontsize=14, fontweight='bold')

n_unanimous = sum(1 for r in ensemble_results if r['unanimous'])
n_split = len(ensemble_results) - n_unanimous
avg_confidence = np.mean([r['confidence_score'] for r in ensemble_results])

stats_text = (
    f'Total test samples: {len(ensemble_results)}\n'
    f'Models unanimous: {n_unanimous} ({n_unanimous/len(ensemble_results):.0%})\n'
    f'Models split: {n_split} ({n_split/len(ensemble_results):.0%})\n'
    f'Avg confidence: {avg_confidence:.3f}\n'
    f'Ensemble weights: 50/50'
)
ax4.text(0.1, 0.5, stats_text, transform=ax4.transAxes,
         fontsize=12, verticalalignment='center', fontfamily='monospace',
         bbox=dict(boxstyle='round', facecolor='#ECF0F1', alpha=0.8))

# --- Panel 5: Risk Tier Pie Chart ---
ax5 = fig.add_subplot(gs[1, 1])
from src.config import RISK_TIERS

ensemble_pcts = [r['pd_probability_percent'] for r in ensemble_results]
tier_counts = []
tier_labels_list = []
tier_colors_list = []
for lower, upper, label, color in RISK_TIERS:
    count = sum(1 for p in ensemble_pcts if lower <= p < upper or (upper == 100 and p == 100))
    if count > 0:
        tier_counts.append(count)
        tier_labels_list.append(f'{label}\n({count})')
        tier_colors_list.append(color)

if tier_counts:
    ax5.pie(tier_counts, labels=tier_labels_list, colors=tier_colors_list,
            autopct='%1.0f%%', startangle=90, textprops={'fontsize': 10})
ax5.set_title('Ensemble Risk Tier Distribution', fontsize=14, fontweight='bold')

# --- Panel 6: Architecture Summary ---
ax6 = fig.add_subplot(gs[1, 2])
ax6.axis('off')
ax6.set_title('Architecture', fontsize=14, fontweight='bold')

arch_text = (
    'TremorTrace Pipeline\n'
    '\u2500' * 25 + '\n'
    'Input: 2 drawings (spiral + wave)\n'
    '\u2193\n'
    'Spiral CNN (MobileNetV2) \u2192 50%\n'
    'Wave CNN (MobileNetV2)   \u2192 50%\n'
    '\u2193\n'
    'Weighted Ensemble\n'
    '\u2193\n'
    'PD Probability (0\u2013100%)\n'
    '+ Risk Tier + Grad-CAM'
)
ax6.text(0.1, 0.5, arch_text, transform=ax6.transAxes,
         fontsize=11, verticalalignment='center', fontfamily='monospace',
         bbox=dict(boxstyle='round', facecolor='#ECF0F1', alpha=0.8))

plt.savefig(os.path.join(SAVE_DIR, 'plots', 'final_dashboard.png'), dpi=150, bbox_inches='tight')
plt.show()

print('\n' + '=' * 60)
print('\u2705 TremorTrace ML Pipeline Complete!')
print('=' * 60)
print(f'  Spiral CNN: Accuracy={spiral_acc:.1%}, AUC={spiral_auc:.4f}')
print(f'  Wave CNN:   Accuracy={wave_acc:.1%}, AUC={wave_auc:.4f}')
print(f'  Export dir: {EXPORT_DIR}')
print(f'  Plots dir:  {os.path.join(SAVE_DIR, "plots")}')
print('\n  Next steps:')
print('  1. Download exports/ folder for backend integration')
print('  2. See exports/README.md for integration instructions')
print('  3. Backend calls input_handler.process_input() — that\'s it!')
print('=' * 60)