# WESAD Stress Detection Model

Here's the steps of training a stress detection model using the WESAD dataset:
1. Load & Explore
2. Extract
3. Train
4. Evaluate

In [3]:
import sys
sys.path.append('../src')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns


## 1. Load & Explore WESAD Data

In [None]:
from pathlib import Path
import pandas as pd
import numpy as np

def load_subject(subject_dir):
    """
    Load one WESAD subject folder (e.g., .../WESAD/S2/) and return a dict:
      signals["EDA"], ["BVP"], ["ACC"], ["TEMP"], ["HR"] -> {"time": np.ndarray, "data": np.ndarray, "fs": float, "start": float}
      signals["IBI"] -> pd.DataFrame with columns ["t_rel", "ibi", "time"]
      signals["tags"] -> pd.DataFrame with column ["time"]
      signals["_paths"] -> {"subject_dir": Path, "e4_dir": Path}
    """
    subject_dir = Path(subject_dir).expanduser().resolve()
    if not subject_dir.exists():
        raise FileNotFoundError(f"Subject directory not found: {subject_dir}")

    # Find *_E4_data (case-insensitive)
    e4_candidates = [p for p in subject_dir.iterdir()
                     if p.is_dir() and p.name.lower().endswith("_e4_data")]
    if not e4_candidates:
        contents = [p.name for p in subject_dir.iterdir()]
        raise FileNotFoundError(
            f"No '*_E4_data' folder found inside: {subject_dir}\n"
            f"Contents: {contents}"
        )
    if len(e4_candidates) > 1:
        e4_candidates.sort()
    e4_dir = e4_candidates[0]

    def _load_standard_signal(csv_name, expect_cols=None):
        p = e4_dir / csv_name
        if not p.exists():
            raise FileNotFoundError(f"Missing file: {p}")

        raw = pd.read_csv(p, header=None)
        if raw.shape[0] < 3:
            raise ValueError(f"File too short: {p} (rows={raw.shape[0]})")

        start = float(raw.iloc[0, 0])
        fs = float(raw.iloc[1, 0])
        data = raw.iloc[2:].to_numpy(dtype=float)

        if expect_cols is not None and data.shape[1] != expect_cols:
            raise ValueError(f"{p.name}: expected {expect_cols} columns, got {data.shape[1]}")

        t = start + np.arange(data.shape[0], dtype=float) / fs
        return {"time": t, "data": data, "fs": fs, "start": start}

    signals = {
        "_paths": {"subject_dir": subject_dir, "e4_dir": e4_dir}
    }

    # Standard sampled signals
    signals["EDA"]  = _load_standard_signal("EDA.csv",  expect_cols=1)
    signals["BVP"]  = _load_standard_signal("BVP.csv",  expect_cols=1)
    signals["TEMP"] = _load_standard_signal("TEMP.csv", expect_cols=1)
    signals["HR"]   = _load_standard_signal("HR.csv",   expect_cols=1)
    signals["ACC"]  = _load_standard_signal("ACC.csv",  expect_cols=3)

    # IBI: first row start unix time; remaining rows are [t_rel, ibi]
    ibi_path = e4_dir / "IBI.csv"
    if not ibi_path.exists():
        raise FileNotFoundError(f"Missing file: {ibi_path}")

    ibi_raw = pd.read_csv(ibi_path, header=None)
    if ibi_raw.shape[0] < 2 or ibi_raw.shape[1] < 2:
        raise ValueError(f"Unexpected IBI.csv shape: {ibi_raw.shape} at {ibi_path}")

    ibi_start = float(ibi_raw.iloc[0, 0])
    ibi = ibi_raw.iloc[1:, :2].copy()
    ibi.columns = ["t_rel", "ibi"]
    ibi["t_rel"] = pd.to_numeric(ibi["t_rel"], errors="coerce")
    ibi["ibi"] = pd.to_numeric(ibi["ibi"], errors="coerce")
    ibi = ibi.dropna().reset_index(drop=True)
    ibi["time"] = ibi_start + ibi["t_rel"].to_numpy(dtype=float)
    signals["IBI"] = ibi

    # tags: each row is an absolute unix timestamp
    tags_path = e4_dir / "tags.csv"
    if not tags_path.exists():
        raise FileNotFoundError(f"Missing file: {tags_path}")

    tags = pd.read_csv(tags_path, header=None, names=["time"])
    tags["time"] = pd.to_numeric(tags["time"], errors="coerce")
    tags = tags.dropna().reset_index(drop=True)
    signals["tags"] = tags

    return signals

WESAD_ROOT = Path("/Users/maggiebowen/Documents/GitHub/doomstopping/data/raw")

# test calls
signals = load_subject(WESAD_ROOT / "S2")
print(signals.keys())
print("EDA data shape:", signals["EDA"]["data"].shape)
print("IBI head:\n", signals["IBI"].head())

dict_keys(['_paths', 'EDA', 'BVP', 'TEMP', 'HR', 'ACC', 'IBI', 'tags'])
EDA data shape: (31494, 1)
IBI head:
        t_rel       ibi          time
0  14.313155  0.765660  1.495437e+09
1  15.203821  0.890666  1.495437e+09
2  15.985107  0.781286  1.495437e+09
3  16.797644  0.812537  1.495437e+09
4  17.578930  0.781286  1.495437e+09


## Label exploration and mapping
Quick helpers to inspect WESAD label segments and map them to tasks without touching existing cells.


In [1]:
# Inspect one subject's label distribution (labels are aligned to the chest sample rate)
import pickle

from pathlib import Path
import numpy as np

def load_pickle_subject(subject_id, root=None):
    root_path = Path(root) if root is not None else Path('/Users/maggiebowen/Documents/GitHub/doomstopping/data/raw')
    pkl_path = root_path / subject_id / f'{subject_id}.pkl'
    with pkl_path.open('rb') as f:
        return pickle.load(f, encoding='latin1')

s2 = load_pickle_subject('S2')
labels = s2['label'].astype(int)
vals, counts = np.unique(labels, return_counts=True)
print('Label counts:', dict(zip(vals.tolist(), counts.tolist())))


Label counts: {0: 2142701, 1: 800800, 2: 430500, 3: 253400, 4: 537599, 6: 45500, 7: 44800}


In [2]:
# Run-length encode labels to see segment order and durations (seconds assume 700 Hz chest rate)
def label_segments(labels, fs=700):
    segments = []
    start = 0
    cur = int(labels[0])
    for idx, val in enumerate(labels[1:], 1):
        v = int(val)
        if v != cur:
            seg_len = idx - start
            segments.append({'label': cur, 'start_idx': start, 'len_samples': seg_len, 'dur_s': seg_len / fs})
            cur, start = v, idx
    seg_len = len(labels) - start
    segments.append({'label': cur, 'start_idx': start, 'len_samples': seg_len, 'dur_s': seg_len / fs})
    return segments

segments = label_segments(labels)
print(f'Found {len(segments)} segments')
for i, seg in enumerate(segments):
    print(f"{i:02d}: label={seg['label']} start={seg['start_idx']} len={seg['len_samples']} dur_s={seg['dur_s']:.1f}")


Found 15 segments
00: label=0 start=0 len=214583 dur_s=306.5
01: label=1 start=214583 len=800800 dur_s=1144.0
02: label=0 start=1015383 len=576099 dur_s=823.0
03: label=2 start=1591482 len=430500 dur_s=615.0
04: label=0 start=2021982 len=190401 dur_s=272.0
05: label=6 start=2212383 len=45500 dur_s=65.0
06: label=0 start=2257883 len=610400 dur_s=872.0
07: label=4 start=2868283 len=273699 dur_s=391.0
08: label=0 start=3141982 len=192501 dur_s=275.0
09: label=3 start=3334483 len=253400 dur_s=362.0
10: label=0 start=3587883 len=100800 dur_s=144.0
11: label=7 start=3688683 len=44800 dur_s=64.0
12: label=0 start=3733483 len=114100 dur_s=163.0
13: label=4 start=3847583 len=263900 dur_s=377.0
14: label=0 start=4111483 len=143817 dur_s=205.5


**WESAD label mapping (from official schedule):**
- 1 = Baseline
- 2 = TSST stress
- 3 = Amusement
- 4 = Meditation
- 6 = sRead (self-report/reading)
- 7 = fRead (final reading)
- 0 = transition / not-worn / between blocks

Use the quest CSV in each subject folder (e.g., `data/raw/S2/S2_quest.csv`) to line up minute-level START/END times with the run-length segments above. For binary stress modeling, collapse to stress=2 vs non-stress={1,3,4,6,7} and drop/ignore 0.

## Batch label summary (all subjects)
Run the reusable script and load its outputs for quick inspection.


In [5]:
# Generate per-subject label summaries to data/processed/
import subprocess, sys
subprocess.run([sys.executable, '../scripts/summarize_wesad.py'], check=True)


/Users/maggiebowen/miniconda3/envs/doomstopping/bin/python: can't open file '/Users/maggiebowen/Documents/GitHub/doomstopping/notebooks/../scripts/summarize_wesad.py': [Errno 2] No such file or directory


CalledProcessError: Command '['/Users/maggiebowen/miniconda3/envs/doomstopping/bin/python', '../scripts/summarize_wesad.py']' returned non-zero exit status 2.

In [None]:
# Load and preview summary CSVs without pandas
import csv
from pathlib import Path

counts_path = Path('../data/processed/wesad_label_counts.csv')
segments_path = Path('../data/processed/wesad_label_segments.csv')

def head_csv(path, n=5):
    rows = []
    with path.open() as f:
        reader = csv.reader(f)
        for i, row in enumerate(reader):
            rows.append(row)
            if i >= n:
                break
    return rows

print('Counts:', counts_path)
for row in head_csv(counts_path):
    print(row)

print('
Segments:', segments_path)
for row in head_csv(segments_path):
    print(row)
