In [4]:
#!/usr/bin/env python3
"""
PPG-based Blood Pressure Model Training Script
Usage: python train_ppg_model.py --person <name> --dir <path>
"""

import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import find_peaks
import pickle
from sklearn.linear_model import LinearRegression
import os
import argparse
import json
from dataclasses import dataclass, asdict
from typing import Dict, List, Tuple, Optional


@dataclass
class PersonModel:
    """Data class to store person-specific model parameters"""

    name: str
    baseline_hr: float
    baseline_mnpv: float
    baseline_sbp: float
    baseline_dbp: float
    baseline_map: float
    sbp_coeffs: Tuple[float, float, float]  # (a, b, c)
    dbp_coeffs: Tuple[float, float, float]
    map_coeffs: Tuple[float, float, float]
    r2_scores: Dict[str, float]  # R² scores for each model


class PPGBloodPressureModel:
    def __init__(self, person_name: str, data_dir: str):
        self.person_name = person_name
        self.data_dir = data_dir
        self.baseline_file = os.path.join(data_dir, f"{person_name}_baseline.csv")
        self.model_file = os.path.join(data_dir, f"{person_name}_model.pkl")

        # Model parameters
        self.baseline_hr = None
        self.baseline_mnpv = None
        self.baseline_sbp = None
        self.baseline_dbp = None
        self.baseline_map = None

        # Regression models
        self.sbp_model = LinearRegression()
        self.dbp_model = LinearRegression()
        self.map_model = LinearRegression()

    def load_ppg_data(self, file_path: str) -> np.ndarray:
        """Load PPG data from CSV file"""
        time_vals = []
        signal_vals = []
        ground_truth = None

        with open(file_path, "r") as f:
            for i, line in enumerate(f):
                parts = line.strip().split(",")
                if len(parts) < 2:
                    continue

                # First line contains ground truth values
                if i == 0 and len(parts) >= 5:
                    try:
                        _, _, sbp, dbp, hr = parts[:5]
                        ground_truth = {
                            "sbp": float(sbp),
                            "dbp": float(dbp),
                            "hr": float(hr),
                        }
                    except (ValueError, IndexError):
                        pass

                try:
                    t = float(parts[0])
                    s = float(parts[1])
                    time_vals.append(t * 1e-3)  # convert ms to s
                    signal_vals.append(s)
                except ValueError:
                    continue

        data = np.column_stack((time_vals, signal_vals))
        return data, ground_truth

    def calculate_ppg_features(self, time: np.ndarray, signal: np.ndarray) -> Dict:
        """Calculate PPG features from time and signal arrays"""
        # Sampling frequency
        fs = 1.0 / np.mean(np.diff(time))

        # DC level
        dc = np.mean(signal)

        # Find peaks and troughs
        peaks, _ = find_peaks(signal, prominence=4, distance=fs * 0.3)
        troughs, _ = find_peaks(-signal, prominence=4, distance=fs * 0.3)

        # AC amplitude
        n = min(len(peaks), len(troughs))
        if n == 0:
            ac = np.nan
        else:
            ac_vals = signal[peaks[:n]] - signal[troughs[:n]]
            ac = np.mean(np.abs(ac_vals))

        # Heart rate
        if len(peaks) > 1:
            intervals = np.diff(time[peaks])
            hr = 60.0 / np.mean(intervals)
        else:
            hr = np.nan

        # mNPV
        mnpv = ac / dc if dc != 0 else np.nan

        return {
            "hr": hr,
            "ac": ac,
            "dc": dc,
            "mnpv": mnpv,
            "peaks": peaks,
            "troughs": troughs,
        }

    def load_baseline(self) -> None:
        """Load baseline values from baseline CSV file"""
        if not os.path.exists(self.baseline_file):
            raise FileNotFoundError(f"Baseline file not found: {self.baseline_file}")

        ppg_data, ground_truth = self.load_ppg_data(self.baseline_file)
        time, signal = ppg_data[:, 0], ppg_data[:, 1]

        features = self.calculate_ppg_features(time, signal)

        self.baseline_hr = features["hr"]
        self.baseline_mnpv = features["mnpv"]
        self.baseline_sbp = ground_truth["sbp"]
        self.baseline_dbp = ground_truth["dbp"]
        self.baseline_map = self.baseline_dbp + (1 / 3) * (
            self.baseline_sbp - self.baseline_dbp
        )

        print(f"Baseline loaded for {self.person_name}:")
        print(f"  HR: {self.baseline_hr:.1f} bpm")
        print(f"  mNPV: {self.baseline_mnpv:.6f}")
        print(f"  BP: {self.baseline_sbp:.0f}/{self.baseline_dbp:.0f} mmHg")
        print(f"  MAP: {self.baseline_map:.1f} mmHg")

    def calculate_deltas(
        self,
        current_hr: float,
        current_mnpv: float,
        current_sbp: float,
        current_dbp: float,
    ) -> Tuple[float, ...]:
        """Calculate delta values from baseline"""
        current_map = current_dbp + (1 / 3) * (current_sbp - current_dbp)

        delta_ln_hr = np.log(current_hr) - np.log(self.baseline_hr)
        delta_ln_mnpv = np.log(current_mnpv) - np.log(self.baseline_mnpv)
        delta_ln_sbp = np.log(current_sbp) - np.log(self.baseline_sbp)
        delta_ln_dbp = np.log(current_dbp) - np.log(self.baseline_dbp)
        delta_ln_map = np.log(current_map) - np.log(self.baseline_map)

        return delta_ln_hr, delta_ln_mnpv, delta_ln_sbp, delta_ln_dbp, delta_ln_map

    def process_training_data(self) -> Tuple[np.ndarray, Dict[str, np.ndarray]]:
        """Process all training CSV files in the directory"""
        delta_data = {"hr": [], "mnpv": [], "sbp": [], "dbp": [], "map": []}

        # Get all CSV files except baseline
        csv_files = [
            f
            for f in os.listdir(self.data_dir)
            if f.endswith(".csv") and f != f"{self.person_name}_baseline.csv"
        ]

        print(f"\nProcessing {len(csv_files)} training files...")

        for i, csv_file in enumerate(csv_files):
            file_path = os.path.join(self.data_dir, csv_file)
            ppg_data, ground_truth = self.load_ppg_data(file_path)

            if ground_truth is None:
                print(f"  Skipping {csv_file}: No ground truth data")
                continue

            time, signal = ppg_data[:, 0], ppg_data[:, 1]
            features = self.calculate_ppg_features(time, signal)

            if np.isnan(features["hr"]) or np.isnan(features["mnpv"]):
                print(f"  Skipping {csv_file}: Invalid features")
                continue

            # Calculate deltas
            deltas = self.calculate_deltas(
                features["hr"],
                features["mnpv"],
                ground_truth["sbp"],
                ground_truth["dbp"],
            )

            delta_data["hr"].append(deltas[0])
            delta_data["mnpv"].append(deltas[1])
            delta_data["sbp"].append(deltas[2])
            delta_data["dbp"].append(deltas[3])
            delta_data["map"].append(deltas[4])

            print(f"  Processed {csv_file} ({i+1}/{len(csv_files)})")

        # Convert to numpy arrays
        for key in delta_data:
            delta_data[key] = np.array(delta_data[key])

        X = np.column_stack((delta_data["hr"], delta_data["mnpv"]))

        return X, delta_data

    def train_models(self) -> None:
        """Train regression models on the processed data"""
        X, delta_data = self.process_training_data()

        if len(X) == 0:
            raise ValueError("No valid training data found")

        print(f"\nTraining models with {len(X)} data points...")

        # Train SBP model
        self.sbp_model.fit(X, delta_data["sbp"])
        sbp_r2 = self.sbp_model.score(X, delta_data["sbp"])

        # Train DBP model
        self.dbp_model.fit(X, delta_data["dbp"])
        dbp_r2 = self.dbp_model.score(X, delta_data["dbp"])

        # Train MAP model
        self.map_model.fit(X, delta_data["map"])
        map_r2 = self.map_model.score(X, delta_data["map"])

        print("\nModel training complete:")
        print(f"  SBP model R²: {sbp_r2:.3f}")
        print(f"  DBP model R²: {dbp_r2:.3f}")
        print(f"  MAP model R²: {map_r2:.3f}")

        return {"sbp": sbp_r2, "dbp": dbp_r2, "map": map_r2}

    def save_model(self) -> None:
        """Save trained model to file"""
        model_data = PersonModel(
            name=self.person_name,
            baseline_hr=self.baseline_hr,
            baseline_mnpv=self.baseline_mnpv,
            baseline_sbp=self.baseline_sbp,
            baseline_dbp=self.baseline_dbp,
            baseline_map=self.baseline_map,
            sbp_coeffs=(
                self.sbp_model.coef_[0],
                self.sbp_model.coef_[1],
                self.sbp_model.intercept_,
            ),
            dbp_coeffs=(
                self.dbp_model.coef_[0],
                self.dbp_model.coef_[1],
                self.dbp_model.intercept_,
            ),
            map_coeffs=(
                self.map_model.coef_[0],
                self.map_model.coef_[1],
                self.map_model.intercept_,
            ),
            r2_scores=self.train_models(),
        )

        with open(self.model_file, "wb") as f:
            pickle.dump(model_data, f)

        # Also save as JSON for readability
        json_file = self.model_file.replace(".pkl", ".json")
        with open(json_file, "w") as f:
            json.dump(asdict(model_data), f, indent=2)

        print(f"\nModel saved to {self.model_file}")
        print(f"JSON copy saved to {json_file}")

    def load_model(self) -> PersonModel:
        """Load trained model from file"""
        if not os.path.exists(self.model_file):
            raise FileNotFoundError(f"Model file not found: {self.model_file}")

        with open(self.model_file, "rb") as f:
            model_data = pickle.load(f)

        # Update instance variables
        self.baseline_hr = model_data.baseline_hr
        self.baseline_mnpv = model_data.baseline_mnpv
        self.baseline_sbp = model_data.baseline_sbp
        self.baseline_dbp = model_data.baseline_dbp
        self.baseline_map = model_data.baseline_map

        # Update regression models
        self.sbp_model.coef_ = np.array(model_data.sbp_coeffs[:2])
        self.sbp_model.intercept_ = model_data.sbp_coeffs[2]

        self.dbp_model.coef_ = np.array(model_data.dbp_coeffs[:2])
        self.dbp_model.intercept_ = model_data.dbp_coeffs[2]

        self.map_model.coef_ = np.array(model_data.map_coeffs[:2])
        self.map_model.intercept_ = model_data.map_coeffs[2]

        print(f"Model loaded for {self.person_name}")
        return model_data

    def estimate_blood_pressure(
        self, hr: float, mnpv: float
    ) -> Tuple[float, float, float]:
        """Estimate blood pressure from HR and mNPV values"""
        delta_ln_hr = np.log(hr) - np.log(self.baseline_hr)
        delta_ln_mnpv = np.log(mnpv) - np.log(self.baseline_mnpv)

        X = np.array([[delta_ln_hr, delta_ln_mnpv]])

        est_delta_ln_sbp = self.sbp_model.predict(X)[0]
        est_delta_ln_dbp = self.dbp_model.predict(X)[0]
        est_delta_ln_map = self.map_model.predict(X)[0]

        est_sbp = self.baseline_sbp * np.exp(est_delta_ln_sbp)
        est_dbp = self.baseline_dbp * np.exp(est_delta_ln_dbp)
        est_map = self.baseline_map * np.exp(est_delta_ln_map)

        return est_sbp, est_dbp, est_map


def main(person_name, person_folder):
    model = PPGBloodPressureModel(person_name, person_folder)

    model.load_baseline()
    model.train_models()
    model.save_model()

    print("\nModel ready for inference!")

    # Example usage
    example_hr = 80
    example_mnpv = 0.015
    est_sbp, est_dbp, est_map = model.estimate_blood_pressure(example_hr, example_mnpv)
    print(f"\nExample prediction for HR={example_hr:.1f}, mNPV={example_mnpv:.4f}:")
    print(f"  Estimated BP: {est_sbp:.1f}/{est_dbp:.1f} mmHg, MAP: {est_map:.1f} mmHg")


if __name__ == "__main__":
    main("Bryan", "/Users/bryanjangeesingh/Documents/6.1820/KardIoT/data/Bryan")

Baseline loaded for Bryan:
  HR: 86.2 bpm
  mNPV: 0.012308
  BP: 120/62 mmHg
  MAP: 81.3 mmHg

Processing 13 training files...
  Processed bryan_baseline.csv (1/13)
  Processed 30_apr_bryan_6.csv (2/13)
  Processed 30_apr_bryan_7.csv (3/13)
  Processed 30_apr_bryan_5.csv (4/13)
  Processed 30_apr_bryan_4.csv (5/13)
  Processed 30_apr_bryan_1.csv (6/13)
  Processed 30_apr_bryan_3.csv (7/13)
  Skipping 30_apr_bryan_2.csv: No ground truth data
  Skipping test8_bryan_rest.csv: No ground truth data
  Processed 23apr_bryan_5.csv (10/13)
  Processed 23apr_bryan_4.csv (11/13)
  Processed 23apr_bryan_3.csv (12/13)
  Processed 23apr_bryan_2.csv (13/13)

Training models with 11 data points...

Model training complete:
  SBP model R²: 0.320
  DBP model R²: 0.373
  MAP model R²: 0.137

Processing 13 training files...
  Processed bryan_baseline.csv (1/13)
  Processed 30_apr_bryan_6.csv (2/13)
  Processed 30_apr_bryan_7.csv (3/13)
  Processed 30_apr_bryan_5.csv (4/13)
  Processed 30_apr_bryan_4.csv (