# PPG Heart Rate Estimation with LMU

Train LMU model on PPG dataset for continuous heart rate estimation.

This notebook follows a standard regression workflow:
- **Task**: Predict heart rate (bpm) from PPG windows
- **Model**: LMU-based sequence model
- **Training**: MSE loss with MAE tracking
- **Evaluation**: Test set performance (MAE, RMSE)


In [2]:
# CRITICAL: Force reload all custom modules to ensure fixes are applied
import sys

modules_to_reload = [k for k in list(sys.modules.keys()) if k.startswith('src.')]
for mod in modules_to_reload:
    del sys.modules[mod]

if modules_to_reload:
    print(f"üîÑ Cleared {len(modules_to_reload)} cached modules")


In [3]:
from __future__ import annotations
from typing import Tuple, Dict, Any
from pathlib import Path
import torch

from torch.utils.data import DataLoader

from src.types.task_protocol import TaskProtocol
from src.datasets.ppg.ppg_config import PPGDaliaConfig
from src.datasets.ppg.ppg_dataloader import make_ppgdalia_loaders
from src.models.v2.build_model import BlockConfig


CUDA extension for structured kernels (Cauchy and Vandermonde multiplication) not found. Install by going to extensions/kernels/ and running `python setup.py install`, for improved speed and memory efficiency. Note that the kernel changed for state-spaces 4.0 and must be recompiled.
Falling back on slow Cauchy and Vandermonde kernel. Install at least one of pykeops or the CUDA extension for better speed and memory efficiency.


## Verify Shape Fix

This cell verifies that the data loader and model output compatible shapes.

In [4]:
print("=" * 60)
print("SHAPE VERIFICATION")
print("=" * 60)

# Quick test to verify shapes
from src.datasets.ppg.ppg_dataset import PPGDaliaDataset
from src.datasets.ppg.ppg_dataloader import ppg_collate

# Create a tiny test dataset
test_cfg = PPGDaliaConfig(
    root=str(Path.cwd().parent.parent.parent / "src" / "datasets" / "ppg" / "data"),
    subjects_train=("S1",),
    fs_in=64.0, fs=100.0, win_sec=8, stride_sec=2,
    do_bandpass=True, low_hz=0.5, high_hz=8.0, split='train'
)
test_ds = PPGDaliaDataset(test_cfg)

# Get a batch
batch = [test_ds[i] for i in range(min(4, len(test_ds)))]
x_batch, y_batch, _ = ppg_collate(batch)

print(f"‚úÖ Batch x shape: {tuple(x_batch.shape)} (expected: [B, T, 1])")
print(f"‚úÖ Batch y shape: {tuple(y_batch.shape)} (expected: [B])")

if y_batch.ndim == 1:
    print("‚úÖ Target shape is correct: [B] for regression")
elif y_batch.shape[-1] == 1 and y_batch.ndim == 2:
    print("‚ùå Target shape is [B, 1] but should be [B]")
    print("   ‚Üí Restart kernel and run again")
else:
    print(f"‚ùå Unexpected target shape: {tuple(y_batch.shape)}")

print("=" * 60)


SHAPE VERIFICATION
‚úÖ Batch x shape: (4, 800, 1, 1) (expected: [B, T, 1])
‚úÖ Batch y shape: (4,) (expected: [B])
‚úÖ Target shape is correct: [B] for regression


## Task Definition

In [5]:
class PPGTask(TaskProtocol):
    """PPG-based heart rate estimation (regression)."""
    problem_type: str = "regression"

    def make_loaders(
        self,
        data_root: str,
        batch_size: int = 64,
        num_workers: int = 4,
        **kwargs
    ) -> Tuple[DataLoader, DataLoader, DataLoader]:
        """Create data loaders using PPGDaliaConfig."""
        # Separate dataloader kwargs from config kwargs
        pin_memory = kwargs.pop("pin_memory", False)
        persistent_workers = kwargs.pop("persistent_workers", False)

        # Remaining kwargs go to PPGDaliaConfig
        cfg = PPGDaliaConfig(root=data_root, **kwargs)

        return make_ppgdalia_loaders(
            cfg,
            batch=batch_size,
            num_workers=num_workers,
            pin_memory=pin_memory,
            persistent_workers=persistent_workers,
        )

    def infer_input_dim(self, args: Dict[str, Any]) -> int:
        """Single PPG channel."""
        return 1

    def infer_num_classes(self, args: Dict[str, Any]) -> int:
        """Single HR output."""
        return 1

    def infer_theta(self, args: Dict[str, Any]) -> int:
        """Sequence length = window size in samples."""
        win_sec = args.get("win_sec", 8)
        fs = args.get("fs", 100.0)
        return int(win_sec * fs)


## Block Configuration Helper

In [6]:
def create_block_cfg_ctor(dropout, mlp_ratio, droppath_final, layerscale_init, residual_gain, pool, memory_size=256):
    """Create LMU block config constructor (agnostic pattern for later S4 comparison)."""
    def block_cfg_ctor(theta: int):
        return BlockConfig(
            kind="lmu",
            memory_size=memory_size,
            theta=theta,
            dropout=dropout,
            mlp_ratio=mlp_ratio,
            droppath_final=droppath_final,
            layerscale_init=layerscale_init,
            residual_gain=residual_gain,
            pool=pool
        )
    return block_cfg_ctor


## Configuration

**IMPORTANT**: Update `TRAIN_SUBJECTS`, `VAL_SUBJECT`, and `TEST_SUBJECT` to match your actual dataset subject IDs.

Place your PPG data in: `src/datasets/ppg/data/`

Expected structure:
```
data/
  S1/
    *.csv
  S2/
    *.csv
  ...
```

Each CSV should contain columns: `ppg`, `hr`


In [7]:
current_dir = Path.cwd()
project_root = current_dir.parent.parent.parent  # from src/notebooks/ppg to project root
data_root = str(project_root / "src" / "datasets" / "ppg" / "data")

# Example subject split - ADJUST TO YOUR DATASET
TRAIN_SUBJECTS = ("S1", "S2", "S3", "S4", "S5", "S6", "S7", "S8")
VAL_SUBJECT = "S9"
TEST_SUBJECT = "S10"

args: Dict[str, Any] = {
    # Data
    "data_root": data_root,
    "batch": 32,  # Reduced from 128 for MPS memory
    "data_loader_kwargs": {
        "num_workers": 0,
        "pin_memory": False,
        "persistent_workers": False,
        # PPGDaliaConfig parameters
        "subjects_train": TRAIN_SUBJECTS,
        "subject_val": VAL_SUBJECT,
        "subject_test": TEST_SUBJECT,
        "fs_in": 64.0,        # Input sampling rate (Hz)
        "fs": 100.0,          # Target sampling rate (Hz)
        "win_sec": 8,         # Window length (seconds)
        "stride_sec": 2,      # Stride (seconds)
        "do_bandpass": True,
        "low_hz": 0.5,
        "high_hz": 8.0,
    },

    # Training
    "epochs": 100,
    "lr": 5e-4,
    "wd": 1e-4,
    "amp": False,  # Disable AMP for MPS stability
    "save_dir": "./runs/ppg_lmu_task",
    "warmup_epochs": 5,
    "patience": 10,
    "min_delta": 0.01,  # MAE improvement threshold

    # Model (reduced for MPS memory)
    "d_model": 128,      # Reduced from 256
    "depth": 4,          # Reduced from 6
    "dropout": 0.2,
    "mlp_ratio": 2.0,
    "droppath_final": 0.1,
    "layerscale_init": 1e-2,
    "residual_gain": 1.0,
    "pool": "mean",

    # LMU-specific
    "memory_size": 128,  # Reduced from 256
}

args["block_cfg_ctor"] = create_block_cfg_ctor(
    dropout=args["dropout"],
    mlp_ratio=args["mlp_ratio"],
    droppath_final=args["droppath_final"],
    layerscale_init=args["layerscale_init"],
    residual_gain=args["residual_gain"],
    pool=args["pool"],
    memory_size=args["memory_size"],
)

# Device selection
import os

if torch.backends.mps.is_available():
    args["device"] = torch.device("mps")
    # Set MPS memory management BEFORE any operations
    os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
    torch.mps.set_per_process_memory_fraction(0.7)
    print("üöÄ Using MPS (Apple Silicon) with reduced memory")
elif torch.cuda.is_available():
    args["device"] = torch.device("cuda")
    print("üöÄ Using CUDA")
else:
    args["device"] = torch.device("cpu")
    args["amp"] = False
    print("‚ö†Ô∏è  Using CPU")

print(f"\nüìÇ Data root: {args['data_root']}")
print(f"üéØ Train subjects: {TRAIN_SUBJECTS}")
print(f"üéØ Val subject: {VAL_SUBJECT}")
print(f"üéØ Test subject: {TEST_SUBJECT}")


üöÄ Using MPS (Apple Silicon) with reduced memory

üìÇ Data root: /Users/glbrlb/PycharmProjects/Msc/LMU_S4/src/datasets/ppg/data
üéØ Train subjects: ('S1', 'S2', 'S3', 'S4', 'S5', 'S6', 'S7', 'S8')
üéØ Val subject: S9
üéØ Test subject: S10


## Training

In [8]:
from src.train_utils.trainer import Trainer
from src.models.v2.build_model import build_model

# Define the task
task = PPGTask()


# Initialize trainer
trainer = Trainer(args=args, task=task, model_builder=build_model)

# Train
best_metric, best_path = trainer.fit()

history = trainer.history

print(f"\n‚úÖ Training complete! Best validation {trainer.early_key}: {best_metric:.4f}")
print(f"üíæ Best model saved to: {best_path}")


  m, s = pick_pade_structure(Am)
  m, s = pick_pade_structure(Am)
  m, s = pick_pade_structure(Am)
  eAw = eAw @ eAw
  eAw = eAw @ eAw
  eAw = eAw @ eAw
                                                          

üíæ saved best model to ./runs/ppg_lmu_task/best.pt
‚úÖ new best mse 7965.1578
Epoch 000/100 | train 8912.6477/8912.6477 | val 7965.1578/7965.1578 | t 244.4s/10.5s | lr 5.00e-07


                                                          

üíæ saved best model to ./runs/ppg_lmu_task/best.pt
‚úÖ new best mse 4664.3891
Epoch 001/100 | train 6893.1810/6893.1810 | val 4664.3891/4664.3891 | t 267.9s/10.4s | lr 1.00e-04


                                                          

üíæ saved best model to ./runs/ppg_lmu_task/best.pt
‚úÖ new best mse 593.2457
Epoch 002/100 | train 2916.3936/2916.3936 | val 593.2457/593.2457 | t 264.5s/10.1s | lr 2.00e-04


                                                          

üíæ saved best model to ./runs/ppg_lmu_task/best.pt
‚úÖ new best mse 419.0924
Epoch 003/100 | train 601.0572/601.0572 | val 419.0924/419.0924 | t 259.3s/10.0s | lr 3.00e-04




üíæ saved best model to ./runs/ppg_lmu_task/best.pt
‚úÖ new best mse 414.0154
Epoch 004/100 | train 300.0032/300.0032 | val 414.0154/414.0154 | t 257.8s/10.1s | lr 4.00e-04


                                                          

üíæ saved best model to ./runs/ppg_lmu_task/best.pt
‚úÖ new best mse 345.5797
Epoch 005/100 | train 203.5332/203.5332 | val 345.5797/345.5797 | t 256.9s/10.1s | lr 5.00e-04


                                                          

üíæ saved best model to ./runs/ppg_lmu_task/best.pt
‚úÖ new best mse 230.4436
Epoch 006/100 | train 170.5696/170.5696 | val 230.4436/230.4436 | t 257.5s/10.1s | lr 5.00e-04


                                                          

üíæ saved best model to ./runs/ppg_lmu_task/best.pt
‚úÖ new best mse 227.4730
Epoch 007/100 | train 146.6591/146.6591 | val 227.4730/227.4730 | t 256.0s/10.0s | lr 4.99e-04


                                                          

Epoch 008/100 | train 129.3242/129.3242 | val 265.1766/265.1766 | t 257.5s/10.1s | lr 4.99e-04


                                                          

Epoch 009/100 | train 114.6819/114.6819 | val 236.4995/236.4995 | t 256.0s/10.1s | lr 4.98e-04


                                                          

Epoch 010/100 | train 102.5112/102.5112 | val 319.3702/319.3702 | t 255.6s/9.9s | lr 4.97e-04


                                                          

Epoch 011/100 | train 90.9796/90.9796 | val 283.3677/283.3677 | t 263.3s/9.9s | lr 4.95e-04


                                                          

Epoch 012/100 | train 82.6397/82.6397 | val 261.4781/261.4781 | t 256.6s/10.2s | lr 4.93e-04


                                                          

Epoch 013/100 | train 76.8817/76.8817 | val 290.6767/290.6767 | t 255.1s/10.0s | lr 4.91e-04


                                                           

Epoch 014/100 | train 69.2895/69.2895 | val 360.9883/360.9883 | t 307.6s/20.6s | lr 4.89e-04


                                                               

Epoch 015/100 | train 62.9593/62.9593 | val 258.3350/258.3350 | t 535.5s/20.4s | lr 4.86e-04


                                                               

Epoch 016/100 | train 58.9457/58.9457 | val 446.3455/446.3455 | t 526.7s/20.8s | lr 4.84e-04


                                                               

Epoch 017/100 | train 54.0642/54.0642 | val 318.8208/318.8208 | t 501.9s/16.2s | lr 4.81e-04


                                                               

‚èπ Early stopping (patience=10, best=227.4730).
üìä Training history saved to ./runs/ppg_lmu_task/history.json

‚úÖ Training complete! Best validation mse: 227.4730
üíæ Best model saved to: ./runs/ppg_lmu_task/best.pt




## Plot History

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 5))

# Loss
plt.subplot(1, 2, 1)
plt.plot(history["train_loss"], label="train_loss", linewidth=2)
plt.plot(history["val_loss"], label="val_loss", linewidth=2)
plt.xlabel("Epoch", fontsize=12)
plt.ylabel("MSE Loss", fontsize=12)
plt.legend(fontsize=11)
plt.title("Training & Validation Loss", fontsize=14)
plt.grid(True, alpha=0.3)

# MAE
plt.subplot(1, 2, 2)
plt.plot(history["train_mae"], label="train_mae", linewidth=2)
plt.plot(history["val_mae"], label="val_mae", linewidth=2)
plt.xlabel("Epoch", fontsize=12)
plt.ylabel("MAE (bpm)", fontsize=12)
plt.legend(fontsize=11)
plt.title("Mean Absolute Error", fontsize=14)
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


## Test Evaluation

In [None]:
import numpy as np
from tqdm.auto import tqdm
from torch.amp import autocast as amp_autocast

def evaluate_best_model(args: dict, task: TaskProtocol, model_builder: callable, best_model_path: str):
    """Load best checkpoint and evaluate on test set."""
    print("üìä Evaluating best model on the test set...")

    # 1. Test loader
    _, _, test_loader = task.make_loaders(
        data_root=args["data_root"],
        batch_size=args["batch"],
        **args["data_loader_kwargs"]
    )

    # 2. Load checkpoint and rebuild model
    device = args.get("device", torch.device("cpu"))
    checkpoint = torch.load(best_model_path, map_location=device)

    flat_args = dict(args)
    flat_args.update(args.get("data_loader_kwargs", {}))
    d_in = task.infer_input_dim(flat_args)
    n_classes = task.infer_num_classes(flat_args)
    theta = task.infer_theta(flat_args)

    block_cfg = args["block_cfg_ctor"](theta)
    model = model_builder(
        d_in=d_in,
        n_classes=n_classes,
        d_model=args["d_model"],
        depth=args["depth"],
        block_cfg=block_cfg
    ).to(device)

    model.load_state_dict(checkpoint["model"])
    model.eval()

    print(f"‚úÖ Loaded checkpoint from epoch {checkpoint.get('epoch', 'N/A')}")
    val_metrics = checkpoint.get('val', {})
    print(f"üìà Val MSE: {val_metrics.get('mse', 'N/A'):.4f}, Val MAE: {val_metrics.get('mae', 'N/A'):.4f}")

    # 3. Evaluation loop
    all_preds, all_targets = [], []

    with torch.no_grad():
        for x, y, _ in tqdm(test_loader, desc="Testing"):
            x, y = x.to(device), y.to(device)
            # Disable AMP for evaluation
            with amp_autocast(device_type=device.type, enabled=False):
                out = model(x)
            all_preds.append(out.cpu().numpy())
            all_targets.append(y.cpu().numpy())

    # 4. Metrics
    all_preds = np.concatenate(all_preds).flatten()
    all_targets = np.concatenate(all_targets).flatten()

    test_mse = np.mean((all_preds - all_targets) ** 2)
    test_mae = np.mean(np.abs(all_preds - all_targets))
    test_rmse = np.sqrt(test_mse)

    print("\n" + "=" * 50)
    print("TEST SET RESULTS:")
    print("=" * 50)
    print(f"MSE:  {test_mse:.4f}")
    print(f"MAE:  {test_mae:.4f} bpm")
    print(f"RMSE: {test_rmse:.4f}")

    return all_preds, all_targets

# Run evaluation
preds, targets = evaluate_best_model(
    args=args,
    task=task,
    model_builder=build_model,
    best_model_path=best_path
)
