# WESAD Model Inference Test
This notebook verifies the trained model by:
1. Generating a fresh feature vector from raw WESAD data (Subject S2).
2. Saving it as `artifacts/sample_features.csv`.
3. Reloading it and running inference using the saved model.

In [9]:
import joblib, json, os
import numpy as np
import pandas as pd
from scipy import stats
import warnings
warnings.filterwarnings('ignore')

# Import WESAD Helper pipelines
import sys
sys.path.append('../src')
from wesad.load_wesad import load_subject_data, extract_labels, extract_chest_signals
from wesad.features_eda import extract_eda_features
from wesad.features_hrv import extract_hrv_features
from wesad.features_acc import extract_acc_features
from wesad.normalization import normalize_subjects

# Constants
FS = 700
WINDOW_SAMPLES = 42000  # 60 seconds
STEP_SAMPLES = 2100     # 3 seconds overlap step
WESAD_PATH = '../data/raw/'

## 1. Generate Sample Features (Robust Pipeline)

In [10]:
# Load Subject S2 (Sample)
print("Loading S2 data...")
data = load_subject_data("S2", WESAD_PATH)
chest_signals = extract_chest_signals(data)
labels = extract_labels(data)

all_features = []

eda_signal = chest_signals['EDA'].flatten()
ecg_signal = chest_signals['ECG'].flatten()
n_samples = len(labels)

# Process just the first 20 windows (approx 2 mins of sliding windows) for speed
max_windows = 20
window_count = 0

print("Extracting features...")
for start_idx in range(0, n_samples - WINDOW_SAMPLES + 1, STEP_SAMPLES):
    if window_count >= max_windows: break
    
    end_idx = start_idx + WINDOW_SAMPLES
    
    # Extract raw window
    eda_window = eda_signal[start_idx:end_idx]
    ecg_window = ecg_signal[start_idx:end_idx]
    acc_window = chest_signals['ACC'][start_idx:end_idx]
    labels_window = labels[start_idx:end_idx]
    
    # Get Label
    labels_nonzero = labels_window[labels_window != 0]
    if len(labels_nonzero) == 0: continue
    window_label = stats.mode(labels_nonzero, keepdims=False)[0]
    
    # ROBUST FEATURE EXTRACTION
    eda_feats = extract_eda_features(eda_window, sampling_rate=FS)
    hrv_feats = extract_hrv_features(ecg_window, sampling_rate=FS)
    acc_feats = extract_acc_features(acc_window, sampling_rate=FS)
    
    # Validity Check
    if not (eda_feats['valid_eda'] and hrv_feats['valid_hrv']):
        continue
    
    # Clean flags
    del eda_feats['valid_eda']
    del hrv_feats['valid_hrv']
    if 'valid_acc' in acc_feats: del acc_feats['valid_acc']
    
    # Merge
    f = {**eda_feats, **hrv_feats, **acc_feats}
    f['label'] = window_label
    f['subject'] = 'S2'
    all_features.append(f)
    window_count += 1

df_sample = pd.DataFrame(all_features)
print(f"Extracted {len(df_sample)} valid windows.")

Loading S2 data...
Extracting features...
Extracted 20 valid windows.


## 2. Normalization (Subject Baseline)

In [11]:
# Normalize using S2's own baseline stats from this session
# Note: In production you'd load pre-computed baseline stats, but here we calculate on the fly for the test.
if os.path.exists('../artifacts/feature_spec.json'):
    feature_spec = json.load(open('../artifacts/feature_spec.json'))
    feature_cols = feature_spec['feature_cols']
    
    # Ensure columns exist
    missing = [c for c in feature_cols if c not in df_sample.columns]
    if missing:
        print(f"Warning: Missing columns {missing}")
        # create empty
        for c in missing: df_sample[c] = 0.0
    
    # Just use built-in normalization for this test DF
    # Since S2 is the only subject, it will use its own baseline rows
    # Note: Baseline label = 1
    X_norm, _ = normalize_subjects(df_sample, df_sample, feature_cols, baseline_label=1)
    
    # Convert back to DF for saving with names
    df_norm = pd.DataFrame(X_norm, columns=feature_cols)
    df_norm.to_csv("../artifacts/sample_features.csv", index=False)
    print("Saved normalized sample features to ../artifacts/sample_features.csv")
else:
    print("Please generate artifacts/feature_spec.json from the training notebook first!")

Saved normalized sample features to ../artifacts/sample_features.csv


## 3. Inference Test

In [13]:
# The test code requested by user
if os.path.exists("../models/wesad_linear_svm_3class.joblib") and os.path.exists("../artifacts/feature_spec.json"):
    model_container = joblib.load("../models/wesad_linear_svm_3class.joblib")
    
    # Check if loaded object is the wrapper dictionary
    if isinstance(model_container, dict) and 'pipeline' in model_container:
        print("Loaded model artifact (dictionary wrapper detected). Extracting pipeline...")
        model = model_container['pipeline']
    else:
        model = model_container
        
    feat_spec = json.load(open("../artifacts/feature_spec.json"))
    feature_cols = feat_spec["feature_cols"]  # ordered list

    X = pd.read_csv("../artifacts/sample_features.csv")  # you generate this from WESAD
    assert all(c in X.columns for c in feature_cols)

    X_ordered = X[feature_cols].to_numpy()
    
    # Check for probability support
    proba = model.predict_proba(X_ordered)[:, 1] if hasattr(model, "predict_proba") else None
    if proba is None and hasattr(model, "decision_function"):
        # Fallback to decision function for SVM
        print("Model uses decision_function (LinearSVC default).")
        proba = model.decision_function(X_ordered)
        # If multi-class, this will be shape (n_samples, n_classes). Just taking raw scores.
        if len(proba.shape) > 1:
             proba = proba[:, 0] # Example for printing
        
    pred = model.predict(X_ordered)

    print("n_windows:", len(X_ordered))
    print("pred unique:", np.unique(pred, return_counts=True))
    print("proba range/scores:", (float(proba.min()), float(proba.max())) if proba is not None else None)
else:
    print("Model or Feature Spec missing. Please run 01_wesad_stress_model.ipynb first.")

Loaded model artifact (dictionary wrapper detected). Extracting pipeline...
Model uses decision_function (LinearSVC default).
n_windows: 20
pred unique: (array(['other'], dtype=object), array([20]))
proba range/scores: (-1.1873572650404416, -0.015055313141423388)
