In [17]:
import pickle
import numpy as np
from scipy import interpolate
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import os
from typing import Dict, Tuple, List, Optional
import warnings
warnings.filterwarnings('ignore')


In [18]:
data_dir = 'data/WESAD/'

subjects = []

for subject in os.listdir(data_dir):
    if subject.startswith('S'):
        subjects.append(subject)
        
subjects = sorted(subjects, key=lambda x: int(x[1:]))
print(subjects)


['S2', 'S3', 'S4', 'S5', 'S6', 'S7', 'S8', 'S9', 'S10', 'S11', 'S13', 'S14', 'S15', 'S16', 'S17']


In [19]:
def load_subject_data(data_dir, subject_id: str) -> Dict:
    """
    Load data for a specific subject
    
    Args:
        subject_id: Subject ID (e.g., 'S2', 'S3', etc.)
        
    Returns:
        Dictionary containing the loaded data
    """
    pkl_path = os.path.join(data_dir, subject_id, f"{subject_id}.pkl")
    
    if not os.path.exists(pkl_path):
        raise FileNotFoundError(f"Data file not found: {pkl_path}")
        
    with open(pkl_path, 'rb') as f:
        data = pickle.load(f, encoding='latin1')
        
    return data

subject_data = load_subject_data(data_dir, 'S2')

In [20]:
print(subject_data.keys())  # ['signal', 'label', 'subject'], label values are 0, 1, 2, 3, 4, 5, 6, 7
#  0 = not defined / transient, 1 = baseline, 2 = stress, 3 = amusement,
# 4 = meditation, 5/6/7 = should be ignored in this dataset
print(subject_data['signal'].keys())  # ['chest', 'wrist']
print(subject_data['signal']['chest'].keys())  # ['ACC', 'ECG', 'EDA', 'EMG', 'Temp', 'Resp']
print(subject_data['signal']['wrist'].keys())  # ['ACC', 'BVP', 'EDA', 'TEMP']

dict_keys(['signal', 'label', 'subject'])
dict_keys(['chest', 'wrist'])
dict_keys(['ACC', 'ECG', 'EMG', 'EDA', 'Temp', 'Resp'])
dict_keys(['ACC', 'BVP', 'EDA', 'TEMP'])


In [21]:
def extract_chest_data(subject_data):
    chest_dict = subject_data['signal']['chest']
    labels = subject_data['label']
    
    # Combine all chest sensor signals into one array
    chest_signals = []
    for sensor_name, sensor_data in chest_dict.items():
        if len(sensor_data.shape) == 1:
            chest_signals.append(sensor_data.reshape(-1, 1))
        else:
            chest_signals.append(sensor_data)
    
    # Concatenate all signals horizontally
    combined_signals = np.concatenate(chest_signals, axis=1)
    
    return combined_signals, labels

# Example usage
combined_signals, labels = extract_chest_data(subject_data)
combined_signals.shape

(4255300, 8)

In [22]:
def extract_wrist_data(subject_data):
    wrist_dict = subject_data['signal']['wrist']
    labels = subject_data['label']
    
    # Step 1: Find the wrist sensor with the longest time axis (usually ACC)
    max_len = 0
    for sensor_data in wrist_dict.values():
        max_len = max(max_len, len(sensor_data))

    # Step 2: Resample each wrist sensor to the reference length (max_len)
    wrist_signals = []
    for sensor_name, sensor_data in wrist_dict.items():
        # Ensure shape is (T, D)
        if len(sensor_data.shape) == 1:
            sensor_data = sensor_data.reshape(-1, 1)
        T, D = sensor_data.shape
        resampled = np.zeros((max_len, D))
        for d in range(D):
            f = interpolate.interp1d(np.linspace(0, 1, T), sensor_data[:, d], kind='linear', fill_value="extrapolate")
            resampled[:, d] = f(np.linspace(0, 1, max_len))
        wrist_signals.append(resampled)

    # Step 3: Concatenate all wrist signals horizontally → shape: (max_len, num_features)
    combined_signals = np.concatenate(wrist_signals, axis=1)

    # Step 4: Resample labels to match wrist data length
    label_interp = interpolate.interp1d(
        np.linspace(0, 1, len(labels)), labels, kind='nearest', fill_value="extrapolate"
    )
    resampled_labels = label_interp(np.linspace(0, 1, max_len)).astype(int)

    return combined_signals, resampled_labels

# Example usage
combined_signals, labels = extract_wrist_data(subject_data)
combined_signals.shape

(389056, 6)

In [23]:
all_chest_X, all_chest_y = [], []
all_wrist_X, all_wrist_y = [], []

for sid in subjects:
    print(f"Processing {sid}...")
    subject_data = load_subject_data(data_dir, sid)
    
    chest_X, chest_y = extract_chest_data(subject_data)
    wrist_X, wrist_y = extract_wrist_data(subject_data)

    # Filter invalid labels (0, 5, 6, 7)
    valid_mask_chest = np.isin(chest_y, [1, 2, 3, 4])
    valid_mask_wrist = np.isin(wrist_y, [1, 2, 3, 4])

    all_chest_X.append(chest_X[valid_mask_chest])
    all_chest_y.append(chest_y[valid_mask_chest])

    all_wrist_X.append(wrist_X[valid_mask_wrist])
    all_wrist_y.append(wrist_y[valid_mask_wrist])

Processing S2...
Processing S3...
Processing S4...
Processing S5...
Processing S6...
Processing S7...
Processing S8...
Processing S9...
Processing S10...
Processing S11...
Processing S13...
Processing S14...
Processing S15...
Processing S16...
Processing S17...


In [24]:
X_chest = np.vstack(all_chest_X)
y_chest = np.hstack(all_chest_y)

X_wrist = np.vstack(all_wrist_X)
y_wrist = np.hstack(all_wrist_y)

print("Chest data:", X_chest.shape, y_chest.shape)
print("Wrist data:", X_wrist.shape, y_wrist.shape)

Chest data: (31470603, 8) (31470603,)
Wrist data: (2877310, 6) (2877310,)


In [25]:
from sklearn.preprocessing import StandardScaler

scaler_chest = StandardScaler()
X_chest_norm = scaler_chest.fit_transform(X_chest)

scaler_wrist = StandardScaler()
X_wrist_norm = scaler_wrist.fit_transform(X_wrist)

In [26]:
from scipy import stats

def create_windows(X, y, window_size, stride):
    Xw, yw = [], []
    for i in range(0, len(X) - window_size, stride):
        xw = X[i:i+window_size]
        yw_ = y[i:i+window_size]
        # stats.mode may return a scalar or 1-element array
        m = stats.mode(yw_, axis=None)
        lbl = m.mode[0] if hasattr(m.mode, "__len__") else m.mode
        Xw.append(xw)
        yw.append(lbl)
    return np.array(Xw), np.array(yw)

In [27]:
# # chest windows: 5s @700Hz ⇒ 3500, stride 1s⇒700
# Xc_win, yc_win = create_windows(X_chest_norm, y_chest, window_size=3500, stride=700)
# wrist windows: 5s @32Hz ⇒ 160, stride 1s⇒32
Xc_win, yc_win = create_windows(X_wrist_norm, y_wrist, window_size=160, stride=32)

def to_categorical_manual(y, num_classes):
    """Manual implementation of to_categorical"""
    y = np.array(y, dtype='int')
    input_shape = y.shape
    y = y.ravel()
    n = y.shape[0]
    categorical = np.zeros((n, num_classes))
    categorical[np.arange(n), y] = 1
    output_shape = input_shape + (num_classes,)
    categorical = np.reshape(categorical, output_shape)
    return categorical

# Use it like this:
yc_cat = to_categorical_manual(yc_win - 1, num_classes=4)

In [30]:
from sklearn.model_selection import train_test_split

X_train, X_val, y_train, y_val = train_test_split(
    Xc_win, yc_cat, 
    test_size=0.2, 
    stratify=yc_win, 
    random_state=42
)

In [31]:
import tensorflow as tf
from tensorflow.keras import layers, models

window_size, num_feats = X_train.shape[1], X_train.shape[2]
num_classes = y_train.shape[1]

model = models.Sequential([
    layers.Input((window_size, num_feats)),
    layers.Conv1D(32, 5, padding='same', activation='relu'),
    layers.MaxPooling1D(2),
    layers.Conv1D(64, 5, padding='same', activation='relu'),
    layers.MaxPooling1D(2),
    layers.Conv1D(128, 3, padding='same', activation='relu'),
    layers.GlobalAveragePooling1D(),
    layers.Dense(64, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(num_classes, activation='softmax'),
])

model.compile(
    loss='categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)

model.summary()

In [32]:
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

checkpoint_cb = ModelCheckpoint(
    "best_model.keras",    # .keras ⇒ native Keras format
    monitor="val_accuracy",
    mode="max",
    save_best_only=True,
    verbose=1
)

earlystop_cb = EarlyStopping(
    monitor="val_accuracy",
    mode="max",
    patience=3,
    restore_best_weights=True,
    verbose=1
)

history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=50,            # you can set this high
    batch_size=64,
    callbacks=[checkpoint_cb, earlystop_cb]
)

Epoch 1/50
[1m1112/1124[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 4ms/step - accuracy: 0.7643 - loss: 0.6110
Epoch 1: val_accuracy improved from -inf to 0.93850, saving model to best_model.keras
[1m1124/1124[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 5ms/step - accuracy: 0.7655 - loss: 0.6082 - val_accuracy: 0.9385 - val_loss: 0.1462
Epoch 2/50
[1m1120/1124[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 5ms/step - accuracy: 0.9553 - loss: 0.1321
Epoch 2: val_accuracy improved from 0.93850 to 0.96897, saving model to best_model.keras
[1m1124/1124[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 5ms/step - accuracy: 0.9553 - loss: 0.1320 - val_accuracy: 0.9690 - val_loss: 0.1138
Epoch 3/50
[1m1124/1124[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.9744 - loss: 0.0812
Epoch 3: val_accuracy improved from 0.96897 to 0.97748, saving model to best_model.keras
[1m1124/1124[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s

In [14]:
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

checkpoint_cb = ModelCheckpoint(
    "best_model.keras",    # .keras ⇒ native Keras format
    monitor="val_accuracy",
    mode="max",
    save_best_only=True,
    verbose=1
)

earlystop_cb = EarlyStopping(
    monitor="val_accuracy",
    mode="max",
    patience=3,
    restore_best_weights=True,
    verbose=1
)

history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=50,            # you can set this high
    batch_size=64,
    callbacks=[checkpoint_cb, earlystop_cb]
)

Epoch 1/50
[1m562/562[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 51ms/step - accuracy: 0.6895 - loss: 0.7385
Epoch 1: val_accuracy improved from -inf to 0.96953, saving model to best_model.keras
[1m562/562[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 70ms/step - accuracy: 0.6897 - loss: 0.7380 - val_accuracy: 0.9695 - val_loss: 0.1235
Epoch 2/50
[1m561/562[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 49ms/step - accuracy: 0.9616 - loss: 0.1430
Epoch 2: val_accuracy improved from 0.96953 to 0.97753, saving model to best_model.keras
[1m562/562[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m30s[0m 53ms/step - accuracy: 0.9616 - loss: 0.1430 - val_accuracy: 0.9775 - val_loss: 0.0817
Epoch 3/50
[1m561/562[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 50ms/step - accuracy: 0.9738 - loss: 0.0888
Epoch 3: val_accuracy improved from 0.97753 to 0.98365, saving model to best_model.keras
[1m562/562[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m31s[0m 

In [33]:
import os
from tensorflow.keras.models import load_model

# Check that the file exists
print("Exists:", os.path.exists("best_model.keras"))

# Load the saved model
best_model = load_model("best_model.keras")
best_model.summary()

Exists: True


In [None]:
# preds = best_model.predict(X_new)  