In [130]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/cmi-detect-behavior-with-sensor-data/train_demographics.csv
/kaggle/input/cmi-detect-behavior-with-sensor-data/test_demographics.csv
/kaggle/input/cmi-detect-behavior-with-sensor-data/train.csv
/kaggle/input/cmi-detect-behavior-with-sensor-data/test.csv
/kaggle/input/cmi-detect-behavior-with-sensor-data/kaggle_evaluation/cmi_inference_server.py
/kaggle/input/cmi-detect-behavior-with-sensor-data/kaggle_evaluation/cmi_gateway.py
/kaggle/input/cmi-detect-behavior-with-sensor-data/kaggle_evaluation/__init__.py
/kaggle/input/cmi-detect-behavior-with-sensor-data/kaggle_evaluation/core/templates.py
/kaggle/input/cmi-detect-behavior-with-sensor-data/kaggle_evaluation/core/base_gateway.py
/kaggle/input/cmi-detect-behavior-with-sensor-data/kaggle_evaluation/core/relay.py
/kaggle/input/cmi-detect-behavior-with-sensor-data/kaggle_evaluation/core/kaggle_evaluation.proto
/kaggle/input/cmi-detect-behavior-with-sensor-data/kaggle_evaluation/core/__init__.py
/kaggle/input/cmi-detect-behav

In [131]:
import numpy as np
import warnings
import pandas as pd
from sklearn.preprocessing import LabelEncoder, StandardScaler
from IPython.display import display
from scipy.spatial.transform import Rotation as R
import os, joblib
import torch
import torch.nn.functional as F
from scipy.signal import find_peaks
from scipy.signal import butter, filtfilt
import pickle
from tqdm import tqdm
from sklearn.metrics import f1_score,  recall_score
import torch
import polars as pl
from pathlib import Path
import inspect
import psutil
from scipy.signal import welch
from scipy.stats import entropy

warnings.filterwarnings('ignore')

# =============================================================================
# CONFIGURATION --
# =============================================================================

class Config: 
    """Central configuration class for training and data parameters"""

    # Paths for Kaggle environment
    TRAIN_PATH = "/kaggle/input/cmi-detect-behavior-with-sensor-data/train.csv"
    TRAIN_DEMOGRAPHICS_PATH = "/kaggle/input/cmi-detect-behavior-with-sensor-data/train_demographics.csv"
    TEST_PATH = "/kaggle/input/cmi-detect-behavior-with-sensor-data/test.csv"
    TEST_DEMOGRAPHICS_PATH = "/kaggle/input/cmi-detect-behavior-with-sensor-data/test_demographics.csv"
    EXPORT_DIR =  "/kaggle/input/data-input"
    EXPORT_MODELS_PATH = "/kaggle/input/models" #"/kaggle/working/models"  
    EXPORT_MODELS_PATH_OUTPUT = "/kaggle/working/models" #"/kaggle/working/models"  
    os.makedirs(EXPORT_DIR, exist_ok=True)                                 
    os.makedirs(EXPORT_MODELS_PATH, exist_ok=True)                                 
    os.makedirs(EXPORT_MODELS_PATH_OUTPUT, exist_ok=True)     

    # Training parameters
    SEED = 42
    N_FOLDS = 5
    PERCENTILE = 95
    PADDING = 127
    
    # Feature columns
    ACC_COLS = ['acc_x', 'acc_y', 'acc_z']
    ROT_COLS = ['rot_w', 'rot_x', 'rot_y', 'rot_z']
    
# Set reproducibility
np.random.seed(Config.SEED)

def check_gpu_availability():

    import torch
    if torch.cuda.is_available():
        #print("MPS (Apple GPU) is available.")
        return  'cuda'#'cuda:0'
    else:
        #print("MPS not available. Using CPU.")
        return 'cpu'

# Check GPU availability
DEVICE = torch.device(check_gpu_availability())
print(DEVICE)

def clean_data(data_sequences, cols, prefix = 'both'):
    
    if prefix == 'both':
        print("removing tof and thm missing data columns from sequences! Saving seq_id in a dic with cols to remove")
        tof_and_thm_cols = [col for col in cols if (col.startswith('thm') or col.startswith('tof')) ]
    else: 
        print(f"removing {prefix} missing data columns from sequences! Saving seq_id in a dic with cols to remove")
        tof_and_thm_cols = [col for col in cols if col.startswith(prefix) ]

    tof_thm_nan_prefixes = {}
    for sequence_id, sequence_data in data_sequences:
        nan_cols = sequence_data[tof_and_thm_cols].columns[sequence_data[tof_and_thm_cols].isna().any()]
        if nan_cols.any():
            if (prefix == 'both' or prefix == 'tof'):
                prefixes = set(col.rsplit("_", 1)[0] for col in nan_cols if col.startswith('tof'))
            else:
                prefixes = set()
            if (prefix == 'both' or prefix == 'thm'):
                prefixes.update(set(col for col in nan_cols if col.startswith('thm')))

            tof_thm_nan_prefixes[sequence_id] = prefixes
            cols_to_drop = [col for col in sequence_data.columns if any(col.startswith(p) for p in prefixes)]
            sequence_data = sequence_data.drop(columns=cols_to_drop)
    print(f"found {len(tof_thm_nan_prefixes)} sequences with missing data")
    return data_sequences, tof_thm_nan_prefixes

def handle_missing_values_quaternions(quaternion):
    quat_clean = quaternion.copy()
    
    number_of_nan = quaternion.isna().sum(axis = 1)
    rows_with_0_nan = number_of_nan == 0
    rows_with_1_nan = number_of_nan == 1
    rows_with_N_nan = number_of_nan > 1

    ### normalize quaternions to 1 when no NaN has been detected 
    quat_values = quaternion.loc[rows_with_0_nan].values
    norms = np.linalg.norm(quat_values, axis = 1)
    normalized_quats = np.zeros_like(quat_values)
    ## for non-zero norm, normalize to 1  
    nonzero_norms = norms > 1e-6
    normalized_quats[nonzero_norms] = quat_values[nonzero_norms] / norms[nonzero_norms, np.newaxis]
    ## for zero-norm, normalize to the unit quaternion
    normalized_quats[~nonzero_norms] = [1.0, 0.0, 0.0, 0.0]
    ##update quaternion DataFrame
    quat_clean.loc[rows_with_0_nan] = normalized_quats

    ###handle 1 missing value 
    #use |w|² + |x|² + |y|² + |z|² = 1
    if len(quaternion[rows_with_1_nan].index.tolist()) > 0:
        nan_columns_per_row = quaternion[rows_with_1_nan].isna().idxmax(axis=1)
        unnorm_quat = quaternion[rows_with_1_nan].pow(2).sum(axis =1, skipna = True)
        vals = np.sqrt(np.maximum(0, 1 - unnorm_quat))
        for row, col, val in zip(unnorm_quat.index, nan_columns_per_row, vals):
            if row > 0:
                if quat_clean.loc[row - 1, col] >= 0:
                    quat_clean.loc[row, col] = val
                else:
                    quat_clean.loc[row, col] = -val
            else:
                next_row = row + 1
                # Go forward until a non-NaN is found or reach the end
                while next_row < len(quat_clean) and np.isnan(quat_clean.loc[next_row, col]):
                    next_row += 1
                if next_row == len(quat_clean):
                    quat_clean.loc[rows_with_1_nan] = [0, 0, 0, 0]
                    quat_clean.loc[rows_with_1_nan, 'rot_w'] = 1
                    break
                else:
                    if quat_clean.loc[next_row, col] >= 0:
                        quat_clean.loc[row, col] = val
                    else:
                        quat_clean.loc[row, col] = -val
    quat_clean.loc[rows_with_N_nan] = [0, 0, 0, 0]
    quat_clean.loc[rows_with_N_nan, 'rot_w'] = 1
    return quat_clean

def check_missing_values_quaternion(data_sequences):
    seq_id_quaternion_nan = []
    check_norm_quaternion = []
    for seq_id, data_sequence in data_sequences:
        quaternion_cols = [col for col in data_sequence.columns if col.startswith('rot_')]
        nan_quat_cols = data_sequence[quaternion_cols].columns[data_sequence[quaternion_cols].isna().any()]
        normalize_quat = data_sequence[quaternion_cols].pow(2).sum(axis = 1).mean()
        if nan_quat_cols.any():
            #print(data_sequence[[col for col in data_sequence.columns if col.startswith('acc_')]])
            seq_id_quaternion_nan.append(seq_id)
        if (not nan_quat_cols.any()) and normalize_quat < 0.99:
            check_norm_quaternion.append(seq_id)
    print(f"✓ number of seq_id with missing values in quaternion: {len(seq_id_quaternion_nan)}")
    print(f"✓ number of unnormalized quaternions for complete quaternions: {len(check_norm_quaternion)}")
    return seq_id_quaternion_nan


def regularize_quaternions_per_sequence(data_sequence):
    data_clean = data_sequence.copy()
    quaternion_cols = [col for col in data_sequence.columns if col.startswith('rot_')]
    nan_quat_cols = data_sequence[quaternion_cols].columns[data_sequence[quaternion_cols].isna().any()]
    normalize_quat = data_sequence[quaternion_cols].pow(2).sum(axis = 1).mean()  
    if nan_quat_cols.any():
        data_clean[quaternion_cols] = handle_missing_values_quaternions(data_sequence[quaternion_cols])
    if (not nan_quat_cols.any()) and normalize_quat < 0.99:
        data_clean[quaternion_cols] = handle_missing_values_quaternions(data_sequence[quaternion_cols])

    ### Check failed regularization
    nan_quat_cols_clean = data_clean[quaternion_cols].columns[data_clean[quaternion_cols].isna().any()]
    normalize_quat_clean = data_clean[quaternion_cols].pow(2).sum(axis = 1).mean() 
    if nan_quat_cols_clean.any():
        print("!!NaN values have been detected after regularisation!!")
    if (not nan_quat_cols_clean.any()) and normalize_quat_clean < 0.99:
        print("!!Not normalized quaternions have been detected after regularisation!!")
    return data_clean



def clean_and_check_quaternion(data):
    data_clean = data.copy()
    data_sequences = data_clean.groupby('sequence_id')
    seq_id_quaternion_nan = check_missing_values_quaternion(data_sequences)
    if len(seq_id_quaternion_nan) > 0:
        for seq_id in seq_id_quaternion_nan:
            data_sequence = data_sequences.get_group(seq_id)
            idx = data_sequence.index  # Get the index of the group
            quaternion_cols = [col for col in data_sequence.columns if col.startswith('rot_')]
            # Apply quaternion cleaning function
            data_clean.loc[idx, quaternion_cols] = handle_missing_values_quaternions(data_sequence[quaternion_cols])
    ##Check quaternion
        data_sequences = data_clean.groupby('sequence_id')
        print("")
        print(" --- missing values in quaternions have been handled ---")
        check_missing_values_quaternion(data_sequences)
        print("")
    return data_clean

def compute_acceleration_features(sequence_data, demographics):
    sequence_data_with_acc = sequence_data.copy()
    correct_rot_order = ['rot_x', 'rot_y', 'rot_z', 'rot_w']
    correct_acc_order = ['acc_x', 'acc_y', 'acc_z']
    col_acc_world = ['acc_x_world', 'acc_y_world', 'acc_z_world']
    col_linear_acc = ['linear_acc_x', 'linear_acc_y', 'linear_acc_z']
    col_X_world = ['X_world_x', 'X_world_y', 'X_world_z']
    col_Y_world = ['Y_world_x', 'Y_world_y', 'Y_world_z']
    col_Z_world = ['Z_world_x', 'Z_world_y', 'Z_world_z']
    remove_gravity = [0, 0, 9.81]
    
    data_rot = sequence_data[correct_rot_order]
    data_acc = sequence_data[correct_acc_order]
    sensor_x = np.zeros( data_acc.to_numpy().shape )
    sensor_y = np.zeros( data_acc.to_numpy().shape )
    sensor_z = np.zeros( data_acc.to_numpy().shape )
    sensor_x[:, 0] = 1
    sensor_y[:, 1] = 1
    sensor_z[:, 2] = 1
    data_rot_scipy = data_rot.to_numpy() 

    try:
        r = R.from_quat(data_rot_scipy)
        sequence_data_with_acc[col_acc_world] = pd.DataFrame(r.apply(data_acc.to_numpy()) - remove_gravity)
        sequence_data_with_acc[col_X_world] = pd.DataFrame(r.apply(sensor_x))
        sequence_data_with_acc[col_Y_world] = pd.DataFrame(r.apply(sensor_y))
        sequence_data_with_acc[col_Z_world] = pd.DataFrame(r.apply(sensor_z))
        
        gravity_in_sensor = r.apply(remove_gravity, inverse=True)
        acc_raw = sequence_data_with_acc[correct_acc_order].values
        linear_acc = acc_raw - gravity_in_sensor
        sequence_data_with_acc[col_linear_acc] = linear_acc

    except ValueError:
        print("Warning: world accelerations failed using device accelerations, replace by device acc data")
        sequence_data_with_acc[col_linear_acc] = sequence_data_with_acc[correct_acc_order]
        sequence_data_with_acc[col_acc_world] = sequence_data_with_acc[correct_acc_order]
        sequence_data_with_acc[col_X_world] = sequence_data_with_acc[correct_acc_order]
        sequence_data_with_acc[col_Y_world] = sequence_data_with_acc[correct_acc_order]
        sequence_data_with_acc[col_Z_world] = sequence_data_with_acc[correct_acc_order]

    sequence_data_with_acc['acc_norm_world'] =sequence_data_with_acc[col_acc_world].apply(np.linalg.norm, axis=1)
    sequence_data_with_acc['acc_norm'] =sequence_data_with_acc[correct_acc_order].apply(np.linalg.norm, axis=1)
    sequence_data_with_acc['linear_acc_norm'] =sequence_data_with_acc[col_linear_acc].apply(np.linalg.norm, axis=1)
    sequence_data_with_acc['acc_norm_jerk'] = sequence_data_with_acc['acc_norm'].diff().fillna(0)
    sequence_data_with_acc['linear_acc_norm_jerk'] =  sequence_data_with_acc['linear_acc_norm'].diff().fillna(0)

    subject = sequence_data['subject'].iloc[0]
    handedness = demographics[demographics['subject'] == subject]['handedness'].iloc[0] ## (0): left, (1): right
    if handedness == 0:
        sequence_data_with_acc['acc_x'] = - sequence_data_with_acc['acc_x'] #+ (-0.8526133780336856 + 0.3518238644621146)
        sequence_data_with_acc['linear_acc_x'] = - sequence_data_with_acc['linear_acc_x'] #+ (-0.8526133780336856 + 0.3518238644621146)

    return sequence_data_with_acc

def compute_angular_features(sequence_data, demographics, time_delta = 10):
    sequence_data_with_ang_vel = sequence_data.copy()
    correct_rot_order = ['rot_x', 'rot_y', 'rot_z', 'rot_w']
    quats = sequence_data[correct_rot_order].values

    rotations = R.from_quat(quats)
    rotvecs = rotations.as_rotvec()
    sequence_data_with_ang_vel[['rotvec_x', 'rotvec_y', 'rotvec_z']] = rotvecs
    sequence_data_with_ang_vel['angle_rad'] =  sequence_data_with_ang_vel[['rotvec_x', 'rotvec_y', 'rotvec_z']].apply(np.linalg.norm, axis=1)
    rot_diff = sequence_data_with_ang_vel[['rotvec_x', 'rotvec_y', 'rotvec_z']].diff().fillna(0)
    sequence_data_with_ang_vel['angular_speed'] = rot_diff.pow(2).sum(axis=1).pow(0.5)
    sequence_data_with_ang_vel['rot_angle'] = 2 * np.arccos(sequence_data['rot_w'].clip(-1, 1))
    sequence_data_with_ang_vel['rot_angle_vel'] = sequence_data_with_ang_vel['rot_angle'].diff().fillna(0)
    
    n_samples = quats.shape[0]
    ang_vel = np.zeros( (n_samples, 3))
    ang_dist = np.zeros(n_samples)

    for i in range(n_samples - 1):
        q1 = quats[i]
        q2 = quats[i + 1]

        if np.any(np.isnan(q1)) or np.any(np.isnan(q2)):
            continue

        try:
            r1 = R.from_quat(q1)
            r2 = R.from_quat(q2)

            # Relative rotation from q1 to q2
            delta_r = r1.inv() * r2

            # Angle of rotation (in radians)
            ang_vel[i, : ] =  delta_r.as_rotvec()/time_delta
            ang_dist[i] = np.linalg.norm(delta_r.as_rotvec())
        except ValueError:
            pass

    sequence_data_with_ang_vel[['ang_vel_x', 'ang_vel_y', 'ang_vel_z']] = ang_vel
    sequence_data_with_ang_vel['ang_dist'] = ang_dist

    subject = sequence_data['subject'].iloc[0]
    handedness = demographics[demographics['subject'] == subject]['handedness'].iloc[0] ## (0): left, (1): right
    if handedness == 0:
        sequence_data_with_ang_vel['rotvec_x'] = - sequence_data_with_ang_vel['rotvec_x'] #+ (-0.8526133780336856 + 0.3518238644621146)
        sequence_data_with_ang_vel['ang_vel_y'] = - sequence_data_with_ang_vel['ang_vel_y'] #+ (-0.8526133780336856 + 0.3518238644621146)
        sequence_data_with_ang_vel['ang_vel_z'] = - sequence_data_with_ang_vel['ang_vel_z'] #+ (-0.8526133780336856 + 0.3518238644621146)

    return sequence_data_with_ang_vel

def fft_gesture(signal):
    """
    Compute the normalized power in a band around a target frequency.

    Parameters:
    - signal: 1D array-like signal
    - freq: frequency of interest (Hz)
    - sampling_rate: sampling rate in Hz
    - bandwidth_ratio: fraction of freq to define integration window (e.g., 0.05 for ±5%)

    Returns:
    - normalized_band_power: power in [freq ± bandwidth] / total power
    """
    signal = np.asarray(signal)
    n = len(signal)
    #freqs = np.fft.rfftfreq(n, d=1./sampling_rate)
    fft_vals = np.fft.rfft(signal)
    power_spectrum = np.abs(fft_vals)**2 / n
    return power_spectrum / np.sum(power_spectrum)

def compute_fft_features(sequence_data):
    sequence_data_fft = sequence_data.copy()
    fft_to_compute = [
        'acc_x', 'acc_y', 'acc_z',
        'linear_acc_x', 'linear_acc_y', 'linear_acc_z',
        'rotvec_x', 'rotvec_y', 'rotvec_z',
        'ang_vel_x', 'ang_vel_y', 'ang_vel_z',
        'acc_norm', 'angle_rad'
    ]
    check_quat = ['rot_x', 'rot_y', 'rot_z']
    phase_gesture = sequence_data['phase_adj'] == 1    
    for feat in fft_to_compute:
        signal = sequence_data.loc[phase_gesture, feat].to_numpy()
        if sequence_data[check_quat].apply(np.linalg.norm, axis=1).mean() < 1e-6:
            signal_fft_pad = np.zeros_like(sequence_data[feat])
        else:
            signal_fft = fft_gesture( (signal - np.mean(signal))/np.std(signal) )
            signal_fft_pad =np.pad(signal_fft, (0, len(phase_gesture) - len(signal_fft)), 'constant')
        sequence_data_fft[f'{feat}_FFT'] = signal_fft_pad
    
    return sequence_data_fft

def get_angles(time_series, world_coord = False):
    theta, phi = [], []
    acc_features = ['acc_norm', 'acc_x', 'acc_y', 'acc_z']
    f_phi, f_theta = 'phi', 'theta'
    if world_coord:
        add_name = '_world'
        acc_features = [f + add_name for f in acc_features]
        f_phi, f_theta = f_phi + add_name, f_theta + add_name

    #numpy_time_series = time_series[acc_features].to_numpy()
    for a, ax, ay, az in zip(*time_series[acc_features].to_numpy().T):
        # Avoid division by zero
        # if a < 0:
        #     print(a)
        th = np.arccos(np.clip(az / (a + 1e-8), -1.0, 1.0))  # polar angle
        ph = np.arctan2(ay, ax)  # azimuthal angle
        theta.append(th)
        phi.append(ph)
    time_series[f_theta] = np.array(theta)
    time_series[f_phi] = np.array(phi)
    return time_series

def autocorr_frequency(signal, sampling_rate=1.0, min_lag=2, max_lag=None):
    """
    Estimate the dominant frequency in a signal using autocorrelation.

    Parameters:
    - signal: list or np.array of values
    - sampling_rate: Hz
    - min_lag: minimum lag to consider (to skip lag 0 and noise)
    - max_lag: optional max lag to consider

    Returns:
    - dominant_freq: float or None (in Hz)
    """
    signal = np.array(signal)
    if len(signal) < min_lag + 2:
        return 0.

    # Normalize and detrend
    signal = signal - np.mean(signal)
    autocorr = np.correlate(signal, signal, mode='full')
    autocorr = autocorr[len(autocorr)//2:]  # Keep only non-negative lags
    autocorr /= autocorr[0]  # Normalize

    # Define lag range to search
    if max_lag is None:
        max_lag = len(signal) #// 2

    search_range = autocorr[min_lag:max_lag]

    # Find peaks in the autocorrelation
    peaks, _ = find_peaks(search_range)

    if len(peaks) < 2:
        return 0.

    first_peak_lag = (peaks[-1] - peaks[0])/(len(peaks)-1)  # adjust for sliced lag
    period = first_peak_lag / sampling_rate
    freq = 1.0 / period

    return freq

def remove_frequency_component(signal, freq, sampling_rate, bandwidth=1.0, order=4):
    """
    Remove a specific frequency component using a Butterworth band-stop filter.

    Parameters:
    - signal: np.array of the signal values
    - freq: the target frequency to remove (Hz)
    - sampling_rate: the sampling rate of the signal (Hz)
    - bandwidth: the width of the stop band (Hz)
    - order: filter order (higher = steeper filter)

    Returns:
    - filtered_signal: the signal with the frequency component removed
    """
    bandwidth = 1 * freq
    nyquist = 0.5 * sampling_rate
    low = (freq - bandwidth / 2) / nyquist
    high = (freq + bandwidth / 2) / nyquist

    if low <= 0 or high >= 1:
        # Invalid range – don't apply filtering
        return signal.copy()

    # Create band-stop filter
    b, a = butter(order, [low, high], btype='bandstop')

    # Calculate required padding length
    padlen = 3 * max(len(a), len(b))

    if len(signal) <= padlen:
        # Too short for reliable filtering
        b, a = butter(order, [low, high], btype='bandstop')
        padlen = 3 * max(len(a), len(b))
        filtered_signal = filtfilt(b, a, signal, padlen=min(padlen, len(signal) - 1))
        # if len(signal) <= padlen:
        #     print("short")
        #     return signal.copy()
    else:
        filtered_signal = filtfilt(b, a, signal)

    return filtered_signal


def extract_freq_features(signal, fs=10):  # signal: [T, 3] for x,y,z IMU
    features = []
    for axis in range(signal.shape[1]):
        f, Pxx = welch(signal[:, axis], fs=fs, nperseg=fs)
        Pxx /= Pxx.sum()  # Normalize power spectrum

        centroid = np.sum(f * Pxx)
        entropy_val = entropy(Pxx)
        rolloff = f[np.where(np.cumsum(Pxx) >= 0.85)[0][0]]
        peak_freq = f[np.argmax(Pxx)]
        flatness = np.exp(np.mean(np.log(Pxx + 1e-8))) / (np.mean(Pxx) + 1e-8)

        features += [centroid, entropy_val, rolloff, peak_freq, flatness]

    return features  # [5 features x 3 axes = 15 features]

def sliding_window_freq_features(data_sequence, fs=10, window_size=10, stride=10):
    """
    data: (N, T, C) - batch of sequences
    returns: (N, T_new, F_freq)
    """
    names = ['centroid', 'entropy_val', 'rolloff', 'peak_freq', 'flatness']
    data_sequence_with_FFT = data_sequence.copy()
    signal = data_sequence[['acc_x', 'acc_y', 'acc_z']].to_numpy()
    T, _ = signal.shape
    features = []
    for i in range(0, T - window_size + 1, stride):
        window = signal[i:i+window_size, :]  # (N, w, C)
        f_list = extract_freq_features(window, fs=fs)  # for each sequence
        for j in range(window_size):
            features.append(f_list)  # (T, F_freq)
    print(np.array(features).shape)
    data_sequence_with_FFT[names] = np.array(features)
    # Stack over time: (T_new, N, F) → transpose → (N, T_new, F)
    return data_sequence_with_FFT


def compute_theta_phi_features(sequence_data):
    sequence_data_theta_phi = sequence_data.copy()
        
    sequence_data_theta_phi = get_angles(sequence_data_theta_phi)
    sequence_data_theta_phi = get_angles(sequence_data_theta_phi, world_coord=True)

    signal_phi = sequence_data_theta_phi['phi_world'].to_numpy()
    
    dym_zero_cross = np.zeros(len(signal_phi))
    window_size = 10
    for i in range(window_size, len(signal_phi)):
        window_phi = signal_phi[i-window_size: i]
        dym_zero_cross[i] = np.sum(np.diff(np.signbit(window_phi)).astype(int))

    sequence_data_theta_phi['zero_crossings_phi_dyn'] = dym_zero_cross
    return sequence_data_theta_phi

def compute_corr_and_svd_features(sequence_data):
    sequence_data_with_corr_and_svd = sequence_data.copy()

    svd_axis = [
        ['acc_x', 'acc_y', 'acc_z'],
        ['linear_acc_x', 'linear_acc_y', 'linear_acc_z'],
        ['rotvec_x', 'rotvec_y', 'rotvec_z'],
        ['ang_vel_x', 'ang_vel_y', 'ang_vel_z']
    ]
    corr_features = [
        ('acc_x', 'acc_y'),
        ('acc_x', 'acc_z'),
        ('acc_y', 'acc_z'),
        ('linear_acc_x', 'linear_acc_y'),
        ('linear_acc_x', 'linear_acc_z'),
        ('linear_acc_y', 'linear_acc_z'),
        ('ang_vel_x', 'ang_vel_y'),
        ('ang_vel_x', 'ang_vel_z'),
        ('ang_vel_y', 'ang_vel_z'),
        ('rotvec_x', 'rotvec_y'),
        ('rotvec_x', 'rotvec_z'),
        ('rotvec_y', 'rotvec_z'),
        ('acc_norm', 'angle_rad'),
        ('acc_x', 'rotvec_x'),
        ('acc_x', 'rotvec_y'),
        ('acc_x', 'rotvec_z'),
        ('acc_y', 'rotvec_x'),
        ('acc_y', 'rotvec_y'),
        ('acc_y', 'rotvec_z'),
        ('acc_z', 'rotvec_x'),
        ('acc_z', 'rotvec_y'),
        ('acc_z', 'rotvec_z'),
        ('theta', 'phi'),
        ('theta_world', 'phi_world')
    ]

    for main_axes in svd_axis:
        #svd_features = [f + '_svd' for f in main_axes]
        principal_axis_features = [f + '_contribution_main_axis' for f in main_axes]

        name = '_'.join(main_axes[0].split('_')[:-1])
        svd_ratio_features = [f'{name}_ratio_svd_{i}' for i in range(len(main_axes[1:]))]
        svd_features = [f'{name}_svd_{i}' for i in range(len(main_axes))]

        acc_vec = sequence_data[main_axes].to_numpy()

        window_size = 10
        sv =  np.zeros( (3, len(acc_vec)) )
        sv_ratio = np.zeros( (2, len(acc_vec)) )
        principal_axis = np.zeros( (3, len(acc_vec)) )
        for i in range(window_size, len(acc_vec)):
            window = acc_vec[i-window_size: i]
            U, S, Vt = np.linalg.svd(window - window.mean(axis = 0))
            principal_axis[:, i] =  Vt[0] ** 2
            sv[:, i] = S
            sv_ratio[0, i] = S[1]/S[0]
            sv_ratio[1, i] = S[2]/S[0]

        sequence_data_with_corr_and_svd[svd_features] = sv.T        
        sequence_data_with_corr_and_svd[principal_axis_features] = principal_axis.T
        sequence_data_with_corr_and_svd[svd_ratio_features] = sv_ratio.T 

    phase_transition = sequence_data['phase_adj'] == 0
    phase_gesture = sequence_data['phase_adj'] == 1

    #f_freq = [f for f in sequence_data.columns if ('acc' in f) or ('rotvec' in f) or ('angle_rad' in f) or ('phi' in f) or ('theta' in f) or ('ang_vel' in f)]
    f_freq = [f for f in sequence_data.columns if any(substr in f for substr in ['acc', 'rotvec', 'angle_rad', 'phi', 'theta', 'angle_vel'])]

    for f in f_freq:
        f0_series = np.zeros(len(sequence_data))
        f1_series = np.zeros(len(sequence_data))
        ratio_freq = np.zeros(len(sequence_data))
        if phase_gesture.sum() > 1:  
            extracted_sig = sequence_data.loc[phase_gesture, f]
            f0 = autocorr_frequency(extracted_sig, sampling_rate=10)
            if f0 > 0:
                residual = remove_frequency_component(extracted_sig, f0, sampling_rate=10)
                f1 = autocorr_frequency(residual, sampling_rate=10)
                ratio_freq[phase_gesture] = f1 / f0

            else:
                f1 = 0.
                ratio_freq[phase_gesture] = 0.

            f0_series[phase_gesture] = f0
            f1_series[phase_gesture] = f1
            
        else:
            f0_series[phase_gesture] = 0.
            f1_series[phase_gesture] = 0.
            ratio_freq[phase_gesture] = 0.

        sequence_data_with_corr_and_svd[f'{f}_f0'] = f0_series
        sequence_data_with_corr_and_svd[f'{f}_f1'] = f1_series
        sequence_data_with_corr_and_svd[f'{f}_ratio_freqs'] = ratio_freq


    for sig1, sig2 in corr_features:
        # Initialize correlation series
        corr_series = np.zeros(len(sequence_data))

        if phase_transition.sum() > 1:  
            corr_trans = sequence_data.loc[phase_transition, sig1].corr(sequence_data.loc[phase_transition, sig2])
            corr_series[phase_transition] = corr_trans
        else:
            corr_series[phase_transition] = 0. 

        if phase_gesture.sum() > 1:
            corr_gest = sequence_data.loc[phase_gesture, sig1].corr(sequence_data.loc[phase_gesture, sig2])
            corr_series[phase_gesture] = corr_gest
        else:
            corr_series[phase_gesture] = 0.

        # Save in your dataframe
        sequence_data_with_corr_and_svd[f'{sig1}_{sig2}_corr'] = corr_series

    return sequence_data_with_corr_and_svd

def add_gesture_phase(sequence_data):
    sequence_data_phase = sequence_data.copy()
    length_sequence = len(sequence_data)
    idx_transition = int( 0.45 * length_sequence)
    phase = np.zeros(length_sequence)
    phase[idx_transition:] = 1.
    sequence_data_phase['phase_adj'] = phase
    return sequence_data_phase

def manage_tof(sequence_data, demographics):
    sequence_data_tof = sequence_data.copy()
    #tof_col = []
    for i in range(1, 6):
        pixel_cols = [f for f in sequence_data.columns if f'tof_{i}' in f]
        tof_data = sequence_data[pixel_cols].replace(-1, np.nan)
        sequence_data_tof[f'tof_{i}_mean'] = sequence_data[pixel_cols].mean(axis = 1)
        sequence_data_tof[f'tof_{i}_std'] = sequence_data[pixel_cols].std(axis = 1)
        sequence_data_tof[f'tof_{i}_min'] = tof_data.min(axis = 1)
        sequence_data_tof[f'tof_{i}_max'] = tof_data.max(axis = 1)

    subject = sequence_data['subject'].iloc[0]
    handedness = demographics[demographics['subject'] == subject]['handedness'].iloc[0] ## (0): left, (1): right
    if handedness == 2:
        cols_tof_3 = [col for col in sequence_data.columns if 'tof_3' in col]
        cols_thm_3 = [col for col in sequence_data.columns if 'thm_3' in col]
        cols_tof_5 = [col for col in sequence_data.columns if 'tof_5' in col]
        cols_thm_5 = [col for col in sequence_data.columns if 'thm_5' in col]
        rename_dict = {}
        # TOF3 <-> TOF5
        for c3, c5 in zip(cols_tof_3, cols_tof_5):
            rename_dict[c3] = c5
            rename_dict[c5] = c3

        # THM3 <-> THM5
        for c3, c5 in zip(cols_thm_3, cols_thm_5):
            rename_dict[c3] = c5
            rename_dict[c5] = c3

        sequence_data_tof.rename(columns=rename_dict, inplace=True)

    return sequence_data_tof

# def add_correlations_tof_imu(sequence_data):

def split_into_transition_and_gesture_phases(sequence_data, meta_cols):
    sequence_data_split = sequence_data.copy()
    df_transition = sequence_data[sequence_data['phase_adj'] == 0].drop(columns='phase_adj')
    df_gesture = sequence_data[sequence_data['phase_adj'] == 1].drop(columns='phase_adj')

    # Rename columns
    df_transition = df_transition.add_suffix('_transition')
    df_gesture = df_gesture.add_suffix('_gesture')

    # Pad shorter DataFrame with NaNs to match the longer one
    max_len = max(len(df_transition), len(df_gesture))

    df_transition = df_transition.reset_index(drop=True).reindex(range(max_len))
    df_gesture = df_gesture.reset_index(drop=True).reindex(range(max_len))

    # Concatenate along columns
    df_combined = pd.concat([df_transition, df_gesture], axis=1)
    # Drop transition versions of meta columns
    df_combined.drop(columns=[col + '_transition' for col in meta_cols], inplace=True)
    # Rename gesture versions of meta columns back to original names
    df_combined.rename(columns={col + '_gesture': col for col in meta_cols}, inplace=True)

    sequence_data_split = df_combined.fillna(0)
    return sequence_data_split


def wrapper_data( TRAIN = True, split = False):
    if TRAIN:
        train_df = pd.read_csv(Config.TRAIN_PATH)
        train_demographics = pd.read_csv(Config.TRAIN_DEMOGRAPHICS_PATH)

        label_encoder = LabelEncoder()
        train_df['gesture_id'] = label_encoder.fit_transform(train_df['gesture'].astype(str))
        joblib.dump(label_encoder, os.path.join(Config.EXPORT_DIR, "label_encoder.pkl"))

        gesture_id_to_gestures = {idx: cl for idx, cl in enumerate(label_encoder.classes_)}

        gesture_to_seq_ids = (
            train_df.groupby('gesture_id')['sequence_id']
            .unique()
            .apply(list)
            .to_dict()
        )

        seq_type_to_seq_ids = (
            train_df.groupby('sequence_type')['sequence_id']
            .unique()
            .apply(list)
            .to_dict()
        )

        train_sequence_subject = {
            seq_id: sequence['subject'].iloc[0]
            for seq_id, sequence in train_df.groupby('sequence_id')
        }

        train_sequence_ids = sorted(train_df['sequence_id'].unique())


        train_cols = set(train_df.columns)


        # Group by sequence_id for training data - need to include gesture column for labels
        train_cols = train_cols + ['gesture_id'] if 'gesture_id' not in train_cols else train_cols

        print("Handle quaternion missing values in the train dataset...")
        train_df_clean = clean_and_check_quaternion(train_df[train_cols])


        train_sequences = train_df_clean.groupby('sequence_id')


        split_ids = {
            'classes': gesture_id_to_gestures,
            'train': {
                'train_sequence_ids': train_sequence_ids, ##List of all train ids
                'train_sequence_subject': train_sequence_subject, ##List of all train subject
                'gesture_to_seq_ids': gesture_to_seq_ids, ##dic by gesture
                'seq_type_to_seq_ids': seq_type_to_seq_ids ##dic by sequence_type
            },
        }
        # Save
        with open(os.path.join(Config.EXPORT_DIR, 'split_ids.pkl'), 'wb') as f:
            pickle.dump(split_ids, f)
        

        ### FEATURES ####
        meta_cols = sorted(['gesture', 'gesture_id', 'sequence_type', 'behavior', 'orientation',
                    'row_id', 'subject', 'phase', 'sequence_id', 'sequence_counter'])
        train_df_clean[meta_cols].to_csv( os.path.join(Config.EXPORT_DIR, 'train_metadata.csv' ))

        features_cols = [c for c in train_cols if c not in meta_cols]
        print("adding new features...")
        processed_sequences = []
        for _, data_sequence in tqdm(train_sequences, desc="Processing Sequences"):
            data_sequence = data_sequence.reset_index(drop=True)
            data_sequence = add_gesture_phase(data_sequence)
            data_sequence = compute_acceleration_features(data_sequence, train_demographics)
            data_sequence = compute_angular_features(data_sequence, train_demographics)
            #data_sequence = compute_fft_features(data_sequence)
            #data_sequence = compute_theta_phi_features(data_sequence) 
            #data_sequence = compute_corr_and_svd_features(data_sequence)
            data_sequence = manage_tof(data_sequence, train_demographics)

            if split:
                data_sequence = split_into_transition_and_gesture_phases(data_sequence, meta_cols)

            #print(data_sequence[['acc_x_transition', 'acc_x_gesture']])

            processed_sequences.append(data_sequence)
    
        train_df_clean = pd.concat(processed_sequences).sort_index()

        train_cols = train_df_clean.columns
        #new_features = [c for c in cols if c not in features_cols and c not in meta_cols]
        features_cols = [c for c in train_cols if c not in meta_cols]
        imu_cols  = sorted([c for c in features_cols if not (c.startswith('thm_') or c.startswith('tof_'))])
        tof_cols  = sorted([c for c in features_cols if c.startswith('tof_')])
        thm_cols  = sorted([c for c in features_cols if c.startswith('thm_')])

        fixed_order_features = np.concatenate( (imu_cols, thm_cols, tof_cols) )


        print(f"all features have been generated")
        # global scaler
        #features_to_exclude = [f for f in fixed_order_features if ('svd' in f) or ('contribution_main_axis' in f) or ('f0' in f)]  # for example
        features_to_exclude = [f for f in fixed_order_features if any(substr in f for substr in ['phase_adj'])]
        features_to_scale = [f for f in fixed_order_features if f not in features_to_exclude]
        print(features_to_scale)
        all_features = np.concatenate( (meta_cols, fixed_order_features) )
        
        for f in train_df_clean.columns:
            if f not in all_features:
                print(f)

        train_df_clean = train_df_clean[all_features]

        scaler = StandardScaler().fit(train_df_clean[features_to_scale].to_numpy())
        joblib.dump(scaler, os.path.join(Config.EXPORT_DIR, "scaler.pkl") )

        train_sequences = train_df_clean.groupby('sequence_id')
        print(train_df_clean.columns)

        cols = {
            #'train': train_cols,
            'meta': meta_cols,
            #'features': features_cols,
            'imu': imu_cols,
            'tof': tof_cols,
            'thm': thm_cols
        }
        with open(os.path.join(Config.EXPORT_DIR, 'cols.pkl'), 'wb') as f:
            pickle.dump(cols, f)


        X, y = build_train_test_data(train_sequences, cols)
        return X, y



def get_info(data_sequences, demograph, seq_id, print_data = False):
    # Filter rows with the given sequence_id
    #seq_id = 'SEQ_051475'
    sequence_data = data_sequences.get_group(seq_id)

    subject_id = sequence_data['subject'].iloc[0]
    subject_demographics = demograph[demograph['subject'] == subject_id]

    seq_info = sequence_data[
        ["sequence_id", "subject", "orientation", "gesture", "gesture_id", "sequence_type"]
    ].head(1).squeeze() 
    demo_info = subject_demographics[
        ["adult_child", "age", "sex", "handedness", 'height_cm', 'shoulder_to_wrist_cm', 'elbow_to_wrist_cm']
    ].head(1).squeeze()
    demo_info["adult_child"] = {0: "child", 1: "adult"}.get(demo_info["adult_child"], "unknown")
    demo_info["sex"] = {0: "female", 1: "male"}.get(demo_info["sex"], "unknown")
    demo_info["handedness"] = {0: "left-handed", 1: "right-handed"}.get(demo_info["handedness"], "unknown")
    combined = pd.concat([seq_info, demo_info])
    if print_data:
        display(combined.to_frame(name='Value'))
    return combined

def get_info_v2(demograph, seq_id, seq_id_to_subject, print_data = False):
    subject_id = seq_id_to_subject[seq_id]
    subject_demographics = demograph[demograph['subject'] == subject_id]

    demo_info = subject_demographics[
        ["adult_child", "age", "sex", "handedness", 'height_cm', 'shoulder_to_wrist_cm', 'elbow_to_wrist_cm']
    ].head(1).squeeze()
    demo_info["adult_child"] = {0: "child", 1: "adult"}.get(demo_info["adult_child"], "unknown")
    demo_info["sex"] = {0: "female", 1: "male"}.get(demo_info["sex"], "unknown")
    demo_info["handedness"] = {0: "left-handed", 1: "right-handed"}.get(demo_info["handedness"], "unknown")
    if print_data:
        display(demo_info.to_frame(name='Value'))
    return demo_info


def pad_and_truncate(X_batch, maxlen, padding_value=0.0, dtype=torch.float32):
    padded_batch = []
    for seq in X_batch:
        seq = torch.tensor(seq, dtype=dtype)
        length = seq.size(0)

        # Truncate
        if length > maxlen:
            seq = seq[:maxlen]
        # Pad
        elif length < maxlen:
            pad_len = maxlen - length
            padding = torch.full((pad_len, *seq.shape[1:]), padding_value, dtype=dtype)
            seq = torch.cat([seq, padding], dim=0)

        padded_batch.append(seq)

    return torch.stack(padded_batch)  # [batch_size, maxlen, features]

def build_train_test_data(data_sequences, cols, mask_gesture = False):
    X_batch, y_batch, len_seq = [], [], []
    features = np.concatenate( (cols['imu'], cols['thm'], cols['tof']) )
    features_to_exclude = [f for f in features if any(substr in f for substr in ['phase_adj'])]
    features_to_scale = [f for f in features if f not in features_to_exclude]

    idx_to_scale = np.where(np.isin(features, features_to_scale))[0]
    #idx_to_exclude = np.where(np.isin(features, features_to_exclude))[0]

    seq_ids = []
    for seq_id, data_sequence in data_sequences:
        if mask_gesture:
            gesture_phase = data_sequence['phase'] == 'Gesture'
            sequence = data_sequence[features][gesture_phase]
        else:
            sequence = data_sequence[features]
        
        sequence = sequence.to_numpy()

        # Fit and transform only those columns
        scaler = joblib.load( os.path.join(Config.EXPORT_DIR, "scaler.pkl") )
        if len(sequence) > 0:
            sequence[:, idx_to_scale] =  scaler.transform(sequence[:, idx_to_scale])

        #print(sequence[['linear_acc_ratio_svd_0', 'linear_acc_ratio_svd_1']])
        #cols_to_scale = [c for c in cols['imu'] if c.startswith('acc_')]
        #sequence[cols_to_scale] = scaler.fit_transform(sequence[cols_to_scale])

        X_batch.append(sequence)
        seq_ids.append(seq_id)
        y_batch.append(data_sequence['gesture_id'].iloc[0])
        len_seq.append(len(sequence))

    ### labels one-hot categorical ###
    y_final = torch.tensor(y_batch)
    #y_final = F.one_hot(y_torch, num_classes = num_classes).float()
    
    ### pad and truncate sequences to the 95 percentile
    pad_len_seq = int(np.percentile(len_seq, Config.PERCENTILE))
    X_final = pad_and_truncate(X_batch, maxlen=pad_len_seq)

    return X_final, y_final #, seq_ids


### COMPETITION METRIC ###

def competition_metric(y_true, y_pred) -> tuple:
    """Calculate the competition metric (Binary F1 + Macro F1) / 2"""
    BFRB_gesture = [0, 1, 3, 4, 6, 7, 9, 10]
    #non_BFRB_gesture = [2, 5, 8, 11, 12, 13, 14, 15, 16, 17]
     
    # Binary F1: BFRB vs non-BFRB
    binary_f1 = f1_score(
        np.where(np.isin(y_true, BFRB_gesture), 1, 0),
        np.where(np.isin(y_pred, BFRB_gesture), 1, 0),
        zero_division=0.0,
    )

    binary_recall =  recall_score(
        np.where(np.isin(y_true, BFRB_gesture), 1, 0),
        np.where(np.isin(y_pred, BFRB_gesture), 1, 0),
        zero_division=0.0,
    )
    
    # Macro F1: specific gesture classification (only for BFRB gestures)
    macro_f1 = f1_score(
        np.where(np.isin(y_true, BFRB_gesture), y_true, 99),  # Map non-BFRB to 99
        np.where(np.isin(y_pred, BFRB_gesture), y_pred, 99),  # Map non-BFRB to 99
        average="macro", 
        zero_division=0.0,
    )
    
    # Final competition score
    final_score = 0.5 * (binary_f1 + macro_f1)
    
    return final_score, binary_recall, macro_f1

def reset_seed(seed=42):
    np.random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

cuda


In [132]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset
from scipy.spatial.transform import Rotation as R
import glob
from collections import Counter


class Conv1DAutoencoder(nn.Module):
    def __init__(self, input_channels, hidden_dim = 16, latent_dim=32, drop = 0.3):
        super().__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv1d(input_channels, 4 * hidden_dim, kernel_size=5, stride=4, padding=2),  # -> (B, 32, L/2)
            nn.ReLU(),
            nn.Dropout(p = drop),
            nn.Conv1d(4 * hidden_dim, 2 * hidden_dim, kernel_size=5, stride=4, padding=2),           # -> (B, 64, L/4)
            nn.ReLU(),
            nn.Dropout(p = drop),
            nn.Conv1d(2* hidden_dim, hidden_dim, kernel_size=5, stride=2, padding=2),          # -> (B, 128, L/8)
            nn.ReLU(),
            nn.Dropout(p = drop),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, stride=2, padding=2),              #nn.AdaptiveAvgPool1d(1),                                         # -> (B, 128, 1)
        )
        self.latent = nn.Linear(28 * hidden_dim, latent_dim)

        # Decoder
        self.decoder_fc = nn.Linear(latent_dim, 28 * hidden_dim) # (B, 128)
        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(hidden_dim, hidden_dim, kernel_size=5, stride=2, padding=2, output_padding=1),  #(B, 64, 11)
            nn.ReLU(),
            nn.Dropout(p = drop),
            nn.ConvTranspose1d(hidden_dim, 2 * hidden_dim, kernel_size=5, stride=2, padding=2, output_padding=1),   # (B, 64, 41)
            nn.ReLU(),
            nn.Dropout(p = drop),
            nn.ConvTranspose1d(2 * hidden_dim, 4 * hidden_dim, kernel_size=5, stride=4, padding=2),
            nn.ReLU(),
            nn.Dropout(p = drop),
            nn.ConvTranspose1d(4 * hidden_dim, input_channels, kernel_size=5, stride=4, padding=2, output_padding=1),
            #nn.Tanh()  # Assuming normalized input
        )
        self.hidden_dim = hidden_dim

    def forward(self, x):
        z = self.encoder(x)
        z = z.reshape(z.shape[0], -1)
        z = self.latent(z)
        x_recon = self.decoder_fc(z)
        x_recon = x_recon.unsqueeze(-1).reshape(-1, self.hidden_dim, 28)
        #print(x_recon.shape)
        x_recon = self.decoder(x_recon)
        return x_recon


class LSTMWithAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True, bidirectional=False)
        self.attn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.Dropout(p = 0.3),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        # x: [B, T, F]
        lstm_out, _ = self.lstm(x)  # lstm_out: [B, T, H]

        attn_scores = self.attn(lstm_out).squeeze(-1)  # [B, T]
        attn_weights = torch.softmax(attn_scores, dim=1)  # [B, T]

        # Weighted sum
        context = torch.sum(lstm_out * attn_weights.unsqueeze(-1), dim=1)  # [B, H]

        return context



class AttentionPooling(nn.Module):
    def __init__(self, hidden_dim, bias_strength = 5.):
        super().__init__()
        # self.attn = nn.Sequential(
        #     nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1),
        #     nn.Tanh(),
        #     nn.Conv1d(hidden_dim, 1, kernel_size=1)
        # )
        # self.attn = nn.Sequential(
        #     nn.Conv1d(hidden_dim, 1, kernel_size=1),  # [B, 1, T]
        #     nn.Softmax(dim=-1)
        # )
        self.attn = nn.Sequential(
            nn.Linear(hidden_dim, 1),  # [B, 1, T]
            nn.Tanh()
        )
        self.bias_strength = bias_strength
        self.weights = None

    def forward(self, x, phase_adj = None):
        # x: [B, hidden_dim, T]
        scores = self.attn(x.permute(0, 2, 1)).squeeze(-1)  # [B, T]

        if phase_adj is not None:
            #bias = (phase_adj.float() * self.bias_strength)  # [B, T]
            scores = scores #+ bias

        weights = F.softmax(scores, dim=1)  # [B, T]
        #weights = scores
        self.weights = weights
        pooled = torch.sum(x * weights.unsqueeze(1), dim=2)  # [B, hidden_dim]
        return pooled

class IMUEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_dim, hidden_dim, kernel_size=3, padding=1),
            #nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True), #inplace=True
            nn.BatchNorm1d(hidden_dim),
            nn.Dropout(p=0.3), 
            #nn.MaxPool1d(kernel_size=2, stride=2),  # halves time length
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
            #nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(hidden_dim),
            #nn.Dropout(p=0.2), 
            #nn.MaxPool1d(kernel_size=2, stride=2),   # halves again → total /4
        )

    def forward(self, x):
        # x: [B, T, input_dim] → [B, input_dim, T]
        x = x.permute(0, 2, 1)
        out = self.net(x)  # [B, hidden_dim, T]
        return out

class OptionalEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, norm = True):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(input_dim, hidden_dim, kernel_size=3, padding=1),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            #nn.BatchNorm1d(hidden_dim),
            nn.Dropout(p=0.3),
            #nn.MaxPool1d(kernel_size=2, stride=2),
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            #nn.BatchNorm1d(hidden_dim),
            #nn.Dropout(p=0.2),
            #nn.MaxPool1d(kernel_size=2, stride=2)
        )
        self.norm = norm

    def forward(self, x, mask):
        # x: [B, T, input_dim] → [B, input_dim, T]
        x = x.permute(0, 2, 1)
        out = self.net(x)  # [B, hidden_dim, T/4]

        # Adjust mask accordingly by downsampling (average pooling)
        # mask: [B, T]
        #mask = mask.unsqueeze(1).float()  # [B, 1, T]
        #mask = F.avg_pool1d(mask, kernel_size=2, stride=2)  # [B, 1, T/2]
        #mask = F.avg_pool1d(mask, kernel_size=2, stride=2)  # [B, 1, T/4]
        #mask = mask.squeeze(1)  # [B, T/4]

        #out = out * mask.unsqueeze(1)  # [B, hidden_dim, T]

        #Normalize by sum of mask per timestep (avoid div zero)
        if self.norm: 
            norm_mask = mask.sum(dim=1, keepdim=True).clamp(min=1e-6)  # [B, 1]
            out = out / norm_mask.unsqueeze(1)  # broadcast on hidden_dim

        return out

class TabularEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim)
        )
    def forward(self, x, seq_len):
        # x: [B, n_feats]
        emb = self.net(x)  # [B, hidden_dim]
        # Expand along time dimension to [B, hidden_dim, seq_len]
        emb = emb.unsqueeze(2).expand(-1, -1, seq_len)
        return emb
    

class TOFEncoder3DWithSpatialAttention(nn.Module):
    def __init__(self, in_channels=5, out_channels=64, hidden_dim=128, H=8, W=8):
        super().__init__()

        # 3D CNN to process [B, 5, T, 8, 8]
        self.conv3d = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv3d(out_channels, hidden_dim, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.BatchNorm3d(hidden_dim),
            nn.ReLU(inplace=True),
        )

        # Spatial attention: per pixel over each 8x8 grid at each time step
        self.spatial_attn = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim // 2, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_dim // 2, 1, kernel_size=1),  # attention logits per pixel
        )

        self.H = H
        self.W = W

    def forward(self, x):
        """
        x: [B, 5, T, 8, 8]
        returns: [B, hidden_dim, T] (aggregated per timestep)
        """
        B, C, T, H, W = x.shape

        # Apply 3D conv
        feat = self.conv3d(x)  # [B, hidden_dim, T, H, W]

        # Reshape for spatial attention per time step
        feat_reshaped = feat.permute(0, 2, 1, 3, 4).contiguous()  # [B, T, C, H, W]
        feat_flat = feat_reshaped.view(B * T, -1, H, W)  # [B*T, C, H, W]

        # Spatial attention scores
        attn_logits = self.spatial_attn(feat_flat)  # [B*T, 1, H, W]
        attn_scores = F.softmax(attn_logits.view(B * T, -1), dim=-1).view(B * T, 1, H, W)  # [B*T, 1, H, W]

        # Apply attention
        weighted_feat = feat_flat * attn_scores  # [B*T, C, H, W]
        aggregated = weighted_feat.view(B, T, -1, H * W).sum(dim=-1)  # [B, T, C]

        # Transpose to [B, C, T] to match other branches
        aggregated = aggregated.permute(0, 2, 1)

        return aggregated  # [B, hidden_dim, T]

class TOFEncoder(nn.Module):
    def __init__(self, hidden_dim, C, H, W, C_TOF_RAW = False, norm = False):
        super().__init__()
        if C_TOF_RAW:
            self.tof_spatial_weight = nn.Parameter(torch.ones(1, 5, H, W))  # Learnable
        else:
            self.tof_spatial_weight = nn.Parameter(torch.ones(1, 1, H, W))  # Learnable

        self.spatial_pool = nn.AdaptiveAvgPool2d( 1 )  # or MaxPool2d or Flatten
        self.tof_post = nn.Sequential(
            nn.Linear(C , hidden_dim),  # or Conv1D
            nn.ReLU(),
        )
        self.H = H
        self.W = W
        self.C = C

        self.norm = norm
    
    def forward(self, tof_raw, mask_ones):
        B, T, _ = tof_raw.shape

        tof_raw =  tof_raw.reshape(B, T, self.C, self.H, self.W) #[B, T, 5, 8, 8] 
        #tof_raw = tof_raw.permute(0, 2, 1, 3, 4)                

        tof_raw_weighted = (tof_raw * self.tof_spatial_weight)#.view(-1, self.C, self.H, self.W)  # Broadcasting over batch and channel
        #print(tof_raw_weighted.shape)
        pooled = self.spatial_pool(tof_raw_weighted)#.view(B, T, -1)  # [B, T, 5 * 2 * 2]
        #print(pooled.shape)
        pooled = pooled.squeeze(-1).squeeze(-1)#.permute(0, 2, 1)  
        tof_raw_feat = self.tof_post(pooled)  # [B, T, hidden_dim]  

        if self.norm:
            norm_mask = mask_ones.sum(dim=1, keepdim=True).clamp(min=1e-6)  # [B, 1]
            tof_raw_feat = tof_raw_feat / norm_mask.unsqueeze(1)  # broadcast on hidden_dim

        return tof_raw_feat.permute(0, 2, 1) # [B, hidden_dim, T]  


class TOFEncoderTemporalBeforePool(nn.Module):
    def __init__(self, hidden_dim, C, H, W):
        super().__init__()

        self.C = C
        self.H = H
        self.W = W

        self.temporal_attn = nn.Sequential(
            nn.Conv2d(C, 1, kernel_size=1),  # [B*T, 1, H, W]
            nn.Sigmoid()
        )

        self.spatial_pool = nn.AdaptiveAvgPool2d(1)

        self.tof_post = nn.Sequential(
            nn.Linear(C, hidden_dim),
            nn.ReLU()
        )

    def forward(self, tof_raw):
        B, T, _ = tof_raw.shape
        tof_raw = tof_raw.view(B, T, self.C, self.H, self.W)  # [B, T, C, H, W]
        tof_raw_2d = tof_raw.view(B*T, self.C, self.H, self.W)  # Merge batch & time

        # Compute per-frame spatial attention weights
        attn_maps = self.temporal_attn(tof_raw_2d)  # [B*T, 1, H, W]
        tof_weighted = tof_raw_2d * attn_maps  # Apply attention
        tof_weighted = tof_weighted.view(B, T, self.C, self.H, self.W)

        # Pool and project
        pooled = self.spatial_pool(tof_weighted).squeeze(-1).squeeze(-1)  # [B, T, C]
        tof_feat = self.tof_post(pooled)  # [B, T, hidden_dim]

        return tof_feat.permute(0, 2, 1)  # [B, hidden_dim, T]


class GatedFusion(nn.Module):
    def __init__(self, hidden_dim, num_modalities):
        super().__init__()
        self.gate = nn.Linear(hidden_dim * num_modalities, num_modalities)

    def forward(self, features_list):
        # features_list: list of [B, hidden_dim, T]
        concat = torch.cat(features_list, dim=1)  # [B, hidden_dim * M, T]
        concat_t = concat.permute(0, 2, 1)        # [B, T, hidden_dim * M]
        gate_weights = torch.sigmoid(self.gate(concat_t))  # [B, T, M]

        gated_feats = []
        for i, f in enumerate(features_list):
            f_t = f.permute(0, 2, 1)  # [B, T, hidden_dim]
            w = gate_weights[:, :, i].unsqueeze(-1)  # [B, T, 1]
            gated_feats.append(f_t * w)
        fused = sum(gated_feats)  # [B, T, hidden_dim]
        return fused.permute(0, 2, 1)  # [B, hidden_dim, T]

class AttentionFusion(nn.Module):
    def __init__(self, hidden_dim, num_modalities=3):
        super().__init__()
        self.query = nn.Linear(hidden_dim, hidden_dim)
        self.keys = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_modalities)])
        self.values = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_modalities)])
        self.scale = hidden_dim ** 0.5

    def forward(self, inputs):
        """
        inputs: list of [B, hidden_dim, T] tensors
        output: [B, hidden_dim, T] fused tensor
        """
        # Compute shared query
        #stacked_inputs = torch.stack(inputs, dim=1)  # [B, M, hidden_dim, T]
        #B, M, C, T = stacked_inputs.shape

        query = self.query(inputs[0].permute(0, 2, 1))  # [B, T, hidden_dim]
        keys = [key(x.permute(0, 2, 1)) for key, x in zip(self.keys, inputs)]   # each: [B, T, hidden_dim]
        values = [val(x.permute(0, 2, 1)) for val, x in zip(self.values, inputs)]  # each: [B, T, hidden_dim]

        key_tensor = torch.stack(keys, dim=1)     # [B, M, T, hidden_dim]
        value_tensor = torch.stack(values, dim=1) # [B, M, T, hidden_dim]

        # Attention: dot product over modalities
        attn_scores = torch.einsum('btc,bmtc->btm', query, key_tensor) / self.scale  # [B, T, M]
        attn_weights = F.softmax(attn_scores, dim=-1)  # [B, T, M]

        # Weighted sum over modalities
        fused = torch.einsum('btm,bmtc->btc', attn_weights, value_tensor)  # [B, T, hidden_dim]
        return fused.permute(0, 2, 1)  # back to [B, hidden_dim, T]

    
class GlobalGestureClassifier(nn.Module):
    def __init__(self, 
                 imu_dim, 
                 hidden_dim, 
                 num_classes, 
                 thm_tof_dim = None, 
                 tof_raw_dim = None, 
                 C_TOF_RAW = False,
                 norm_TOF_THM = True,
                 norm_TOF_RAW = False,
                 attention_for_fusion = True,
                 attention_pooled = True,
                 ): # tabular_dim = None
        super().__init__()

        if thm_tof_dim is None:
            thm_tof_dim = imu_dim
        
        self.thm_tof_dim = thm_tof_dim
        self.tof_raw_dim = tof_raw_dim

        self.attention_pooled = attention_pooled
        self.attention_for_fusion = attention_for_fusion

        self.imu_encoder = IMUEncoder(imu_dim, hidden_dim)
        if self.attention_pooled:
            self.attn_pool = AttentionPooling(hidden_dim)

        self.thm_tof_encoder = OptionalEncoder(thm_tof_dim, hidden_dim, norm=norm_TOF_THM)
        self.tof_pixels = TOFEncoder(hidden_dim, C = 4, H = 8, W = 8, C_TOF_RAW=C_TOF_RAW, norm=norm_TOF_RAW)
        #self.tof_pixels = TOFEncoderTemporalBeforePool(hidden_dim, C = 5, H = 8, W = 8)
#         self.tof_spatial_weight = nn.Parameter(torch.ones(1, 1, 8, 8))  # Learnable
#         self.spatial_pool = nn.AdaptiveAvgPool2d(1)  # or MaxPool2d or Flatten
#         self.tof_post = nn.Sequential(
#             nn.Linear(5, hidden_dim),  # or Conv1D
#             nn.ReLU(),
# )

        self.gated_fusion = GatedFusion(hidden_dim, num_modalities=3)
        if self.attention_for_fusion:
            self.attention_fusion = AttentionFusion(hidden_dim, num_modalities=3)
            self.alpha = nn.Parameter(torch.tensor(0.0))  # sigmoid(0) = 0.5
        
        self.final_fusion = nn.Sequential(
            nn.Linear(2 * hidden_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3)
        )
        #self.bilstm = nn.LSTM(hidden_dim * 2, hidden_dim, bidirectional=True, batch_first=True)
        #self.classifier_rnn = nn.GRU(hidden_dim, hidden_dim, batch_first=True, bidirectional=True)

        self.classifier_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(p = 0.3),
            nn.Linear(hidden_dim, num_classes),
            #nn.Softmax()
        )

        self.imu_head = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Dropout(p = 0.3),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, imu, thm_tof=None, tof_raw = None): #, tabular_feats=None
        B, T, _ = imu.shape

        imu_feat = self.imu_encoder(imu)  # [B, hidden_dim, T]
        pooled_imu = imu_feat.mean(dim= 2)
        logits_imu = self.imu_head(imu_feat)

        if thm_tof is None:
            thm_tof = torch.zeros_like(imu)
            tof_mask = torch.zeros(B, T, device=imu.device)
        else:
            tof_mask = torch.ones( B, T, device = thm_tof.device ) # (~torch.isnan(thm_tof).any(dim=2)).float() #(thm_tof != 0).float() #

        if tof_raw is None:
            tof_raw = torch.zeros_like(imu)
            tof_raw_mask = torch.zeros(B, T, device=imu.device)
        else:
            tof_raw_mask = torch.ones( B, T, device = thm_tof.device )


        thm_tof_feat = self.thm_tof_encoder(thm_tof, tof_mask)  # [B, hidden_dim, T]
        tof_raw_feat = self.tof_pixels(tof_raw, tof_raw_mask) # [B, hidden_dim, T]

        gated =  self.gated_fusion([imu_feat, thm_tof_feat, tof_raw_feat])  # [B, hidden_dim, T]

        if self.attention_for_fusion:
            attn =  self.attention_fusion([imu_feat, thm_tof_feat, tof_raw_feat])  # [B, hidden_dim, T]
            alpha = torch.sigmoid(self.alpha).view(1, -1, 1)  # shape [1, C, 1] 
            fused = alpha * gated + (1 - alpha) * attn  # broadcast over B, T
        else:
            fused = gated

        if self.attention_pooled:     
            pooled_fused= self.attn_pool(fused) #fused.mean(dim= 2) #
        else:
            pooled_fused= fused.mean(dim= 2) #fused.mean(dim= 2) #


        pooled = self.final_fusion(torch.cat([pooled_imu, pooled_fused], dim=1))
        #pooled = torch.cat([pooled_imu, pooled_fused], dim=1)
        out = self.classifier_head(pooled)  # [B, num_classes]

        return out, logits_imu   

class MiniGestureClassifier(nn.Module):
    def __init__(self, imu_dim, hidden_dim, num_classes):
        super().__init__()

        self.imu_encoder = IMUEncoder(imu_dim, hidden_dim)
        self.attn_pool = AttentionPooling(hidden_dim)
    

        self.classifier_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(p = 0.3),
            nn.Linear(hidden_dim, num_classes),
            #nn.Softmax()
        )

    def forward(self, imu, return_attention=False): #, phase_adj = None,
        B, T, _ = imu.shape

        imu_feat = self.imu_encoder(imu)  # [B, hidden_dim, T]

        pooled = imu_feat.mean(dim=2)#   # [B, hidden_dim]

        #pooled = self.attn_pool(imu_feat)

        out = self.classifier_head(pooled)  # [B, num_classes]

        return (out, self.attn_pool.weights) if return_attention else out
    
class MiniGestureLSTMClassifier(nn.Module):
    def __init__(self, imu_dim, imu_dim_lstm, hidden_dim, lstm_hidden_dim, num_classes):
        super().__init__()

        self.imu_encoder = IMUEncoder(imu_dim, hidden_dim)
        self.lstm_attn = LSTMWithAttention(imu_dim_lstm, lstm_hidden_dim)
        
        fused_dim = hidden_dim + lstm_hidden_dim

        self.classifier_head = nn.Sequential(
            nn.Linear(fused_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(p = 0.3),
            nn.Linear(hidden_dim, num_classes),
            #nn.Softmax()
        )

    def forward(self, imu): #, phase_adj = None,
        #B, T, _ = imu.shape

        imu_cnn_out = self.imu_encoder(imu)  # [B, hidden_dim, T]
        imu_pooled = imu_cnn_out.mean(dim=2) # [B, hidden_dim]
        imu_lstm_out = self.lstm_attn(imu)  # [B, H]

        fused = torch.cat([imu_pooled, imu_lstm_out], dim=1)  # [B, hidden_dim + H]
        out = self.classifier_head(fused)

        return out
    

class EarlyStopping:
    def __init__(self, patience=5, mode='max', restore_best_weights=True, verbose=False, logger = None):
        self.patience = patience
        self.mode = mode
        self.restore_best_weights = restore_best_weights
        self.verbose = verbose
        self.best_score = None
        self.counter = 0
        self.early_stop = False
        self.best_model_state = None
        self.logger = logger

    def __call__(self, current_score, model):
        if self.mode == 'max':
            score_improved = self.best_score is None or current_score > self.best_score
        else:  # 'min'
            score_improved = self.best_score is None or current_score < self.best_score

        if score_improved:
            self.best_score = current_score
            self.counter = 0
            if self.restore_best_weights:
                self.best_model_state = model.state_dict()
            if self.verbose:
                if self.logger is not None:
                    self.logger.info(f"EarlyStopping: Improvement found, saving model with score {current_score:.4f}")
                else:
                    print(f"EarlyStopping: Improvement found, saving model with score {current_score:.4f}")
        else:
            self.counter += 1
            if self.verbose:
                if self.logger is not None:
                    self.logger.info(f"EarlyStopping: No improvement for {self.counter} epoch(s)")
                else:
                    print(f"EarlyStopping: No improvement for {self.counter} epoch(s)")
            if self.counter >= self.patience:
                self.early_stop = True
                if self.verbose:
                    if self.logger is not None:
                        self.logger.info("EarlyStopping: Stopping early.")
                    else:
                        print("EarlyStopping: Stopping early.")
                if self.restore_best_weights and self.best_model_state is not None:
                    model.load_state_dict(self.best_model_state)



class SensorDataset(Dataset):
    def __init__(self, X, y, imu_dim, alpha = 0., augment = None, training = True):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
        self.alpha = alpha
        self.augment = augment
        self.training = training
        self.imu_dim = imu_dim

    def __len__(self):
        return len(self.X) 

    def __getitem__(self, idx):
        x1, y1 = self.X[idx], self.y[idx]
        y1_onehot = torch.nn.functional.one_hot(y1, num_classes=18).float()

        if self.training and self.augment:
            x1 = x1.numpy().copy()
            x1 = self.augment(x1, imu_dim = self.imu_dim)
            x1 = torch.tensor(x1,  dtype=torch.float32)
            
        if self.alpha > 1e-6:
            rand_idx = np.random.randint(0, len(self.X) - 1)
            x2, y2 = self.X[rand_idx], self.y[rand_idx]

            if self.training and self.augment:
                x2 = x2.numpy().copy()
                x2 = self.augment(x2, imu_dim=self.imu_dim)
                x2 = torch.tensor(x2, dtype=torch.float32)

            y2_onehot = torch.nn.functional.one_hot(y2, num_classes=18).float()

            # Generate lambda from Beta distribution and ensure alpha > 0.
            lam = np.random.beta(self.alpha, self.alpha)
            lam = max(0, min(1, lam))

            x1 = lam * x1 + (1 - lam) * x2
            y1_onehot = lam * y1_onehot + (1 - lam) * y2_onehot
        
        return x1, y1_onehot

class TrackingSampler(torch.utils.data.Sampler):
    def __init__(self, base_sampler):
        self.base_sampler = base_sampler
        self.sampled_indices = []

    def __iter__(self):
        self.sampled_indices = list(self.base_sampler)  # Store for external access
        return iter(self.sampled_indices)

    def __len__(self):
        return len(self.base_sampler)

class DeviceRotationAugment:
    def __init__(self,
                X, y, seqs,       
                seqs_by_subject,
                selected_features,
                x_rot_range = (0, 30), # (0, 45)
                y_rot_range = (0, 30), # (0, 45)
                p_rotation = 0.4,
                small_rotation = 2
                ):     
        
        self.features_to_rotate = [
        ['acc_x', 'acc_y', 'acc_z'],
        ['acc_x_world', 'acc_y_world', 'acc_z_world'],
        ['linear_acc_x', 'linear_acc_y', 'linear_acc_z'],
        ['rotvec_x', 'rotvec_y', 'rotvec_z'],
        ['ang_vel_x', 'ang_vel_y', 'ang_vel_z'],
        ['X_world_x', 'X_world_y', 'X_world_z'], 
        ['Y_world_x', 'Y_world_y', 'Y_world_z'],
        ['Z_world_x', 'Z_world_y', 'Z_world_z'],
        ['rot_x', 'rot_y', 'rot_z', 'rot_w']
        ]
        
        self.seqs_by_subject = seqs_by_subject 
        self.p_rotation = p_rotation
        self.selected_features = selected_features
        self.x_rot_range = x_rot_range
        self.y_rot_range = y_rot_range
        self.small_rotation = small_rotation

        self.X =  torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
        self.seqs = seqs
        self.count = 0
        self.iter = 2

    def random_angles_by_seq(self):
        unique_subjects = list(self.seqs_by_subject.keys())
        # Assign a consistent random Y angle per subject
        subject_to_angle = {
            subj:  (np.random.uniform(*self.x_rot_range), np.random.uniform(*self.y_rot_range)) #np.random.choice(y_range)
            for subj in unique_subjects
        }

        random_small_angles_by_subject = {
            subj: np.random.uniform(-self.small_rotation, self.small_rotation, size=len(seqs))
            for subj, seqs in self.seqs_by_subject.items()
        }


        subject_for_seq = {
        seq_id: (i, subj) for subj, seq_ids in self.seqs_by_subject.items() for i, seq_id in enumerate(seq_ids)
        }

        seq_to_angle = {
            seq_id: (subject_to_angle[subj][0] + random_small_angles_by_subject[subj][i], subject_to_angle[subj][1] + random_small_angles_by_subject[subj][i])
            for seq_id, (i, subj) in subject_for_seq.items()
        }

        return seq_to_angle

    def apply_rotation(self, 
                       x: torch.tensor, 
                       ax: str, 
                       seq_id: str,
                       seqs_to_angle) -> np.ndarray:
        x_copy = x.numpy().copy()
        rot_x, rot_y = seqs_to_angle.get(seq_id, (0.0, 0.0))
        if ax == 'x':
            rot = R.from_euler(ax, rot_x, degrees=True)
        if ax == 'y':
            rot = R.from_euler(ax, rot_y, degrees=True)
        if ax == 'z':
            rot = R.from_euler(ax, 180, degrees=True)
        if ax == 'zx':
            rot = R.from_euler('z', 180, degrees=True) *  R.from_euler('x', rot_x, degrees=True)
        if ax == 'zy':
            rot = R.from_euler('z', 180, degrees=True) *  R.from_euler('y', rot_y, degrees=True)
        if ax == 'xy':
            rot = R.from_euler('x', rot_x, degrees=True) *  R.from_euler('y', rot_y, degrees=True)
        if ax == 'zxy':
            rot = R.from_euler('z', 180, degrees=True) * R.from_euler('x', rot_x, degrees=True) *  R.from_euler('y', rot_y, degrees=True)

        for feats in self.features_to_rotate:
            idx_rotate = np.where(np.isin(self.selected_features, feats))[0]
            if len(idx_rotate) == 0:
                continue

            if not any('rot_' in f for f in feats):
                rotated = rot.apply(x_copy[:, idx_rotate])
                x_copy[:, idx_rotate] = rotated
            else:
                init_quat = x_copy[:, idx_rotate]
                mask = np.linalg.norm(init_quat, axis=1) < 1e-6
                R_orig = R.from_quat(init_quat[~mask])
                R_new = rot * R_orig
                new_quat = np.zeros_like(init_quat)
                new_quat[~mask] = R_new.as_quat()
                x_copy[:, idx_rotate] = new_quat

        return x_copy
    
    def apply_specific_rotation( 
                       x, 
                       ax: str, 
                       rots,
                       selected_features,
                       features_to_rotate) -> np.ndarray:
        x_copy = x.numpy().copy()
        rot_x, rot_y, rot_z = rots
        if ax == 'x':
            rot = R.from_euler(ax, rot_x, degrees=True)
        if ax == 'y':
            rot = R.from_euler(ax, rot_y, degrees=True)
        if ax == 'z':
            rot = R.from_euler(ax, rot_z, degrees=True)
        if ax == 'zx':
            rot = R.from_euler('z', rot_z, degrees=True) *  R.from_euler('x', rot_x, degrees=True)
        if ax == 'xz':
            rot = R.from_euler('x', rot_x, degrees=True) *  R.from_euler('z', rot_z, degrees=True)
        if ax == 'zy':
            rot = R.from_euler('z', rot_z, degrees=True) *  R.from_euler('y', rot_y, degrees=True)
        if ax == 'xy':
            rot = R.from_euler('x', rot_x, degrees=True) *  R.from_euler('y', rot_y, degrees=True)
        if ax == 'zxy':
            rot = R.from_euler('z', rot_z, degrees=True) * R.from_euler('x', rot_x, degrees=True) *  R.from_euler('y', rot_y, degrees=True)

        for feats in features_to_rotate:
            idx_rotate = np.where(np.isin(selected_features, feats))[0]
            if len(idx_rotate) == 0:
                continue

            if not any('rot_' in f for f in feats):
                rotated = rot.apply(x_copy[:, idx_rotate])
                x_copy[:, idx_rotate] = rotated
            else:
                init_quat = x_copy[:, idx_rotate]
                mask = np.linalg.norm(init_quat, axis=1) < 1e-6
                R_orig = R.from_quat(init_quat[~mask])
                R_new = rot * R_orig
                new_quat = np.zeros_like(init_quat)
                new_quat[~mask] = R_new.as_quat()
                x_copy[:, idx_rotate] = new_quat

        return x_copy
    
    # ---------- master call ----------
    def __call__(self,
                 axes: list) -> np.ndarray:
    

        seqs_to_angle = self.random_angles_by_seq()
        
        augmented_X_tr = []
        augmented_y_tr = []

        for xx, yy, seq_id in zip(self.X, self.y, self.seqs):
            augmented_X_tr.append(xx)
            augmented_y_tr.append(yy)

            # Reverse time (assuming time is dimension 0)
            x_rotated = []
            axes_choice = np.array(axes)
            for ax in axes_choice: #self.iter
                #if (np.random.random() < self.p_rotation) and (len(axes_choice) > 0):
                #ax = axes_choice[i]  #np.random.choice(axes_choice) # ##
                x_rotated.append(self.apply_rotation(xx, ax, seq_id, seqs_to_angle)) # subject_id, subject_to_angle)
                    #axes_choice = np.delete(axes_choice, np.where(axes_choice == ax)) 
            if len(x_rotated) > 0:
                self.count += len(x_rotated)
                x_rotated = [torch.tensor(x) for x in x_rotated]
                augmented_X_tr.extend(x_rotated)
                augmented_y_tr.extend([yy] * len(x_rotated))


        #augmented_X_tr = torch.tensor(augmented_X_tr)
        augmented_X_tr = torch.stack(augmented_X_tr)
        augmented_y_tr = torch.tensor(augmented_y_tr)  # Or use torch.stack if already tensors

        # X_aug = my_aug.augment(self.X.numpy())  # shape preserved
        # y_aug = self.y.clone()          # labels remain the same
        
        # X_aug = torch.cat([self.X, torch.from_numpy(X_aug)], dim=0)
        # y_aug = torch.cat([self.y, y_aug], dim=0)
        return augmented_X_tr, augmented_y_tr, self.count


class Augment:
    def __init__(self,
                 p_jitter=0.8, sigma=0.02, scale_range=[0.9,1.1],
                 p_dropout=0.3,
                 p_moda=0.5,
                 drift_std=0.005,     
                 drift_max=0.25):      
        self.p_jitter  = p_jitter
        self.sigma     = sigma
        self.scale_min, self.scale_max = scale_range
        self.p_dropout = p_dropout
        self.p_moda    = p_moda

        self.drift_std = drift_std
        self.drift_max = drift_max


    # ---------- Jitter & Scaling ----------
    def jitter_scale(self, x: np.ndarray) -> np.ndarray:
        noise  = np.random.randn(*x.shape) * self.sigma
        scale  = np.random.uniform(self.scale_min,
                                   self.scale_max,
                                   size=(1, x.shape[1]))
        return (x + noise) * scale

    # ---------- Sensor Drop-out ----------
    def sensor_dropout(self,
                       x: np.ndarray,
                       imu_dim: int) -> np.ndarray:

        if np.random.random() < self.p_dropout:
            x[:, imu_dim:] = 0.0
        return x

    def motion_drift(self, x: np.ndarray, imu_dim: int) -> np.ndarray:

        T = x.shape[0]

        drift = np.cumsum(
            np.random.normal(scale=self.drift_std, size=(T, 1)),
            axis=0
        )
        drift = np.clip(drift, -self.drift_max, self.drift_max)   

        x[:, :6] += drift

        if imu_dim > 6:
            x[:, 6:imu_dim] += drift     
        return x
    

    
    # ---------- master call ----------
    def __call__(self,
                 x: np.ndarray,
                 imu_dim: int) -> np.ndarray:
        
        if np.random.random() < self.p_jitter:
            x = self.jitter_scale(x)

        if np.random.random() < self.p_moda:
            x = self.motion_drift(x, imu_dim)

        x = self.sensor_dropout(x, imu_dim)
        return x
    
class EnsemblePredictor:
    def __init__(self,  processing_dir, models_dir, device, params):
        self.device = device
        self.models = {
            'hybrid_models': [],
            'imu_only_models': [],
            'imu_tof_thm_models': []
            }
        self.params = params
        self.scaler = None
        self.features = None
        self.label_encoder = None
        self.map_classes = None
        self.inverse_map_classes = None
        self.cols = None
        self._load_models(models_dir, seed_CV_fold = params["SEED_CV_FOLD"])
        self._load_processing(processing_dir)

    def _load_models(self, models_dir, seed_CV_fold = None):
        model_files = {}
        if seed_CV_fold is None:
            model_files['hybrid_models'] = sorted(glob.glob(f"{models_dir}/best_model_fold_*.pth"))
            model_files['imu_only_models'] = sorted(glob.glob(f"{models_dir}/best_model_imu_only_fold_*.pth"))
            model_files['imu_tof_thm_models'] = sorted(glob.glob(f"{models_dir}/best_model_imu_tof_thm_fold_*.pth"))
        else:
            model_files['hybrid_models'] = sorted(glob.glob(f"{models_dir}/best_model_fold_*_{seed_CV_fold}.pth"))
            model_files['imu_only_models'] = sorted(glob.glob(f"{models_dir}/best_model_imu_only_fold_*_{seed_CV_fold}.pth"))
            model_files['imu_tof_thm_models'] = sorted(glob.glob(f"{models_dir}/best_model_imu_tof_thm_fold_*_{seed_CV_fold}.pth"))
            
        for key, models in model_files.items():
            print(f"{len(models)} {' '.join(key.split('_'))} have been found")
        
        for key, models in model_files.items():
            for model_file in models:
                checkpoint = torch.load(model_file, map_location=self.device, weights_only=True)
                
                #model = MiniGestureClassifier(imu_dim=14, hidden_dim=128, num_classes=18)
                model = GlobalGestureClassifier(imu_dim=self.params['imu_dim'], 
                                                thm_tof_dim=self.params['thm_tof_dim'], 
                                                tof_raw_dim=self.params['tof_raw_dim'],  
                                                hidden_dim=self.params['HIDDEN_DIM'], 
                                                num_classes=self.params['N_CLASSES'], 
                                                C_TOF_RAW=self.params['C_TOF_RAW'],
                                                norm_TOF_RAW=self.params['normalisation_TOF_RAW'],
                                                norm_TOF_THM=self.params['normalisation_TOF_THM'],
                                                attention_for_fusion=self.params['attention_for_fusion'],
                                                attention_pooled= self.params['attention_pooled']
                                            ) # MODEL            
                model.load_state_dict(checkpoint) #['model_state_dict']
                model.to(self.device)
                model.eval()
                self.models[key].append(model)

    def _load_processing(self, processing_dir):
        self.scaler = joblib.load(os.path.join(processing_dir, "scaler.pkl"))
        self.label_encoder = joblib.load(os.path.join(processing_dir, "label_encoder.pkl"))
        self.map_classes = {idx: cl for idx, cl in enumerate(self.label_encoder.classes_)}
        self.inverse_map_classes = {cl: idx for idx, cl in enumerate(self.label_encoder.classes_)}

        
        file_path_cols = os.path.join(processing_dir, "cols.pkl")
        with open(file_path_cols, 'rb') as f:
            self.cols = pickle.load(f)
        self.features = np.concatenate( (self.cols['imu'], self.cols['thm'], self.cols['tof']) ) 

        print("-> scaler, features, labels classes loaded")
        #print(f"features = {self.features}")
    
    def features_eng(self, df_seq: pd.DataFrame, demographics: pd.DataFrame):
        df_seq = regularize_quaternions_per_sequence(df_seq)

        ### -- ADD NEW FEATURES (IMU + AVERAGED TOF COLUMNS) --- 
        df_seq = df_seq.reset_index(drop=True)
        df_seq = add_gesture_phase(df_seq)
        df_seq = compute_acceleration_features(df_seq, demographics)
        df_seq = compute_angular_features(df_seq, demographics)
        df_seq = manage_tof(df_seq, demographics)
        return df_seq

    def scale_pad_and_transform_to_torch_sequence(self, df_seq, pad_length, is_imu_only = False):
        ### -- Columns re-ordering to match train order
        df_seq_features = df_seq[self.features].copy()

        #has_nan_tof_thm = df_seq_features[ np.concatenate( (self.cols['tof'], self.cols['thm']) ) ].isnull().all(axis=1).all()
        # if has_nan_tof_thm:
        #     print("NaN values have been found in TOF and/or THM data")
        
        has_nan_imu = df_seq_features[self.cols['imu']].isnull().any().any()
        if has_nan_imu:
            print("x IMU cols have NaN values. Shouldn't be the case! Check data!")

        ### -- Scale features and check NaN for IMU COLS  
        np_seq_features  =  df_seq_features.to_numpy()
        features_to_exclude = [f for f in self.features if any(substr in f for substr in ['phase_adj'])]
        features_to_scale = [f for f in self.features if f not in features_to_exclude]
        idx_to_scale = np.where(np.isin(self.features, features_to_scale))[0]
        if len(np_seq_features) > 0:
            np_seq_features[:, idx_to_scale] =  self.scaler.transform(np_seq_features[:, idx_to_scale])
        
        imu_features = [
            'acc_x','acc_y','acc_z', 'rotvec_x', 'rotvec_y', 'rotvec_z', 
            'linear_acc_x', 'linear_acc_y', 'linear_acc_z', 
            'ang_vel_x', 'ang_vel_y', 'ang_vel_z', 'ang_dist',
            'phase_adj'
            ] 

        if is_imu_only:
            idx_imu = [np.where(self.features == f)[0][0] for f in imu_features]    ### select features from selected_features above
            np_seq_features = np_seq_features[:, idx_imu]
        else:
            selected_tof = [f for f in self.cols['tof'] if ('v' not in f) and ('tof_5' not in f)]
            raw_tof = [f for f in self.cols['tof'] if ('v' in f) and ('tof_5' not in f)]

            raw_tof_sorted = np.array([f'tof_{i}_v{j}' for i in range(1, 5) for j in range(64)])
            check_all_pixels = np.array([f in raw_tof for f in raw_tof_sorted]   )            ### THM Features for later

            if not np.all(check_all_pixels):
                print(f"missing pixel raw data in TOF data: {np.array(raw_tof_sorted)[~check_all_pixels]}")
            
            raw_tof_sorted = list(raw_tof_sorted[check_all_pixels])
            idx_imu = [np.where(self.features == f)[0][0] for f in imu_features]    ### select features from selected_features above
            idx_tof = [np.where(self.features == f)[0][0] for f in selected_tof]                   ### TOF Features for later
            idx_raw_tof = [np.where(self.features == f)[0][0] for f in raw_tof_sorted]                   ### TOF Features for later
            idx_thm = [np.where(self.features == f)[0][0] for f in self.cols['thm'] if 'thm_5' not in f]               ### THM Features for later           ### THM Features for later
            
            idx_all = idx_imu + idx_thm + idx_tof + idx_raw_tof
            np_seq_features = np_seq_features[:, idx_all]

        seq = torch.tensor(np_seq_features, dtype=torch.float32)
        length = seq.size(0)
        # Truncate
        if length >= pad_length:
            seq = seq[:pad_length].unsqueeze(0)
        # Pad
        elif length < pad_length:
            pad_len = pad_length - length
            padding = torch.full((pad_len, *seq.shape[1:]), 0.0, dtype=torch.float32)
            seq = torch.cat([seq, padding], dim=0).unsqueeze(0)

        #print(f"sequence has been scaled and padded. shape (1, T, F): {seq.shape}")
        return seq.to(self.device)

    def predict(self, torch_seq, by_fold = None, models_to_use = ['hybrid_models', 'imu_only_models', 'imu_tof_thm_models']):
    # torch_seq: [N, ...]  (N = batch size)

        weights_models = {
            'hybrid_models': 1.,
            'imu_only_models': 0.75,
            'imu_tof_thm_models': 0.5
            }
        
        weights = {name: weights_models[name] for name in models_to_use}
        weights = {name: w/sum(weights.values()) for name, w in weights.items()}

        indices_branches = {
            'imu': np.arange(self.params['imu_dim']), 
            'thm_tof': np.arange(self.params['imu_dim'],self.params['imu_dim'] + self.params['thm_tof_dim']), 
            'tof_raw': np.arange(self.params['imu_dim'] + self.params['thm_tof_dim'], self.params['imu_dim'] + self.params['thm_tof_dim'] + self.params['tof_raw_dim'])
            }

        pred_by_model = {model_type: [] for model_type in models_to_use}
        
        if by_fold is None:
            
            for key, models in self.models.items():
                if key in models_to_use:
                    for model in models:
                        model.eval()
                        with torch.no_grad():
                            output, _ =  model(
                                torch_seq[:, :, indices_branches['imu']], 
                                torch_seq[:, :, indices_branches['thm_tof']], 
                                torch_seq[:, :, indices_branches['tof_raw']]
                                )  # [N, num_classes]
                            probs = F.softmax(output, dim=1).cpu().numpy()  # [N, num_classes]
                            pred_by_model[key].append(probs)
                else:
                    continue

            # Merge predictions
            N, num_classes = pred_by_model[models_to_use[0]][0].shape
            merged_probs = np.zeros((N, num_classes))

            for key, pred_list in pred_by_model.items():
                if not pred_list:
                    continue
                stacked = np.stack(pred_list, axis=0)  # [num_models, N, num_classes]
                avg_probs = np.mean(stacked, axis=0)   # [N, num_classes]
                merged_probs += weights[key] * avg_probs

            # Final prediction by argmax
            preds = merged_probs.argmax(axis=1)
            final_preds = [str(self.map_classes[pred]) for pred in preds]  # [N]

            #     for model in models:
            #         model.eval()
            #         with torch.no_grad():
            #             output = model(torch_seq)  # [N, num_classes]
            #             preds = output.argmax(1).cpu().numpy()  # shape: [N]
            #             pred_by_model[key].append(preds)  # list of arrays
            # pred_by_model = list(zip(*pred_by_model))  # shape: [N, num_models]
            # final_preds = []
            # for sample_preds in pred_by_model:
            #     most_common_prediction = Counter(sample_preds).most_common(1)[0][0]
            #     final_preds.append(str(self.map_classes[most_common_prediction]))
        else:      
            for key, models in self.models.items():
                if key in models_to_use:
                    model = models[by_fold]
                    model.eval()
                    with torch.no_grad():
                        output, _ =  model(
                                torch_seq[:, :, indices_branches['imu']], 
                                torch_seq[:, :, indices_branches['thm_tof']], 
                                torch_seq[:, :, indices_branches['tof_raw']]
                                )  # [N, num_classes]
                        probs = F.softmax(output, dim=1).cpu().numpy()  # [N, num_classes]
                        pred_by_model[key].append(probs)
                else:
                    continue

                # Merge predictions
            print(pred_by_model)
            N, num_classes = pred_by_model[models_to_use[0]][0].shape
            merged_probs = np.zeros((N, num_classes))

            for key, pred_list in pred_by_model.items():
                if not pred_list:
                    continue
                stacked = np.stack(pred_list, axis=0)  # [num_models, N, num_classes]
                avg_probs = np.mean(stacked, axis=0)   # [N, num_classes]
                merged_probs += weights[key] * avg_probs

                # Final prediction by argmax
            preds = merged_probs.argmax(axis=1)
            final_preds = [str(self.map_classes[pred]) for pred in preds]  # [N]

            # else:
            #     model = self.models['hybrid'][by_fold]
            #     model.eval()
            #     with torch.no_grad():
            #         output = model(torch_seq)  # [N, num_classes]
            #         preds = output.argmax(1).cpu().numpy()  # shape: [N]
            #     final_preds = [str(self.map_classes[pred]) for pred in preds]
        
        if len(final_preds) == 1:
            return final_preds[0]
        else:
            return final_preds  # length N list of mapped predictions



        
   



 


# class EnsemblePredictor:
#     def __init__(self,  processing_dir, models_dir, device):
#         self.device = device
#         self.models = []
#         self.scaler = None
#         self.features = None
#         self.label_encoder = None
#         self.map_classes = None
#         self.inverse_map_classes = None
#         self.cols = None
#         self._load_models(models_dir)
#         self._load_processing(processing_dir)

#     def _load_models(self, models_dir):
#         model_files = sorted(glob.glob(f"{models_dir}/best_model_fold_*.pth"))
#         print(f"{len(model_files)} models have been found")
        
#         for model_file in model_files:
#             checkpoint = torch.load(model_file, map_location=self.device, weights_only=True)
            
#             model = MiniGestureClassifier(imu_dim=14, hidden_dim=128, num_classes=18)

#             model.load_state_dict(checkpoint) #['model_state_dict']
#             model.to(self.device)
#             model.eval()
#             self.models.append(model)

#     def _load_processing(self, processing_dir):
#         self.scaler = joblib.load(os.path.join(processing_dir, "scaler.pkl"))
#         self.label_encoder = joblib.load(os.path.join(processing_dir, "label_encoder.pkl"))
#         self.map_classes = {idx: cl for idx, cl in enumerate(self.label_encoder.classes_)}
#         self.inverse_map_classes = {cl: idx for idx, cl in enumerate(self.label_encoder.classes_)}

        
#         file_path_cols = os.path.join(processing_dir, "cols.pkl")
#         with open(file_path_cols, 'rb') as f:
#             self.cols = pickle.load(f)
#         self.features = np.concatenate( (self.cols['imu'], self.cols['thm'], self.cols['tof']) ) 

#         print("-> scaler, features, labels classes loaded")
#         #print(f"features = {self.features}")
    
#     def features_eng(self, df_seq: pd.DataFrame):
#         df_seq = regularize_quaternions_per_sequence(df_seq)

#         ### -- ADD NEW FEATURES (IMU + AVERAGED TOF COLUMNS) --- 
#         df_seq = df_seq.reset_index(drop=True)
#         df_seq = add_gesture_phase(df_seq)
#         df_seq = compute_acceleration_features(df_seq)
#         df_seq = compute_angular_features(df_seq)
#         df_seq = manage_tof(df_seq)
#         return df_seq
    
#     def scale_pad_and_transform_to_torch_sequence(self, df_seq, pad_length, is_imu_only = True):
#         ### -- Columns re-ordering to match train order
#         df_seq_features = df_seq[self.features].copy()

#         #has_nan_tof_thm = df_seq_features[ np.concatenate( (self.cols['tof'], self.cols['thm']) ) ].isnull().all(axis=1).all()
#         # if has_nan_tof_thm:
#         #     print("NaN values have been found in TOF and/or THM data")
        
#         has_nan_imu = df_seq_features[self.cols['imu']].isnull().any().any()
#         if has_nan_imu:
#             print("x IMU cols have NaN values. Shouldn't be the case! Check data!")

#         ### -- Scale features and check NaN for IMU COLS  
#         np_seq_features  =  df_seq_features.to_numpy()
#         features_to_exclude = [f for f in self.features if any(substr in f for substr in ['phase_adj'])]
#         features_to_scale = [f for f in self.features if f not in features_to_exclude]
#         idx_to_scale = np.where(np.isin(self.features, features_to_scale))[0]
#         if len(np_seq_features) > 0:
#             np_seq_features[:, idx_to_scale] =  self.scaler.transform(np_seq_features[:, idx_to_scale])

#         if is_imu_only:
#             imu_features = [
#             'acc_x','acc_y','acc_z', 'rotvec_x', 'rotvec_y', 'rotvec_z', 
#             'linear_acc_x', 'linear_acc_y', 'linear_acc_z', 
#             'ang_vel_x', 'ang_vel_y', 'ang_vel_z', 'ang_dist',
#             'phase_adj'
#             ] 
#             idx_imu = [np.where(self.features == f)[0][0] for f in imu_features]    ### select features from selected_features above
#             np_seq_features = np_seq_features[:, idx_imu]


#         seq = torch.tensor(np_seq_features, dtype=torch.float32)
#         length = seq.size(0)
#         # Truncate
#         if length >= pad_length:
#             seq = seq[:pad_length].unsqueeze(0)
#         # Pad
#         elif length < pad_length:
#             pad_len = pad_length - length
#             padding = torch.full((pad_len, *seq.shape[1:]), 0.0, dtype=torch.float32)
#             seq = torch.cat([seq, padding], dim=0).unsqueeze(0)

#         #print(f"sequence has been scaled and padded. shape (1, T, F): {seq.shape}")
#         return seq.to(self.device)
    
#     def predict(self, torch_seq, by_fold = None):
#     # torch_seq: [N, ...]  (N = batch size)

#         if by_fold is None:
#             pred_by_model = []
    
#             for model in self.models:
#                 model.eval()
#                 with torch.no_grad():
#                     output = model(torch_seq)  # [N, num_classes]
#                     preds = output.argmax(1).cpu().numpy()  # shape: [N]
#                     pred_by_model.append(preds)  # list of arrays
        
#             # Transpose to get predictions per sample:
#             # pred_by_model: list of model predictions → shape: [num_models, N]
#             # after zip(*...), we get: [ [model1_pred_sample1, model2_pred_sample1, ...], ... ]
#             pred_by_model = list(zip(*pred_by_model))  # shape: [N, num_models]
        
#             final_preds = []
#             for sample_preds in pred_by_model:
#                 most_common_prediction = Counter(sample_preds).most_common(1)[0][0]
#                 final_preds.append(str(self.map_classes[most_common_prediction]))
#         else:
#             model = self.models[by_fold]
#             model.eval()
#             with torch.no_grad():
#                 output = model(torch_seq)  # [N, num_classes]
#                 preds = output.argmax(1).cpu().numpy()  # shape: [N]
#             final_preds = [str(self.map_classes[pred]) for pred in preds]
        
#         if len(final_preds) == 1:
#             return final_preds[0]
#         else:
#             return final_preds  # length N list of mapped predictions


# class GestureClassifier(nn.Module):
#     def __init__(self, imu_dim, hidden_dim, num_classes, tof_dim = None, thm_dim = None): # tabular_dim = None
#         super().__init__()

#         if tof_dim is None:
#             tof_dim = imu_dim
#         if thm_dim is None:
#             thm_dim = imu_dim

#         self.imu_encoder = IMUEncoder(imu_dim, hidden_dim)
#         self.tof_encoder = OptionalEncoder(tof_dim, hidden_dim)
#         self.thm_encoder = OptionalEncoder(thm_dim, hidden_dim)
#         self.fusion = GatedFusion(hidden_dim, num_modalities=3)


#         self.classifier_rnn = nn.GRU(hidden_dim, hidden_dim, batch_first=True, bidirectional=True)
#         self.classifier_head = nn.Sequential(
#             nn.Linear(hidden_dim * 2 , hidden_dim),
#             nn.ReLU(),
#             nn.Dropout(p = 0.3),
#             nn.Linear(hidden_dim, num_classes),
#             #nn.Softmax()
#         )

#     def forward(self, imu, thm=None, tof=None): #, tabular_feats=None
#         B, T, _ = imu.shape

#         imu_feat = self.imu_encoder(imu)  # [B, hidden_dim, T/4]

#         # nan_mask = torch.isnan(thm)
#         # nan_indices = torch.nonzero(nan_mask, as_tuple=True)[0].detach().cpu()
#         # print(f"number of NaN (detect possible FE errors): {len(np.unique(nan_indices.numpy()))}")

#         if tof is None:
#             tof = torch.zeros_like(imu)
#             tof_mask = torch.zeros(B, T, device=imu.device)
#         else:
#             tof_mask = (~torch.isnan(tof).any(dim=2)).float()

#         if thm is None:
#             thm = torch.zeros_like(imu)
#             thm_mask = torch.zeros(B, T, device=imu.device)
#         else:
#             thm_mask = (~torch.isnan(thm).any(dim=2)).float()

#         tof_feat = self.tof_encoder(tof, tof_mask)  # [B, hidden_dim, T/4]
#         thm_feat = self.thm_encoder(thm, thm_mask)  # [B, hidden_dim, T/4]

#         fused =  self.fusion([imu_feat, tof_feat, thm_feat])  # [B, hidden_dim, T/4]

#         fused_t = fused.permute(0, 2, 1)  # [B, T/4, hidden_dim]
#         rnn_out, _ = self.classifier_rnn(fused_t)  # [B, T/4, hidden_dim*2]

#         pooled = rnn_out.mean(dim=1)#   # [B, hidden_dim*2]
#         #pooled = F.dropout(pooled, p=0.5, training=self.training) 

#         out = self.classifier_head(pooled)  # [B, num_classes]

#         return out

In [133]:
from sklearn.metrics import recall_score
import os


def train_model(model, 
                train_loader, val_loader, 
                optimizer, criterion, 
                epochs,
                device, 
                seed_CV_fold = None,
                patience = 50, 
                fold = None, 
                logger = None, 
                split_indices = None, 
                scheduler = None, 
                hide_val_half = True,
                L_IMU = 0.2
                ):
    reset_seed(42)
    model.to(device)
    early_stopper = EarlyStopping(patience=patience, mode='max', restore_best_weights=True, verbose=True, logger = logger)
    if split_indices is not None:
        idx_thm_tof = list(split_indices['thm']) + list(split_indices['tof'])
    
    if logger is not None:
        logger.info(f"lengths features: \
                            {len(split_indices['imu'])} (IMU) \
                            {len(idx_thm_tof)} (TOF-THM) \
                            {len(split_indices['tof_raw'])} (TOF-RAW) \
                            ")
    else:
        print(f"lengths features: \
                            {len(split_indices['imu'])} (IMU) \
                            {len(idx_thm_tof)} (TOF-THM) \
                            {len(split_indices['tof_raw'])} (TOF-RAW) \
                            ")
    best_score = 0
    best_score_imu_only = 0
    best_score_imu_tof_thm = 0
    for epoch in range(1, epochs + 1):
        #check_memory()
        model.train()
        train_loss = 0
        train_preds = []
        train_targets = []

        for inputs, targets in tqdm(train_loader, f"Epoch {epoch}"):

            # if hide_val_half and split_indices is not None:
            #     half = batch_size // 2
            #     x_front = inputs[:half]               
            #     x_back  = inputs[half:].clone()   
            #     x_back[:, :, idx_thm_tof] = 0.0    
            #     inputs = torch.cat([x_front, x_back], dim=0)
            # print(targets[:5])
            # print(inputs[0, :10, 0])
            inputs, targets = inputs.to(device), targets.to(device)
            #check_memory()
            optimizer.zero_grad()
            if split_indices is not None:
                outputs, imu_logits = model(inputs[:, :, split_indices['imu']], inputs[:, :, idx_thm_tof], inputs[:, :, split_indices['tof_raw']]) #, phase_adj = inputs[:, :,  -1]
            else:
                outputs = model(inputs) #, phase_adj = inputs[:, :,  -1]
            #check_memory()
            #targets = targets * (1 - 0.1) + (0.1 / 18)
            imu_loss = criterion(imu_logits, targets)
            loss = criterion(outputs, targets) #, class_weight, bfrb_classes)
            loss += L_IMU * imu_loss
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_preds.extend(outputs.argmax(1).cpu().numpy())
            train_targets.extend(targets.argmax(1).cpu().numpy())
        

        train_acc, _, train_macro_f1  = competition_metric(train_targets, train_preds)

        # ---- Validation ----
        model.eval()
        val_loss = 0
        val_preds = {'out': [], 'imu_only': [], 'all': []}
        val_targets = {'out': [], 'imu_only': [], 'all': []}
        # bin_preds = []
        # bin_targets = []

        with torch.no_grad():
            for inputs, targets in val_loader:
                #check_memory()
                if hide_val_half and split_indices is not None:
                    B = inputs.shape[0]
                    half = B // 2
                    x_front = inputs[:half]               
                    x_back  = inputs[half:].clone()   
                    x_back[:, :, len(split_indices['imu']):] = 0.0    
                    inputs = torch.cat([x_front, x_back], dim=0)
                    x_back, x_front = x_back.to(device), x_front.to(device)


                inputs, targets = inputs.to(device), targets.to(device)
                if split_indices is not None:
                    outputs, imu_logits = model(inputs[:, :, split_indices['imu']], inputs[:, :, idx_thm_tof], inputs[:, :, split_indices['tof_raw']]) 
                    assert x_back[:, :, split_indices['imu']].shape[2] > 0, "IMU split is empty!"
                    outputs_imu_only, _ = model(x_back[:, :, split_indices['imu']], x_back[:, :, idx_thm_tof], x_back[:, :, split_indices['tof_raw']]) 
                    outputs_all, _ = model(x_front[:, :, split_indices['imu']], x_front[:, :, idx_thm_tof], x_front[:, :, split_indices['tof_raw']]) 
                else:
                    outputs = model(inputs) #, phase_adj = inputs[:, :,  -1]               
                
                loss = criterion(outputs, targets) #, class_weight, bfrb_classes)
                imu_loss = criterion(imu_logits, targets) #, class_weight, bfrb_classes)
                loss += L_IMU * imu_loss
                val_loss += loss.item()

                if split_indices is not None:
                    val_preds['all'].extend(outputs_all.argmax(1).cpu().numpy())
                    val_preds['imu_only'].extend(outputs_imu_only.argmax(1).cpu().numpy())

                    val_targets['all'].extend(targets[:half].argmax(1).cpu().numpy())
                    val_targets['imu_only'].extend(targets[half:].argmax(1).cpu().numpy())

                val_preds['out'].extend(outputs.argmax(1).cpu().numpy())
                val_targets['out'].extend(targets.argmax(1).cpu().numpy())


                # mask_bfrb_classes = np.array([idx in bfrb_classes.numpy() for idx in range(outputs.shape[1])])
                # outputs = torch.nn.functional.softmax(outputs, dim=1)
        
                # bin_pred = outputs[:, mask_bfrb_classes].sum(1) > 0.4 #torch.stack([, outputs[:, ~mask_bfrb_classes].sum(1)], dim=1) 
                # bin_preds.extend(bin_pred.cpu().numpy())

                # bin_target = targets[:, mask_bfrb_classes].sum(1) #, targets[:, ~mask_bfrb_classes].sum(1)], dim=1) 
                # bin_targets.extend(bin_target.cpu().numpy())
                
        val_acc, _, val_macro_f1 = competition_metric(val_targets['out'], val_preds['out'])     #accuracy_score(val_targets, val_preds)
        early_stopper(val_acc, model)
        if scheduler is not None:
            scheduler.step(val_acc)

        #val_binary_recall = recall_score(bin_targets, bin_preds)
        if early_stopper.best_score > best_score:
            best_score = early_stopper.best_score
            name = "best_model"
            if (fold is not None) and (seed_CV_fold is not None):
                name += f"_fold_{fold}_seed_{seed_CV_fold}.pth"
            else:
                name += ".pth"
            #torch.save(early_stopper.best_model_state, os.path.join(Config.EXPORT_MODELS_PATH, name ))

        
        
        if split_indices is not None:
            val_acc_all, _, _ = competition_metric(val_targets['all'], val_preds['all'])    
            val_acc_imu_only, _, _ = competition_metric(val_targets['imu_only'], val_preds['imu_only'])   
            if logger is not None:
                logger.info(f"Epoch {epoch}/{epochs} - Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, Macro: {train_macro_f1:.4f} | Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f},  Acc (imu only): {val_acc_imu_only:.4f},  Acc (imu+thm+tof): {val_acc_all:.4f}, Macro: {val_macro_f1:.4f}")
            else:
                print(f"Epoch {epoch}/{epochs} - Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, Macro: {train_macro_f1:.4f} | Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f},  Acc (imu only): {val_acc_imu_only:.4f},   Acc (imu+thm+tof): {val_acc_all:.4f},  Macro: {val_macro_f1:.4f}")
        
            ### BEST IMU-ONLY MODEL ###
            if  val_acc_imu_only > best_score_imu_only:
                best_score_imu_only = val_acc_imu_only
                name = "best_model_imu_only"
                if (fold is not None) and (seed_CV_fold is not None):
                    name += f"_fold_{fold}_seed_{seed_CV_fold}.pth"
                else:
                    name += ".pth"
                #torch.save(model.state_dict(), os.path.join(Config.EXPORT_MODELS_PATH, name ))
        
            ### BEST IMU-TOF-THM MODEL ###
            if  val_acc_all > best_score_imu_tof_thm:
                best_score_imu_tof_thm = val_acc_all
                name = "best_model_imu_tof_thm"
                if (fold is not None) and (seed_CV_fold is not None):
                    name += f"_fold_{fold}_seed_{seed_CV_fold}.pth"
                else:
                    name += ".pth"
                #torch.save(model.state_dict(), os.path.join(Config.EXPORT_MODELS_PATH, name ))
        
        else:
            if logger is not None:
                logger.info(f"Epoch {epoch}/{epochs} - Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, Macro: {train_macro_f1:.4f} | Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, Macro: {val_macro_f1:.4f}")
            else:
                print(f"Epoch {epoch}/{epochs} - Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, Macro: {train_macro_f1:.4f} | Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, Macro: {val_macro_f1:.4f}")

        if early_stopper.early_stop:
            if logger is not None:
                logger.info("Training stopped early.")
            else:
                print("Training stopped early.")
            break


    return best_score, best_score_imu_only, best_score_imu_tof_thm 



        




class SoftCrossEntropy:
    def __init__(self,
                 bfrb_classes = None, gamma = None, lamb = None, class_weights = None, device = DEVICE):      
        self.gamma = gamma
        self.lamb = lamb
        self.class_weights = class_weights
        self.bfrb_classes = bfrb_classes
        self.device = device

    def __call__(self,
                 preds: torch.tensor,
                 soft_targets: torch.tensor
                 ):
        
        outputs = torch.nn.functional.softmax(preds, dim=1)
        preds_log = F.log_softmax(preds, dim=1)


        if self.class_weights is not None:
            soft_targets = soft_targets * self.class_weights.to(self.device).unsqueeze(0)
            soft_targets = soft_targets / soft_targets.sum(dim=1, keepdim=True)  # re-normalize

        weighted_kl = F.kl_div(preds_log, soft_targets, reduction='batchmean')

        if self.bfrb_classes is None and (self.gamma is not None or self.lamb is not None):
            raise ValueError("bfrb_classes should not be None when lamb or gamma is specified")

        if self.bfrb_classes is not None and (self.gamma is not None or self.lamb is not None):
            mask_bfrb_classes = np.array([idx in self.bfrb_classes.numpy() for idx in range(preds.shape[1])])
            

            bfrb_pred = torch.cat( [outputs[:, mask_bfrb_classes], outputs[:, ~mask_bfrb_classes].sum(dim=1, keepdim=True)], dim=1)
            bfrb_target = torch.cat( [soft_targets[:, mask_bfrb_classes], soft_targets[:, ~mask_bfrb_classes].sum(dim=1, keepdim=True)], dim=1)

            bin_pred = torch.stack([ outputs[:, mask_bfrb_classes].sum(1), outputs[:, ~mask_bfrb_classes].sum(1)], dim=1)
            bin_target = torch.stack([soft_targets[:, mask_bfrb_classes].sum(1), soft_targets[:, ~mask_bfrb_classes].sum(1)], dim=1) 

            brfb_loss = F.kl_div(
            torch.log(bfrb_pred + 1e-8),  # log-probabilities
            bfrb_target,
            reduction='batchmean'
            )

            binary_loss = F.kl_div(
            torch.log(bin_pred + 1e-8),  # log-probabilities #torch.log(+1e-8)
            bin_target,
            reduction='batchmean'
            )


            if self.gamma is not None and self.lamb is None:
                return  weighted_kl + self.gamma * brfb_loss 
            if self.gamma is None and self.lamb is not None:
                return  weighted_kl + self.lamb * binary_loss 
            if self.gamma is not None and self.lamb is not None:
                return   weighted_kl + self.gamma * brfb_loss + self.lamb * binary_loss    
        else:
            return weighted_kl


def predict(sequence: pl.DataFrame, demographics: pl.DataFrame) -> str:
    sequence = sequence.to_pandas()
    demographics = demographics.to_pandas()
    sequence = predictor.features_eng(sequence, demographics)
    torch_seq = predictor.scale_pad_and_transform_to_torch_sequence(sequence, pad_length)
    most_common_prediction = predictor.predict(torch_seq)
    return str(most_common_prediction)

In [134]:
"""
Hierarchical macro F1 metric for the CMI 2025 Challenge.

This script defines a single entry point `score(solution, submission, row_id_column_name)`
that the Kaggle metrics orchestrator will call.
It performs validation on submission IDs and computes a combined binary & multiclass F1 score.
"""

import pandas as pd
from sklearn.metrics import f1_score


class ParticipantVisibleError(Exception):
    """Errors raised here will be shown directly to the competitor."""
    pass


class CompetitionMetric:
    """Hierarchical macro F1 for the CMI 2025 challenge."""
    def __init__(self):
        self.target_gestures = [
            'Above ear - pull hair',
            'Cheek - pinch skin',
            'Eyebrow - pull hair',
            'Eyelash - pull hair',
            'Forehead - pull hairline',
            'Forehead - scratch',
            'Neck - pinch skin',
            'Neck - scratch',
        ]
        self.non_target_gestures = [
            'Write name on leg',
            'Wave hello',
            'Glasses on/off',
            'Text on phone',
            'Write name in air',
            'Feel around in tray and pull out an object',
            'Scratch knee/leg skin',
            'Pull air toward your face',
            'Drink from bottle/cup',
            'Pinch knee/leg skin'
        ]
        self.all_classes = self.target_gestures + self.non_target_gestures

    def calculate_hierarchical_f1(
        self,
        sol: pd.DataFrame,
        sub: pd.DataFrame
    ) -> float:

        # Validate gestures
        invalid_types = {i for i in sub['gesture'].unique() if i not in self.all_classes}
        if invalid_types:
            raise ParticipantVisibleError(
                f"Invalid gesture values in submission: {invalid_types}"
            )

        # Compute binary F1 (Target vs Non-Target)
        y_true_bin = sol['gesture'].isin(self.target_gestures).values
        y_pred_bin = sub['gesture'].isin(self.target_gestures).values
        f1_binary = f1_score(
            y_true_bin,
            y_pred_bin,
            pos_label=True,
            zero_division=0,
            average='binary'
        )

        # Build multi-class labels for gestures
        y_true_mc = sol['gesture'].apply(lambda x: x if x in self.target_gestures else 'non_target')
        y_pred_mc = sub['gesture'].apply(lambda x: x if x in self.target_gestures else 'non_target')

        # Compute macro F1 over all gesture classes
        f1_macro = f1_score(
            y_true_mc,
            y_pred_mc,
            average='macro',
            zero_division=0
        )

        return 0.5 * f1_binary + 0.5 * f1_macro


def score(
    solution: pd.DataFrame,
    submission: pd.DataFrame,
    row_id_column_name: str
) -> float:
    """
    Compute hierarchical macro F1 for the CMI 2025 challenge.

    Expected input:
      - solution and submission as pandas.DataFrame
      - Column 'sequence_id': unique identifier for each sequence
      - 'gesture': one of the eight target gestures or "Non-Target"

    This metric averages:
    1. Binary F1 on SequenceType (Target vs Non-Target)
    2. Macro F1 on gesture (mapping non-targets to "Non-Target")

    Raises ParticipantVisibleError for invalid submissions,
    including invalid SequenceType or gesture values.


    Examples
    --------
    >>> import pandas as pd
    >>> row_id_column_name = "id"
    >>> solution = pd.DataFrame({'id': range(4), 'gesture': ['Eyebrow - pull hair']*4})
    >>> submission = pd.DataFrame({'id': range(4), 'gesture': ['Forehead - pull hairline']*4})
    >>> score(solution, submission, row_id_column_name=row_id_column_name)
    0.5
    >>> submission = pd.DataFrame({'id': range(4), 'gesture': ['Text on phone']*4})
    >>> score(solution, submission, row_id_column_name=row_id_column_name)
    0.0
    >>> score(solution, solution, row_id_column_name=row_id_column_name)
    1.0
    """
    # Validate required columns
    for col in (row_id_column_name, 'gesture'):
        if col not in solution.columns:
            raise ParticipantVisibleError(f"Solution file missing required column: '{col}'")
        if col not in submission.columns:
            raise ParticipantVisibleError(f"Submission file missing required column: '{col}'")

    metric = CompetitionMetric()
    return metric.calculate_hierarchical_f1(solution, submission)


In [135]:
import torch
import torch.optim as optim
from torch.utils.data import WeightedRandomSampler
from torch.utils.data import DataLoader

from sklearn.model_selection import StratifiedGroupKFold

import sys

print(f"✓ Configuration loaded for Kaggle environment (Device: {DEVICE})")
print("Device in use:", torch.cuda.current_device() if torch.cuda.is_available() else "CPU")

if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(torch.cuda.current_device())
    print(f"Using device: {device_name}")


TRAIN = False

data_file =  "train_torch_tensors_from_wrapper_left_corrected_without_TOF_correction.pt"


N_SPLITS = 5
BATCH_SIZE = 64
EPOCHS = 160
HIDDEN_DIM = 128
PATIENCE = 45
ALPHA = 0.4
LR = 1e-3
SEED_CV_FOLD = 39

p_dropout = 0.48 #0.42
p_jitter= 0.0 #0.98
p_moda = 0.4 #0.4
p_rotation = 1.1
small_rotation = 2.
x_max_angle = 30.
y_max_angle = 15.

normalisation_TOF_RAW = False
normalisation_TOF_THM = True
attention_for_fusion = False
attention_pooled = True
C_TOF_RAW = False
ADD_TOF_TO_THM = True

SCHEDULER = True
patience_scheduler = 8
factor_scheduler = 0.7


GAMMA = 0.0
LAMB = 0.0
L_IMU = 0.25


SEED = Config.SEED
reset_seed(SEED)

file_path_train = os.path.join(Config.EXPORT_DIR, data_file)
file_path_cols = os.path.join(Config.EXPORT_DIR, "cols.pkl")
file_path_splits = os.path.join(Config.EXPORT_DIR, "split_ids.pkl")


selected_features = [
        'acc_x','acc_y','acc_z',#,'rot_x', 'rot_y', 'rot_z', 'rot_w', 
        'rotvec_x', 'rotvec_y', 'rotvec_z', 
        'linear_acc_x', 'linear_acc_y', 'linear_acc_z', 
        #'linear_acc_x_FFT', 'linear_acc_y_FFT', 'linear_acc_z_FFT', 
        #'acc_norm_world', 
        # 'acc_norm', 'linear_acc_norm', 
        # 'acc_norm_jerk', 'linear_acc_norm_jerk', 
        #'angle_rad', 'angular_speed', 
        # 'rot_angle', 'rot_angle_vel', 'angular_speed', 
        'ang_vel_x', 'ang_vel_y', 'ang_vel_z', 'ang_dist',
        # 'ang_vel_x_FFT', 'ang_vel_y_FFT', 'ang_vel_z_FFT', 
        'phase_adj',
        ] 

print("Features:", selected_features)


# ---------------- LOAD DATA ------------------------


if os.path.exists(file_path_train):
    print("Loading existing tensor...")
    data = torch.load(file_path_train)
    X_train, y_train = data['X_train'], data['y_train']

    with open(file_path_cols, 'rb') as f:
        cols = pickle.load(f)

    with open(file_path_splits, 'rb') as f:
        split_ids = pickle.load(f)


else:
    print("File not found. Generating data...")
    X_train, y_train = wrapper_data(split=False)
    print(X_train.shape, y_train.shape)

    data = {'X_train': X_train, 'y_train': y_train} 
    torch.save(data, file_path_train)
    print("Data saved.")

    with open(file_path_cols, 'rb') as f:
        cols = pickle.load(f)

    with open(file_path_splits, 'rb') as f:
        split_ids = pickle.load(f)


gesture_mapping = {cl: idx for idx, cl in split_ids['classes'].items()}   ### GESTURE MAP CLASSES --> LABELS
bfrb_gesture = CompetitionMetric().target_gestures                        ### TARGET GESTURE CLASSES
bfrb_classes = torch.tensor([gesture_mapping[cl] for cl in bfrb_gesture]) ### TARGET GESTURE LABELS


# ------------------ SELECT FEATURES AND PREPARE DATA FOR TRAINING ------------------------

all_features = np.concatenate( (cols['imu'], cols['thm'], cols['tof']) ) 
selected_tof = [f for f in cols['tof'] if ('v' not in f) and ('tof_5' not in f)]
raw_tof = [f for f in cols['tof'] if ('v' in f) and ('tof_5' not in f)]
print(raw_tof)

raw_tof_sorted = np.array([f'tof_{i}_v{j}' for i in range(1, 5) for j in range(64)])
check_all_pixels = np.array([f in raw_tof for f in raw_tof_sorted]   )            ### THM Features for later

if not np.all(check_all_pixels):
    print(f"missing pixel raw data in TOF data: {np.array(raw_tof_sorted)[~check_all_pixels]}")



raw_tof_sorted = list(raw_tof_sorted[check_all_pixels])


idx_imu = [np.where(all_features == f)[0][0] for f in selected_features]    ### select features from selected_features above
idx_tof = [np.where(all_features == f)[0][0] for f in selected_tof]                   ### TOF Features for later
idx_raw_tof = [np.where(all_features == f)[0][0] for f in raw_tof_sorted]                   ### TOF Features for later
idx_thm = [np.where(all_features == f)[0][0] for f in cols['thm'] if 'thm_5' not in f]               ### THM Features for later              ### THM Features for later

idx_all = idx_imu + idx_thm + idx_tof + idx_raw_tof
indices_branches = {
    'imu': np.arange(len(idx_imu)), 
    'thm': np.arange(len(idx_imu), len(idx_imu + idx_thm)), 
    'tof': np.arange(len(idx_imu + idx_thm), len(idx_imu + idx_thm + idx_tof)),
    'tof_raw': np.arange(len(idx_imu + idx_thm + idx_tof), len(idx_all))
    }
# else:
#     idx_all = idx_imu + idx_thm + idx_raw_tof
#     indices_branches = {
#         'imu': np.arange(len(idx_imu)), 
#         'thm': np.arange(len(idx_imu), len(idx_imu + idx_thm)), 
#         'tof': [],
#         'tof_raw': np.arange(len(idx_imu + idx_thm), len(idx_all))
#         }


X = X_train[:, :, idx_all]   ## select idx features in X
y = y_train                  ## labels 


#### NaN ? in DATA #### 
nan_mask = torch.isnan(X[:, :, indices_branches['imu']])
nan_indices = torch.nonzero(nan_mask, as_tuple=True)
print(f"number of NaN (detect possible FE errors): {len(np.unique(nan_indices[0].numpy()))}")
      
if len(np.unique(nan_indices[0].numpy())) > 0:      
    X[:, :, indices_branches['imu']] = torch.tensor(np.nan_to_num(X[:, :, indices_branches['imu']], nan=0.0))

nan_mask = torch.isnan(X)
nan_indices = torch.nonzero(nan_mask, as_tuple=True)
      
if len(np.unique(nan_indices[0].numpy())) > 0:      
    X = torch.tensor(np.nan_to_num(X, nan=0.0))

########################


print(f"Data shape (X, y): {X.shape, y.shape}")

# cw_vals = compute_class_weight('balanced', classes=list(split_ids['classes'].keys()), y=y.numpy())  ## Class weights to handle imbalance
# class_weight = torch.from_numpy(cw_vals).float()                                                    ## class weights as torch tensor

class_weight = 0.7 * torch.ones(len(split_ids['classes'].keys()))
class_weight[bfrb_classes] = 2.


# ----------- ALL PARAMETERS TO SAVE IT ---------------
if TRAIN:
    all_parameters = {
        "data_file": data_file,
        "SEED": SEED,
        "SEED_CV_FOLD": SEED_CV_FOLD if SEED_CV_FOLD is not None else None,
        "N_SPLITS": N_SPLITS,
        "BATCH_SIZE": BATCH_SIZE,
        "EPOCHS": EPOCHS,
        "HIDDEN_DIM": HIDDEN_DIM,
        "PATIENCE": PATIENCE,
        "ALPHA":ALPHA,
        "LR": LR,
        "normalisation_TOF_RAW": normalisation_TOF_RAW,
        "normalisation_TOF_THM": normalisation_TOF_THM,
        "attention_for_fusion": attention_for_fusion,
        "attention_pooled": attention_pooled,
        "add_tof_features_to_thm": ADD_TOF_TO_THM,
        "C_TOF_RAW": C_TOF_RAW,
        "IMU_FEATURES": selected_features,
        "THM-TOF FEATURES": selected_tof,
        "TOF-RAW FEATURES": raw_tof_sorted,
        "loss_GAMMA": GAMMA,
        "loss_LAMBDA": LAMB,
        "additionnal_IMU_loss": L_IMU,
        "N_CLASSES": len(class_weight),
        "imu_dim":len(selected_features),
        "thm_tof_dim":len(selected_tof),
        "tof_raw_dim":len(raw_tof_sorted),
        "scheduler": SCHEDULER if SCHEDULER else None,
        "factor_scheduler": factor_scheduler if SCHEDULER else None,
        "patience_scheduler": patience_scheduler if SCHEDULER else None,
        "p_dropout": p_dropout,
        "p_jitter": p_jitter,
        "p_moda": p_moda,
        "p_rotation": p_rotation,
        "small_rotation": small_rotation, 
        "x_max_angle": x_max_angle,
        "y_max_angle": y_max_angle,
    }
else:
    file_path_params = os.path.join(Config.EXPORT_DIR, 'all_parameters.pkl')
    with open(file_path_params, 'rb') as f:
        file_params = pickle.load(f)
    all_parameters = file_params['hyperparams']
    all_parameters["thm_tof_dim"] = 20
    print(all_parameters)

# ------------------------------- DEMO DATA ---------------------------------
 
train_demographics = pd.read_csv(Config.TRAIN_DEMOGRAPHICS_PATH)

# ------------------------------- TRAINING ---------------------------------

sgkf = StratifiedGroupKFold(n_splits=N_SPLITS, shuffle=True, random_state = SEED_CV_FOLD) #STRATIFIED k-Fold by group (subject_id)

if not ADD_TOF_TO_THM:
    indices_branches['tof'] = []

train_ids = np.array(split_ids['train']['train_sequence_ids']) #seq_id of data sequences 
groups = [split_ids['train']['train_sequence_subject'][seq_id] for seq_id in train_ids] #subject_id of data_sequences
wrong_subjects = ['SUBJ_045235', 'SUBJ_019262']

# idx_spe_seq = np.where(train_ids == 'SEQ_000007')[0]


### LOOP FOR EACH TRAINING FOLD
best_scores = {
    'mixture':[],
    'imu_only':[],
    'imu_tof_thm':[], 
    }
best_scores_inference = []
for fold, (train_idx, val_idx) in enumerate(sgkf.split(X, y, groups)):
    print(f"\n===== FOLD {fold+1}/{N_SPLITS} =====\n")
    reset_seed(SEED)

    # Split data
    X_tr, X_val = X[train_idx], X[val_idx]
    y_tr, y_val = y[train_idx], y[val_idx]

    if TRAIN:

        subjects_id = np.array(groups)[train_idx]
        train_seq_ids = train_ids[train_idx]

        ###### Handedness of subjects in train and validation fold
        subjects_fold_train = np.unique(subjects_id)
        subjects_fold_val= np.unique(np.array(groups)[val_idx])
        handedness_train = [train_demographics[train_demographics['subject'] == subject]['handedness'].iloc[0] for subject in subjects_fold_train]
        handedness_val = [train_demographics[train_demographics['subject'] == subject]['handedness'].iloc[0] for subject in subjects_fold_val]
        print(f"number of left-handed (right-handed) subject in train fold {fold + 1} = {np.sum(np.array(handedness_train) == 0)} ({np.sum(np.array(handedness_train) == 1)})")
        print(f"number of left-handed (right-handed) subject in val fold {fold + 1} = {np.sum(np.array(handedness_val) == 0)} ({np.sum(np.array(handedness_val) == 1)})")
        cond_wrong = [wg_sub in subjects_fold_val for wg_sub in wrong_subjects]
        if any(cond_wrong):
            print(f"wrong wrist wearing detected in val fold {fold + 1}: {np.array(wrong_subjects)[cond_wrong]}")

        print(" ---- check for reproductibility ----")
        print(f"first 10 seq_id = {train_seq_ids[:10]}")
        print(f"first 10 train idx = {train_idx[:10]}, and val idx = {val_idx[:10]}")
        print(f"mean train idx = {np.mean(train_idx)}, and mean val idx = {np.mean(val_idx)}\n")

        df = pd.DataFrame({'subject_id': subjects_id, 'seq_id': train_seq_ids})
        seqs_by_subject = (
                df.groupby('subject_id')['seq_id']
                .unique()
                .apply(list)
                .to_dict()
            )

        #### DATA AUGMENTATION #####
        print("------ DATA AUGMENTATION: DEVICE ROTATION ------")
        rotation_augmented = DeviceRotationAugment(X_tr, y_tr, train_seq_ids,     
                            seqs_by_subject, selected_features, 
                            p_rotation=1.1, 
                            small_rotation=2., 
                            x_rot_range=(0., 30.)
                            )
        X_tr, y_tr, count = rotation_augmented(axes=['z', 'x'])
        print(f"number of additional rotated features samples: {count}")
        print(f"shape of training data after augmentation (X, y): {X_tr.shape, y_tr.shape}\n")

        # augmenter = Augment()

        # augmenter = Augment(
        #     p_jitter=0.98, sigma=0.033, scale_range=(0.75,1.16),
        #     p_dropout=0.42,
        #     p_moda=0.39, drift_std=0.004, drift_max=0.39    
        # )
        augmenter = Augment(
            p_jitter=0.98, sigma=0.033, scale_range=(0.75,1.16),
            p_dropout=0.42,
            p_moda=0.39, drift_std=0.004, drift_max=0.39    
        )
        #########################################

        train_ds = SensorDataset(X_tr, y_tr, imu_dim = len(idx_imu), alpha=ALPHA, augment=augmenter)  ### TRAINING ROTATION AUGMENTED DATA WITH MixUp \alpha 


        # CLASS IMBALANCE handling 
        print(" ----------- CLASS INBALANCE SAMPLER (WeightedRandomSampler) ---------") 
        class_counts = np.bincount(y_tr.numpy())
        print(f"Number of samples per class: {Counter(y_tr.numpy())}\n")
        class_weights_balanced = 1. / class_counts
        sample_weights = class_weights_balanced[y_tr.numpy()]
        sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights) , replacement=True)
        tracking_sampler = TrackingSampler(sampler)

        sampled_indices = list(sampler)
        sampled_labels = y_tr[sampled_indices]
        print(Counter(sampled_labels.numpy()))

        train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=tracking_sampler, pin_memory=True)
    
        val_ds = SensorDataset(X_val, y_val, imu_dim = 7, training=False) ### VALIDATION DATA (NO AUG, NO MixUp)
        val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True)


    if TRAIN:
        criterion = SoftCrossEntropy(bfrb_classes=bfrb_classes, gamma = GAMMA, lamb = LAMB) # LOSS FUNCTION bfrb_classes=bfrb_classes, gamma = .5, lamb = .5 

        #model = MiniGestureClassifier(imu_dim=X_tr.shape[2], hidden_dim=128, num_classes=len(class_weight)) # MODEL
        model = GlobalGestureClassifier(imu_dim=len(indices_branches['imu']), 
                                        thm_tof_dim=len(indices_branches['thm']) + len(indices_branches['tof']), 
                                        tof_raw_dim=len(indices_branches['tof_raw']),  
                                        hidden_dim=HIDDEN_DIM, 
                                        num_classes=len(class_weight), 
                                        C_TOF_RAW=C_TOF_RAW,
                                        norm_TOF_RAW=normalisation_TOF_RAW,
                                        norm_TOF_THM=normalisation_TOF_THM,
                                        attention_for_fusion=attention_for_fusion,
                                        attention_pooled= attention_pooled
                                        ) # MODEL


        optimizer = optim.Adam(model.parameters(), lr=LR) # OPTIMIZER  weight_decay=WD

        if SCHEDULER:
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=factor_scheduler, patience=patience_scheduler, verbose=True)
        else:
            scheduler = None
        
        best_score, best_score_imu_only, best_score_imu_tof_thm = train_model(model, train_loader, val_loader, 
                                 optimizer, criterion, 
                                 EPOCHS, 
                                 DEVICE, 
                                 patience=PATIENCE, 
                                 fold = fold, 
                                 split_indices = indices_branches,
                                 scheduler = scheduler,                                            
                                 L_IMU= L_IMU,
                                 seed_CV_fold = SEED_CV_FOLD                                            
                                 )
        best_scores['mixture'].append(best_score)
        best_scores['imu_only'].append(best_score_imu_only)
        best_scores['imu_tof_thm'].append(best_score_imu_tof_thm)
    else:
        print("---- INFERENCE MODE ----")
        processing_dir = Config.EXPORT_DIR
        models_dir = Config.EXPORT_MODELS_PATH
        predictor = EnsemblePredictor(processing_dir, models_dir, DEVICE, all_parameters)
        inverse_map_classes = predictor.inverse_map_classes
        #map_classes = predictor.map_classes

        #val_ds = SensorDataset(X_val, y_val, imu_dim = 7, training=False) ### VALIDATION DATA (NO AUG, NO MixUp)
        #val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True)

        #X_val = []
        #for inputs, targets in val_loader:
        #    B = inputs.shape[0]
        #    half = B // 2
        #    x_front = inputs[:half]               
        #    x_back  = inputs[half:].clone()   
        #    x_back[:, :, 14:] = 0.0    
        #    inputs = torch.cat([x_front, x_back], dim=0)
        #    X_val.extend(inputs)
        #X_val = torch.stack(X_val)
        
        preds_str = predictor.predict(X_val.to(DEVICE), by_fold = fold) #, models_to_use = ['hybrid_models'])
        preds_int = [inverse_map_classes[pred_str] for pred_str in preds_str]
        best_score, _, _ = competition_metric(y_val, preds_int)
        print(best_score)
    
        best_scores_inference.append(best_score)

if TRAIN:
    print(np.mean(best_scores['mixture']))
    print(np.mean(best_scores['imu_only']))
    print(np.mean(best_scores['imu_tof_thm']))

else:
    print(np.mean(best_scores_inference))

✓ Configuration loaded for Kaggle environment (Device: cuda)
Device in use: 0
Using device: Tesla T4
Features: ['acc_x', 'acc_y', 'acc_z', 'rotvec_x', 'rotvec_y', 'rotvec_z', 'linear_acc_x', 'linear_acc_y', 'linear_acc_z', 'ang_vel_x', 'ang_vel_y', 'ang_vel_z', 'ang_dist', 'phase_adj']
Loading existing tensor...
['tof_1_v0', 'tof_1_v1', 'tof_1_v10', 'tof_1_v11', 'tof_1_v12', 'tof_1_v13', 'tof_1_v14', 'tof_1_v15', 'tof_1_v16', 'tof_1_v17', 'tof_1_v18', 'tof_1_v19', 'tof_1_v2', 'tof_1_v20', 'tof_1_v21', 'tof_1_v22', 'tof_1_v23', 'tof_1_v24', 'tof_1_v25', 'tof_1_v26', 'tof_1_v27', 'tof_1_v28', 'tof_1_v29', 'tof_1_v3', 'tof_1_v30', 'tof_1_v31', 'tof_1_v32', 'tof_1_v33', 'tof_1_v34', 'tof_1_v35', 'tof_1_v36', 'tof_1_v37', 'tof_1_v38', 'tof_1_v39', 'tof_1_v4', 'tof_1_v40', 'tof_1_v41', 'tof_1_v42', 'tof_1_v43', 'tof_1_v44', 'tof_1_v45', 'tof_1_v46', 'tof_1_v47', 'tof_1_v48', 'tof_1_v49', 'tof_1_v5', 'tof_1_v50', 'tof_1_v51', 'tof_1_v52', 'tof_1_v53', 'tof_1_v54', 'tof_1_v55', 'tof_1_v56', 't

IndexError: list index out of range

In [None]:
dummy_test = False

if dummy_test:

    pad_length = Config.PADDING
    processing_dir = Config.EXPORT_DIR
    models_dir = Config.EXPORT_MODELS_PATH
    test_path = Config.TEST_PATH
    train_path = Config.TRAIN_PATH
    train_path_demo = Config.TRAIN_DEMOGRAPHICS_PATH
    
    # Check GPU availability
    DEVICE = torch.device(check_gpu_availability())
    print(f"✓ Configuration loaded for Kaggle environment (Device: {DEVICE})")
    
    
    predictor = EnsemblePredictor(processing_dir, models_dir, DEVICE, all_parameters)
    inverse_map_classes = predictor.inverse_map_classes
    map_classes = predictor.map_classes

    train = pd.read_csv(train_path)
    train_demo = pd.read_csv(train_path_demo)
    
    print(f"---> Original shape = {train.shape}")
    sel_seq  = train["sequence_id"].unique()#[0 : 3500]
    seq      = sel_seq[0: 1750]
    oth_cols = sorted([c for c in train.columns if (c.startswith('thm_') or c.startswith('tof_'))]) #train.columns[16:]
    train    = train.loc[train.sequence_id.isin(sel_seq)]
    train.loc[train.sequence_id.isin(seq), oth_cols] = np.nan
    print(f"---> Truncated shape = {train.shape}")
    train_sequences = train.groupby("sequence_id")
    
    ypred = []
    ytruth = []
    for _, sequence in tqdm(train_sequences, desc="Processing Sequences"):
    #     #print(f"======== SEQUENCE {seq_id} ========")
        sequence = pl.DataFrame(sequence)
        pred = predict(sequence, train_demo)
        ypred.append(inverse_map_classes[pred])
        sequence = sequence.to_pandas()
        ytruth.append(inverse_map_classes[sequence['gesture'].iloc[0]])
    
    
    print(competition_metric(ytruth, ypred))

In [None]:
### SUBMISSION ####

