In [61]:
"""wavenet_hrv_model.py

WaveNet‑style 1‑D CNN for HRV‑based SCD vs NSR classification
=============================================================

**Update 2025‑04‑18** – The input now contains **8 HRV features** per 5‑minute
segment (the target label is *not* part of the network input).  The default
`input_length` argument is therefore changed from 9 to 8.

The network structure remains the same:
  • Input (8) → Reshape (8, 1)
  • Dilated causal Conv1D stack with gated activation & residual / skip
    connections (dilation rates 1, 2, 4, 8, 16, 32, 64; 64 filters; kernel 2)
  • TimeDistributed Dense (64, ReLU)
  • Skip‑sum → ReLU → Flatten → Dropout(0.5) → Dense(512, ReLU) → Dropout(0.5)
  • Soft‑max output (2 classes) or Sigmoid (binary‑logit)

Use it like this:
>>> from wavenet_hrv_model import build_wavenet_hrv
>>> model = build_wavenet_hrv(input_length=8, n_classes=2)
>>> model.fit(X_train, y_train, epochs=50, batch_size=32,
...           validation_data=(X_val, y_val))
"""

import tensorflow as tf
from tensorflow.keras import layers, Model, Input


def build_wavenet_hrv(input_length: int = 8,
                      n_classes: int = 2,
                      n_filters: int = 64,
                      kernel_size: int = 2,
                      dilation_rates: list | None = None,
                      dropout_rate: float = 0.5,
                      dense_units: int = 512,
                      learning_rate: float = 1e-3) -> Model:
    """Create and compile the WaveNet HRV classifier.

    Parameters
    ----------
    input_length : int, default 8
        Number of HRV features per sample (MeanRR, RMSDD, pNN50, SDRR, CVRR,
        NN50, MinRR, MaxRR).
    n_classes : int, default 2
        Output classes (binary SCD vs NSR).
    n_filters : int, default 64
        Filters in each dilated Conv1D layer.
    kernel_size : int, default 2
        Convolution kernel size.
    dilation_rates : list[int] | None
        Dilation rates; default powers‑of‑two up to 64.
    dropout_rate : float, default 0.5
        Dropout probability after flatten and dense.
    dense_units : int, default 512
        Units in the penultimate dense layer.
    learning_rate : float, default 1e‑3
        Adam optimiser learning rate.
    """

    if dilation_rates is None:
        dilation_rates = [1, 2, 4, 8, 16, 32, 64]

    # Input → Reshape (batch, 8, 1)
    inp = Input(shape=(input_length,), name="Input")
    x = layers.Reshape((input_length, 1), name="Reshape_to_1D")(inp)

    skips = []

    for d in dilation_rates:
        tanh_out = layers.Conv1D(n_filters, kernel_size, padding="causal",
                                 dilation_rate=d, activation="tanh",
                                 name=f"Tanh_d{d}")(x)
        sig_out = layers.Conv1D(n_filters, kernel_size, padding="causal",
                                dilation_rate=d, activation="sigmoid",
                                name=f"Sig_d{d}")(x)
        gated = layers.Multiply(name=f"Gated_mul_d{d}")([tanh_out, sig_out])
        td = layers.TimeDistributed(layers.Dense(n_filters, activation="relu"),
                                    name=f"TD_Dense_d{d}")(gated)
        skips.append(td)
        x = layers.Add(name=f"Residual_add_d{d}")([x, gated])

    x = layers.Add(name="Skip_sum")(skips) if len(skips) > 1 else skips[0]
    x = layers.Activation("relu", name="Skip_relu")(x)

    x = layers.Flatten(name="Flatten")(x)
    x = layers.Dropout(dropout_rate, name="Dropout1")(x)
    x = layers.Dense(dense_units, activation="relu", name="Dense512")(x)
    x = layers.Dropout(dropout_rate, name="Dropout2")(x)

    act = "sigmoid" if n_classes == 1 else "softmax"
    out_units = 1 if n_classes == 1 else n_classes
    output = layers.Dense(out_units, activation=act, name="Output")(x)

    model = Model(inp, output, name="WaveNet_HRV")

    loss = "binary_crossentropy" if n_classes == 1 else "categorical_crossentropy"
    model.compile(tf.keras.optimizers.Adam(learning_rate), loss="sparse_categorical_crossentropy", metrics=["accuracy"])
    return model

In [62]:
model = build_wavenet_hrv(input_length=8, n_classes=2, dilation_rates=[64])

In [63]:
model.summary()

In [64]:
def load_feature_csvs(feature_dir):
    """Scan feature_dir for all CSV files, build a DataFrame with label."""
    rows = []
    for csv_path in glob.glob(os.path.join(feature_dir, '*.csv')):
        df = pd.read_csv(csv_path)
        # Determine label from filename
        fname = os.path.basename(csv_path)
        if 'First' in fname:
            label = 1 if 'First' in fname and fname.startswith('SCD') else 0  # SCD positive class
            df['Label'] = label
            rows.append(df)
            
    if not rows:
        raise ValueError(f"No CSV files found in {feature_dir}")
    full_df = pd.concat(rows, ignore_index=True)
    return full_df

In [65]:
import os, glob
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


feature_dir_scd = 'SCD_Features_CSV_17apr'  # adjust if different
feature_dir_nsr = 'NSR_Features_CSV_17apr'

# Combine both directories
feature_df = pd.concat([
    load_feature_csvs(feature_dir_scd),
    load_feature_csvs(feature_dir_nsr)
], ignore_index=True)

In [66]:
feature_df.shape

(36, 9)

In [67]:
feature_df.head()

Unnamed: 0,MeanRR,RMSDD,pNN50,SDRR,CVRR,NN50,MinRR,MaxRR,Label
0,0.676552,0.532766,0.091591,0.359284,0.053105,0.403,0.257812,2.328125,1
1,0.677597,0.749536,0.091364,0.501929,0.074075,0.402,0.257812,6.898438,1
2,0.591739,0.535856,0.081944,0.32909,0.055614,0.413,0.257812,1.890625,1
3,0.780396,0.590277,0.092167,0.412564,0.052866,0.353,0.28125,2.109375,1
4,0.708481,0.54593,0.092891,0.383095,0.054073,0.392,0.257812,3.664062,1


In [68]:
feature_cols = [
    "MeanRR", "RMSDD", "pNN50", "SDRR",
    "CVRR", "NN50", "MinRR", "MaxRR"
]

X = feature_df[feature_cols].values.astype("float32")
y = feature_df["Label"].values.astype("int")  


scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

In [69]:
X_train = X_scaled
y_train = y

In [70]:
y_train

array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [72]:
history = model.fit(X_train, y_train, epochs=50, batch_size=32)

Epoch 1/50
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 210ms/step - accuracy: 0.9711 - loss: 0.0753
Epoch 2/50
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 84ms/step - accuracy: 0.9711 - loss: 0.0619 
Epoch 3/50
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 83ms/step - accuracy: 0.9711 - loss: 0.0518 
Epoch 4/50
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 67ms/step - accuracy: 0.9711 - loss: 0.0620 
Epoch 5/50
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 106ms/step - accuracy: 0.9421 - loss: 0.0827
Epoch 6/50
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 103ms/step - accuracy: 1.0000 - loss: 0.0312
Epoch 7/50
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 84ms/step - accuracy: 1.0000 - loss: 0.0256 
Epoch 8/50
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 94ms/step - accuracy: 1.0000 - loss: 0.0255 
Epoch 9/50
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37

In [74]:
history.__dict__

{'params': {'verbose': 'auto', 'epochs': 50, 'steps': 2},
 '_model': <Functional name=WaveNet_HRV, built=True>,
 'history': {'accuracy': [0.9722222089767456,
   0.9722222089767456,
   0.9722222089767456,
   0.9722222089767456,
   0.9444444179534912,
   1.0,
   1.0,
   1.0,
   0.9722222089767456,
   0.9444444179534912,
   0.9722222089767456,
   1.0,
   1.0,
   1.0,
   0.9722222089767456,
   1.0,
   0.9722222089767456,
   0.9722222089767456,
   1.0,
   1.0,
   0.9722222089767456,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   0.9722222089767456,
   1.0,
   1.0,
   1.0,
   0.9722222089767456,
   1.0,
   0.9722222089767456,
   0.9722222089767456,
   0.9722222089767456,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   1.0,
   0.9722222089767456,
   1.0,
   1.0],
  'loss': [0.07866330444812775,
   0.06344245374202728,
   0.05071103200316429,
   0.060455866158008575,
   0.08153235912322998,
   0.03136797621846199,
   0.026886122301220894,
   0.02576674148