# Sleep Stage Model Training & Evaluation

This notebook trains the 4-class sleep stage model and evaluates its performance,
with special focus on **Wake/Light detection** for smart wake-up alarm applications.

## Goal: Smart Wake-Up Alarm

The alarm should trigger during **Wake** or **Light Sleep** stages, avoiding:
- **Deep Sleep (N3)** - Most restorative, waking here causes grogginess
- **REM Sleep** - Dream sleep, important for memory consolidation

### Key Metrics

| Metric | Why It Matters |
|--------|----------------|
| Wake/Light Recall | Don't miss optimal wake windows |
| Deep/REM Precision | Don't falsely trigger alarm during deep sleep |
| "Safe to Wake" Accuracy | Combined Wake+Light vs Deep+REM binary accuracy |


## 1. Setup


In [None]:
import sys
from pathlib import Path

# Add src to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root / 'src'))

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm

# Our modules
from data.loader import DREAMTLoader
from features.extractor import FeatureExtractor
from models.tflite_model import SleepStageMLP

# Sklearn metrics
from sklearn.metrics import (
    confusion_matrix, classification_report, 
    accuracy_score, precision_recall_fscore_support,
    cohen_kappa_score
)

%matplotlib inline
plt.rcParams['figure.dpi'] = 100
plt.rcParams['figure.figsize'] = [14, 6]

print("✓ Imports successful!")


In [None]:
# Configuration
DATA_DIR = project_root / 'data' / 'dreamt'
OUTPUT_DIR = project_root / 'models' / 'tflite_4class'
RESOLUTION = '64Hz'
EPOCH_DURATION = 30.0  # seconds

# Create output directory
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

print(f"Data directory: {DATA_DIR}")
print(f"Output directory: {OUTPUT_DIR}")
print(f"Data exists: {DATA_DIR.exists()}")


## 2. Load and Extract Features


In [None]:
# Initialize loader and feature extractor
loader = DREAMTLoader(DATA_DIR, resolution=RESOLUTION)
print(f"Found {len(loader.participants)} participants")

fs = 64.0 if RESOLUTION == '64Hz' else 100.0
extractor = FeatureExtractor(epoch_duration=EPOCH_DURATION, fs=fs)

# Extract features from all participants (may take a few minutes)
all_features = []

for pid in tqdm(loader.participants, desc="Extracting features"):
    try:
        df = loader.load_participant(pid)
        features = extractor.extract_all_features(df, include_imu=True, include_ppg=True)
        features['participant'] = pid
        all_features.append(features)
    except Exception as e:
        print(f"Error with {pid}: {e}")

df_all = pd.concat(all_features, ignore_index=True)
print(f"\nTotal epochs: {len(df_all)}")


In [None]:
# Filter to IMU + PPG features only (matching what ESP32 can compute)
imu_prefixes = ['imu_x', 'imu_y', 'imu_z', 'imu_mag', 'imu_activity', 'imu_movement']
ppg_prefixes = ['ppg_', 'hr_', 'hrv_']

feature_cols = [col for col in df_all.columns 
                if any(col.startswith(p) for p in imu_prefixes + ppg_prefixes)]
print(f"Selected {len(feature_cols)} features for training")

# Keep only valid sleep stages
valid_stages = ['W', 'N1', 'N2', 'N3', 'R']
df_clean = df_all[df_all['Sleep_Stage'].isin(valid_stages)].copy()
df_clean = df_clean.dropna(subset=feature_cols)
print(f"Epochs after cleaning: {len(df_clean)}")

# Create 4-class mapping (merge N1 + N2 → Light)
stage_mapping = {'W': 'Wake', 'N1': 'Light', 'N2': 'Light', 'N3': 'Deep', 'R': 'REM'}
df_clean['Stage_4class'] = df_clean['Sleep_Stage'].map(stage_mapping)

print("\n4-class distribution:")
print(df_clean['Stage_4class'].value_counts())


## 3. Train/Test Split & Model Training


In [None]:
# Split by participant (not epoch) to avoid data leakage
participants = df_clean['participant'].unique()
np.random.seed(42)
np.random.shuffle(participants)

n_train = int(len(participants) * 0.8)
train_pids, test_pids = participants[:n_train], participants[n_train:]

train_mask = df_clean['participant'].isin(train_pids)
test_mask = df_clean['participant'].isin(test_pids)

X_train = df_clean.loc[train_mask, feature_cols].values
y_train = df_clean.loc[train_mask, 'Sleep_Stage'].values
X_test = df_clean.loc[test_mask, feature_cols].values
y_test = df_clean.loc[test_mask, 'Sleep_Stage'].values
y_test_4class = df_clean.loc[test_mask, 'Stage_4class'].values

print(f"Train: {len(train_pids)} participants, {len(X_train)} epochs")
print(f"Test: {len(test_pids)} participants, {len(X_test)} epochs")


In [None]:
# Create and train model
model = SleepStageMLP(
    input_dim=len(feature_cols),
    hidden_layers=[64, 32, 16],
    dropout_rate=0.3,
    l2_reg=0.001
)
model.feature_names = feature_cols
model.compile(learning_rate=0.001)

print("Model Architecture:")
model.summary()

# Train
history = model.fit(
    X_train, y_train,
    validation_split=0.15,
    epochs=100,
    batch_size=64,
    early_stopping_patience=15,
    verbose=1
)


## 4. Evaluate Model Performance


In [None]:
# Get predictions
y_pred_names = model.predict_stage_names(X_test)

# Standard 4-class evaluation
metrics = model.evaluate(X_test, y_test, verbose=True)

# Confusion matrix
class_names = ['Wake', 'Light', 'Deep', 'REM']
cm = confusion_matrix(y_test_4class, y_pred_names, labels=class_names)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=class_names, yticklabels=class_names, ax=axes[0])
axes[0].set_xlabel('Predicted'); axes[0].set_ylabel('True')
axes[0].set_title('Confusion Matrix (Counts)')

cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
sns.heatmap(cm_norm, annot=True, fmt='.2f', cmap='Blues',
            xticklabels=class_names, yticklabels=class_names, ax=axes[1])
axes[1].set_xlabel('Predicted'); axes[1].set_ylabel('True')
axes[1].set_title('Confusion Matrix (Normalized)')
plt.tight_layout()
plt.show()


## 5. Smart Wake-Up Alarm Analysis ⏰

For the alarm, we care about a **binary decision**:
- **Safe to wake**: Wake or Light sleep → Alarm can trigger
- **Don't wake**: Deep or REM sleep → Alarm should wait


In [None]:
# Create binary labels: Safe to Wake (1) vs Don't Wake (0)
y_test_binary = np.array([1 if s in ['Wake', 'Light'] else 0 for s in y_test_4class])
y_pred_binary = np.array([1 if s in ['Wake', 'Light'] else 0 for s in y_pred_names])

# Binary metrics
binary_accuracy = accuracy_score(y_test_binary, y_pred_binary)
precision, recall, f1, _ = precision_recall_fscore_support(
    y_test_binary, y_pred_binary, average='binary', pos_label=1
)

print("=" * 60)
print("SMART WAKE-UP ALARM METRICS")
print("=" * 60)
print(f"\n'Safe to Wake' = Wake or Light Sleep")
print(f"'Don't Wake' = Deep or REM Sleep\n")
print(f"Binary Accuracy:        {binary_accuracy:.1%}")
print(f"Safe-to-Wake Precision: {precision:.1%}  (When alarm says 'safe', how often correct?)")
print(f"Safe-to-Wake Recall:    {recall:.1%}  (What % of safe windows do we detect?)")
print(f"F1 Score:               {f1:.3f}")


In [None]:
# Binary confusion matrix for alarm decision
cm_binary = confusion_matrix(y_test_binary, y_pred_binary)

fig, ax = plt.subplots(figsize=(8, 6))
labels = ["Don't Wake\n(Deep/REM)", "Safe to Wake\n(Wake/Light)"]
sns.heatmap(cm_binary, annot=True, fmt='d', cmap='RdYlGn',
            xticklabels=labels, yticklabels=labels, ax=ax, annot_kws={'size': 16})
ax.set_xlabel('Predicted', fontsize=12)
ax.set_ylabel('True', fontsize=12)
ax.set_title('Wake-Up Alarm Binary Classification', fontsize=14)
plt.tight_layout()
plt.show()

# Interpretation
tn, fp, fn, tp = cm_binary.ravel()
print(f"\nAlarm Decision Analysis:")
print(f"  ✓ Correctly detected safe windows: {tp} epochs")
print(f"  ✓ Correctly avoided deep/REM: {tn} epochs")
print(f"  ✗ Missed safe windows: {fn} epochs (alarm could have triggered earlier)")
print(f"  ✗ False alarms during deep/REM: {fp} epochs (would cause grogginess)")
print(f"\n  False alarm rate: {fp / (fp + tn):.1%}")


## 6. Visualize Predictions Over a Night

Let's see how the model tracks sleep stages and identifies wake windows.


In [None]:
# Pick a test participant and visualize their night
test_participant = test_pids[0]
mask = df_clean['participant'] == test_participant
X_p = df_clean.loc[mask, feature_cols].values
y_true_p = df_clean.loc[mask, 'Stage_4class'].values
y_pred_p = model.predict_stage_names(X_p)

time_hours = np.arange(len(y_true_p)) * (EPOCH_DURATION / 3600)

# Plot hypnograms
stage_to_num = {'Wake': 3, 'Light': 2, 'Deep': 1, 'REM': 0}
y_true_num = [stage_to_num[s] for s in y_true_p]
y_pred_num = [stage_to_num[s] for s in y_pred_p]

fig, axes = plt.subplots(2, 1, figsize=(16, 8), sharex=True)

axes[0].step(time_hours, y_true_num, where='mid', linewidth=2, color='#2c3e50', label='True')
axes[0].step(time_hours, y_pred_num, where='mid', linewidth=2, color='#e74c3c', alpha=0.7, label='Predicted')
axes[0].set_yticks([0, 1, 2, 3])
axes[0].set_yticklabels(['REM', 'Deep', 'Light', 'Wake'])
axes[0].set_ylabel('Sleep Stage')
axes[0].set_title(f'Sleep Staging: {test_participant}')
axes[0].legend(loc='upper right')
axes[0].grid(True, alpha=0.3)

# Safe to wake windows
safe_true = np.array([s in ['Wake', 'Light'] for s in y_true_p])
safe_pred = np.array([s in ['Wake', 'Light'] for s in y_pred_p])

axes[1].fill_between(time_hours, 0, safe_true.astype(int), step='mid', alpha=0.3, color='#2ecc71', label='True Safe')
axes[1].step(time_hours, safe_pred.astype(int) * 0.9, where='mid', linewidth=2, color='#e74c3c', label='Predicted Safe')
axes[1].set_xlabel('Time (hours)')
axes[1].set_ylabel('Safe to Wake?')
axes[1].set_yticks([0, 1])
axes[1].set_yticklabels(['No', 'Yes'])
axes[1].set_title('Smart Alarm: Detected Wake Windows')
axes[1].legend(loc='upper right')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Participant accuracy: {np.mean(np.array(y_true_num) == np.array(y_pred_num)):.1%}")


## 7. Save Model for Device Deployment


In [None]:
# Save model, convert to TFLite, and export C++ scaler
model.save(str(OUTPUT_DIR / 'keras_model'))

tflite_path = str(OUTPUT_DIR / 'sleep_model.tflite')
model.convert_to_tflite(tflite_path, quantize=True, representative_data=X_train[:1000])

scaler_path = str(OUTPUT_DIR / 'scaler_params.h')
model.export_scaler_for_cpp(scaler_path)

# Save feature list
with open(OUTPUT_DIR / 'feature_list.txt', 'w') as f:
    for i, name in enumerate(feature_cols):
        f.write(f"{i:3d}: {name}\n")

print(f"\n✓ All files saved to: {OUTPUT_DIR}")


## 8. Summary & Next Steps

### Model Performance Summary

Run the cell below to see the final summary.

### Ideas for Improvement

1. **Adjust class weights** - Prioritize Wake/Light recall if you want to catch more wake windows
2. **Add temporal context** - Use previous N epochs as additional features
3. **Tune threshold** - Instead of argmax, use probability threshold for "safe to wake"
4. **Post-processing** - Smooth predictions to avoid rapid stage switching


In [None]:
print("=" * 60)
print("FINAL MODEL SUMMARY")
print("=" * 60)
print(f"\nArchitecture: MLP [{len(feature_cols)} → 64 → 32 → 16 → 4]")
print(f"Features: {len(feature_cols)} (IMU + PPG)")
print(f"Classes: Wake, Light, Deep, REM")
print(f"\n4-Class Performance:")
print(f"  Accuracy:   {metrics['accuracy']:.1%}")
print(f"  Macro F1:   {metrics['f1_macro']:.3f}")
print(f"  Cohen's κ:  {metrics['kappa']:.3f}")
print(f"\nSmart Alarm (Wake/Light vs Deep/REM):")
print(f"  Binary Accuracy:  {binary_accuracy:.1%}")
print(f"  Precision:        {precision:.1%}")
print(f"  Recall:           {recall:.1%}")
print(f"  False Alarm Rate: {fp / (fp + tn):.1%}")
print(f"\nOutput: {OUTPUT_DIR}")
