## Import necessary library

In [None]:
import os
import pickle
import numpy as np
import pandas as pd
import scipy.signal as scisig
from scipy.signal import welch
from collections import Counter

## Import WESAD data

In [None]:
class SubjectData:
    def __init__(self, main_path, subject_number):
        self.name = f'S{subject_number}'
        with open(os.path.join(main_path, self.name, self.name + '.pkl'), 'rb') as file:
            self.data = pickle.load(file, encoding='latin1')
        self.bvp = self.data['signal']['wrist']['BVP']
        self.acc = self.data['signal']['wrist']['ACC']
        self.eda = self.data['signal']['wrist']['EDA']
        self.resp = self.data['signal']['chest']['Resp']
        self.labels = self.data['label']

## Define feature calculate function

In [None]:
def bvp_to_hrv(bvp_signal, fs):
    # Detect peaks
    peaks, _ = scisig.find_peaks(bvp_signal, distance=int(fs * 0.4))

    if len(peaks) < 3:
        return pd.DataFrame()
    
    # คำนวณ IBI
    ibi = np.diff(peaks) / fs * 1000 # ms
    rr_diff = np.diff(ibi)
    
    # Time axis for interpolation
    ibi_time = np.cumsum(ibi) / 1000 # sec
    interp_time = np.arange(0, ibi_time[-1], 1/0.4)
    ibi_interp = np.interp(interp_time, ibi_time, ibi)

    # คำนวณ HR
    hr = (60 * 1000) / ibi # bpm

    # HRV  metrics
    rmssd = np.sqrt(np.mean(rr_diff ** 2)) if len(rr_diff) > 0 else np.nan
    sdnn = np.std(ibi) if len(ibi) > 1 else np.nan
    pNN50 = np.sum(np.abs(rr_diff) > 50) / len(rr_diff) * 100 if len(rr_diff) > 0 else np.nan

    # Frequency domain
    f, pxx = welch(ibi_interp, fs=0.4)
    lf = np.trapz(pxx[(f >= 0.04) & (f <= 0.15)], f[(f >= 0.04) & (f <= 0.15)])
    hf = np.trapz(pxx[(f > 0.15) & (f <= 0.4)], f[(f > 0.15) & (f <= 0.4)])
    lf_hf_ratio = lf / hf if hf != 0 else np.nan

    # Alighn HR/IBI timestamsp (start at 2nd Beats)
    timestamps = peaks[1:] / fs

    return pd.DataFrame({
        'timestamps': pd.to_datetime(timestamps, unit='s'),
        'HR': hr,
        'IBI': ibi,
        'RMSSD': rmssd,
        'SDNN': sdnn,
        'pNN50': pNN50,
        'lf/hf': lf_hf_ratio
    })

In [None]:
def extract_resp_features(resp_signal, fs):
    peaks, _ = scisig.find_peaks(resp_signal, distance=fs * 2)
    if len(peaks) < 2:
        return {'RESP_rate': np.nan, 'RESP_regularity': np.nan}
    
    ibi = np.diff(peaks) / fs
    resp_rate = 60 / np.mean(ibi) if np.mean(ibi) > 0 else np.nan
    regularity = 1 / np.std(ibi) if np.std(ibi) > 0 else np.nan
    return {
        'RESP_rate': resp_rate,
        'RESP_regularity': regularity
        }

def extract_eda_features(eda_signal):
    x = np.arange(len(eda_signal))
    slope = float(np.polyfit(x, eda_signal, 1)[0]) if len(eda_signal) > 1 else np.nan
    return {
        'EDA_mean': np.mean(eda_signal),
        'EDA_std': np.std(eda_signal),
        'EDA_slope': slope
        }

## Feature extraction

In [None]:
# Sampling rates
fs_dict = {'ACC': 32, 'BVP': 64, 'EDA': 4, 'Resp': 700, 'label': 700}
WINDOW_IN_SECONDS = 30

# Save path
save_path = "../../data/processed/WESAD/feature_extracted_label"

In [None]:
def feature_extract(subject_id):
    subject = SubjectData(main_path="../data/raw/WESAD", subject_number=subject_id)

    # Signals
    bvp = subject.bvp.flatten()
    acc = subject.acc
    eda = subject.eda
    resp = subject.resp
    labels = subject.labels

    # Windows
    window_len = fs_dict['label'] * WINDOW_IN_SECONDS
    total_len = len(labels)
    n_windows = total_len // window_len

    all_window = []

    for i in range(n_windows):
        start = i * window_len
        end = (i + 1) * window_len
        
        '''# label majority vote (optional, not used in clustering)
        label_window = labels[start:end]
        label = Counter(label_window).most_common(1)[0][0]'''

        # majority vote label
        label_window = labels[start:end]
        label_window = [l for l in label_window if l in [1, 2, 3, 4]]
        if len(label_window) == 0:
            continue
        label = Counter(label_window).most_common(1)[0][0]

        # --- ACC ---
        acc_window = acc[start * fs_dict['ACC'] // fs_dict['label']: end * fs_dict['ACC'] // fs_dict['label'], :]
        if acc_window.shape[0] == 0: continue
        acc_x, acc_y, acc_z = acc_window[:, 0], acc_window[:, 1], acc_window[:, 2]
        net_acc = np.sqrt(acc_x ** 2 + acc_y ** 2 + acc_z ** 2)
        acc_features = {
            'ACC_x_mean': np.mean(acc_x),
            'ACC_y_mean': np.mean(acc_y),
            'ACC_z_mean': np.mean(acc_z),
            'net_acc_mean': np.mean(net_acc),
            'net_acc_std': np.std(net_acc)
        }

        # --- BVP / HRV ---
        bvp_window = bvp[start * fs_dict['BVP'] // fs_dict['label']: end * fs_dict['BVP'] // fs_dict['label']]
        hrv_df = bvp_to_hrv(bvp_window, fs_dict['BVP'])
        if hrv_df.empty: continue
        hrv_mean = hrv_df[['HR', 'IBI', 'RMSSD', 'SDNN', 'pNN50', 'lf/hf']].mean()

        # --- EDA ---
        eda_window = eda[start * fs_dict['EDA'] // fs_dict['label']: end * fs_dict['EDA'] // fs_dict['label']]
        eda_features = extract_eda_features(eda_window)

        # --- RESP ---
        resp_window = resp[start * fs_dict['Resp'] // fs_dict['label']: end * fs_dict['Resp'] // fs_dict['label']]
        if resp_window.ndim > 1:
            resp_window = resp_window.flatten()
        resp_features = extract_resp_features(resp_window, fs_dict['Resp'])

        data = {
            **acc_features,
            **eda_features,
            **resp_features,
            'HR': hrv_mean['HR'],
            'IBI': hrv_mean['IBI'],
            'RMSSD': hrv_mean['RMSSD'],
            'SDNN': hrv_mean['SDNN'],
            'pNN50': hrv_mean['pNN50'],
            'lf/hf': hrv_mean['lf/hf'],
            'label': label,  # ไม่ใช้ในการ clustering โดยตรง แต่เก็บไว้ดู
            'subject': subject_id
        }
        all_window.append(data)

    df = pd.DataFrame(all_window)
    df.to_csv(f'{save_path}/S{subject_id}.csv', index=False)
    print(f'Subject {subject_id} processed with {len(df)} samples.')

In [None]:
for subject_id in [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17]:
    feature_extract(subject_id)