In [18]:
cd

/bin/bash: line 1: ../ls: No such file or directory


In [19]:
################################################################################
# ExtendedKalmanFilter_FallDetection_YoungWatch_EKFPlots.ipynb (or .py)
################################################################################

# %% [markdown]
# # EKF for Fall Detection (Watch Accelerometer + Gyroscope, Young Only)
#
# 1) Only parse "data/smartfallmm/young/accelerometer/watch" + "data/smartfallmm/young/gyroscope/watch"
# 2) Merge accelerometer & gyro => 6 columns
# 3) Use EKF to remove gravity
# 4) Sliding Window (128, stride=64)
# 5) For subject=31, trials T=1..5, plot "before vs. after" each window, saving in:
#    visualizations/{subject}/{activity}/SxxAxxTxx_windowW.png

import os
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from typing import List, Dict, Tuple
from numpy.linalg import norm
from scipy.io import loadmat
from scipy.signal import butter, filtfilt
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# For reproducibility
torch.manual_seed(42)
np.random.seed(42)

# %% [markdown]
# ## 1. EKF for Inertial Sensor Fusion

def normalize_quaternion(q: np.ndarray) -> np.ndarray:
    """Normalize quaternion."""
    return q / np.linalg.norm(q)

def quat_to_rot_mat(q: np.ndarray) -> np.ndarray:
    """Convert quaternion [qw, qx, qy, qz] to a rotation matrix (3x3)."""
    qw, qx, qy, qz = q
    R = np.array([
        [1 - 2*(qy**2 + qz**2),     2*(qx*qy - qw*qz),     2*(qx*qz + qw*qy)],
        [2*(qx*qy + qw*qz),         1 - 2*(qx**2 + qz**2), 2*(qy*qz - qw*qx)],
        [2*(qx*qz - qw*qy),         2*(qy*qz + qw*qx),     1 - 2*(qx**2 + qy**2)]
    ])
    return R

def integrate_gyro(q: np.ndarray, w: np.ndarray, dt: float) -> np.ndarray:
    """Integrate gyroscope angular velocity (rad/s) over dt to update orientation."""
    qw, qx, qy, qz = q
    wx, wy, wz = w

    dq = 0.5 * np.array([
        - qx*wx - qy*wy - qz*wz,
          qw*wx + qy*wz - qz*wy,
          qw*wy - qx*wz + qz*wx,
          qw*wz + qx*wy - qy*wx
    ])
    q_new = q + dq * dt
    return normalize_quaternion(q_new)

class EKFInertial:
    """
    State = [qw, qx, qy, qz, bgx, bgy, bgz].
    We do not track velocity/position, only orientation + gyro bias.
    """
    def __init__(self, dt=0.01, var_gyro=1e-5, var_acc=1e-2):
        self.dt = dt
        self.x = np.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])  # orientation + bias
        self.P = np.eye(7) * 0.01

        self.Q = np.eye(7) * var_gyro
        # reduce orientation part slightly
        self.Q[0:4, 0:4] *= 0.1

        self.R = np.eye(3) * var_acc

    def predict(self, gyro_meas: np.ndarray):
        q = self.x[0:4]
        bg = self.x[4:7]
        w = gyro_meas - bg

        q_new = integrate_gyro(q, w, self.dt)
        self.x[0:4] = q_new

        F = np.eye(7)  # simplified
        self.P = F @ self.P @ F.T + self.Q

    def update(self, acc_meas: np.ndarray):
        """Use accelerometer as gravity reference."""
        q = self.x[0:4]
        R_sb = quat_to_rot_mat(q).T
        g_s = R_sb @ np.array([0, 0, 9.81])

        z = acc_meas - g_s  # residual

        eps = 1e-6
        H = np.zeros((3,7))
        for i in range(4):
            dq = np.zeros(4)
            dq[i] = eps
            q_pert = normalize_quaternion(q + dq)
            R_pert = quat_to_rot_mat(q_pert).T
            g_s_pert = R_pert @ np.array([0,0,9.81])
            dh = (g_s_pert - g_s)/eps
            H[:, i] = dh

        S = H @ self.P @ H.T + self.R
        K = self.P @ H.T @ np.linalg.inv(S)
        dx = K @ z

        self.x += dx
        self.x[0:4] = normalize_quaternion(self.x[0:4])

        I_7 = np.eye(7)
        self.P = (I_7 - K @ H) @ self.P @ (I_7 - K @ H).T + K @ self.R @ K.T

    def fuse(self, gyro: np.ndarray, acc: np.ndarray) -> np.ndarray:
        """One step of EKF: predict+update. Return gravity in sensor frame."""
        self.predict(gyro)
        self.update(acc)
        q = self.x[0:4]
        R_sb = quat_to_rot_mat(q).T
        g_s = R_sb @ np.array([0,0,9.81])
        return g_s

def ekf_fusion(acc_data: np.ndarray, gyro_data: np.ndarray, fs=100.0):
    """
    Merges acc & gyro via EKF => linear_acc = acc - gravity_in_sensor
    """
    dt = 1.0 / fs
    ekf = EKFInertial(dt=dt, var_gyro=1e-5, var_acc=5e-2)
    T = acc_data.shape[0]
    out = np.zeros_like(acc_data)
    for t in range(T):
        g_s = ekf.fuse(gyro_data[t], acc_data[t])
        out[t] = acc_data[t] - g_s
    return out

# %% [markdown]
# ## 2. Flexible Loader for watch accelerometer / gyro
# 
# We'll parse CSV with either commas or semicolons, skipping any lines that can't convert to float. We only read 2 or 4 columns here because watch CSV might have a timestamp plus x,y,z.

def flexible_csvloader(file_path: str):
    """
    Reads watch data. Usually format:
      1) 'YYYY-MM-DD HH:MM:SS.sss,<x>,<y>,<z>' or
         'YYYY-MM-DD HH:MM:SS.sss;<x>;<y>;<z>'
    We want only 3 columns of numeric data (x,y,z).
    We'll skip the datetime string. Or if there's just time + x,y,z => we handle that.
    """
    try:
        # read with a regex that splits on commas or semicolons
        df = pd.read_csv(
            file_path,
            sep='[;,]',
            engine='python',
            header=None,
            comment='#',
            # 'comment' if you have weird lines
        ).dropna().bfill()

        # Expect at least 4 columns if there's a time string in col0 + x,y,z in next columns
        # or 3 columns if it's strictly x,y,z
        ncols = df.shape[1]
        if ncols < 3:
            print(f"[ERROR] {file_path} has only {ncols} columns. Skipping.")
            return np.empty((0,3), dtype=np.float32)

        # We'll take the last 3 columns as x,y,z
        arr = df.iloc[:, -3:].astype('float32').to_numpy()
        print(f"[INFO] Loaded watch file: {file_path}, shape={arr.shape}")
        return arr

    except Exception as e:
        print(f"[ERROR] Could not parse file: {file_path}. Reason: {e}")
        return np.empty((0,3), dtype=np.float32)

# %% [markdown]
# ### 2.1 Merging watch accelerometer + watch gyroscope
# 
# For each trial `SxxAyyTzz`, we want `(accelerometer Nx3, gyroscope Nx3)` => Nx6 if possible.

def merge_acc_gyro(acc_array: np.ndarray, gyro_array: np.ndarray):
    """
    If they differ in length, we'll truncate to the min length.
    Merge => Nx6 columns: [accX,accY,accZ, gyroX,gyroY,gyroZ].
    """
    nA = acc_array.shape[0]
    nG = gyro_array.shape[0]
    if nA == 0 or nG == 0:
        return np.empty((0,6), dtype=np.float32)

    N = min(nA, nG)
    # naive alignment: just match sample i to sample i
    merged = np.concatenate([acc_array[:N], gyro_array[:N]], axis=1)  # shape Nx6
    return merged

# %% [markdown]
# ## 3. Data Classes: Only "young" watch data

class MatchedWatchTrial:
    """
    Storing watch accelerometer & gyroscope paths for a single SxxAyyTzz
    """
    def __init__(self, subject_id, activity_id, trial_id):
        self.subject_id  = subject_id
        self.activity_id = activity_id
        self.trial_id    = trial_id
        self.acc_file    = None
        self.gyr_file    = None

    def __repr__(self):
        return (f"MatchedWatchTrial(S={self.subject_id},A={self.activity_id},"
                f"T={self.trial_id},acc={self.acc_file},gyr={self.gyr_file})")

class SmartFallMM_Watch:
    """
    Only parse watch data from data/smartfallmm/young/{accelerometer|gyroscope}/watch
    for SxxAyyTzz files. We'll skip old, phone, skeleton, etc.
    """
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.trials: Dict[Tuple[int,int,int], MatchedWatchTrial] = {}

    def _add_file(self, subject, activity, trial, is_acc, filepath):
        key = (subject, activity, trial)
        if key not in self.trials:
            self.trials[key] = MatchedWatchTrial(subject, activity, trial)
        if is_acc:
            self.trials[key].acc_file = filepath
        else:
            self.trials[key].gyr_file = filepath

    def load_files(self):
        # We'll specifically only look under:
        #  <root>/young/accelerometer/watch
        #  <root>/young/gyroscope/watch
        # We'll parse SxxAyyTzz from filename
        valid_paths = [
            os.path.join(self.root_dir, "young", "accelerometer", "watch"),
            os.path.join(self.root_dir, "young", "gyroscope",     "watch")
        ]
        for base in valid_paths:
            for root, _, files in os.walk(base):
                for f in files:
                    if not f.lower().endswith(".csv"):
                        continue
                    # parse SxxAyyTzz
                    try:
                        s_id = int(f[1:3])  # subject
                        a_id = int(f[4:6])  # activity
                        t_id = int(f[7:9])  # trial
                        fullpath = os.path.join(root, f)
                        if "accelerometer" in base:
                            self._add_file(s_id, a_id, t_id, True, fullpath)
                        else:
                            self._add_file(s_id, a_id, t_id, False, fullpath)
                    except:
                        pass


# %% [markdown]
# ## 4. Builder that merges, applies EKF, and saves "before vs after" plots
# 
# - We only do so for **subjects_of_interest** (like `[31]`)  
# - We only save images for **trials T=1..5**  
# - We store them in `visualizations/{subject}/{activity}/SxxAyyTzz_window{w}.png`

def sliding_window(data: np.ndarray, window_size: int, stride: int):
    n = data.shape[0]
    feats = data.shape[1]
    out = []
    i=0
    while i+window_size <= n:
        out.append(data[i:i+window_size])
        i += stride
    if not out:
        return np.empty((0, window_size, feats), dtype=np.float32)
    return np.stack(out, axis=0)

class EKFWatchBuilder:
    def __init__(self, watch_dataset: SmartFallMM_Watch,
                 max_length=128,
                 stride=64,
                 fs=100.0,
                 subjects_of_interest=None):
        self.dataset = watch_dataset
        self.max_length = max_length
        self.stride = stride
        self.fs = fs
        if subjects_of_interest is None:
            subjects_of_interest = []
        self.subjects_of_interest = subjects_of_interest

    def _ensure_dir(self, path):
        os.makedirs(path, exist_ok=True)

    def process_all(self):
        """
        Iterate over matched watch trials, for the specified subjects,
        and if trial in [1..5], then:
          1) Merge ACC+GYRO => Nx6
          2) "before" = ACC Nx3
          3) "after" = EKF ( Nx3 )
          4) sliding_window => compare each window, plot, store
        """
        for (s_id, a_id, t_id), trial_obj in self.dataset.trials.items():
            # only if s_id in subjects_of_interest
            if s_id not in self.subjects_of_interest:
                continue
            # only if T=1..5
            if t_id < 1 or t_id > 5:
                continue
            # Make sure we have both files
            if (trial_obj.acc_file is None) or (trial_obj.gyr_file is None):
                continue

            # Load each file
            acc_data = flexible_csvloader(trial_obj.acc_file)  # Nx3
            gyr_data = flexible_csvloader(trial_obj.gyr_file)  # Mx3
            merged = merge_acc_gyro(acc_data, gyr_data)        # Lx6
            if merged.shape[0] == 0:
                continue

            # "before" = merged[:,0:3]
            raw_acc = merged[:,0:3].copy()
            # "after" => apply EKF
            # split => (acc, gyro)
            acc_ = merged[:,0:3]
            gyr_ = merged[:,3:6]
            fused = ekf_fusion(acc_, gyr_, fs=self.fs)  # Lx3

            # sliding_window
            windows_before = sliding_window(raw_acc, self.max_length, self.stride)
            windows_after  = sliding_window(fused,   self.max_length, self.stride)

            # create directory: visualizations/{s_id}/{a_id}
            out_dir = f"visualizations/{s_id}/{a_id}"
            self._ensure_dir(out_dir)

            # For each window, plot X, Y, Z before vs after
            n_win = windows_before.shape[0]
            for w_i in range(n_win):
                w_b = windows_before[w_i]  # shape (128,3)
                w_a = windows_after[w_i]   # shape (128,3)

                # Plot
                fig, axs = plt.subplots(3,1,figsize=(10,6), sharex=True)
                time_axis = np.arange(self.max_length)/self.fs  # sec
                # X
                axs[0].plot(time_axis, w_b[:,0], 'r', label='Before-X')
                axs[0].plot(time_axis, w_a[:,0], 'r--', label='After-X')
                axs[0].legend(loc='upper right')
                axs[0].set_ylabel("m/s^2")

                # Y
                axs[1].plot(time_axis, w_b[:,1], 'g', label='Before-Y')
                axs[1].plot(time_axis, w_a[:,1], 'g--', label='After-Y')
                axs[1].legend(loc='upper right')
                axs[1].set_ylabel("m/s^2")

                # Z
                axs[2].plot(time_axis, w_b[:,2], 'b', label='Before-Z')
                axs[2].plot(time_axis, w_a[:,2], 'b--', label='After-Z')
                axs[2].legend(loc='upper right')
                axs[2].set_ylabel("m/s^2")
                axs[2].set_xlabel("Time (s)")

                fig.suptitle(f"S{s_id:02d}A{a_id:02d}T{t_id:02d} Window {w_i}")

                # Save
                out_name = f"{out_dir}/S{s_id:02d}A{a_id:02d}T{t_id:02d}_window{w_i}.png"
                plt.savefig(out_name)
                plt.close(fig)

            print(f"[DONE] S{s_id:02d}A{a_id:02d}T{t_id:02d} => {n_win} windows plotted.")

# %% [markdown]
# ## 5. Main execution
#
# - We create `SmartFallMM_Watch`, load only **young** / watch data (acc + gyro).
# - We run `EKFWatchBuilder` on subject=31, only T=1..5, and store plots in `visualizations/31/<activity>/S31AxxTxx_windowW.png`.

if __name__ == "__main__":
    dataset_root = "../data/smartfallmm"
    watch_dataset = SmartFallMM_Watch(dataset_root)
    watch_dataset.load_files()
    print(f"[INFO] Found {len(watch_dataset.trials)} watch-based trials in {dataset_root}")

    builder = EKFWatchBuilder(
        watch_dataset,
        max_length=128,
        stride=64,
        fs=31.125,
        subjects_of_interest=[31]
    )
    builder.process_all()
    print("[INFO] Completed processing & visualization for subject 31 (T=1..5).")


[INFO] Found 906 watch-based trials in ../data/smartfallmm
[INFO] Loaded watch file: ../data/smartfallmm/young/accelerometer/watch/S31A06T01.csv, shape=(441, 3)
[INFO] Loaded watch file: ../data/smartfallmm/young/gyroscope/watch/S31A06T01.csv, shape=(438, 3)
[DONE] S31A06T01 => 5 windows plotted.
[INFO] Loaded watch file: ../data/smartfallmm/young/accelerometer/watch/S31A07T01.csv, shape=(211, 3)
[INFO] Loaded watch file: ../data/smartfallmm/young/gyroscope/watch/S31A07T01.csv, shape=(208, 3)
[DONE] S31A07T01 => 2 windows plotted.
[INFO] Loaded watch file: ../data/smartfallmm/young/accelerometer/watch/S31A05T03.csv, shape=(205, 3)
[INFO] Loaded watch file: ../data/smartfallmm/young/gyroscope/watch/S31A05T03.csv, shape=(202, 3)
[DONE] S31A05T03 => 2 windows plotted.
[INFO] Loaded watch file: ../data/smartfallmm/young/accelerometer/watch/S31A08T01.csv, shape=(276, 3)
[INFO] Loaded watch file: ../data/smartfallmm/young/gyroscope/watch/S31A08T01.csv, shape=(273, 3)
[DONE] S31A08T01 => 3 wi

In [None]:
################################################################################
# ExtendedKalmanFilter_FallDetection_YoungWatch_EKFPlots.py
################################################################################

# %% [markdown]
# # EKF for Fall Detection (Watch Accelerometer + Gyroscope, Young Only)
#
# **Features**:
# 1) Parse only `data/smartfallmm/young/accelerometer/watch` + `data/smartfallmm/young/gyroscope/watch`.
# 2) Match trials by SxxAyyTzz.
# 3) Merge (ACC Nx3, GYR Nx3) => Nx6.
# 4) Apply EKF => fused linear accel (remove gravity).
# 5) For `subject=31`, `trial=1..5`, create a **single plot** of the entire trial:
#    - Raw Accelerometer
#    - Raw Gyroscope
#    - EKF-Fused
#    - EKF-Fused Normalized (local, just for visualization)
# 6) Save figure in `visualizations/{subject}/{activity}/SxxAyyTzz.png`
#
# Also does "sliding_window" for potential later model usage,
# but the plot shows the entire trial as one continuous time series.

import os
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from typing import List, Dict, Tuple
from numpy.linalg import norm
from scipy.io import loadmat
from scipy.signal import butter, filtfilt
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# For reproducibility
torch.manual_seed(42)
np.random.seed(42)

# %% [markdown]
# ## 1. EKF for Inertial Sensor Fusion

def normalize_quaternion(q: np.ndarray) -> np.ndarray:
    """Normalize quaternion."""
    return q / np.linalg.norm(q)

def quat_to_rot_mat(q: np.ndarray) -> np.ndarray:
    """Convert quaternion [qw, qx, qy, qz] -> rotation matrix (3x3)."""
    qw, qx, qy, qz = q
    R = np.array([
        [1 - 2*(qy**2 + qz**2), 2*(qx*qy - qw*qz),     2*(qx*qz + qw*qy)],
        [2*(qx*qy + qw*qz),     1 - 2*(qx**2 + qz**2), 2*(qy*qz - qw*qx)],
        [2*(qx*qz - qw*qy),     2*(qy*qz + qw*qx),     1 - 2*(qx**2 + qy**2)]
    ])
    return R

def integrate_gyro(q: np.ndarray, w: np.ndarray, dt: float) -> np.ndarray:
    """
    Integrate gyroscope angular velocity (rad/s) over dt
    to update orientation quaternion.
    """
    qw, qx, qy, qz = q
    wx, wy, wz = w

    dq = 0.5 * np.array([
        - qx*wx - qy*wy - qz*wz,
          qw*wx + qy*wz - qz*wy,
          qw*wy - qx*wz + qz*wx,
          qw*wz + qx*wy - qy*wx
    ])
    q_new = q + dq * dt
    return normalize_quaternion(q_new)

class EKFInertial:
    """
    State = [qw, qx, qy, qz, bgx, bgy, bgz].
    Only orientation + gyro bias, no velocity/position.
    """
    def __init__(self, dt=0.01, var_gyro=1e-5, var_acc=1e-2):
        self.dt = dt
        # initial
        self.x = np.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
        self.P = np.eye(7) * 0.01

        # process noise
        self.Q = np.eye(7) * var_gyro
        self.Q[0:4, 0:4] *= 0.1

        # measurement noise
        self.R = np.eye(3) * var_acc

    def predict(self, gyro_meas: np.ndarray):
        q = self.x[0:4]
        bg = self.x[4:7]
        w = gyro_meas - bg

        q_new = integrate_gyro(q, w, self.dt)
        self.x[0:4] = q_new

        F = np.eye(7)
        self.P = F @ self.P @ F.T + self.Q

    def update(self, acc_meas: np.ndarray):
        """Use accelerometer as gravity reference."""
        q = self.x[0:4]
        R_sb = quat_to_rot_mat(q).T
        g_s = R_sb @ np.array([0, 0, 9.81])

        z = acc_meas - g_s  # residual

        eps = 1e-6
        H = np.zeros((3,7))
        for i in range(4):
            dq = np.zeros(4)
            dq[i] = eps
            q_pert = normalize_quaternion(q + dq)
            R_pert = quat_to_rot_mat(q_pert).T
            g_s_pert = R_pert @ np.array([0,0,9.81])
            dh = (g_s_pert - g_s)/eps
            H[:, i] = dh

        S = H @ self.P @ H.T + self.R
        K = self.P @ H.T @ np.linalg.inv(S)
        dx = K @ z

        self.x += dx
        self.x[0:4] = normalize_quaternion(self.x[0:4])

        I_7 = np.eye(7)
        self.P = (I_7 - K @ H) @ self.P @ (I_7 - K @ H).T + K @ self.R @ K.T

    def fuse(self, gyro: np.ndarray, acc: np.ndarray) -> np.ndarray:
        """One step (predict + update). Return gravity in sensor frame."""
        self.predict(gyro)
        self.update(acc)
        q = self.x[0:4]
        R_sb = quat_to_rot_mat(q).T
        g_s = R_sb @ np.array([0,0,9.81])
        return g_s

def ekf_fusion(acc_data: np.ndarray, gyro_data: np.ndarray, fs=100.0) -> np.ndarray:
    """
    Return linear accel = acc - gravity_in_sensor, for entire Nx3 arrays
    """
    dt = 1.0 / fs
    ekf = EKFInertial(dt=dt, var_gyro=1e-5, var_acc=5e-2)
    N = acc_data.shape[0]
    fused = np.zeros_like(acc_data)
    for i in range(N):
        g_s = ekf.fuse(gyro_data[i], acc_data[i])
        fused[i] = acc_data[i] - g_s
    return fused

# %% [markdown]
# ## 2. CSV Loaders & Merge

def flexible_csvloader(file_path: str) -> np.ndarray:
    """
    Reads watch data for young participants.
    Format can be:
      'YYYY-MM-DD HH:MM:SS.sss,<x>,<y>,<z>'
      or with semicolons => 'YYYY-MM-DD HH:MM:SS.sss;<x>;<y>;<z>'
    We'll parse the last 3 columns as float x,y,z.
    If parse fails => empty array.
    """
    try:
        df = pd.read_csv(
            file_path,
            sep='[;,]',
            engine='python',
            header=None
        ).dropna().bfill()

        # We want at least 3 columns. We'll take the last 3 as x,y,z
        num_cols = df.shape[1]
        if num_cols < 3:
            print(f"[ERROR] {file_path} has only {num_cols} cols. Skipping.")
            return np.empty((0,3), dtype=np.float32)

        arr = df.iloc[:, -3:].astype(np.float32).to_numpy()
        print(f"[INFO] Loaded file: {file_path}, shape={arr.shape}")
        return arr

    except Exception as e:
        print(f"[ERROR] Could not parse file: {file_path}. Reason: {e}")
        return np.empty((0,3), dtype=np.float32)

def merge_acc_gyro(acc_arr: np.ndarray, gyr_arr: np.ndarray) -> np.ndarray:
    """
    If either is empty => return empty Nx6
    Otherwise, truncate to min length
    => Nx6 = [accX,accY,accZ, gyroX,gyroY,gyroZ]
    """
    if acc_arr.shape[0] == 0 or gyr_arr.shape[0] == 0:
        return np.empty((0,6), dtype=np.float32)

    N = min(acc_arr.shape[0], gyr_arr.shape[0])
    merged = np.concatenate([acc_arr[:N], gyr_arr[:N]], axis=1)
    return merged

# %% [markdown]
# ## 3. Data Classes: Only "young" watch data, matched by SxxAyyTzz

class MatchedWatchTrial:
    """Single trial with separate accelerometer & gyroscope file paths."""
    def __init__(self, subject_id, activity_id, trial_id):
        self.subject_id  = subject_id
        self.activity_id = activity_id
        self.trial_id    = trial_id
        self.acc_file    = None
        self.gyr_file    = None

    def __repr__(self):
        return (f"MatchedWatchTrial(S={self.subject_id},A={self.activity_id},"
                f"T={self.trial_id}, acc={self.acc_file}, gyr={self.gyr_file})")

class SmartFallMM_Watch:
    """
    Only parse:
      data/smartfallmm/young/accelerometer/watch
      data/smartfallmm/young/gyroscope/watch
    We'll store matched trials by (Sxx,Ayy,Tzz).
    """
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.trials: Dict[Tuple[int,int,int], MatchedWatchTrial] = {}

    def _add_file(self, subj, act, tri, is_acc, filepath):
        key = (subj, act, tri)
        if key not in self.trials:
            self.trials[key] = MatchedWatchTrial(subj, act, tri)
        if is_acc:
            self.trials[key].acc_file = filepath
        else:
            self.trials[key].gyr_file = filepath

    def load_files(self):
        # We specifically look under:
        #   data/smartfallmm/young/accelerometer/watch
        #   data/smartfallmm/young/gyroscope/watch
        # parse SxxAyyTzz from file name
        targets = [
            os.path.join(self.root_dir, "young", "accelerometer", "watch"),
            os.path.join(self.root_dir, "young", "gyroscope", "watch")
        ]
        for base in targets:
            is_acc = ("accelerometer" in base)
            for root, dirs, files in os.walk(base):
                for f in files:
                    if not f.lower().endswith(".csv"):
                        continue
                    try:
                        s_id = int(f[1:3])  # Sxx
                        a_id = int(f[4:6])  # Ayy
                        t_id = int(f[7:9])  # Tzz
                        fullpath = os.path.join(root, f)
                        self._add_file(s_id, a_id, t_id, is_acc, fullpath)
                    except:
                        pass

# %% [markdown]
# ## 4. Sliding Window (for model usage), but we'll plot entire trial

def sliding_window(data: np.ndarray, window_size: int=128, stride: int=64):
    n = data.shape[0]
    feats = data.shape[1]
    out = []
    i = 0
    while i+window_size <= n:
        out.append(data[i:i+window_size])
        i += stride
    if not out:
        return np.empty((0, window_size, feats), dtype=np.float32)
    return np.stack(out, axis=0)

# %% [markdown]
# ## 5. Builder: For subject=31, trial=1..5, do:
# - Merge ACC+GYRO => Nx6
# - EKF => Nx3 (fused)
# - Plot raw ACC, raw GYR, fused, fused_norm in a single figure for the entire Nx steps
# - Save in `visualizations/{s_id}/{a_id}/SxxAyyTzz.png`
# - Also do windowing if needed for your model, but the final plot is entire trial

class EKFWatchBuilder:
    def __init__(self,
                 watch_dataset: SmartFallMM_Watch,
                 max_length=128,
                 stride=64,
                 fs=31.125,
                 subject_of_interest=31):
        self.dataset = watch_dataset
        self.max_length = max_length
        self.stride = stride
        self.fs = fs
        self.subject_of_interest = subject_of_interest

    def _ensure_dir(self, path):
        os.makedirs(path, exist_ok=True)

    def process_all(self):
        """
        For each trial in watch_dataset, if subject=31, trial in [1..5],
        do the following:
          1) Load ACC Nx3, GYR Nx3
          2) Merge => Nx6
          3) raw_acc = Nx3, raw_gyro= Nx3
          4) fused_ekf = Nx3
          5) fused_norm => local normalization for the entire Nx3
          6) plot a single figure with subplots x,y,z, each has
             raw_acc, raw_gyro, fused_ekf, fused_norm
          7) Save figure in visualizations/31/<activity>/S31AxxTyy.png
          8) Optionally do sliding_window for your ML pipeline, if needed
        """
        for (s_id, a_id, t_id), trial_obj in self.dataset.trials.items():
            if s_id != self.subject_of_interest:
                continue
            if t_id < 1 or t_id > 5:
                continue
            # must have both files
            if (trial_obj.acc_file is None) or (trial_obj.gyr_file is None):
                continue

            # load
            acc_data = flexible_csvloader(trial_obj.acc_file)  # Nx3
            gyr_data = flexible_csvloader(trial_obj.gyr_file)  # Mx3
            merged = merge_acc_gyro(acc_data, gyr_data)        # Nx6
            if merged.shape[0] == 0:
                print(f"[WARNING] Merged is empty: S{s_id:02d}A{a_id:02d}T{t_id:02d}")
                continue

            raw_acc = merged[:,0:3]  # Nx3
            raw_gyro= merged[:,3:6]  # Nx3

            # apply EKF
            fused_ekf = ekf_fusion(raw_acc, raw_gyro, fs=self.fs)  # Nx3

            # local normalization for fused
            f_mean = fused_ekf.mean(axis=0)
            f_std  = fused_ekf.std(axis=0) + 1e-8
            fused_norm = (fused_ekf - f_mean) / f_std

            # sliding_window (for your model) if needed
            windows = sliding_window(fused_ekf, self.max_length, self.stride)
            print(f"[INFO] S{s_id:02d}A{a_id:02d}T{t_id:02d} => {windows.shape[0]} windows for ML usage.")

            # Plot entire Nx data
            out_dir = f"visualizations/{s_id}/{a_id}"
            self._ensure_dir(out_dir)

            fig, axs = plt.subplots(3,1, figsize=(12,8), sharex=True)
            t_axis = np.arange(len(raw_acc))/self.fs  # seconds

            # X
            axs[0].plot(t_axis, raw_acc[:,0], 'g', label='RawAcc-X')
            axs[0].plot(t_axis, raw_gyro[:,0], 'm', label='RawGyro-X')
            axs[0].plot(t_axis, fused_ekf[:,0], 'b', label='EKF-X')
            axs[0].plot(t_axis, fused_norm[:,0], 'r', label='EKF-X Norm')
            axs[0].legend(loc='upper right')
            axs[0].set_ylabel('m/s^2?')

            # Y
            axs[1].plot(t_axis, raw_acc[:,1], 'g', label='RawAcc-Y')
            axs[1].plot(t_axis, raw_gyro[:,1], 'm', label='RawGyro-Y')
            axs[1].plot(t_axis, fused_ekf[:,1], 'b', label='EKF-Y')
            axs[1].plot(t_axis, fused_norm[:,1], 'r', label='EKF-Y Norm')
            axs[1].legend(loc='upper right')
            axs[1].set_ylabel('m/s^2?')

            # Z
            axs[2].plot(t_axis, raw_acc[:,2], 'g', label='RawAcc-Z')
            axs[2].plot(t_axis, raw_gyro[:,2], 'm', label='RawGyro-Z')
            axs[2].plot(t_axis, fused_ekf[:,2], 'b', label='EKF-Z')
            axs[2].plot(t_axis, fused_norm[:,2], 'r', label='EKF-Z Norm')
            axs[2].legend(loc='upper right')
            axs[2].set_ylabel('m/s^2?')
            axs[2].set_xlabel('Time (s)')

            fig.suptitle(f"S{s_id:02d}A{a_id:02d}T{t_id:02d} (watch data @ {self.fs}Hz)")

            out_path = f"{out_dir}/S{s_id:02d}A{a_id:02d}T{t_id:02d}.png"
            plt.savefig(out_path)
            plt.close(fig)

            print(f"[SAVED] {out_path}")

# %% [markdown]
# ## 6. Main
#
# 1) Create `SmartFallMM_Watch`.
# 2) Load young/watch data for accelerometer + gyroscope.
# 3) Build with subject_of_interest=31, trials=1..5
# 4) Merge, apply EKF, plot entire trial, plus window if needed.

if __name__ == "__main__":
    dataset_root = "../data/smartfallmm"
    watch_data = SmartFallMM_Watch(dataset_root)
    watch_data.load_files()
    print(f"[INFO] Found {len(watch_data.trials)} watch-based trials in {dataset_root}")

    builder = EKFWatchBuilder(
        watch_dataset=watch_data,
        max_length=32,
        stride=32,
        fs=31.125,
        subject_of_interest=31
    )
    builder.process_all()

    print("[INFO] Completed. Plots in visualizations/31/<activity>/S31AxxTxx.png")


[INFO] Found 906 watch-based trials in ../data/smartfallmm
[INFO] Loaded file: ../data/smartfallmm/young/accelerometer/watch/S31A06T01.csv, shape=(441, 3)
[INFO] Loaded file: ../data/smartfallmm/young/gyroscope/watch/S31A06T01.csv, shape=(438, 3)
[INFO] S31A06T01 => 13 windows for ML usage.
[SAVED] visualizations/31/6/S31A06T01.png
[INFO] Loaded file: ../data/smartfallmm/young/accelerometer/watch/S31A07T01.csv, shape=(211, 3)
[INFO] Loaded file: ../data/smartfallmm/young/gyroscope/watch/S31A07T01.csv, shape=(208, 3)
[INFO] S31A07T01 => 6 windows for ML usage.
[SAVED] visualizations/31/7/S31A07T01.png
[INFO] Loaded file: ../data/smartfallmm/young/accelerometer/watch/S31A05T03.csv, shape=(205, 3)
[INFO] Loaded file: ../data/smartfallmm/young/gyroscope/watch/S31A05T03.csv, shape=(202, 3)
[INFO] S31A05T03 => 6 windows for ML usage.
[SAVED] visualizations/31/5/S31A05T03.png
[INFO] Loaded file: ../data/smartfallmm/young/accelerometer/watch/S31A08T01.csv, shape=(276, 3)
[INFO] Loaded file: ..

In [2]:
################################################################################
# ExtendedKalmanFilter_FallDetection_YoungWatch_MultiFilters.py
################################################################################

# %% [markdown]
# # Multi-Filter Sensor Fusion for Fall Detection (Watch, Young Only)
#
# **Pipeline**:
# 1. Only parse `data/smartfallmm/young/accelerometer/watch` + `data/smartfallmm/young/gyroscope/watch`.
# 2. Merge (acc Nx3, gyro Nx3) => Nx6 by sample index.
# 3. Provide multiple filters:
#    - Complementary
#    - Madgwick
#    - Mahony
#    - Extended Kalman (EKF)
# 4. For subject=31, trials=1..5, produce the following figures per trial:
#    - Raw vs. Madgwick
#    - Raw vs. Complementary
#    - Raw vs. Mahony
#    - Raw vs. Kalman
#    - "allFilters" => overlay all 4 filters
#    - "allFiltersNorm" => overlay all 4 filters, local normalized
# 5. Also do sliding_window(128, stride=64) for each fused signal if needed for ML.
#
# The script logs parse errors, skips empty data, and ensures
# you get a "before vs. after" style comparison for each filter
# plus an "all-filters" comparison.

import os
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from typing import List, Dict, Tuple
from numpy.linalg import norm
from scipy.io import loadmat
from scipy.signal import butter, filtfilt
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# For reproducibility
torch.manual_seed(42)
np.random.seed(42)

# %% [markdown]
# ## 1. Filter Functions (Complementary, Madgwick, Mahony, EKF)

###############################################################################
# 1.1 Complementary Filter
###############################################################################
def complementary_fusion(acc_data: np.ndarray, gyro_data: np.ndarray, fs=100.0,
                         alpha=0.98):
    """
    Very simplistic approach:
    - Integrate gyro to get orientation (like a naive angle).
    - Estimate angle from accelerometer (like pitch/roll).
    - Blend with alpha * gyro + (1-alpha)*acc
    -> This is a *toy* example for demonstration. Real complementary
       filters are often done axis-by-axis with angle computations.
    We'll produce a "fused" 3D acceleration by removing gravity,
    but with naive logic:
      1) We treat 'acc_data' as raw, 'gyro_data' as rate of rotation,
      2) We just do a simple alpha blend.

    For a real complementary, you'd do angle-based blending, or
    a constant high-pass on gyro integrals, low-pass on acc gravity.
    But we keep it short for demonstration.
    """
    N = acc_data.shape[0]
    out = np.zeros_like(acc_data)
    # keep a "running" orientation or angle
    # skip the detailed math, just do a naive blend
    # You might do something more advanced in real usage
    for i in range(N):
        # "gyro-based" is just =acc? This is obviously incomplete
        # We'll simply do out[i] = alpha*acc + (1-alpha)*acc ? That is trivial
        # We'll demonstrate a simple blend that doesn't remove gravity well
        out[i] = alpha*acc_data[i] + (1-alpha)*acc_data[i]
    return out  # TOTALLY naive placeholder, for real use you'd do angle-based.

###############################################################################
# 1.2 Madgwick Filter
###############################################################################
def madgwick_fusion(acc_data: np.ndarray, gyro_data: np.ndarray, fs=100.0):
    """
    Pseudocode placeholder:
    Usually uses a gradient-descent approach to track orientation,
    then subtract gravity. We'll just do a dummy pass for demonstration.
    """
    # For demonstration, just do "acc - mean(acc)" to pretend we removed gravity
    # In real usage, you'd implement the full Madgwick AHRS.
    # We'll produce something that changes the distribution a bit.
    N = acc_data.shape[0]
    out = np.copy(acc_data)
    gravity_est = acc_data.mean(axis=0)  # super naive
    for i in range(N):
        out[i] = acc_data[i] - gravity_est
    return out

###############################################################################
# 1.3 Mahony Filter
###############################################################################
def mahony_fusion(acc_data: np.ndarray, gyro_data: np.ndarray, fs=100.0):
    """
    Another placeholder approach for demonstration.
    A real Mahony filter would track orientation quaternion with
    a PI controller. We'll do a simple 'acc_data * 0.9' for demonstration.
    """
    return acc_data * 0.9

###############################################################################
# 1.4 Extended Kalman Filter (fully from prior example)
###############################################################################
def normalize_quaternion(q: np.ndarray) -> np.ndarray:
    return q / np.linalg.norm(q)

def quat_to_rot_mat(q: np.ndarray) -> np.ndarray:
    qw, qx, qy, qz = q
    R = np.array([
        [1 - 2*(qy**2 + qz**2), 2*(qx*qy - qw*qz),     2*(qx*qz + qw*qy)],
        [2*(qx*qy + qw*qz),     1 - 2*(qx**2 + qz**2), 2*(qy*qz - qw*qx)],
        [2*(qx*qz - qw*qy),     2*(qy*qz + qw*qx),     1 - 2*(qx**2 + qy**2)]
    ])
    return R

def integrate_gyro(q: np.ndarray, w: np.ndarray, dt: float) -> np.ndarray:
    qw, qx, qy, qz = q
    wx, wy, wz = w

    dq = 0.5 * np.array([
        - qx*wx - qy*wy - qz*wz,
          qw*wx + qy*wz - qz*wy,
          qw*wy - qx*wz + qz*wx,
          qw*wz + qx*wy - qy*wx
    ])
    q_new = q + dq * dt
    return normalize_quaternion(q_new)

class EKFInertial:
    def __init__(self, dt=0.01, var_gyro=1e-5, var_acc=1e-2):
        self.dt = dt
        self.x = np.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
        self.P = np.eye(7) * 0.01
        self.Q = np.eye(7) * var_gyro
        self.Q[0:4, 0:4] *= 0.1
        self.R = np.eye(3) * var_acc

    def predict(self, gyro_meas: np.ndarray):
        q = self.x[0:4]
        bg = self.x[4:7]
        w = gyro_meas - bg
        q_new = integrate_gyro(q, w, self.dt)
        self.x[0:4] = q_new
        F = np.eye(7)
        self.P = F @ self.P @ F.T + self.Q

    def update(self, acc_meas: np.ndarray):
        q = self.x[0:4]
        R_sb = quat_to_rot_mat(q).T
        g_s = R_sb @ np.array([0,0,9.81])
        z = acc_meas - g_s
        eps = 1e-6
        H = np.zeros((3,7))
        for i in range(4):
            dq = np.zeros(4)
            dq[i] = eps
            q_pert = normalize_quaternion(q + dq)
            R_pert = quat_to_rot_mat(q_pert).T
            g_s_pert = R_pert @ np.array([0,0,9.81])
            dh = (g_s_pert - g_s)/eps
            H[:, i] = dh
        S = H @ self.P @ H.T + self.R
        K = self.P @ H.T @ np.linalg.inv(S)
        dx = K @ z
        self.x += dx
        self.x[0:4] = normalize_quaternion(self.x[0:4])
        I_7 = np.eye(7)
        self.P = (I_7 - K @ H) @ self.P @ (I_7 - K @ H).T + K @ self.R @ K.T

    def fuse(self, gyro: np.ndarray, acc: np.ndarray) -> np.ndarray:
        self.predict(gyro)
        self.update(acc)
        q = self.x[0:4]
        R_sb = quat_to_rot_mat(q).T
        g_s = R_sb @ np.array([0,0,9.81])
        return g_s

def ekf_fusion(acc_data: np.ndarray, gyro_data: np.ndarray, fs=100.0):
    dt = 1.0/fs
    ekf = EKFInertial(dt=dt, var_gyro=1e-5, var_acc=5e-2)
    N = acc_data.shape[0]
    fused = np.zeros_like(acc_data)
    for i in range(N):
        g_s = ekf.fuse(gyro_data[i], acc_data[i])
        fused[i] = acc_data[i] - g_s
    return fused

# %% [markdown]
# ## 2. Load & Merge ACC/GYR for watch (young only)

def flexible_csvloader(file_path: str) -> np.ndarray:
    """
    Reads watch data with possible commas/semicolons.
    Takes last 3 columns as floats (x,y,z).
    """
    try:
        df = pd.read_csv(
            file_path,
            sep='[;,]',
            engine='python',
            header=None
        ).dropna().bfill()
        nc = df.shape[1]
        if nc < 3:
            print(f"[ERROR] {file_path} has only {nc} columns.")
            return np.empty((0,3), dtype=np.float32)
        arr = df.iloc[:,-3:].astype(np.float32).to_numpy()
        print(f"[INFO] Loaded {file_path}, shape={arr.shape}")
        return arr
    except Exception as e:
        print(f"[ERROR] Could not parse file: {file_path}. Reason: {e}")
        return np.empty((0,3), dtype=np.float32)

def merge_acc_gyro(acc_arr: np.ndarray, gyr_arr: np.ndarray) -> np.ndarray:
    if acc_arr.shape[0] == 0 or gyr_arr.shape[0] == 0:
        return np.empty((0,6), dtype=np.float32)
    N = min(acc_arr.shape[0], gyr_arr.shape[0])
    merged = np.concatenate([acc_arr[:N], gyr_arr[:N]], axis=1)
    return merged

class MatchedWatchTrial:
    def __init__(self, s_id, a_id, t_id):
        self.subject_id = s_id
        self.activity_id= a_id
        self.trial_id   = t_id
        self.acc_file   = None
        self.gyr_file   = None

    def __repr__(self):
        return (f"MatchedWatchTrial(S={self.subject_id},A={self.activity_id},T={self.trial_id},"
                f"acc={self.acc_file}, gyr={self.gyr_file})")

class SmartFallMM_Watch:
    """
    Only parse watch data from:
      data/smartfallmm/young/accelerometer/watch
      data/smartfallmm/young/gyroscope/watch
    """
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.trials: Dict[Tuple[int,int,int], MatchedWatchTrial] = {}

    def _add_file(self, subj, act, tri, is_acc, filepath):
        key = (subj, act, tri)
        if key not in self.trials:
            self.trials[key] = MatchedWatchTrial(subj, act, tri)
        if is_acc:
            self.trials[key].acc_file = filepath
        else:
            self.trials[key].gyr_file = filepath

    def load_files(self):
        # parse SxxAyyTzz for "young/accelerometer/watch" or "young/gyroscope/watch"
        paths = [
            os.path.join(self.root_dir, "young", "accelerometer", "watch"),
            os.path.join(self.root_dir, "young", "gyroscope",     "watch")
        ]
        for base in paths:
            is_acc = ("accelerometer" in base)
            for root, dirs, files in os.walk(base):
                for f in files:
                    if not f.lower().endswith(".csv"):
                        continue
                    try:
                        s_id = int(f[1:3])
                        a_id = int(f[4:6])
                        t_id = int(f[7:9])
                        fullpath = os.path.join(root, f)
                        self._add_file(s_id, a_id, t_id, is_acc, fullpath)
                    except:
                        pass

# %% [markdown]
# ## 3. Sliding Window (if needed for ML)

def sliding_window(data: np.ndarray, window_size=128, stride=64):
    n = data.shape[0]
    feats = data.shape[1]
    out = []
    i = 0
    while i+window_size <= n:
        out.append(data[i:i+window_size])
        i += stride
    if not out:
        return np.empty((0, window_size, feats), dtype=np.float32)
    return np.stack(out, axis=0)

# %% [markdown]
# ## 4. MultiFilterBuilder:
#  - For subject=31, trial=1..5
#  - Merge => Nx6
#  - compute each filter => Nx3
#  - plot 6 pictures:
#    1) raw vs. madgwick
#    2) raw vs. complementary
#    3) raw vs. mahony
#    4) raw vs. ekf
#    5) all 4 filters
#    6) all 4 filters normalized (local)
#  - store in `visualizations/{s_id}/{a_id}/SxxAyyTzz_*.png`
#  - also do sliding_window if user wants

class MultiFilterBuilder:
    def __init__(self,
                 watch_dataset: SmartFallMM_Watch,
                 fs=31.125,
                 subject_of_interest=31,
                 trial_range=range(1,6),  # T=1..5
                 window_size=128,
                 stride=64):
        self.dataset = watch_dataset
        self.fs = fs
        self.subject_of_interest = subject_of_interest
        self.trial_range = trial_range
        self.window_size = window_size
        self.stride = stride

    def _ensure_dir(self, path):
        os.makedirs(path, exist_ok=True)

    def _plot_raw_vs_filter(self,
                            t_axis,
                            raw_acc,
                            raw_gyro,
                            fused,
                            s_id,
                            a_id,
                            t_id,
                            filter_name):
        """
        Plot a single figure: raw ACC & GYR vs. fused filter
        """
        fig, axs = plt.subplots(3,1, figsize=(12,8), sharex=True)

        axs[0].plot(t_axis, raw_acc[:,0], 'g', label='AccX')
        axs[0].plot(t_axis, raw_gyro[:,0],'m', label='GyroX')
        axs[0].plot(t_axis, fused[:,0],   'b', label=f'{filter_name}-X')
        axs[0].legend(loc='upper right')
        axs[0].set_ylabel('m/s^2')

        axs[1].plot(t_axis, raw_acc[:,1], 'g', label='AccY')
        axs[1].plot(t_axis, raw_gyro[:,1],'m', label='GyroY')
        axs[1].plot(t_axis, fused[:,1],   'b', label=f'{filter_name}-Y')
        axs[1].legend(loc='upper right')
        axs[1].set_ylabel('m/s^2')

        axs[2].plot(t_axis, raw_acc[:,2], 'g', label='AccZ')
        axs[2].plot(t_axis, raw_gyro[:,2],'m', label='GyroZ')
        axs[2].plot(t_axis, fused[:,2],   'b', label=f'{filter_name}-Z')
        axs[2].legend(loc='upper right')
        axs[2].set_ylabel('m/s^2')
        axs[2].set_xlabel('Time (s)')

        fig.suptitle(f"S{s_id:02d}A{a_id:02d}T{t_id:02d}: Raw vs. {filter_name}")
        return fig

    def _plot_all_filters(self,
                          t_axis,
                          raw_acc,
                          raw_gyro,
                          fused_dict,
                          s_id,
                          a_id,
                          t_id,
                          norm=False):
        """
        If norm=True, we apply local normalization to each filter output for plotting
        """
        # fused_dict = { "madgwick": Nx3, "comp": Nx3, "mahony": Nx3, "ekf": Nx3 }

        # if norm
        fused_plot = {}
        if norm:
            for k in fused_dict:
                f_dat = fused_dict[k]
                mean_ = f_dat.mean(axis=0)
                std_  = f_dat.std(axis=0) + 1e-9
                fused_plot[k] = (f_dat - mean_) / std_
        else:
            fused_plot = fused_dict

        fig, axs = plt.subplots(3,1, figsize=(12,8), sharex=True)

        # color map or line styles for each filter
        color_map = {
            "madgwick": 'r',
            "comp":     'c',
            "mahony":   'y',
            "ekf":      'b'
        }

        # raw
        axs[0].plot(t_axis, raw_acc[:,0], 'g', label='AccX')
        axs[0].plot(t_axis, raw_gyro[:,0],'m', label='GyroX')
        for fil_name, arr_ in fused_plot.items():
            axs[0].plot(t_axis, arr_[:,0], color_map[fil_name],
                        label=f'{fil_name}-X')
        axs[0].legend(loc='upper right')
        axs[0].set_ylabel('X')

        axs[1].plot(t_axis, raw_acc[:,1], 'g', label='AccY')
        axs[1].plot(t_axis, raw_gyro[:,1],'m', label='GyroY')
        for fil_name, arr_ in fused_plot.items():
            axs[1].plot(t_axis, arr_[:,1], color_map[fil_name],
                        label=f'{fil_name}-Y')
        axs[1].legend(loc='upper right')
        axs[1].set_ylabel('Y')

        axs[2].plot(t_axis, raw_acc[:,2], 'g', label='AccZ')
        axs[2].plot(t_axis, raw_gyro[:,2],'m', label='GyroZ')
        for fil_name, arr_ in fused_plot.items():
            axs[2].plot(t_axis, arr_[:,2], color_map[fil_name],
                        label=f'{fil_name}-Z')
        axs[2].legend(loc='upper right')
        axs[2].set_ylabel('Z')
        axs[2].set_xlabel('Time (s)')

        if norm:
            fig.suptitle(f"S{s_id:02d}A{a_id:02d}T{t_id:02d}: All Filters (Normalized)")
        else:
            fig.suptitle(f"S{s_id:02d}A{a_id:02d}T{t_id:02d}: All Filters (Unnorm)")

        return fig

    def process_all(self):
        for (s_id, a_id, t_id), trial_obj in self.dataset.trials.items():
            if s_id != self.subject_of_interest:
                continue
            if t_id not in self.trial_range:
                continue
            if (trial_obj.acc_file is None) or (trial_obj.gyr_file is None):
                continue

            acc_data = flexible_csvloader(trial_obj.acc_file)  # Nx3
            gyr_data = flexible_csvloader(trial_obj.gyr_file)  # Mx3
            merged   = merge_acc_gyro(acc_data, gyr_data)       # Nx6

            if merged.shape[0] == 0:
                print(f"[WARNING] Empty merged: S{s_id:02d}A{a_id:02d}T{t_id:02d}")
                continue

            raw_acc  = merged[:,0:3]
            raw_gyro = merged[:,3:6]
            N = raw_acc.shape[0]
            t_axis = np.arange(N)/self.fs

            # 1) Madgwick
            madg_out = madgwick_fusion(raw_acc, raw_gyro, fs=self.fs)
            # 2) Complementary
            comp_out = complementary_fusion(raw_acc, raw_gyro, fs=self.fs, alpha=0.98)
            # 3) Mahony
            mahy_out = mahony_fusion(raw_acc, raw_gyro, fs=self.fs)
            # 4) EKF
            ekf_out  = ekf_fusion(raw_acc, raw_gyro, fs=self.fs)

            # 5) For each filter, produce a single figure
            outdir = f"visualizations/{s_id}/{a_id}"
            self._ensure_dir(outdir)

            # a) raw vs. madg
            fig_madg = self._plot_raw_vs_filter(t_axis, raw_acc, raw_gyro,
                                                madg_out, s_id,a_id,t_id,
                                                filter_name="Madgwick")
            plt.savefig(f"{outdir}/S{s_id:02d}A{a_id:02d}T{t_id:02d}_madgwick.png")
            plt.close(fig_madg)

            # b) raw vs. comp
            fig_comp = self._plot_raw_vs_filter(t_axis, raw_acc, raw_gyro,
                                                comp_out, s_id,a_id,t_id,
                                                filter_name="Complementary")
            plt.savefig(f"{outdir}/S{s_id:02d}A{a_id:02d}T{t_id:02d}_complementary.png")
            plt.close(fig_comp)

            # c) raw vs. mahony
            fig_mahy = self._plot_raw_vs_filter(t_axis, raw_acc, raw_gyro,
                                                mahy_out, s_id,a_id,t_id,
                                                filter_name="Mahony")
            plt.savefig(f"{outdir}/S{s_id:02d}A{a_id:02d}T{t_id:02d}_mahony.png")
            plt.close(fig_mahy)

            # d) raw vs. ekf
            fig_ekf  = self._plot_raw_vs_filter(t_axis, raw_acc, raw_gyro,
                                                ekf_out, s_id,a_id,t_id,
                                                filter_name="EKF")
            plt.savefig(f"{outdir}/S{s_id:02d}A{a_id:02d}T{t_id:02d}_ekf.png")
            plt.close(fig_ekf)

            # e) all filters unnorm
            fused_dict = {
                "madgwick": madg_out,
                "comp":     comp_out,
                "mahony":   mahy_out,
                "ekf":      ekf_out
            }
            fig_all = self._plot_all_filters(t_axis, raw_acc, raw_gyro,
                                             fused_dict, s_id,a_id,t_id, norm=False)
            plt.savefig(f"{outdir}/S{s_id:02d}A{a_id:02d}T{t_id:02d}_allFilters.png")
            plt.close(fig_all)

            # f) all filters norm
            fig_allN = self._plot_all_filters(t_axis, raw_acc, raw_gyro,
                                              fused_dict, s_id,a_id,t_id, norm=True)
            plt.savefig(f"{outdir}/S{s_id:02d}A{a_id:02d}T{t_id:02d}_allFiltersNorm.png")
            plt.close(fig_allN)

            # 6) Sliding windows for each filter if we want to feed ML models
            #    e.g. ekf only
            ekf_windows = sliding_window(ekf_out, self.window_size, self.stride)
            print(f"[INFO] S{s_id:02d}A{a_id:02d}T{t_id:02d} => EKF windows: {ekf_windows.shape[0]}")

            # Potentially store them or pass to a model. This script just logs it.

            print(f"[DONE] S{s_id:02d}A{a_id:02d}T{t_id:02d} => Plots saved in {outdir}")

# %% [markdown]
# ## 5. Main
#
# - Make a `SmartFallMM_Watch` for only watch data (accelerometer + gyro).
# - Use `MultiFilterBuilder(subject_of_interest=31, trial_range=[1..5])`.
# - Generate 6 figures per trial:
#   1) raw vs. madg
#   2) raw vs. complementary
#   3) raw vs. mahony
#   4) raw vs. ekf
#   5) all 4 unnorm
#   6) all 4 local norm

if __name__ == "__main__":
    dataset_root = "../data/smartfallmm"
    watch_data = SmartFallMM_Watch(dataset_root)
    watch_data.load_files()
    print(f"[INFO] Found {len(watch_data.trials)} watch-based trials in {dataset_root}")

    builder = MultiFilterBuilder(
        watch_dataset=watch_data,
        fs=31.125,
        subject_of_interest=31,
        trial_range=range(1,6),  # T=1..5
        window_size=10,
        stride=128
    )
    builder.process_all()
    print("[INFO] Completed multi-filter comparisons for subject=31, T=1..5.")


[INFO] Found 906 watch-based trials in ../data/smartfallmm
[INFO] Loaded ../data/smartfallmm/young/accelerometer/watch/S31A06T01.csv, shape=(441, 3)
[INFO] Loaded ../data/smartfallmm/young/gyroscope/watch/S31A06T01.csv, shape=(438, 3)
[INFO] S31A06T01 => EKF windows: 43
[DONE] S31A06T01 => Plots saved in visualizations/31/6
[INFO] Loaded ../data/smartfallmm/young/accelerometer/watch/S31A07T01.csv, shape=(211, 3)
[INFO] Loaded ../data/smartfallmm/young/gyroscope/watch/S31A07T01.csv, shape=(208, 3)
[INFO] S31A07T01 => EKF windows: 20
[DONE] S31A07T01 => Plots saved in visualizations/31/7
[INFO] Loaded ../data/smartfallmm/young/accelerometer/watch/S31A05T03.csv, shape=(205, 3)
[INFO] Loaded ../data/smartfallmm/young/gyroscope/watch/S31A05T03.csv, shape=(202, 3)
[INFO] S31A05T03 => EKF windows: 20
[DONE] S31A05T03 => Plots saved in visualizations/31/5
[INFO] Loaded ../data/smartfallmm/young/accelerometer/watch/S31A08T01.csv, shape=(276, 3)
[INFO] Loaded ../data/smartfallmm/young/gyroscope/

In [None]:
################################################################################
# ExtendedKalmanFilter_FallDetection_YoungWatch_MultiFilters.py
# Larger figure size, higher DPI, plus basic statistical analysis
################################################################################

# %% [markdown]
# # Multi-Filter Sensor Fusion for Fall Detection (Watch, Young Only)
#
# **Pipeline**:
# 1. Only parse `data/smartfallmm/young/accelerometer/watch` + `data/smartfallmm/young/gyroscope/watch`.
# 2. Merge (acc Nx3, gyro Nx3) => Nx6 by sample index.
# 3. Provide multiple filters:
#    - Complementary
#    - Madgwick
#    - Mahony
#    - Extended Kalman (EKF)
# 4. For subject=31, trials=1..5, produce the following figures per trial:
#    - Raw vs. Madgwick
#    - Raw vs. Complementary
#    - Raw vs. Mahony
#    - Raw vs. Extended Kalman
#    - "allFilters" => overlay all 4 filters
#    - "allFiltersNorm" => overlay all 4 filters, local normalized
# 5. Also do sliding_window(128, stride=64) for each fused signal if needed for ML.
# 6. **Added**: Simple statistical analysis (RMSE, correlation) comparing each filter
#    to the raw accelerometer data, printed to console.
#
# This script logs parse errors, skips empty data, and ensures a "before vs. after"
# style comparison for each filter plus an "all-filters" comparison, with bigger figs.

import os
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from typing import List, Dict, Tuple
from numpy.linalg import norm
from scipy.io import loadmat
from scipy.signal import butter, filtfilt
from sklearn.preprocessing import StandardScaler
from scipy.stats import pearsonr  # for correlation

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# For reproducibility
torch.manual_seed(42)
np.random.seed(42)

# %% [markdown]
# ## 1. Filter Functions (Complementary, Madgwick, Mahony, EKF)

###############################################################################
# 1.1 Complementary Filter
###############################################################################
def complementary_fusion(acc_data: np.ndarray, gyro_data: np.ndarray, fs=100.0,
                         alpha=0.98):
    """
    Very simplistic approach:
    - Integrate gyro to get orientation (like a naive angle).
    - Estimate angle from accelerometer (like pitch/roll).
    - Blend with alpha * gyro + (1-alpha)*acc

    Here we do a toy example that doesn't truly remove gravity:
    out[i] = alpha*acc[i] + (1-alpha)*acc[i] => effectively the same as acc data
    In real code, you'd do angle-based or high-pass/low-pass fusion.
    """
    N = acc_data.shape[0]
    out = np.zeros_like(acc_data)
    for i in range(N):
        out[i] = alpha*acc_data[i] + (1-alpha)*acc_data[i]
    return out

###############################################################################
# 1.2 Madgwick Filter
###############################################################################
def madgwick_fusion(acc_data: np.ndarray, gyro_data: np.ndarray, fs=100.0):
    """
    Pseudocode placeholder:
    Typically uses a gradient-descent approach to track orientation,
    then subtract gravity. Here we do a naive approach:
    => subtract the mean as a "gravity estimate."
    """
    N = acc_data.shape[0]
    out = np.copy(acc_data)
    gravity_est = acc_data.mean(axis=0)
    for i in range(N):
        out[i] = acc_data[i] - gravity_est
    return out

###############################################################################
# 1.3 Mahony Filter
###############################################################################
def mahony_fusion(acc_data: np.ndarray, gyro_data: np.ndarray, fs=100.0):
    """
    Another placeholder for demonstration.
    Real Mahony would do orientation quaternion with PI controller.
    We do a naive scaling of acc.
    """
    return acc_data * 0.9

###############################################################################
# 1.4 Extended Kalman Filter (fully from prior example)
###############################################################################
def normalize_quaternion(q: np.ndarray) -> np.ndarray:
    return q / np.linalg.norm(q)

def quat_to_rot_mat(q: np.ndarray) -> np.ndarray:
    qw, qx, qy, qz = q
    R = np.array([
        [1 - 2*(qy**2 + qz**2), 2*(qx*qy - qw*qz),     2*(qx*qz + qw*qy)],
        [2*(qx*qy + qw*qz),     1 - 2*(qx**2 + qz**2), 2*(qy*qz - qw*qx)],
        [2*(qx*qz - qw*qy),     2*(qy*qz + qw*qx),     1 - 2*(qx**2 + qy**2)]
    ])
    return R

def integrate_gyro(q: np.ndarray, w: np.ndarray, dt: float) -> np.ndarray:
    qw, qx, qy, qz = q
    wx, wy, wz = w
    dq = 0.5 * np.array([
        - qx*wx - qy*wy - qz*wz,
          qw*wx + qy*wz - qz*wy,
          qw*wy - qx*wz + qz*wx,
          qw*wz + qx*wy - qy*wx
    ])
    q_new = q + dq * dt
    return normalize_quaternion(q_new)

class EKFInertial:
    def __init__(self, dt=0.01, var_gyro=1e-5, var_acc=1e-2):
        self.dt = dt
        self.x = np.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
        self.P = np.eye(7) * 0.01
        self.Q = np.eye(7) * var_gyro
        self.Q[0:4, 0:4] *= 0.1
        self.R = np.eye(3) * var_acc

    def predict(self, gyro_meas: np.ndarray):
        q = self.x[0:4]
        bg = self.x[4:7]
        w = gyro_meas - bg
        q_new = integrate_gyro(q, w, self.dt)
        self.x[0:4] = q_new
        F = np.eye(7)
        self.P = F @ self.P @ F.T + self.Q

    def update(self, acc_meas: np.ndarray):
        q = self.x[0:4]
        R_sb = quat_to_rot_mat(q).T
        g_s = R_sb @ np.array([0,0,9.81])
        z = acc_meas - g_s
        eps = 1e-6
        H = np.zeros((3,7))
        for i in range(4):
            dq = np.zeros(4)
            dq[i] = eps
            q_pert = normalize_quaternion(q + dq)
            R_pert = quat_to_rot_mat(q_pert).T
            g_s_pert = R_pert @ np.array([0,0,9.81])
            dh = (g_s_pert - g_s)/eps
            H[:, i] = dh
        S = H @ self.P @ H.T + self.R
        K = self.P @ H.T @ np.linalg.inv(S)
        dx = K @ z
        self.x += dx
        self.x[0:4] = normalize_quaternion(self.x[0:4])
        I_7 = np.eye(7)
        self.P = (I_7 - K @ H) @ self.P @ (I_7 - K @ H).T + K @ self.R @ K.T

    def fuse(self, gyro: np.ndarray, acc: np.ndarray) -> np.ndarray:
        self.predict(gyro)
        self.update(acc)
        q = self.x[0:4]
        R_sb = quat_to_rot_mat(q).T
        g_s = R_sb @ np.array([0,0,9.81])
        return g_s

def ekf_fusion(acc_data: np.ndarray, gyro_data: np.ndarray, fs=100.0):
    dt = 1.0/fs
    ekf = EKFInertial(dt=dt, var_gyro=1e-5, var_acc=5e-2)
    N = acc_data.shape[0]
    fused = np.zeros_like(acc_data)
    for i in range(N):
        g_s = ekf.fuse(gyro_data[i], acc_data[i])
        fused[i] = acc_data[i] - g_s
    return fused

# %% [markdown]
# ## 2. Load & Merge ACC/GYR for watch (young only)

def flexible_csvloader(file_path: str) -> np.ndarray:
    """
    Reads watch data with possible commas/semicolons.
    Takes last 3 columns as floats (x,y,z).
    """
    try:
        df = pd.read_csv(
            file_path,
            sep='[;,]',
            engine='python',
            header=None
        ).dropna().bfill()
        nc = df.shape[1]
        if nc < 3:
            print(f"[ERROR] {file_path} has only {nc} columns.")
            return np.empty((0,3), dtype=np.float32)
        arr = df.iloc[:,-3:].astype(np.float32).to_numpy()
        print(f"[INFO] Loaded {file_path}, shape={arr.shape}")
        return arr
    except Exception as e:
        print(f"[ERROR] Could not parse file: {file_path}. Reason: {e}")
        return np.empty((0,3), dtype=np.float32)

def merge_acc_gyro(acc_arr: np.ndarray, gyr_arr: np.ndarray) -> np.ndarray:
    if acc_arr.shape[0] == 0 or gyr_arr.shape[0] == 0:
        return np.empty((0,6), dtype=np.float32)
    N = min(acc_arr.shape[0], gyr_arr.shape[0])
    merged = np.concatenate([acc_arr[:N], gyr_arr[:N]], axis=1)
    return merged

class MatchedWatchTrial:
    def __init__(self, s_id, a_id, t_id):
        self.subject_id = s_id
        self.activity_id= a_id
        self.trial_id   = t_id
        self.acc_file   = None
        self.gyr_file   = None

    def __repr__(self):
        return (f"MatchedWatchTrial(S={self.subject_id},A={self.activity_id},T={self.trial_id},"
                f"acc={self.acc_file}, gyr={self.gyr_file})")

class SmartFallMM_Watch:
    """
    Only parse watch data from:
      data/smartfallmm/young/accelerometer/watch
      data/smartfallmm/young/gyroscope/watch
    """
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.trials: Dict[Tuple[int,int,int], MatchedWatchTrial] = {}

    def _add_file(self, subj, act, tri, is_acc, filepath):
        key = (subj, act, tri)
        if key not in self.trials:
            self.trials[key] = MatchedWatchTrial(subj, act, tri)
        if is_acc:
            self.trials[key].acc_file = filepath
        else:
            self.trials[key].gyr_file = filepath

    def load_files(self):
        # parse SxxAyyTzz for "young/accelerometer/watch" or "young/gyroscope/watch"
        paths = [
            os.path.join(self.root_dir, "young", "accelerometer", "watch"),
            os.path.join(self.root_dir, "young", "gyroscope",     "watch")
        ]
        for base in paths:
            is_acc = ("accelerometer" in base)
            for root, dirs, files in os.walk(base):
                for f in files:
                    if not f.lower().endswith(".csv"):
                        continue
                    try:
                        s_id = int(f[1:3])
                        a_id = int(f[4:6])
                        t_id = int(f[7:9])
                        fullpath = os.path.join(root, f)
                        self._add_file(s_id, a_id, t_id, is_acc, fullpath)
                    except:
                        pass

# %% [markdown]
# ## 3. Sliding Window (if needed for ML)

def sliding_window(data: np.ndarray, window_size=128, stride=64):
    n = data.shape[0]
    feats = data.shape[1]
    out = []
    i = 0
    while i+window_size <= n:
        out.append(data[i:i+window_size])
        i += stride
    if not out:
        return np.empty((0, window_size, feats), dtype=np.float32)
    return np.stack(out, axis=0)

# %% [markdown]
# ## 4. Simple Stats for comparing each filter to Raw Accelerometer
#
# We'll compute:
# - RMSE for each dimension (X, Y, Z) comparing fuse vs. raw_acc
# - correlation (Pearson) for each dimension

def compare_filter_stats(raw_acc: np.ndarray, fused: np.ndarray, label:str):
    """
    raw_acc, fused: Nx3
    We compute RMSE and correlation dimension by dimension vs. raw.
    Print results.
    """
    eps = 1e-9
    N = raw_acc.shape[0]
    if fused.shape[0] != N:
        print(f"[ERROR] Mismatch shapes for stats. {label} skip stats.")
        return

    dims = ["X","Y","Z"]
    for d in range(3):
        raw_d = raw_acc[:,d]
        fus_d = fused[:,d]
        # RMSE
        rmse = np.sqrt(np.mean((fus_d - raw_d)**2))
        # correlation
        # ensure variance is not zero
        corr = 0.0
        try:
            corr = np.corrcoef(raw_d, fus_d)[0,1]
        except:
            pass

        print(f"{label}: dim={dims[d]} RMSE={rmse:.3f}, Corr={corr:.3f}")


# %% [markdown]
# ## 5. MultiFilterBuilder:
#  - For subject=31, trial=1..5
#  - Merge => Nx6
#  - compute each filter => Nx3
#  - plot 6 pictures:
#    1) raw vs. madg
#    2) raw vs. complementary
#    3) raw vs. mahony
#    4) raw vs. ekf
#    5) all 4 unnorm
#    6) all 4 local norm
#  - store in `visualizations/{s_id}/{a_id}/SxxAyyTzz_*.png`
#  - also do sliding_window if user wants
#  - compute stats vs. raw accelerometer (RMSE, correlation)

class MultiFilterBuilder:
    def __init__(self,
                 watch_dataset: SmartFallMM_Watch,
                 fs=31.125,
                 subject_of_interest=31,
                 trial_range=range(1,6),  # T=1..5
                 window_size=128,
                 stride=64):
        self.dataset = watch_dataset
        self.fs = fs
        self.subject_of_interest = subject_of_interest
        self.trial_range = trial_range
        self.window_size = window_size
        self.stride = stride

    def _ensure_dir(self, path):
        os.makedirs(path, exist_ok=True)

    def _plot_raw_vs_filter(self,
                            t_axis,
                            raw_acc,
                            raw_gyro,
                            fused,
                            s_id,
                            a_id,
                            t_id,
                            filter_name):
        """
        Plot a single figure: raw ACC & GYR vs. fused filter
        Larger figsize, higher DPI => clearer
        """
        fig, axs = plt.subplots(3, 1, figsize=(16,12), dpi=200, sharex=True)

        axs[0].plot(t_axis, raw_acc[:,0], 'g', label='AccX')
        axs[0].plot(t_axis, raw_gyro[:,0],'m', label='GyroX')
        axs[0].plot(t_axis, fused[:,0],   'b', label=f'{filter_name}-X')
        axs[0].legend(loc='upper right')
        axs[0].set_ylabel('m/s^2')

        axs[1].plot(t_axis, raw_acc[:,1], 'g', label='AccY')
        axs[1].plot(t_axis, raw_gyro[:,1],'m', label='GyroY')
        axs[1].plot(t_axis, fused[:,1],   'b', label=f'{filter_name}-Y')
        axs[1].legend(loc='upper right')
        axs[1].set_ylabel('m/s^2')

        axs[2].plot(t_axis, raw_acc[:,2], 'g', label='AccZ')
        axs[2].plot(t_axis, raw_gyro[:,2],'m', label='GyroZ')
        axs[2].plot(t_axis, fused[:,2],   'b', label=f'{filter_name}-Z')
        axs[2].legend(loc='upper right')
        axs[2].set_ylabel('m/s^2')
        axs[2].set_xlabel('Time (s)')

        fig.suptitle(f"S{s_id:02d}A{a_id:02d}T{t_id:02d}: Raw vs. {filter_name}")
        return fig

    def _plot_all_filters(self,
                          t_axis,
                          raw_acc,
                          raw_gyro,
                          fused_dict,
                          s_id,
                          a_id,
                          t_id,
                          norm=False):
        """
        fused_dict = { 'madgwick': Nx3, 'comp': Nx3, 'mahony': Nx3, 'ekf': Nx3 }
        If norm=True, apply local normalization for each filter's Nx3
        """
        fused_plot = {}
        if norm:
            for k in fused_dict:
                f_dat = fused_dict[k]
                mean_ = f_dat.mean(axis=0)
                std_  = f_dat.std(axis=0) + 1e-9
                fused_plot[k] = (f_dat - mean_) / std_
        else:
            fused_plot = fused_dict

        fig, axs = plt.subplots(3, 1, figsize=(16,12), dpi=200, sharex=True)

        color_map = {
            "madgwick": 'r',
            "comp":     'c',
            "mahony":   'y',
            "ekf":      'b'
        }

        # raw
        axs[0].plot(t_axis, raw_acc[:,0], 'g', label='AccX')
        axs[0].plot(t_axis, raw_gyro[:,0],'m', label='GyroX')
        for fil_name, arr_ in fused_plot.items():
            axs[0].plot(t_axis, arr_[:,0], color_map[fil_name],
                        label=f'{fil_name}-X')
        axs[0].legend(loc='upper right')
        axs[0].set_ylabel('X')

        axs[1].plot(t_axis, raw_acc[:,1], 'g', label='AccY')
        axs[1].plot(t_axis, raw_gyro[:,1],'m', label='GyroY')
        for fil_name, arr_ in fused_plot.items():
            axs[1].plot(t_axis, arr_[:,1], color_map[fil_name],
                        label=f'{fil_name}-Y')
        axs[1].legend(loc='upper right')
        axs[1].set_ylabel('Y')

        axs[2].plot(t_axis, raw_acc[:,2], 'g', label='AccZ')
        axs[2].plot(t_axis, raw_gyro[:,2],'m', label='GyroZ')
        for fil_name, arr_ in fused_plot.items():
            axs[2].plot(t_axis, arr_[:,2], color_map[fil_name],
                        label=f'{fil_name}-Z')
        axs[2].legend(loc='upper right')
        axs[2].set_ylabel('Z')
        axs[2].set_xlabel('Time (s)')

        if norm:
            fig.suptitle(f"S{s_id:02d}A{a_id:02d}T{t_id:02d}: All Filters (Normalized)")
        else:
            fig.suptitle(f"S{s_id:02d}A{a_id:02d}T{t_id:02d}: All Filters (Unnorm)")

        return fig

    def process_all(self):
        for (s_id, a_id, t_id), trial_obj in self.dataset.trials.items():
            if s_id != self.subject_of_interest:
                continue
            if t_id not in self.trial_range:
                continue
            if (trial_obj.acc_file is None) or (trial_obj.gyr_file is None):
                continue

            acc_data = flexible_csvloader(trial_obj.acc_file)  # Nx3
            gyr_data = flexible_csvloader(trial_obj.gyr_file)  # Mx3
            merged   = merge_acc_gyro(acc_data, gyr_data)       # Nx6

            if merged.shape[0] == 0:
                print(f"[WARNING] Empty merged: S{s_id:02d}A{a_id:02d}T{t_id:02d}")
                continue

            raw_acc  = merged[:,0:3]
            raw_gyro = merged[:,3:6]
            N = raw_acc.shape[0]
            t_axis = np.arange(N)/self.fs

            # 1) Madgwick
            madg_out = madgwick_fusion(raw_acc, raw_gyro, fs=self.fs)
            # 2) Complementary
            comp_out = complementary_fusion(raw_acc, raw_gyro, fs=self.fs, alpha=0.98)
            # 3) Mahony
            mahy_out = mahony_fusion(raw_acc, raw_gyro, fs=self.fs)
            # 4) EKF
            ekf_out  = ekf_fusion(raw_acc, raw_gyro, fs=self.fs)

            # Stats vs. raw acc
            print(f"\n[STATS] S{s_id:02d}A{a_id:02d}T{t_id:02d} comparing each filter to raw accelerometer")
            compare_filter_stats(raw_acc, madg_out,  "Madgwick")
            compare_filter_stats(raw_acc, comp_out,  "Complementary")
            compare_filter_stats(raw_acc, mahy_out,  "Mahony")
            compare_filter_stats(raw_acc, ekf_out,   "EKF")

            outdir = f"visualizations/{s_id}/{a_id}"
            self._ensure_dir(outdir)

            # a) raw vs. madg
            fig_madg = self._plot_raw_vs_filter(t_axis, raw_acc, raw_gyro,
                                                madg_out, s_id,a_id,t_id,
                                                filter_name="Madgwick")
            plt.savefig(f"{outdir}/S{s_id:02d}A{a_id:02d}T{t_id:02d}_madgwick.png")
            plt.close(fig_madg)

            # b) raw vs. comp
            fig_comp = self._plot_raw_vs_filter(t_axis, raw_acc, raw_gyro,
                                                comp_out, s_id,a_id,t_id,
                                                filter_name="Complementary")
            plt.savefig(f"{outdir}/S{s_id:02d}A{a_id:02d}T{t_id:02d}_complementary.png")
            plt.close(fig_comp)

            # c) raw vs. mahony
            fig_mahy = self._plot_raw_vs_filter(t_axis, raw_acc, raw_gyro,
                                                mahy_out, s_id,a_id,t_id,
                                                filter_name="Mahony")
            plt.savefig(f"{outdir}/S{s_id:02d}A{a_id:02d}T{t_id:02d}_mahony.png")
            plt.close(fig_mahy)

            # d) raw vs. ekf
            fig_ekf  = self._plot_raw_vs_filter(t_axis, raw_acc, raw_gyro,
                                                ekf_out, s_id,a_id,t_id,
                                                filter_name="EKF")
            plt.savefig(f"{outdir}/S{s_id:02d}A{a_id:02d}T{t_id:02d}_ekf.png")
            plt.close(fig_ekf)

            # e) all filters unnorm
            fused_dict = {
                "madgwick": madg_out,
                "comp":     comp_out,
                "mahony":   mahy_out,
                "ekf":      ekf_out
            }
            fig_all = self._plot_all_filters(t_axis, raw_acc, raw_gyro,
                                             fused_dict, s_id,a_id,t_id, norm=False)
            plt.savefig(f"{outdir}/S{s_id:02d}A{a_id:02d}T{t_id:02d}_allFilters.png")
            plt.close(fig_all)

            # f) all filters norm
            fig_allN = self._plot_all_filters(t_axis, raw_acc, raw_gyro,
                                              fused_dict, s_id,a_id,t_id, norm=True)
            plt.savefig(f"{outdir}/S{s_id:02d}A{a_id:02d}T{t_id:02d}_allFiltersNorm.png")
            plt.close(fig_allN)

            # sliding_window if needed for ML
            ekf_windows = sliding_window(ekf_out, self.window_size, self.stride)
            print(f"[INFO] S{s_id:02d}A{a_id:02d}T{t_id:02d} => EKF windows: {ekf_windows.shape[0]}")

            print(f"[DONE] S{s_id:02d}A{a_id:02d}T{t_id:02d} => Plots in {outdir}")

# %% [markdown]
# ## 6. Main

if __name__ == "__main__":
    dataset_root = "../data/smartfallmm"
    watch_data = SmartFallMM_Watch(dataset_root)
    watch_data.load_files()
    print(f"[INFO] Found {len(watch_data.trials)} watch-based trials in {dataset_root}")

    builder = MultiFilterBuilder(
        watch_dataset=watch_data,
        fs=31.125,
        subject_of_interest=31,
        trial_range=range(1,6),  # T=1..5
        window_size=128,
        stride=64
    )
    builder.process_all()
    print("[INFO] Completed multi-filter comparisons for subject=31, T=1..5.")


[INFO] Found 906 watch-based trials in ../data/smartfallmm
[INFO] Loaded ../data/smartfallmm/young/accelerometer/watch/S31A06T01.csv, shape=(441, 3)
[INFO] Loaded ../data/smartfallmm/young/gyroscope/watch/S31A06T01.csv, shape=(438, 3)

[STATS] S31A06T01 comparing each filter to raw accelerometer
Madgwick: dim=X RMSE=0.170, Corr=1.000
Madgwick: dim=Y RMSE=1.907, Corr=1.000
Madgwick: dim=Z RMSE=0.008, Corr=1.000
Complementary: dim=X RMSE=0.000, Corr=1.000
Complementary: dim=Y RMSE=0.000, Corr=1.000
Complementary: dim=Z RMSE=0.000, Corr=1.000
Mahony: dim=X RMSE=0.678, Corr=1.000
Mahony: dim=Y RMSE=1.263, Corr=1.000
Mahony: dim=Z RMSE=0.530, Corr=1.000
EKF: dim=X RMSE=6.426, Corr=0.464
EKF: dim=Y RMSE=6.056, Corr=0.882
EKF: dim=Z RMSE=4.275, Corr=0.729
[INFO] S31A06T01 => EKF windows: 5
[DONE] S31A06T01 => Plots in visualizations/31/6
[INFO] Loaded ../data/smartfallmm/young/accelerometer/watch/S31A07T01.csv, shape=(211, 3)
[INFO] Loaded ../data/smartfallmm/young/gyroscope/watch/S31A07T01.c

In [1]:
!export DISPLAY=:0

In [2]:
################################################################################
# ExtendedKalmanFilter_FallDetection_YoungWatch_MultiFilters.py
# with REAL math for Complementary, Madgwick (6-DOF), Mahony (6-DOF), and EKF
# for 31.125 Hz watch-based fall detection. Plots entire trial in seconds.
################################################################################

# %% [markdown]
# # Multi-Filter Sensor Fusion for Fall Detection (Watch, Young Only)
#
# - We only parse:
#     data/smartfallmm/young/accelerometer/watch
#     data/smartfallmm/young/gyroscope/watch
# - Merge (acc Nx3, gyro Nx3) => Nx6
# - Implement 4 filters with *actual* orientation-based logic:
#   1) Complementary filter (pitch/roll ignoring yaw)
#   2) Madgwick (6-DOF)
#   3) Mahony  (6-DOF)
#   4) Extended Kalman Filter (EKF) (as before)
# - Each filter -> orientation -> subtract gravity => linear acceleration
# - For subject=31, trials T=1..5, plot entire Nx samples in seconds
# - Larger figure size, higher DPI
# - Keep sliding_window for downstream usage

import os
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from typing import List, Dict, Tuple
from numpy.linalg import norm
from scipy.stats import pearsonr
from scipy.signal import butter, filtfilt
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# For reproducibility
torch.manual_seed(42)
np.random.seed(42)

# ------------------------------------------------------------------------------
# 1) REAL COMPLEMENTARY FILTER (pitch/roll only, ignoring yaw)
# ------------------------------------------------------------------------------

def complementary_fusion(acc_data: np.ndarray,
                         gyro_data: np.ndarray,
                         fs=31.125,
                         alpha=0.98):
    """
    Complementary filter focusing on *pitch* and *roll* only.
    We assume:
      - accel in m/s^2
      - gyro in rad/s
      - sample rate fs (~31.125 Hz => dt=~0.0321s)
    Steps:
      1) Convert accel to pitch/roll angles: pitchAcc, rollAcc
      2) Integrate gyro to get pitchGyro, rollGyro
      3) fused angles = alpha*(prev angles + gyro*dt) + (1-alpha)*(acc angles)
      4) from fused pitch,roll, reconstruct gravity => subtract from raw accel => linear accel

    We ignore yaw because there's no magnetometer => can't fix yaw drift.
    """

    dt = 1.0 / fs
    N = acc_data.shape[0]
    # We'll store orientation as pitch,roll in degrees or radians? Let's do radians internally
    pitch = 0.0
    roll  = 0.0

    out = np.zeros_like(acc_data)

    def accel_to_angles(ax, ay, az):
        # pitch = rotation around Y axis => pitch = atan2(-ax, sqrt(ay^2 + az^2))
        # roll  = rotation around X axis => roll  = atan2( ay, az )
        # sign conventions can vary, we'll pick one
        pitchA = np.arctan2(-ax, np.sqrt(ay*ay + az*az))
        rollA  = np.arctan2(ay, az)
        return pitchA, rollA

    for i in range(N):
        ax, ay, az = acc_data[i]
        gx, gy, gz = gyro_data[i]  # in rad/s ?

        # 1) compute angles from accel
        pitchAcc, rollAcc = accel_to_angles(ax, ay, az)

        # 2) integrate gyro
        pitchGyro = pitch + (gy*dt)
        rollGyro  = roll  + (gx*dt)*(-1.0)  # sign might differ, depends on axis conv
        # We'll assume 'gx' rotates around X => affects roll, 'gy' around Y => affects pitch

        # 3) fuse
        pitch = alpha*(pitchGyro) + (1-alpha)*(pitchAcc)
        roll  = alpha*(rollGyro)  + (1-alpha)*(rollAcc)

        # 4) reconstruct gravity from pitch,roll => then subtract from raw
        # let's define:
        #   gravity in Earth frame = (0,0,9.81)
        #   we want sensor->Earth transform from pitch,roll ignoring yaw
        # We do rotation around X(roll), then Y(pitch)...

        # or simpler approach: we can do direct "gravity = ( -sin(pitch)*g, sin(roll)*cos(pitch)*g, cos(roll)*cos(pitch)*g )"
        # approximate expression:
        g = 9.81
        gx_sens = -np.sin(pitch)*g
        gy_sens =  np.sin(roll)*np.cos(pitch)*g
        gz_sens =  np.cos(roll)*np.cos(pitch)*g

        out[i,0] = ax - gx_sens
        out[i,1] = ay - gy_sens
        out[i,2] = az - gz_sens

    return out

# ------------------------------------------------------------------------------
# 2) REAL MADGWICK (6-DOF) ignoring magnetometer
# ------------------------------------------------------------------------------
def madgwick_fusion(acc_data: np.ndarray, gyro_data: np.ndarray,
                    fs=31.125, beta=0.05):
    """
    6-DOF Madgwick filter (no magnetometer).
    We'll track orientation as a quaternion, then subtract gravity in sensor frame.
    Equation references: "Madgwick, An efficient orientation filter for IMUs" (2010)
    We'll handle sample rate dt=1/fs, do a single gradient descent step each iteration.

    Steps:
    1) q0..q3 => orientation
    2) gyro => qdot
    3) measure direction of gravity from accel => update
    4) subtract gravity => linear accel
    """

    dt = 1.0/fs
    N = acc_data.shape[0]
    out = np.zeros_like(acc_data)

    # quaternion state
    q = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)

    def normalize(vec):
        n = np.linalg.norm(vec)
        if n<1e-12: return vec
        return vec/n

    for i in range(N):
        ax, ay, az = acc_data[i]
        gx, gy, gz = gyro_data[i]  # rad/s
        # 1) normal acc => direction of gravity
        a_norm = np.array([ax, ay, az], dtype=np.float32)
        a_norm = normalize(a_norm)

        # 2) compute objective function & gradient
        # if gravity in Earth is (0,0,1), we want F(q)= quaternionRotate(q) - a_norm=0
        # the details are from Madgwick's paper. We'll do a simplified version.

        # short version: We'll form f = [2(q1q3 - q0q2) - ax,
        #                                2(q0q1 + q2q3) - ay,
        #                                2(0.5 - q1^2 - q2^2) - az]
        # J => the Jacobian, we compute gradient => then apply step

        q0, q1, q2, q3 = q
        f = np.array([
            2*(q1*q3 - q0*q2) - a_norm[0],
            2*(q0*q1 + q2*q3) - a_norm[1],
            2*(0.5 - q1*q1 - q2*q2) - a_norm[2]
        ], dtype=np.float32)
        J = np.array([
            [-2*q2,         2*q3,        -2*q0,        2*q1],
            [ 2*q1,         2*q0,         2*q3,        2*q2],
            [ 0.0   ,      -4*q1,        -4*q2,        0.0   ]
        ], dtype=np.float32)
        step = J.T.dot(f)
        step = normalize(step)

        # 3) apply gradient descent
        qDot_omega = 0.5*np.array([
            -q1*gx - q2*gy - q3*gz,
             q0*gx + q2*gz - q3*gy,
             q0*gy - q1*gz + q3*gx,
             q0*gz + q1*gy - q2*gx
        ], dtype=np.float32)

        qDot = qDot_omega - beta*step
        q = q + qDot*dt
        q = normalize(q)

        # 4) now subtract gravity in sensor frame
        # gravity in Earth = (0,0,9.81)
        # rotate that into sensor frame
        q0, q1, q2, q3 = q
        # rotation matrix or direct formula. We'll do direct:
        # v_s = quatRotate(q^-1, 0,0,9.81)
        # but simpler to build R from q, then R^T*g
        # We'll do:
        R11 = 1 - 2*(q2*q2 + q3*q3)
        R12 = 2*(q1*q2 - q0*q3)
        R13 = 2*(q1*q3 + q0*q2)
        R21 = 2*(q1*q2 + q0*q3)
        R22 = 1 - 2*(q1*q1 + q3*q3)
        R23 = 2*(q2*q3 - q0*q1)
        R31 = 2*(q1*q3 - q0*q2)
        R32 = 2*(q2*q3 + q0*q1)
        R33 = 1 - 2*(q1*q1 + q2*q2)

        gx_s = R11*0 + R12*0 + R13*9.81
        gy_s = R21*0 + R22*0 + R23*9.81
        gz_s = R31*0 + R32*0 + R33*9.81

        out[i,0] = ax - gx_s
        out[i,1] = ay - gy_s
        out[i,2] = az - gz_s

    return out

# ------------------------------------------------------------------------------
# 3) REAL MAHONY (6-DOF) ignoring magnetometer
# ------------------------------------------------------------------------------
def mahony_fusion(acc_data: np.ndarray, gyro_data: np.ndarray,
                  fs=31.125, kp=2.0, ki=0.0):
    """
    6-DOF Mahony filter (no magnetometer).
    - We track orientation q, plus an integral error.
    - update eqn from Mahony's "Nonlinear Complementary Filters on SO(3)" or
      sensor fusion references. We'll do the short version.

    Steps:
    1) measure gravity from accelerometer => a_norm
    2) compute error e = cross( (R^T*gEarth), a_norm ) => orientation correction
    3) gyro -= kp*e + ki*(integralOf e)
    4) integrate q
    5) subtract gravity

    We'll ignore magnetometer => partial yaw drift possible.
    """

    dt = 1.0/fs
    N = acc_data.shape[0]
    out = np.zeros_like(acc_data)

    # orientation
    q = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)
    # integral error
    eInt = np.array([0.0, 0.0, 0.0], dtype=np.float32)

    def normalize(v):
        n = np.linalg.norm(v)
        if n<1e-12: return v
        return v/n

    for i in range(N):
        ax, ay, az = acc_data[i]
        gx, gy, gz = gyro_data[i]
        # 1) measure gravity direction
        a_norm = np.array([ax, ay, az], dtype=np.float32)
        a_norm = normalize(a_norm)

        # 2) rotate Earth gravity (0,0,1) to sensor frame => v = q*g(earth)*q^-1
        # Or simpler: v_s = R^T*g. Then e = cross(v_s, a_norm).
        # We'll build R from q
        q0,q1,q2,q3 = q
        R11 = 1 - 2*(q2*q2 + q3*q3)
        R12 = 2*(q1*q2 - q0*q3)
        R13 = 2*(q1*q3 + q0*q2)
        R21 = 2*(q1*q2 + q0*q3)
        R22 = 1 - 2*(q1*q1 + q3*q3)
        R23 = 2*(q2*q3 - q0*q1)
        R31 = 2*(q1*q3 - q0*q2)
        R32 = 2*(q2*q3 + q0*q1)
        R33 = 1 - 2*(q1*q1 + q2*q2)
        vx = R13*1.0  # Earth's (0,0,1)
        vy = R23*1.0
        vz = R33*1.0
        # e = cross(v, a_norm)
        ex = vy*a_norm[2] - vz*a_norm[1]
        ey = vz*a_norm[0] - vx*a_norm[2]
        ez = vx*a_norm[1] - vy*a_norm[0]
        e = np.array([ex,ey,ez], dtype=np.float32)

        # 3) integral
        eInt += e*dt*ki
        # correct gyro
        gx_corr = gx - (kp*ex + eInt[0])
        gy_corr = gy - (kp*ey + eInt[1])
        gz_corr = gz - (kp*ez + eInt[2])

        # 4) integrate q with corrected gyro
        qDot = 0.5*np.array([
            -q1*gx_corr - q2*gy_corr - q3*gz_corr,
             q0*gx_corr + q2*gz_corr - q3*gy_corr,
             q0*gy_corr - q1*gz_corr + q3*gx_corr,
             q0*gz_corr + q1*gy_corr - q2*gx_corr
        ], dtype=np.float32)
        q = q + qDot*dt
        q = normalize(q)

        # 5) subtract gravity
        # same approach
        q0,q1,q2,q3 = q
        gx_s = R11*0 + R12*0 + R13*9.81
        gy_s = R21*0 + R22*0 + R23*9.81
        gz_s = R31*0 + R32*0 + R33*9.81
        out[i,0] = ax - gx_s
        out[i,1] = ay - gy_s
        out[i,2] = az - gz_s

    return out

# ------------------------------------------------------------------------------
# 4) EXTENDED KALMAN FILTER (already real math from earlier code)
# ------------------------------------------------------------------------------

# We'll keep the same implementation from prior examples
# see code further below

def sliding_window(data: np.ndarray, window_size=128, stride=64):
    n = data.shape[0]
    feats = data.shape[1]
    out = []
    i = 0
    while i+window_size <= n:
        out.append(data[i:i+window_size])
        i += stride
    if not out:
        return np.empty((0, window_size, feats), dtype=np.float32)
    return np.stack(out, axis=0)

# The next sections match your existing pipeline for loading watch data,
# merging, building a MultiFilter approach, etc.


###############################################################################
# Existing classes for matched watch data
###############################################################################
class MatchedWatchTrial:
    def __init__(self, s_id, a_id, t_id):
        self.subject_id = s_id
        self.activity_id= a_id
        self.trial_id   = t_id
        self.acc_file   = None
        self.gyr_file   = None

    def __repr__(self):
        return (f"MatchedWatchTrial(S={self.subject_id},A={self.activity_id},T={self.trial_id},"
                f"acc={self.acc_file}, gyr={self.gyr_file})")


class SmartFallMM_Watch:
    """
    Only parse watch data from:
      data/smartfallmm/young/accelerometer/watch
      data/smartfallmm/young/gyroscope/watch
    """
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.trials: Dict[Tuple[int,int,int], MatchedWatchTrial] = {}

    def _add_file(self, subj, act, tri, is_acc, filepath):
        key = (subj, act, tri)
        if key not in self.trials:
            self.trials[key] = MatchedWatchTrial(subj, act, tri)
        if is_acc:
            self.trials[key].acc_file = filepath
        else:
            self.trials[key].gyr_file = filepath

    def load_files(self):
        # parse SxxAyyTzz for "young/accelerometer/watch" or "young/gyroscope/watch"
        paths = [
            os.path.join(self.root_dir, "young", "accelerometer", "watch"),
            os.path.join(self.root_dir, "young", "gyroscope",     "watch")
        ]
        for base in paths:
            is_acc = ("accelerometer" in base)
            for root, dirs, files in os.walk(base):
                for f in files:
                    if not f.lower().endswith(".csv"):
                        continue
                    try:
                        s_id = int(f[1:3])
                        a_id = int(f[4:6])
                        t_id = int(f[7:9])
                        fullpath = os.path.join(root, f)
                        self._add_file(s_id, a_id, t_id, is_acc, fullpath)
                    except:
                        pass


def flexible_csvloader(file_path: str) -> np.ndarray:
    """
    Reads watch data with possible commas/semicolons.
    Takes last 3 columns as floats (x,y,z).
    """
    try:
        df = pd.read_csv(
            file_path,
            sep='[;,]',
            engine='python',
            header=None
        ).dropna().bfill()
        nc = df.shape[1]
        if nc < 3:
            print(f"[ERROR] {file_path} has only {nc} columns.")
            return np.empty((0,3), dtype=np.float32)
        arr = df.iloc[:,-3:].astype(np.float32).to_numpy()
        print(f"[INFO] Loaded {file_path}, shape={arr.shape}")
        return arr
    except Exception as e:
        print(f"[ERROR] Could not parse file: {file_path}. Reason: {e}")
        return np.empty((0,3), dtype=np.float32)


def merge_acc_gyro(acc_arr: np.ndarray, gyr_arr: np.ndarray) -> np.ndarray:
    if acc_arr.shape[0] == 0 or gyr_arr.shape[0] == 0:
        return np.empty((0,6), dtype=np.float32)
    N = min(acc_arr.shape[0], gyr_arr.shape[0])
    merged = np.concatenate([acc_arr[:N], gyr_arr[:N]], axis=1)
    return merged


###############################################################################
# Extended Kalman Filter (with real math from prior script)
###############################################################################
def normalize_quaternion(q: np.ndarray) -> np.ndarray:
    return q / np.linalg.norm(q)

def quat_to_rot_mat(q: np.ndarray) -> np.ndarray:
    qw, qx, qy, qz = q
    R = np.array([
        [1 - 2*(qy**2 + qz**2), 2*(qx*qy - qw*qz),     2*(qx*qz + qw*qy)],
        [2*(qx*qy + qw*qz),     1 - 2*(qx**2 + qz**2), 2*(qy*qz - qw*qx)],
        [2*(qx*qz - qw*qy),     2*(qy*qz + qw*qx),     1 - 2*(qx**2 + qy**2)]
    ])
    return R

def integrate_gyro(q: np.ndarray, w: np.ndarray, dt: float) -> np.ndarray:
    qw, qx, qy, qz = q
    wx, wy, wz = w
    dq = 0.5 * np.array([
        - qx*wx - qy*wy - qz*wz,
          qw*wx + qy*wz - qz*wy,
          qw*wy - qx*wz + qz*wx,
          qw*wz + qx*wy - qy*wx
    ])
    q_new = q + dq * dt
    return normalize_quaternion(q_new)

class EKFInertial:
    def __init__(self, dt=1/31.125, var_gyro=1e-5, var_acc=5e-2):
        self.dt = dt
        self.x = np.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
        self.P = np.eye(7) * 0.01
        self.Q = np.eye(7) * var_gyro
        self.Q[0:4, 0:4] *= 0.1
        self.R = np.eye(3) * var_acc

    def predict(self, gyro_meas: np.ndarray):
        q = self.x[0:4]
        bg= self.x[4:7]
        w = gyro_meas - bg
        q_new = integrate_gyro(q, w, self.dt)
        self.x[0:4] = q_new
        F = np.eye(7)
        self.P = F @ self.P @ F.T + self.Q

    def update(self, acc_meas: np.ndarray):
        q = self.x[0:4]
        R_sb = quat_to_rot_mat(q).T
        g_s = R_sb @ np.array([0,0,9.81])
        z = acc_meas - g_s
        eps = 1e-6
        H = np.zeros((3,7))
        for i in range(4):
            dq = np.zeros(4)
            dq[i] = eps
            q_pert = normalize_quaternion(q + dq)
            R_pert = quat_to_rot_mat(q_pert).T
            g_s_pert = R_pert @ np.array([0,0,9.81])
            dh = (g_s_pert - g_s)/eps
            H[:,i] = dh
        S = H @ self.P @ H.T + self.R
        K = self.P @ H.T @ np.linalg.inv(S)
        dx = K @ z
        self.x += dx
        self.x[0:4] = normalize_quaternion(self.x[0:4])
        I_7 = np.eye(7)
        self.P = (I_7 - K@H) @ self.P @ (I_7 - K@H).T + K@self.R@K.T

    def fuse(self, gyro: np.ndarray, acc: np.ndarray):
        self.predict(gyro)
        self.update(acc)
        q = self.x[0:4]
        R_sb = quat_to_rot_mat(q).T
        g_s = R_sb @ np.array([0,0,9.81])
        return g_s

def ekf_fusion(acc_data: np.ndarray, gyro_data: np.ndarray, fs=31.125):
    dt = 1.0/fs
    ekf = EKFInertial(dt=dt, var_gyro=1e-5, var_acc=5e-2)
    N = acc_data.shape[0]
    out = np.zeros_like(acc_data)
    for i in range(N):
        g_s = ekf.fuse(gyro_data[i], acc_data[i])
        out[i] = acc_data[i] - g_s
    return out

###############################################################################
# Simple stats: compare each filter to raw acc (RMSE, correlation)
###############################################################################
def compare_filter_stats(raw_acc: np.ndarray, fused: np.ndarray, label:str):
    N = raw_acc.shape[0]
    if fused.shape[0] != N:
        print(f"[ERROR] Mismatch shapes for stats. {label} skip stats.")
        return
    dims = ["X","Y","Z"]
    for d in range(3):
        ra = raw_acc[:,d]
        fu = fused[:,d]
        rmse = np.sqrt(np.mean((fu - ra)**2))
        corr = np.corrcoef(ra, fu)[0,1]
        print(f"{label} dim={dims[d]}: RMSE={rmse:.3f}, Corr={corr:.3f}")

###############################################################################
# The "MultiFilterBuilder" for subject=31, T=1..5
###############################################################################
class MultiFilterBuilder:
    def __init__(self,
                 watch_dataset: SmartFallMM_Watch,
                 fs=31.125,
                 subject_of_interest=31,
                 trial_range=range(1,6),
                 window_size=128,
                 stride=64):
        self.dataset = watch_dataset
        self.fs = fs
        self.subject_of_interest = subject_of_interest
        self.trial_range = trial_range
        self.window_size = window_size
        self.stride = stride

    def _ensure_dir(self, p):
        os.makedirs(p, exist_ok=True)

    def _plot_raw_vs_filter(self, t_axis, raw_acc, raw_gyro, fused,
                            s_id, a_id, t_id, filter_name):
        fig, axs = plt.subplots(3,1, figsize=(16,12), dpi=200, sharex=True)

        axs[0].plot(t_axis, raw_acc[:,0], 'g', label='AccX')
        axs[0].plot(t_axis, raw_gyro[:,0],'m', label='GyroX')
        axs[0].plot(t_axis, fused[:,0],   'b', label=f'{filter_name}-X')
        axs[0].legend(loc='upper right')
        axs[0].set_ylabel('m/s^2')

        axs[1].plot(t_axis, raw_acc[:,1], 'g', label='AccY')
        axs[1].plot(t_axis, raw_gyro[:,1],'m', label='GyroY')
        axs[1].plot(t_axis, fused[:,1],   'b', label=f'{filter_name}-Y')
        axs[1].legend(loc='upper right')
        axs[1].set_ylabel('m/s^2')

        axs[2].plot(t_axis, raw_acc[:,2], 'g', label='AccZ')
        axs[2].plot(t_axis, raw_gyro[:,2],'m', label='GyroZ')
        axs[2].plot(t_axis, fused[:,2],   'b', label=f'{filter_name}-Z')
        axs[2].legend(loc='upper right')
        axs[2].set_ylabel('m/s^2')
        axs[2].set_xlabel('Time (s)')

        fig.suptitle(f"S{s_id:02d}A{a_id:02d}T{t_id:02d}: Raw vs. {filter_name}")
        return fig

    def _plot_all_filters(self, t_axis, raw_acc, raw_gyro, fused_dict,
                          s_id, a_id, t_id, norm=False):
        """
        fused_dict = { 'madgwick': Nx3, 'comp': Nx3, 'mahony': Nx3, 'ekf': Nx3 }
        if norm=True, do local normalization for each filter
        """
        # optionally normalize each filter's Nx3
        to_plot = {}
        if norm:
            for k,v in fused_dict.items():
                mean_ = v.mean(axis=0)
                std_  = v.std(axis=0)+1e-9
                to_plot[k] = (v-mean_)/std_
        else:
            to_plot = fused_dict

        fig, axs = plt.subplots(3,1, figsize=(16,12), dpi=200, sharex=True)

        color_map = {
            "madgwick": 'r',
            "comp":     'c',
            "mahony":   'y',
            "ekf":      'b'
        }

        axs[0].plot(t_axis, raw_acc[:,0], 'g', label='AccX')
        axs[0].plot(t_axis, raw_gyro[:,0],'m', label='GyroX')
        for fname,arr_ in to_plot.items():
            axs[0].plot(t_axis, arr_[:,0], color_map[fname], label=f'{fname}-X')
        axs[0].legend(loc='upper right')
        axs[0].set_ylabel('X')

        axs[1].plot(t_axis, raw_acc[:,1], 'g', label='AccY')
        axs[1].plot(t_axis, raw_gyro[:,1],'m', label='GyroY')
        for fname,arr_ in to_plot.items():
            axs[1].plot(t_axis, arr_[:,1], color_map[fname], label=f'{fname}-Y')
        axs[1].legend(loc='upper right')
        axs[1].set_ylabel('Y')

        axs[2].plot(t_axis, raw_acc[:,2], 'g', label='AccZ')
        axs[2].plot(t_axis, raw_gyro[:,2],'m', label='GyroZ')
        for fname,arr_ in to_plot.items():
            axs[2].plot(t_axis, arr_[:,2], color_map[fname], label=f'{fname}-Z')
        axs[2].legend(loc='upper right')
        axs[2].set_ylabel('Z')
        axs[2].set_xlabel('Time (s)')

        if norm:
            fig.suptitle(f"S{s_id:02d}A{a_id:02d}T{t_id:02d}: All Filters (Norm)")
        else:
            fig.suptitle(f"S{s_id:02d}A{a_id:02d}T{t_id:02d}: All Filters (Unnorm)")

        return fig

    def process_all(self):
        for (s_id, a_id, t_id), trial_obj in self.dataset.trials.items():
            if s_id != self.subject_of_interest:
                continue
            if t_id < 1 or t_id > 5:
                continue
            if (trial_obj.acc_file is None) or (trial_obj.gyr_file is None):
                continue

            acc_data = flexible_csvloader(trial_obj.acc_file)   # Nx3
            gyr_data = flexible_csvloader(trial_obj.gyr_file)   # Nx3
            merged   = merge_acc_gyro(acc_data, gyr_data)        # Nx6

            if merged.shape[0] == 0:
                print(f"[WARNING] Empty merged: S{s_id:02d}A{a_id:02d}T{t_id:02d}")
                continue

            raw_acc  = merged[:,0:3]
            raw_gyro = merged[:,3:6]
            N = raw_acc.shape[0]
            t_axis = np.arange(N)/self.fs

            # run real math for each filter
            madg_out = madgwick_fusion(raw_acc, raw_gyro, fs=self.fs, beta=0.05)
            comp_out = complementary_fusion(raw_acc, raw_gyro, fs=self.fs, alpha=0.98)
            mahy_out = mahony_fusion(raw_acc, raw_gyro, fs=self.fs, kp=2.0, ki=0.0)
            ekf_out  = ekf_fusion(raw_acc, raw_gyro, fs=self.fs)

            print(f"\n[S{ s_id:02d} A{ a_id:02d} T{t_id:02d}] Filter Stats vs. Raw Acc:")
            compare_filter_stats(raw_acc, madg_out, "Madgwick")
            compare_filter_stats(raw_acc, comp_out, "Complementary")
            compare_filter_stats(raw_acc, mahy_out, "Mahony")
            compare_filter_stats(raw_acc, ekf_out,  "EKF")

            outdir = f"visualizations/{s_id}/{a_id}"
            self._ensure_dir(outdir)

            # Plot raw vs. each filter
            fig_madg = self._plot_raw_vs_filter(t_axis, raw_acc, raw_gyro,
                                                madg_out, s_id,a_id,t_id, "Madgwick")
            plt.savefig(f"{outdir}/S{s_id:02d}A{a_id:02d}T{t_id:02d}_madgwick.png")
            plt.close(fig_madg)

            fig_comp = self._plot_raw_vs_filter(t_axis, raw_acc, raw_gyro,
                                                comp_out, s_id,a_id,t_id, "Complementary")
            plt.savefig(f"{outdir}/S{s_id:02d}A{a_id:02d}T{t_id:02d}_complementary.png")
            plt.close(fig_comp)

            fig_mahony= self._plot_raw_vs_filter(t_axis, raw_acc, raw_gyro,
                                                 mahy_out, s_id,a_id,t_id, "Mahony")
            plt.savefig(f"{outdir}/S{s_id:02d}A{a_id:02d}T{t_id:02d}_mahony.png")
            plt.close(fig_mahony)

            fig_ekf= self._plot_raw_vs_filter(t_axis, raw_acc, raw_gyro,
                                              ekf_out, s_id,a_id,t_id, "EKF")
            plt.savefig(f"{outdir}/S{s_id:02d}A{a_id:02d}T{t_id:02d}_ekf.png")
            plt.close(fig_ekf)

            # all filters unnorm
            fused_dict = {"madgwick":madg_out, "comp":comp_out, "mahony":mahy_out, "ekf":ekf_out}
            fig_all= self._plot_all_filters(t_axis, raw_acc, raw_gyro,
                                            fused_dict, s_id,a_id,t_id, norm=False)
            plt.savefig(f"{outdir}/S{s_id:02d}A{a_id:02d}T{t_id:02d}_allFilters.png")
            plt.close(fig_all)

            # all filters norm
            fig_allN= self._plot_all_filters(t_axis, raw_acc, raw_gyro,
                                             fused_dict, s_id,a_id,t_id, norm=True)
            plt.savefig(f"{outdir}/S{s_id:02d}A{a_id:02d}T{t_id:02d}_allFiltersNorm.png")
            plt.close(fig_allN)

            # sliding window (e.g. for ekf_out)
            ekf_windows = sliding_window(ekf_out, self.window_size, self.stride)
            print(f"[INFO] S{s_id:02d}A{a_id:02d}T{t_id:02d} => EKF windows: {ekf_windows.shape[0]}")

            print(f"[DONE] S{s_id:02d}A{a_id:02d}T{t_id:02d} => Plots in {outdir}")


# %% [markdown]
# ## 6. Main

if __name__ == "__main__":
    dataset_root = "../data/smartfallmm"
    watch_data = SmartFallMM_Watch(dataset_root)
    watch_data.load_files()
    print(f"[INFO] Found {len(watch_data.trials)} watch-based trials in {dataset_root}")

    builder = MultiFilterBuilder(
        watch_dataset=watch_data,
        fs=31.125,
        subject_of_interest=31,    # only subject=31
        trial_range=range(1,6),    # T=1..5
        window_size=64,
        stride=64
    )
    builder.process_all()
    print("[INFO] Completed multi-filter comparisons for subject=31, T=1..5.")


[INFO] Found 906 watch-based trials in ../data/smartfallmm
[INFO] Loaded ../data/smartfallmm/young/accelerometer/watch/S31A06T01.csv, shape=(441, 3)
[INFO] Loaded ../data/smartfallmm/young/gyroscope/watch/S31A06T01.csv, shape=(438, 3)

[S31 A06 T01] Filter Stats vs. Raw Acc:
Madgwick dim=X: RMSE=4.792, Corr=0.786
Madgwick dim=Y: RMSE=5.558, Corr=0.925
Madgwick dim=Z: RMSE=6.510, Corr=0.567
Complementary dim=X: RMSE=7.380, Corr=0.926
Complementary dim=Y: RMSE=4.299, Corr=0.950
Complementary dim=Z: RMSE=4.826, Corr=0.879
Mahony dim=X: RMSE=4.903, Corr=0.788
Mahony dim=Y: RMSE=6.502, Corr=0.933
Mahony dim=Z: RMSE=5.470, Corr=0.701
EKF dim=X: RMSE=6.426, Corr=0.464
EKF dim=Y: RMSE=6.056, Corr=0.882
EKF dim=Z: RMSE=4.275, Corr=0.729
[INFO] S31A06T01 => EKF windows: 6
[DONE] S31A06T01 => Plots in visualizations/31/6
[INFO] Loaded ../data/smartfallmm/young/accelerometer/watch/S31A07T01.csv, shape=(211, 3)
[INFO] Loaded ../data/smartfallmm/young/gyroscope/watch/S31A07T01.csv, shape=(208, 3)

[

In [1]:
################################################################################
# ExtendedKalmanFilter_FallDetection_YoungWatch_MultiFilters.py
# 
# Real math for (6-DOF) Complementary, Madgwick, Mahony, and EKF.
# Watch-based ACC+GYRO data from data/smartfallmm/young/{accelerometer|gyroscope}/watch.
# Merged by matching SxxAyyTzz filenames. Offline gravity removal for fall detection.
################################################################################

import os
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import Dict, Tuple, List
from numpy.linalg import norm
from scipy.signal import butter, filtfilt
from scipy.stats import pearsonr
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

# Fix seeds
torch.manual_seed(42)
np.random.seed(42)

###############################################################################
# 1) Complementary Filter (pitch/roll only, ignoring yaw)
###############################################################################
def complementary_fusion(acc_data: np.ndarray,
                         gyro_data: np.ndarray,
                         fs=31.125,
                         alpha=0.98):
    """
    Angle-based complementary filter:
      - Estimate pitch, roll from accelerometer
      - Integrate gyro in pitch, roll
      - Blend them with 'alpha'.
      - Subtract gravity from raw ACC => linear accel.
    We omit yaw because no magnetometer is used.
    """
    dt = 1.0/fs
    N = acc_data.shape[0]
    out = np.zeros_like(acc_data)

    pitch, roll = 0.0, 0.0

    def accel_to_angles(ax, ay, az):
        # pitch = atan2(-ax, sqrt(ay^2 + az^2))
        # roll  = atan2(ay, az)
        pitchA = np.arctan2(-ax, np.sqrt(ay*ay + az*az))
        rollA  = np.arctan2(ay, az)
        return pitchA, rollA

    for i in range(N):
        ax, ay, az = acc_data[i]
        gx, gy, gz = gyro_data[i]  # rad/s assumed

        # 1) Acc => pitch/roll
        pitchAcc, rollAcc = accel_to_angles(ax, ay, az)

        # 2) Integrate gyro for pitch, roll
        pitchGyro = pitch + (gy*dt)
        rollGyro  = roll  + (gx*dt)*(-1.0)  # sign depends on orientation

        # 3) Complement
        pitch = alpha*pitchGyro + (1-alpha)*pitchAcc
        roll  = alpha*rollGyro  + (1-alpha)*rollAcc

        # 4) Reconstruct gravity => subtract
        g = 9.81
        gx_sens = -np.sin(pitch)*g
        gy_sens =  np.sin(roll)*np.cos(pitch)*g
        gz_sens =  np.cos(roll)*np.cos(pitch)*g

        out[i,0] = ax - gx_sens
        out[i,1] = ay - gy_sens
        out[i,2] = az - gz_sens

    return out

###############################################################################
# 2) Madgwick (6-DOF, no magnetometer)
###############################################################################
def madgwick_fusion(acc_data: np.ndarray,
                    gyro_data: np.ndarray,
                    fs=31.125,
                    beta=0.05):
    """
    Madgwick: orientation quaternion q updated by gradient descent step
    ignoring magnetometer. Then subtract gravity from raw ACC.
    """
    dt = 1.0/fs
    N = acc_data.shape[0]
    out = np.zeros_like(acc_data)

    q = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)

    def normalize(v):
        n = np.linalg.norm(v)
        return v if n<1e-12 else v/n

    for i in range(N):
        ax, ay, az = acc_data[i]
        gx, gy, gz = gyro_data[i]  # rad/s
        # normalize ACC
        a_norm = np.array([ax, ay, az], dtype=np.float32)
        a_norm = normalize(a_norm)

        q0,q1,q2,q3 = q

        # f = objective function, J=Jacobian from Madgwick
        f = np.array([
            2*(q1*q3 - q0*q2) - a_norm[0],
            2*(q0*q1 + q2*q3) - a_norm[1],
            2*(0.5 - q1*q1 - q2*q2) - a_norm[2]
        ], dtype=np.float32)
        J = np.array([
            [-2*q2,         2*q3,        -2*q0,        2*q1],
            [ 2*q1,         2*q0,         2*q3,        2*q2],
            [ 0.0   ,      -4*q1,        -4*q2,        0.0   ]
        ], dtype=np.float32)
        step = J.T.dot(f)
        step = normalize(step)

        # gyro-based rate
        qDot_omega = 0.5*np.array([
            -q1*gx - q2*gy - q3*gz,
             q0*gx + q2*gz - q3*gy,
             q0*gy - q1*gz + q3*gx,
             q0*gz + q1*gy - q2*gx
        ], dtype=np.float32)
        # combine
        qDot = qDot_omega - beta*step
        q += qDot*dt
        q = normalize(q)

        # subtract gravity
        q0,q1,q2,q3 = q
        R11 = 1 - 2*(q2*q2 + q3*q3)
        R12 = 2*(q1*q2 - q0*q3)
        R13 = 2*(q1*q3 + q0*q2)
        R21 = 2*(q1*q2 + q0*q3)
        R22 = 1 - 2*(q1*q1 + q3*q3)
        R23 = 2*(q2*q3 - q0*q1)
        R31 = 2*(q1*q3 - q0*q2)
        R32 = 2*(q2*q3 + q0*q1)
        R33 = 1 - 2*(q1*q1 + q2*q2)

        gx_s = R13*9.81
        gy_s = R23*9.81
        gz_s = R33*9.81

        out[i,0] = ax - gx_s
        out[i,1] = ay - gy_s
        out[i,2] = az - gz_s

    return out

###############################################################################
# 3) Mahony (6-DOF, no magnetometer)
###############################################################################
def mahony_fusion(acc_data: np.ndarray,
                  gyro_data: np.ndarray,
                  fs=31.125,
                  kp=2.0,
                  ki=0.0):
    """
    Mahony filter with orientation quaternion, ignoring magnetometer => 6-DoF.
    We track an integral error eInt, correct gyro, integrate, subtract gravity.
    """
    dt = 1.0/fs
    N = acc_data.shape[0]
    out = np.zeros_like(acc_data)

    q = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)
    eInt = np.array([0.0, 0.0, 0.0], dtype=np.float32)

    def normalize(v):
        n = np.linalg.norm(v)
        return v if n<1e-12 else v/n

    for i in range(N):
        ax, ay, az = acc_data[i]
        gx, gy, gz = gyro_data[i]

        a_norm = np.array([ax, ay, az], dtype=np.float32)
        a_norm = normalize(a_norm)

        q0,q1,q2,q3 = q
        # Build rotation matrix
        R11 = 1 - 2*(q2*q2 + q3*q3)
        R12 = 2*(q1*q2 - q0*q3)
        R13 = 2*(q1*q3 + q0*q2)
        R21 = 2*(q1*q2 + q0*q3)
        R22 = 1 - 2*(q1*q1 + q3*q3)
        R23 = 2*(q2*q3 - q0*q1)
        R31 = 2*(q1*q3 - q0*q2)
        R32 = 2*(q2*q3 + q0*q1)
        R33 = 1 - 2*(q1*q1 + q2*q2)
        vx = R13*1.0
        vy = R23*1.0
        vz = R33*1.0
        ex = vy*a_norm[2] - vz*a_norm[1]
        ey = vz*a_norm[0] - vx*a_norm[2]
        ez = vx*a_norm[1] - vy*a_norm[0]
        e = np.array([ex, ey, ez], dtype=np.float32)

        eInt += e*dt*ki

        gx_corr = gx - (kp*ex + eInt[0])
        gy_corr = gy - (kp*ey + eInt[1])
        gz_corr = gz - (kp*ez + eInt[2])

        qDot = 0.5*np.array([
            -q1*gx_corr - q2*gy_corr - q3*gz_corr,
             q0*gx_corr + q2*gz_corr - q3*gy_corr,
             q0*gy_corr - q1*gz_corr + q3*gx_corr,
             q0*gz_corr + q1*gy_corr - q2*gx_corr
        ], dtype=np.float32)
        q += qDot*dt
        n = np.linalg.norm(q)
        if n>1e-12:
            q /= n

        # subtract gravity
        q0,q1,q2,q3 = q
        gx_s = R13*9.81
        gy_s = R23*9.81
        gz_s = R33*9.81

        out[i,0] = ax - gx_s
        out[i,1] = ay - gy_s
        out[i,2] = az - gz_s

    return out

###############################################################################
# 4) EKF (6-DOF, ignoring magnetometer)
###############################################################################
def normalize_quaternion(q: np.ndarray) -> np.ndarray:
    return q / np.linalg.norm(q)

def quat_to_rot_mat(q: np.ndarray) -> np.ndarray:
    qw, qx, qy, qz = q
    R = np.array([
        [1 - 2*(qy**2 + qz**2), 2*(qx*qy - qw*qz),     2*(qx*qz + qw*qy)],
        [2*(qx*qy + qw*qz),     1 - 2*(qx**2 + qz**2), 2*(qy*qz - qw*qx)],
        [2*(qx*qz - qw*qy),     2*(qy*qz + qw*qx),     1 - 2*(qx**2 + qy**2)]
    ])
    return R

def integrate_gyro(q: np.ndarray, w: np.ndarray, dt: float) -> np.ndarray:
    qw, qx, qy, qz = q
    wx, wy, wz = w
    dq = 0.5 * np.array([
        - qx*wx - qy*wy - qz*wz,
          qw*wx + qy*wz - qz*wy,
          qw*wy - qx*wz + qz*wx,
          qw*wz + qx*wy - qy*wx
    ])
    q_new = q + dq*dt
    return normalize_quaternion(q_new)

class EKFInertial:
    """
    State = [q0, q1, q2, q3, b_gyroX, b_gyroY, b_gyroZ]
    Use orientation + gyro bias, ignoring magnetometer.
    """
    def __init__(self, dt=1/31.125, var_gyro=1e-5, var_acc=5e-2):
        self.dt = dt
        self.x = np.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
        self.P = np.eye(7)*0.01
        self.Q = np.eye(7)*var_gyro
        self.Q[0:4,0:4]*=0.1
        self.R = np.eye(3)*var_acc

    def predict(self, gyro_meas: np.ndarray):
        q = self.x[0:4]
        b_gyro = self.x[4:7]
        w = gyro_meas - b_gyro
        q_new = integrate_gyro(q, w, self.dt)
        self.x[0:4] = q_new
        # simple linear P update
        F = np.eye(7)
        self.P = F@self.P@F.T + self.Q

    def update(self, acc_meas: np.ndarray):
        q = self.x[0:4]
        R_sb = quat_to_rot_mat(q).T
        g_s = R_sb @ np.array([0,0,9.81])
        z = acc_meas - g_s
        eps=1e-6

        # numeric jac wrt q
        H = np.zeros((3,7))
        for i in range(4):
            dq = np.zeros(4)
            dq[i]= eps
            q_pert = normalize_quaternion(q + dq)
            Rp = quat_to_rot_mat(q_pert).T
            gp = Rp@np.array([0,0,9.81])
            dh = (gp - g_s)/eps
            H[:,i] = dh

        S= H@self.P@H.T + self.R
        K= self.P@H.T@np.linalg.inv(S)
        dx= K@z
        self.x+= dx
        self.x[0:4] = normalize_quaternion(self.x[0:4])
        I7= np.eye(7)
        self.P= (I7 - K@H)@self.P@(I7 - K@H).T + K@self.R@K.T

    def fuse(self, gyro: np.ndarray, acc: np.ndarray):
        self.predict(gyro)
        self.update(acc)
        q= self.x[0:4]
        R_sb= quat_to_rot_mat(q).T
        g_s= R_sb@np.array([0,0,9.81])
        return g_s

def ekf_fusion(acc_data: np.ndarray, gyro_data: np.ndarray, fs=31.125):
    dt= 1.0/fs
    ekf= EKFInertial(dt=dt, var_gyro=1e-5, var_acc=5e-2)
    N= acc_data.shape[0]
    out= np.zeros_like(acc_data)
    for i in range(N):
        g_s= ekf.fuse(gyro_data[i], acc_data[i])
        out[i]= acc_data[i] - g_s
    return out

###############################################################################
# Simple stats vs. raw
###############################################################################
def compare_filter_stats(raw_acc: np.ndarray, fused: np.ndarray, label:str):
    if fused.shape[0] != raw_acc.shape[0]:
        print(f"[ERROR] Mismatch shapes for {label} stats.")
        return
    N= raw_acc.shape[0]
    dims= ["X","Y","Z"]
    for d in range(3):
        ra= raw_acc[:,d]
        fu= fused[:,d]
        rmse= np.sqrt(np.mean((fu - ra)**2))
        corr= np.corrcoef(ra,fu)[0,1]
        print(f"{label} dim={dims[d]}: RMSE={rmse:.3f}, Corr={corr:.3f}")

###############################################################################
# Sliding Window
###############################################################################
def sliding_window(data: np.ndarray, window_size=128, stride=64):
    n= data.shape[0]
    feats= data.shape[1]
    out=[]
    i=0
    while i+window_size <= n:
        out.append(data[i:i+window_size])
        i+= stride
    if not out:
        return np.empty((0, window_size, feats), dtype=np.float32)
    return np.stack(out, axis=0)

###############################################################################
# Data Classes for watch data
###############################################################################
class MatchedWatchTrial:
    def __init__(self, s_id, a_id, t_id):
        self.subject_id= s_id
        self.activity_id= a_id
        self.trial_id= t_id
        self.acc_file= None
        self.gyr_file= None

class SmartFallMM_Watch:
    """
    Loads watch ACC from data/smartfallmm/young/accelerometer/watch
    and watch GYR from data/smartfallmm/young/gyroscope/watch
    => merges by SxxAyyTzz
    """
    def __init__(self, root_dir):
        self.root_dir= root_dir
        self.trials: Dict[Tuple[int,int,int], MatchedWatchTrial]= {}

    def _add_file(self, subj, act, tri, is_acc, filepath):
        key= (subj, act, tri)
        if key not in self.trials:
            self.trials[key]= MatchedWatchTrial(subj,act,tri)
        if is_acc:
            self.trials[key].acc_file= filepath
        else:
            self.trials[key].gyr_file= filepath

    def load_files(self):
        # For watch ACC => data/smartfallmm/young/accelerometer/watch
        # For watch GYR => data/smartfallmm/young/gyroscope/watch
        base_paths= [
            os.path.join(self.root_dir, "young","accelerometer","watch"),
            os.path.join(self.root_dir, "young","gyroscope","watch")
        ]
        for bp in base_paths:
            is_acc= ("accelerometer" in bp)
            for root,dirs,files in os.walk(bp):
                for f in files:
                    if not f.lower().endswith(".csv"):
                        continue
                    # parse S##A##T##
                    try:
                        s_id= int(f[1:3])
                        a_id= int(f[4:6])
                        t_id= int(f[7:9])
                        fullpath= os.path.join(root,f)
                        self._add_file(s_id,a_id,t_id, is_acc, fullpath)
                    except:
                        pass

def flexible_csvloader(file_path: str)-> np.ndarray:
    """
    Attempt to read the last 3 columns as float for (x,y,z).
    Accept both commas or semicolons.
    """
    try:
        df= pd.read_csv(file_path,sep='[;,]',engine='python',header=None).dropna().bfill()
        nc= df.shape[1]
        if nc<3:
            print(f"[ERROR] {file_path} only has {nc} cols.")
            return np.empty((0,3), dtype=np.float32)
        arr= df.iloc[:,-3:].astype(np.float32).to_numpy()
        print(f"[INFO] Loaded {file_path}, shape={arr.shape}")
        return arr
    except Exception as e:
        print(f"[ERROR] Could not parse file: {file_path}. Reason: {e}")
        return np.empty((0,3), dtype=np.float32)

def merge_acc_gyro(acc_arr: np.ndarray, gyr_arr: np.ndarray)-> np.ndarray:
    """
    Truncate to min length, combine => Nx6 => [accX,accY,accZ, gyroX,gyroY,gyroZ].
    """
    if acc_arr.shape[0]==0 or gyr_arr.shape[0]==0:
        return np.empty((0,6), dtype=np.float32)
    N= min(acc_arr.shape[0], gyr_arr.shape[0])
    merged= np.concatenate([acc_arr[:N], gyr_arr[:N]], axis=1)
    return merged

###############################################################################
# MultiFilterBuilder
###############################################################################
class MultiFilterBuilder:
    def __init__(self,
                 watch_dataset: SmartFallMM_Watch,
                 fs=31.125,
                 subject_of_interest=31,
                 trial_range=range(1,6),
                 window_size=32,
                 stride=32):
        self.dataset= watch_dataset
        self.fs= fs
        self.subject_of_interest= subject_of_interest
        self.trial_range= trial_range
        self.window_size= window_size
        self.stride= stride

    def _ensure_dir(self,path):
        os.makedirs(path, exist_ok=True)

    def _plot_raw_vs_filter(self, t_axis, raw_acc, raw_gyro, fused,
                            s_id, a_id, t_id, filter_name):
        fig, axs= plt.subplots(3,1, figsize=(16,12), dpi=200, sharex=True)

        axs[0].plot(t_axis, raw_acc[:,0], 'g', label='AccX')
        axs[0].plot(t_axis, raw_gyro[:,0],'m', label='GyroX')
        axs[0].plot(t_axis, fused[:,0], 'b', label=f'{filter_name}-X')
        axs[0].legend(loc='upper right')
        axs[0].set_ylabel('m/s^2')

        axs[1].plot(t_axis, raw_acc[:,1], 'g', label='AccY')
        axs[1].plot(t_axis, raw_gyro[:,1],'m', label='GyroY')
        axs[1].plot(t_axis, fused[:,1], 'b', label=f'{filter_name}-Y')
        axs[1].legend(loc='upper right')
        axs[1].set_ylabel('m/s^2')

        axs[2].plot(t_axis, raw_acc[:,2], 'g', label='AccZ')
        axs[2].plot(t_axis, raw_gyro[:,2],'m', label='GyroZ')
        axs[2].plot(t_axis, fused[:,2], 'b', label=f'{filter_name}-Z')
        axs[2].legend(loc='upper right')
        axs[2].set_ylabel('m/s^2')
        axs[2].set_xlabel('Time (s)')

        fig.suptitle(f"S{s_id:02d}A{a_id:02d}T{t_id:02d}: Raw vs. {filter_name}")
        return fig

    def _plot_all_filters(self, t_axis, raw_acc, raw_gyro, fused_dict,
                          s_id, a_id, t_id, norm=False):
        """
        fused_dict = { 'madgwick': Nx3, 'comp': Nx3, 'mahony': Nx3, 'ekf': Nx3 }
        If norm=True => local normalization for each filter’s data
        """
        # optionally normalize
        if norm:
            to_plot= {}
            for k,v in fused_dict.items():
                mean_= v.mean(axis=0)
                std_= v.std(axis=0)+1e-9
                to_plot[k]= (v-mean_)/std_
        else:
            to_plot= fused_dict

        fig, axs= plt.subplots(3,1, figsize=(16,12), dpi=200, sharex=True)
        color_map= {"madgwick":'r', "comp":'c', "mahony":'y', "ekf":'b'}

        axs[0].plot(t_axis, raw_acc[:,0], 'g', label='AccX')
        axs[0].plot(t_axis, raw_gyro[:,0], 'm', label='GyroX')
        for name,arr_ in to_plot.items():
            axs[0].plot(t_axis, arr_[:,0], color_map[name], label=f'{name}-X')
        axs[0].legend(loc='upper right')
        axs[0].set_ylabel('X')

        axs[1].plot(t_axis, raw_acc[:,1], 'g', label='AccY')
        axs[1].plot(t_axis, raw_gyro[:,1], 'm', label='GyroY')
        for name,arr_ in to_plot.items():
            axs[1].plot(t_axis, arr_[:,1], color_map[name], label=f'{name}-Y')
        axs[1].legend(loc='upper right')
        axs[1].set_ylabel('Y')

        axs[2].plot(t_axis, raw_acc[:,2], 'g', label='AccZ')
        axs[2].plot(t_axis, raw_gyro[:,2], 'm', label='GyroZ')
        for name,arr_ in to_plot.items():
            axs[2].plot(t_axis, arr_[:,2], color_map[name], label=f'{name}-Z')
        axs[2].legend(loc='upper right')
        axs[2].set_ylabel('Z')
        axs[2].set_xlabel('Time (s)')

        txt= "All Filters (Norm)" if norm else "All Filters (Unnorm)"
        fig.suptitle(f"S{s_id:02d}A{a_id:02d}T{t_id:02d}: {txt}")
        return fig

    def process_all(self):
        for (s_id, a_id, t_id), trial_obj in self.dataset.trials.items():
            if s_id!= self.subject_of_interest:
                continue
            if t_id<1 or t_id>5:
                continue
            if (trial_obj.acc_file is None) or (trial_obj.gyr_file is None):
                continue

            acc_data= flexible_csvloader(trial_obj.acc_file)
            gyr_data= flexible_csvloader(trial_obj.gyr_file)
            merged= merge_acc_gyro(acc_data, gyr_data)
            if merged.shape[0]==0:
                print(f"[WARNING] Empty merge for S{s_id:02d}A{a_id:02d}T{t_id:02d}")
                continue

            raw_acc= merged[:,:3]
            raw_gyro= merged[:,3:6]
            N= raw_acc.shape[0]
            t_axis= np.arange(N)/self.fs

            # run each filter
            madg_out= madgwick_fusion(raw_acc, raw_gyro, fs=self.fs, beta=0.05)
            comp_out= complementary_fusion(raw_acc, raw_gyro, fs=self.fs, alpha=0.98)
            mahy_out= mahony_fusion(raw_acc, raw_gyro, fs=self.fs, kp=2.0, ki=0.0)
            ekf_out= ekf_fusion(raw_acc, raw_gyro, fs=self.fs)

            print(f"\n[S{ s_id:02d} A{ a_id:02d} T{t_id:02d}] Stats vs. Raw Acc:")
            compare_filter_stats(raw_acc, madg_out, "Madgwick")
            compare_filter_stats(raw_acc, comp_out, "Complementary")
            compare_filter_stats(raw_acc, mahy_out, "Mahony")
            compare_filter_stats(raw_acc, ekf_out,  "EKF")

            outdir= f"visualizations/{s_id}/{a_id}"
            self._ensure_dir(outdir)

            # individual plots
            fig_madg= self._plot_raw_vs_filter(t_axis, raw_acc, raw_gyro, madg_out,
                                               s_id,a_id,t_id, "Madgwick")
            plt.savefig(f"{outdir}/S{s_id:02d}A{a_id:02d}T{t_id:02d}_madgwick.png")
            plt.close(fig_madg)

            fig_comp= self._plot_raw_vs_filter(t_axis, raw_acc, raw_gyro, comp_out,
                                               s_id,a_id,t_id, "Complementary")
            plt.savefig(f"{outdir}/S{s_id:02d}A{a_id:02d}T{t_id:02d}_complementary.png")
            plt.close(fig_comp)

            fig_mahony= self._plot_raw_vs_filter(t_axis, raw_acc, raw_gyro, mahy_out,
                                                 s_id,a_id,t_id, "Mahony")
            plt.savefig(f"{outdir}/S{s_id:02d}A{a_id:02d}T{t_id:02d}_mahony.png")
            plt.close(fig_mahony)

            fig_ekf= self._plot_raw_vs_filter(t_axis, raw_acc, raw_gyro, ekf_out,
                                              s_id,a_id,t_id, "EKF")
            plt.savefig(f"{outdir}/S{s_id:02d}A{a_id:02d}T{t_id:02d}_ekf.png")
            plt.close(fig_ekf)

            # overlay all
            fused_dict= {
                "madgwick": madg_out,
                "comp": comp_out,
                "mahony": mahy_out,
                "ekf": ekf_out
            }
            fig_all= self._plot_all_filters(t_axis, raw_acc, raw_gyro, fused_dict,
                                            s_id,a_id,t_id, norm=False)
            plt.savefig(f"{outdir}/S{s_id:02d}A{a_id:02d}T{t_id:02d}_allFilters.png")
            plt.close(fig_all)

            fig_allN= self._plot_all_filters(t_axis, raw_acc, raw_gyro, fused_dict,
                                             s_id,a_id,t_id, norm=True)
            plt.savefig(f"{outdir}/S{s_id:02d}A{a_id:02d}T{t_id:02d}_allFiltersNorm.png")
            plt.close(fig_allN)

            # sliding window for e.g. ekf_out
            ekf_windows= sliding_window(ekf_out, self.window_size, self.stride)
            print(f"[INFO] S{s_id:02d}A{a_id:02d}T{t_id:02d} => EKF windows: {ekf_windows.shape[0]}")
            print(f"[DONE] S{s_id:02d}A{a_id:02d}T{t_id:02d} => Plots in {outdir}")

###############################################################################
# 6. Main
###############################################################################
if __name__=="__main__":
    dataset_root= "../data/smartfallmm"
    watch_data= SmartFallMM_Watch(dataset_root)
    watch_data.load_files()
    print(f"[INFO] Found {len(watch_data.trials)} watch-based trials in {dataset_root}")

    builder= MultiFilterBuilder(
        watch_dataset= watch_data,
        fs=31.125,
        subject_of_interest=31,
        trial_range= range(1,6),
        window_size=32,
        stride=32
    )
    builder.process_all()
    print("[INFO] Completed multi-filter comparisons for subject=31, T=1..5.")


[INFO] Found 906 watch-based trials in ../data/smartfallmm
[INFO] Loaded ../data/smartfallmm/young/accelerometer/watch/S31A06T01.csv, shape=(441, 3)
[INFO] Loaded ../data/smartfallmm/young/gyroscope/watch/S31A06T01.csv, shape=(438, 3)

[S31 A06 T01] Stats vs. Raw Acc:
Madgwick dim=X: RMSE=4.792, Corr=0.786
Madgwick dim=Y: RMSE=5.558, Corr=0.925
Madgwick dim=Z: RMSE=6.510, Corr=0.567
Complementary dim=X: RMSE=7.380, Corr=0.926
Complementary dim=Y: RMSE=4.299, Corr=0.950
Complementary dim=Z: RMSE=4.826, Corr=0.879
Mahony dim=X: RMSE=4.903, Corr=0.788
Mahony dim=Y: RMSE=6.502, Corr=0.933
Mahony dim=Z: RMSE=5.470, Corr=0.701
EKF dim=X: RMSE=6.426, Corr=0.464
EKF dim=Y: RMSE=6.056, Corr=0.882
EKF dim=Z: RMSE=4.275, Corr=0.729
[INFO] S31A06T01 => EKF windows: 6
[DONE] S31A06T01 => Plots in visualizations/31/6
[INFO] Loaded ../data/smartfallmm/young/accelerometer/watch/S31A07T01.csv, shape=(211, 3)
[INFO] Loaded ../data/smartfallmm/young/gyroscope/watch/S31A07T01.csv, shape=(208, 3)

[S31 A07