In [None]:
import os, json, joblib
# Core data science libraries
import numpy as np
import pandas as pd
from math import ceil
from scipy import stats

# Visualization libraries
import matplotlib.pyplot as plt
import seaborn as sns

from pathlib import Path
import warnings 
import random
warnings.filterwarnings("ignore")

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import StratifiedGroupKFold

import tensorflow as tf
from tensorflow.keras.utils import Sequence, to_categorical, pad_sequences
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import (
    Input, Conv1D, BatchNormalization, Activation, add, MaxPooling1D, Dropout,
    Bidirectional, LSTM, GlobalAveragePooling1D, Dense, Multiply, Reshape,
    Lambda, Concatenate, GRU, GaussianNoise, Add, GlobalMaxPooling1D,
    MultiHeadAttention, LayerNormalization
)
from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, LearningRateScheduler
from tensorflow.keras import backend as K
from tensorflow.keras import mixed_precision

import polars as pl
from scipy.spatial.transform import Rotation as R
from joblib import Parallel, delayed
import multiprocessing

# Progress bar
from tqdm import tqdm

from scipy import stats

In [None]:
def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    tf.experimental.numpy.random.seed(seed)
    os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
    os.environ['TF_DETERMINISTIC_OPS'] = '1'
seed_everything(seed=42)

In [None]:
# Configuration
train = True                    
raw_dir = Path("input")
pretrained_dir = Path("best_model") # replace with my trained weights
output_dir = Path("best_model")                                    
batch_size = 64 
pad_percentile = 90
lr_init = 5e-4
wd = 3e-3
mixup_alpha = 0.4
epochs = 120  
patience = 50 
n_splits = 6 

print("Imports ready")

In [None]:
if train:
    from tensorflow.keras import mixed_precision
    policy = mixed_precision.Policy('mixed_float16')
    mixed_precision.set_global_policy(policy)

In [None]:
#Tensor Manipulations
def time_sum(x):
    return K.sum(x, axis=1)

def squeeze_last_axis(x):
    return tf.squeeze(x, axis=-1)

def expand_last_axis(x):
    return tf.expand_dims(x, axis=-1)

def se_block(x, reduction=8):
    ch = x.shape[-1]
    se = GlobalAveragePooling1D()(x)
    se = Dense(ch // reduction, activation='relu')(se)
    se = Dense(ch, activation='sigmoid')(se)
    se = Reshape((1, ch))(se)
    return Multiply()([x, se])

# Residual CNN Block with SE
def residual_se_cnn_block(x, filters, kernel_size, pool_size=2, drop=0.3, wd=1e-4):
    shortcut = x
    for _ in range(2):
        x = Conv1D(filters, kernel_size, padding='same', use_bias=False,
                   kernel_regularizer=l2(wd))(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
    x = se_block(x)
    if shortcut.shape[-1] != filters:
        shortcut = Conv1D(filters, 1, padding='same', use_bias=False,
                          kernel_regularizer=l2(wd))(shortcut)
        shortcut = BatchNormalization()(shortcut)
    x = add([x, shortcut])
    x = Activation('relu')(x)
    x = MaxPooling1D(pool_size)(x)
    x = Dropout(drop)(x)
    return x

def attention_layer(inputs):
    score = Dense(1, activation='tanh')(inputs)
    score = Lambda(squeeze_last_axis)(score)
    weights = Activation('softmax')(score)
    weights = Lambda(expand_last_axis)(weights)
    context = Multiply()([inputs, weights])
    context = Lambda(time_sum)(context)
    return context

In [None]:
# Optimized physics calculations using vectorization
def remove_gravity_from_acc_vectorized(acc_values, quat_values):
    """Vectorized gravity removal for better performance"""
    num_samples = acc_values.shape[0]
    linear_accel = acc_values.copy()
    
    # Filter valid quaternions
    valid_mask = ~(np.any(np.isnan(quat_values), axis=1) | 
                   np.all(np.isclose(quat_values, 0), axis=1))
    
    if np.any(valid_mask):
        # Process all valid quaternions at once
        valid_quats = quat_values[valid_mask]
        
        # Batch rotation computation
        try:
            rotations = R.from_quat(valid_quats)
            gravity_world = np.array([0, 0, 9.81])
            
            # Apply rotations in batch
            gravity_sensor_frames = rotations.apply(gravity_world, inverse=True)
            linear_accel[valid_mask] = acc_values[valid_mask] - gravity_sensor_frames
        except:
            pass
            
    return linear_accel

In [None]:
def calculate_angular_velocity_vectorized(quat_values, time_delta=1/200):
    """Vectorized angular velocity calculation"""
    num_samples = quat_values.shape[0]
    angular_vel = np.zeros((num_samples, 3))
    
    if num_samples < 2:
        return angular_vel
        
    # Process in chunks for memory efficiency
    chunk_size = 1000
    for start in range(0, num_samples - 1, chunk_size):
        end = min(start + chunk_size, num_samples - 1)
        
        q_t = quat_values[start:end]
        q_t_plus_dt = quat_values[start+1:end+1]
        
        # Find valid pairs
        valid_mask = ~(np.any(np.isnan(q_t), axis=1) | 
                      np.all(np.isclose(q_t, 0), axis=1) |
                      np.any(np.isnan(q_t_plus_dt), axis=1) | 
                      np.all(np.isclose(q_t_plus_dt, 0), axis=1))
        
        if np.any(valid_mask):
            try:
                valid_indices = np.where(valid_mask)[0]
                rot_t = R.from_quat(q_t[valid_mask])
                rot_t_plus_dt = R.from_quat(q_t_plus_dt[valid_mask])
                
                # Batch computation
                delta_rot = rot_t.inv() * rot_t_plus_dt
                angular_vel[start + valid_indices] = delta_rot.as_rotvec() / time_delta
            except:
                pass
                
    return angular_vel

In [None]:
# Simplified MixUp generator
class FastMixupGenerator(Sequence):
    def __init__(self, X, y, batch_size, alpha=0.4):
        self.X, self.y = X, y
        self.batch = batch_size
        self.alpha = alpha
        self.indices = np.arange(len(X))
        
    def __len__(self):
        return int(np.ceil(len(self.X) / self.batch))
        
    def __getitem__(self, i):
        idx = self.indices[i*self.batch:(i+1)*self.batch]
        Xb, yb = self.X[idx], self.y[idx]
        
        # Simple mixup
        lam = np.random.beta(self.alpha, self.alpha)
        perm = np.random.permutation(len(Xb))
        X_mix = lam * Xb + (1-lam) * Xb[perm]
        y_mix = lam * yb + (1-lam) * yb[perm]
        
        return X_mix.astype('float32'), y_mix.astype('float32')
        
    def on_epoch_end(self):
        np.random.shuffle(self.indices)

In [None]:
# Simplified model architecture (slightly reduced complexity)
def build_two_branch_model_optimized(pad_len, imu_dim, tof_dim, n_classes, wd=1e-4):
    inp = Input(shape=(pad_len, imu_dim+tof_dim))
    imu = Lambda(lambda t: t[:, :, :imu_dim])(inp)
    tof = Lambda(lambda t: t[:, :, imu_dim:])(inp)
    
    # IMU branch - simplified
    x1 = residual_se_cnn_block(imu, 64, 5, drop=0.1, wd=wd)
    x1 = residual_se_cnn_block(x1, 128, 5, drop=0.1, wd=wd)
    
    # TOF branch - simplified
    x2 = Conv1D(128, 1, padding='same', use_bias=False, kernel_regularizer=l2(wd))(tof)
    x2 = BatchNormalization()(x2)
    x2 = Activation('relu')(x2)
    x2 = residual_se_cnn_block(x2, 192, 3, drop=0.2, wd=wd)
    x2 = residual_se_cnn_block(x2, 256, 3, drop=0.2, wd=wd)
    
    # Simple concatenation
    merged = Concatenate()([x1, x2])
    
    # Single RNN layer instead of multiple
    x = Bidirectional(GRU(128, return_sequences=True, kernel_regularizer=l2(wd),
                         dropout=0.2, recurrent_dropout=0.2))(merged)
    x = Dropout(0.3)(x)
    
    # Attention
    x_att = attention_layer(x)
    x_pool = GlobalAveragePooling1D()(x)
    x = Concatenate()([x_att, x_pool])
    
    # Simplified classifier
    x = Dense(256, use_bias=False, kernel_regularizer=l2(wd))(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Dropout(0.4)(x)
    
    x = Dense(128, use_bias=False, kernel_regularizer=l2(wd))(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Dropout(0.3)(x)
    
    # Output layer with float32 for stability with mixed precision
    x = Dense(n_classes, kernel_regularizer=l2(wd), dtype='float32')(x)
    out = Activation('softmax', dtype='float32')(x)
    
    return Model(inp, out)

In [None]:
# Parallel processing for sequence features
def process_sequence(seq_data):
    """Process a single sequence with all feature engineering"""
    seq_data = seq_data.copy()
    
    # Dictionary to collect all new columns
    new_columns = {}
    
    # Get numpy arrays for faster processing
    acc_values = seq_data[['acc_x', 'acc_y', 'acc_z']].values
    quat_values = seq_data[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values

    try:
        rotations = R.from_quat(quat_values)
        euler_angles = rotations.as_euler('xyz', degrees=True)
        new_columns['euler_x'] = euler_angles[:, 0]
        new_columns['euler_y'] = euler_angles[:, 1]
        new_columns['euler_z'] = euler_angles[:, 2]
    except:
        new_columns['euler_x'] = np.zeros(len(seq_data))
        new_columns['euler_y'] = np.zeros(len(seq_data))
        new_columns['euler_z'] = np.zeros(len(seq_data))
    
    # Linear acceleration
    linear_accel = remove_gravity_from_acc_vectorized(acc_values, quat_values)
    new_columns['linear_acc_x'] = linear_accel[:, 0]
    new_columns['linear_acc_y'] = linear_accel[:, 1]
    new_columns['linear_acc_z'] = linear_accel[:, 2]

    # Acceleration features focusing on z and y axes
    new_columns['acc_yz_mag'] = np.sqrt(seq_data['acc_y']**2 + seq_data['acc_z']**2)
    new_columns['acc_y_z_ratio'] = seq_data['acc_y'] / (seq_data['acc_z'] + 1e-8)
    
    # Magnitudes
    new_columns['acc_mag'] = np.linalg.norm(acc_values, axis=1)
    new_columns['linear_acc_mag'] = np.linalg.norm(linear_accel, axis=1)
    
    # Jerk (simplified)
    new_columns['linear_acc_mag_jerk'] = np.gradient(new_columns['linear_acc_mag']) * 200
    
    # Angular velocity
    angular_vel = calculate_angular_velocity_vectorized(quat_values)
    new_columns['angular_vel_x'] = angular_vel[:, 0]
    new_columns['angular_vel_y'] = angular_vel[:, 1]
    new_columns['angular_vel_z'] = angular_vel[:, 2]
    
    # ToF aggregations (vectorized)
    for i in range(1, 6):
        pixel_cols = [f"tof_{i}_v{p}" for p in range(64)]
        tof_data = seq_data[pixel_cols].values
        tof_data[tof_data == -1] = np.nan
        
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=RuntimeWarning)
            new_columns[f'tof_{i}_mean'] = np.nanmean(tof_data, axis=1)
            new_columns[f'tof_{i}_std'] = np.nanstd(tof_data, axis=1)
            new_columns[f'tof_{i}_min'] = np.nanmin(tof_data, axis=1)
            new_columns[f'tof_{i}_max'] = np.nanmax(tof_data, axis=1)

    # Add rotation-specific features for rot_z and rot_w
    new_columns['rot_z_w_product'] = seq_data['rot_z'] * seq_data['rot_w']
    new_columns['rot_z_w_ratio'] = seq_data['rot_z'] / (seq_data['rot_w'] + 1e-8)

    # Rolling statistics for important features (window of 10 samples = 50ms)
    for col in ['rot_z', 'rot_w', 'acc_z', 'acc_y', 'thm_2']:
        new_columns[f'{col}_rolling_mean'] = seq_data[col].rolling(10, center=True).mean()
        new_columns[f'{col}_rolling_std'] = seq_data[col].rolling(10, center=True).std()
        new_columns[f'{col}_diff'] = seq_data[col].diff()
    
    # Thermopile 2 specific features
    thm_2_mean = seq_data['thm_2'].mean()
    thm_2_std = seq_data['thm_2'].std()
    new_columns['thm_2_normalized'] = (seq_data['thm_2'] - thm_2_mean) / (thm_2_std + 1e-8)
    new_columns['thm_2_delta_from_mean'] = seq_data['thm_2'] - seq_data[['thm_1', 'thm_2', 'thm_3', 'thm_4', 'thm_5']].mean(axis=1)

    # ToF band statistics for tof_1 and tof_2 (statistically significant bands)
    bands_tof_1_2 = [
        (3, 7),    # Band 1: v[3-7]
        (11, 15),  # Band 2: v[11-15]
        (19, 23),  # Band 3: v[19-23]
        (27, 31),  # Band 4: v[27-31]
        (35, 39),  # Band 5: v[35-39]
        (43, 47),  # Band 6: v[43-47]
        (51, 55),  # Band 7: v[51-55]
        (59, 63),  # Band 8: v[59-63]
    ]
    
    for tof_num in [1, 2]:  # Only process tof_1 and tof_2
        for band_idx, (start, end) in enumerate(bands_tof_1_2, 1):
            # Get columns for this band (inclusive range)
            band_cols = [f"tof_{tof_num}_v{p}" for p in range(start, end + 1)]
            band_data = seq_data[band_cols].values
            band_data[band_data == -1] = np.nan
            
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", category=RuntimeWarning)
                new_columns[f'tof_{tof_num}_band{band_idx}_mean'] = np.nanmean(band_data, axis=1)
                new_columns[f'tof_{tof_num}_band{band_idx}_std'] = np.nanstd(band_data, axis=1)
                new_columns[f'tof_{tof_num}_band{band_idx}_var'] = np.nanvar(band_data, axis=1)
                new_columns[f'tof_{tof_num}_band{band_idx}_min'] = np.nanmin(band_data, axis=1)
                new_columns[f'tof_{tof_num}_band{band_idx}_max'] = np.nanmax(band_data, axis=1)
                
    # ToF band statistics for tof_3 (statistically significant bands)
    bands_tof_3 = [
        (0, 47),    
        (50, 55)
    ]
    
    for tof_num in [3]:  # Only process tof_3
        for band_idx, (start, end) in enumerate(bands_tof_3, 1):
            # Get columns for this band (inclusive range)
            band_cols = [f"tof_{tof_num}_v{p}" for p in range(start, end + 1)]
            band_data = seq_data[band_cols].values
            band_data[band_data == -1] = np.nan
            
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", category=RuntimeWarning)
                new_columns[f'tof_{tof_num}_band{band_idx}_mean'] = np.nanmean(band_data, axis=1)
                new_columns[f'tof_{tof_num}_band{band_idx}_std'] = np.nanstd(band_data, axis=1)
                new_columns[f'tof_{tof_num}_band{band_idx}_var'] = np.nanvar(band_data, axis=1)
                new_columns[f'tof_{tof_num}_band{band_idx}_min'] = np.nanmin(band_data, axis=1)
                new_columns[f'tof_{tof_num}_band{band_idx}_max'] = np.nanmax(band_data, axis=1)
                
    # ToF band statistics for tof_4 (statistically significant bands)
    bands_tof_4 = [
        (0, 3),    
        (7, 9),
        (15, 16),
        (21, 23),
        (28, 31),
        (35, 39),
        (43, 47),
        (50, 55),
        (58, 63)
    ]
    
    for tof_num in [4]:  # Only process tof_4
        for band_idx, (start, end) in enumerate(bands_tof_4, 1):
            # Get columns for this band (inclusive range)
            band_cols = [f"tof_{tof_num}_v{p}" for p in range(start, end + 1)]
            band_data = seq_data[band_cols].values
            band_data[band_data == -1] = np.nan
            
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", category=RuntimeWarning)
                new_columns[f'tof_{tof_num}_band{band_idx}_mean'] = np.nanmean(band_data, axis=1)
                new_columns[f'tof_{tof_num}_band{band_idx}_std'] = np.nanstd(band_data, axis=1)
                new_columns[f'tof_{tof_num}_band{band_idx}_var'] = np.nanvar(band_data, axis=1)
                new_columns[f'tof_{tof_num}_band{band_idx}_min'] = np.nanmin(band_data, axis=1)
                new_columns[f'tof_{tof_num}_band{band_idx}_max'] = np.nanmax(band_data, axis=1)
                
    # ToF band statistics for tof_5 (statistically significant bands)
    bands_tof_5 = [
        (1, 7),    
        (9, 15),
        (18, 23),
        (48, 49),
        (56, 61)  
    ]
    
    for tof_num in [5]:  # Only process tof_5
        for band_idx, (start, end) in enumerate(bands_tof_5, 1):
            # Get columns for this band (inclusive range)
            band_cols = [f"tof_{tof_num}_v{p}" for p in range(start, end + 1)]
            band_data = seq_data[band_cols].values
            band_data[band_data == -1] = np.nan
            
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", category=RuntimeWarning)
                new_columns[f'tof_{tof_num}_band{band_idx}_mean'] = np.nanmean(band_data, axis=1)
                new_columns[f'tof_{tof_num}_band{band_idx}_std'] = np.nanstd(band_data, axis=1)
                new_columns[f'tof_{tof_num}_band{band_idx}_var'] = np.nanvar(band_data, axis=1)
                new_columns[f'tof_{tof_num}_band{band_idx}_min'] = np.nanmin(band_data, axis=1)
                new_columns[f'tof_{tof_num}_band{band_idx}_max'] = np.nanmax(band_data, axis=1) 
    
    # Create DataFrame from new columns and concatenate with original
    new_columns_df = pd.DataFrame(new_columns, index=seq_data.index)
    result_df = pd.concat([seq_data, new_columns_df], axis=1)
    
    return result_df

In [None]:
# Prediction function for individual sequences
def predict(sequence: pl.DataFrame, demographics: pl.DataFrame) -> str:
    # Convert to pandas for compatibility
    df_seq = sequence.to_pandas()
    
    # Get numpy arrays for faster processing
    acc_values = df_seq[['acc_x', 'acc_y', 'acc_z']].values
    quat_values = df_seq[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values
    
    # Euler angles
    try:
        rotations = R.from_quat(quat_values)
        euler_angles = rotations.as_euler('xyz', degrees=True)
        df_seq['euler_x'] = euler_angles[:, 0]
        df_seq['euler_y'] = euler_angles[:, 1]
        df_seq['euler_z'] = euler_angles[:, 2]
    except:
        df_seq['euler_x'] = np.zeros(len(df_seq))
        df_seq['euler_y'] = np.zeros(len(df_seq))
        df_seq['euler_z'] = np.zeros(len(df_seq))
    
    # Linear acceleration
    linear_accel = remove_gravity_from_acc_vectorized(acc_values, quat_values)
    df_seq['linear_acc_x'] = linear_accel[:, 0]
    df_seq['linear_acc_y'] = linear_accel[:, 1]
    df_seq['linear_acc_z'] = linear_accel[:, 2]
    
    # Acceleration features focusing on z and y axes
    df_seq['acc_yz_mag'] = np.sqrt(df_seq['acc_y']**2 + df_seq['acc_z']**2)
    df_seq['acc_y_z_ratio'] = df_seq['acc_y'] / (df_seq['acc_z'] + 1e-8)
    
    # Magnitudes
    df_seq['acc_mag'] = np.linalg.norm(acc_values, axis=1)
    df_seq['linear_acc_mag'] = np.linalg.norm(linear_accel, axis=1)
    
    # Jerk (simplified)
    df_seq['linear_acc_mag_jerk'] = np.gradient(df_seq['linear_acc_mag']) * 200
    
    # Angular velocity
    angular_vel = calculate_angular_velocity_vectorized(quat_values)
    df_seq['angular_vel_x'] = angular_vel[:, 0]
    df_seq['angular_vel_y'] = angular_vel[:, 1]
    df_seq['angular_vel_z'] = angular_vel[:, 2]
    
    # ToF aggregations (vectorized)
    for i in range(1, 6):
        pixel_cols = [f"tof_{i}_v{p}" for p in range(64)]
        tof_data = df_seq[pixel_cols].values
        tof_data[tof_data == -1] = np.nan
        
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=RuntimeWarning)
            df_seq[f'tof_{i}_mean'] = np.nanmean(tof_data, axis=1)
            df_seq[f'tof_{i}_std'] = np.nanstd(tof_data, axis=1)
            df_seq[f'tof_{i}_min'] = np.nanmin(tof_data, axis=1)
            df_seq[f'tof_{i}_max'] = np.nanmax(tof_data, axis=1)
    
    # Add rotation-specific features for rot_z and rot_w
    df_seq['rot_z_w_product'] = df_seq['rot_z'] * df_seq['rot_w']
    df_seq['rot_z_w_ratio'] = df_seq['rot_z'] / (df_seq['rot_w'] + 1e-8)
    
    # Rolling statistics for important features (window of 10 samples = 50ms)
    for col in ['rot_z', 'rot_w', 'acc_z', 'acc_y', 'thm_2']:
        df_seq[f'{col}_rolling_mean'] = df_seq[col].rolling(10, center=True).mean()
        df_seq[f'{col}_rolling_std'] = df_seq[col].rolling(10, center=True).std()
        df_seq[f'{col}_diff'] = df_seq[col].diff()
    
    # Thermopile 2 specific features
    thm_2_mean = df_seq['thm_2'].mean()
    thm_2_std = df_seq['thm_2'].std()
    df_seq['thm_2_normalized'] = (df_seq['thm_2'] - thm_2_mean) / (thm_2_std + 1e-8)
    df_seq['thm_2_delta_from_mean'] = df_seq['thm_2'] - df_seq[['thm_1', 'thm_2', 'thm_3', 'thm_4', 'thm_5']].mean(axis=1)
    
    # ToF band statistics for tof_1 and tof_2 (statistically significant bands)
    bands_tof_1_2 = [
        (3, 7),    # Band 1: v[3-7]
        (11, 15),  # Band 2: v[11-15]
        (19, 23),  # Band 3: v[19-23]
        (27, 31),  # Band 4: v[27-31]
        (35, 39),  # Band 5: v[35-39]
        (43, 47),  # Band 6: v[43-47]
        (51, 55),  # Band 7: v[51-55]
        (59, 63),  # Band 8: v[59-63]
    ]
    
    for tof_num in [1, 2]:  # Only process tof_1 and tof_2
        for band_idx, (start, end) in enumerate(bands_tof_1_2, 1):
            # Get columns for this band (inclusive range)
            band_cols = [f"tof_{tof_num}_v{p}" for p in range(start, end + 1)]
            band_data = df_seq[band_cols].values
            band_data[band_data == -1] = np.nan
            
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", category=RuntimeWarning)
                df_seq[f'tof_{tof_num}_band{band_idx}_mean'] = np.nanmean(band_data, axis=1)
                df_seq[f'tof_{tof_num}_band{band_idx}_std'] = np.nanstd(band_data, axis=1)
                df_seq[f'tof_{tof_num}_band{band_idx}_var'] = np.nanvar(band_data, axis=1)
                df_seq[f'tof_{tof_num}_band{band_idx}_min'] = np.nanmin(band_data, axis=1)
                df_seq[f'tof_{tof_num}_band{band_idx}_max'] = np.nanmax(band_data, axis=1)
    
    # ToF band statistics for tof_3 (statistically significant bands)
    bands_tof_3 = [
        (0, 47),    
        (50, 55)
    ]
    
    for tof_num in [3]:  # Only process tof_3
        for band_idx, (start, end) in enumerate(bands_tof_3, 1):
            # Get columns for this band (inclusive range)
            band_cols = [f"tof_{tof_num}_v{p}" for p in range(start, end + 1)]
            band_data = df_seq[band_cols].values
            band_data[band_data == -1] = np.nan
            
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", category=RuntimeWarning)
                df_seq[f'tof_{tof_num}_band{band_idx}_mean'] = np.nanmean(band_data, axis=1)
                df_seq[f'tof_{tof_num}_band{band_idx}_std'] = np.nanstd(band_data, axis=1)
                df_seq[f'tof_{tof_num}_band{band_idx}_var'] = np.nanvar(band_data, axis=1)
                df_seq[f'tof_{tof_num}_band{band_idx}_min'] = np.nanmin(band_data, axis=1)
                df_seq[f'tof_{tof_num}_band{band_idx}_max'] = np.nanmax(band_data, axis=1)
    
    # ToF band statistics for tof_4 (statistically significant bands)
    bands_tof_4 = [
        (0, 3),    
        (7, 9),
        (15, 16),
        (21, 23),
        (28, 31),
        (35, 39),
        (43, 47),
        (50, 55),
        (58, 63)
    ]
    
    for tof_num in [4]:  # Only process tof_4
        for band_idx, (start, end) in enumerate(bands_tof_4, 1):
            # Get columns for this band (inclusive range)
            band_cols = [f"tof_{tof_num}_v{p}" for p in range(start, end + 1)]
            band_data = df_seq[band_cols].values
            band_data[band_data == -1] = np.nan
            
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", category=RuntimeWarning)
                df_seq[f'tof_{tof_num}_band{band_idx}_mean'] = np.nanmean(band_data, axis=1)
                df_seq[f'tof_{tof_num}_band{band_idx}_std'] = np.nanstd(band_data, axis=1)
                df_seq[f'tof_{tof_num}_band{band_idx}_var'] = np.nanvar(band_data, axis=1)
                df_seq[f'tof_{tof_num}_band{band_idx}_min'] = np.nanmin(band_data, axis=1)
                df_seq[f'tof_{tof_num}_band{band_idx}_max'] = np.nanmax(band_data, axis=1)
    
    # ToF band statistics for tof_5 (statistically significant bands)
    bands_tof_5 = [
        (1, 7),    
        (9, 15),
        (18, 23),
        (48, 49),
        (56, 61)  
    ]
    
    for tof_num in [5]:  # Only process tof_5
        for band_idx, (start, end) in enumerate(bands_tof_5, 1):
            # Get columns for this band (inclusive range)
            band_cols = [f"tof_{tof_num}_v{p}" for p in range(start, end + 1)]
            band_data = df_seq[band_cols].values
            band_data[band_data == -1] = np.nan
            
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", category=RuntimeWarning)
                df_seq[f'tof_{tof_num}_band{band_idx}_mean'] = np.nanmean(band_data, axis=1)
                df_seq[f'tof_{tof_num}_band{band_idx}_std'] = np.nanstd(band_data, axis=1)
                df_seq[f'tof_{tof_num}_band{band_idx}_var'] = np.nanvar(band_data, axis=1)
                df_seq[f'tof_{tof_num}_band{band_idx}_min'] = np.nanmin(band_data, axis=1)
                df_seq[f'tof_{tof_num}_band{band_idx}_max'] = np.nanmax(band_data, axis=1)
    
    # Extract features and scale
    mat_unscaled = df_seq[final_feature_cols].fillna(0).values.astype('float32')
    mat_scaled = scaler.transform(mat_unscaled)
    
    # Pad sequence
    pad_input = pad_sequences([mat_scaled], maxlen=pad_len, padding='post', 
                             truncating='post', dtype='float32')
    
    # Ensemble prediction
    all_preds = []
    for model in models:
        pred = model.predict(pad_input, verbose=0)[0]
        all_preds.append(pred)
    
    # Average predictions
    avg_pred = np.mean(all_preds, axis=0)
    predicted_class_idx = avg_pred.argmax()
    
    # Return gesture class name
    return gesture_classes[predicted_class_idx]

## Adding Separate Data Load Function

In [None]:
# Load data
train_df = pd.read_csv(raw_dir / "train.csv")
train_dem_df = pd.read_csv(raw_dir / "train_demographics.csv")
train_df = pd.merge(train_df, train_dem_df, on='subject', how='left')

In [None]:
def load_data(df):
    df = df.copy()
    
    print("Processing sequences with parallel computation...")
    
    # Group by sequence
    sequences = [group for _, group in df.groupby('sequence_id')]
    
    # Process in parallel
    n_cores = multiprocessing.cpu_count()
    print(f"Using {n_cores} cores for parallel processing")
    
    processed_sequences = Parallel(n_jobs=n_cores)(
        delayed(process_sequence)(seq) for seq in sequences
    )

    # Combine processed sequences
    df = pd.concat(processed_sequences, ignore_index=True)
    
    return df

processed_df = load_data(train_df)

In [None]:
def feature_engineering(df):   
    df = df.copy()

    # checks N/A before combining
    def combine_or_na(a, b):
        if 'N/A' in str(a) or 'N/A' in str(b):
            return 'N/A'
        return f"{a}_{b}"

    # combine orientation and gesture
    df['orientation_gesture'] = df.apply(lambda x: combine_or_na(x['orientation'], x['gesture']), axis=1).astype('category') 
        
    # behavioural boolean columns
    df['performs_gesture'] = df['behavior'].str.contains('Performs gesture', case=False, na=False)
    df['move_hand_to_target'] = df['behavior'].str.contains('Moves hand to target location', case=False, na=False)
    df['hand_at_target'] = df['behavior'].str.contains('Hand at target location', case=False, na=False)
    df['relaxes_moves_hand_to_target'] = df['behavior'].str.contains('Relaxes and moves hand to target location', case=False, na=False)
    
    return df

fe_train_df = feature_engineering(processed_df)

In [None]:
def fill_missing_values(df):
    # Fill categorical columns with 'N/A'
    cat_cols = df.select_dtypes(include=['object']).columns
    df[cat_cols] = df[cat_cols].fillna("N/A")
    
    # Fill all numerical columns with 0 (int8, int16, int32, int64, uint8, uint16, uint32, uint64, float16, float32, float64, complex64, complex128)
    num_cols = df.select_dtypes(include=[np.number]).columns  
    df[num_cols] = df[num_cols].replace([np.inf, -np.inf, '', None], np.nan).fillna(0)
    
    return df

fe_train_df = fill_missing_values(fe_train_df)

In [None]:
 # Encode labels
le = LabelEncoder()
fe_train_df['gesture_int'] = le.fit_transform(fe_train_df['gesture'])
gesture_classes = le.classes_
np.save(output_dir / "gesture_classes.npy", gesture_classes)

In [None]:
# Get categorical columns
cat_features = fe_train_df.select_dtypes(include=['object', 'category', 'bool']).columns.tolist()
cat_features

In [None]:
# Dictionary to store all encoder mappings
encoder_mappings = {}

print("Label Encoding Categorical Features: ",end="")
for c in cat_features:
    print(f"{c}, ",end="")
    fe_train_df[c] =  fe_train_df[c].astype('category')
    
    full_le = LabelEncoder()
    fe_train_df[c] = full_le.fit_transform(fe_train_df[c].astype(str))

    # Save the mapping of numerical values to string labels
    encoder_mappings[c] = {
        i: label for i, label in enumerate(full_le.classes_)
    }

with open('new_categorical_encoder_mappings.json', 'w') as f:
    json.dump(encoder_mappings, f, indent=2)

In [None]:
# Define the plot output path
plot_output_path = Path('plots')

# Create directory if it doesn't exist
plot_output_path.mkdir(exist_ok=True)
print(f"Directory '{plot_output_path}' is ready.")

In [None]:
# Ttest analysis
def perform_ttest_analysis(df, group_col, group1_value, group2_value, output_path):  
    # Filter data for the two groups
    group1_data = df[df[group_col] == group1_value].copy()
    group2_data = df[df[group_col] == group2_value].copy()
    
    print(f"Group 1 ({group1_value}): {len(group1_data)} records")
    print(f"Group 2 ({group2_value}): {len(group2_data)} records")
    
    # Get numerical columns
    numerical_cols = df.select_dtypes(include=['int64', 'float64', 'int32', 'float32']).columns.tolist()
    
    # Remove the group column if it's numerical
    if group_col in numerical_cols:
        numerical_cols.remove(group_col)
    
    print(f"Testing {len(numerical_cols)} numerical variables")
    
    # Initialize results list
    results = []
    
    # Perform t-tests for each numerical column
    for col in tqdm(numerical_cols, desc="Running t-tests and creating plots"):
        try:
            # Get data for both groups for this column (remove NaN values)
            group1_values = group1_data[col].dropna()
            group2_values = group2_data[col].dropna()
            
            # Skip if insufficient data
            if len(group1_values) < 2 or len(group2_values) < 2:
                print(f"Warning: Insufficient data for {col} (Group1: {len(group1_values)}, Group2: {len(group2_values)})")
                continue
            
            # Perform Levene's test for equal variances
            levene_stat, levene_p = stats.levene(group1_values, group2_values)
            equal_var = levene_p > 0.05
            
            # Perform independent t-test
            t_stat, p_value = stats.ttest_ind(group1_values, group2_values, equal_var=equal_var)
            
            # Calculate effect size (Cohen's d)
            if equal_var:
                pooled_std = np.sqrt(((len(group1_values) - 1) * group1_values.var() + 
                                     (len(group2_values) - 1) * group2_values.var()) / 
                                    (len(group1_values) + len(group2_values) - 2))
            else:
                pooled_std = np.sqrt((group1_values.var() + group2_values.var()) / 2)
            
            cohens_d = (group1_values.mean() - group2_values.mean()) / pooled_std if pooled_std != 0 else 0
            
            # Calculate 95% confidence interval for the difference in means
            diff_mean = group2_values.mean() - group1_values.mean()
            se_diff = np.sqrt(group1_values.var()/len(group1_values) + group2_values.var()/len(group2_values))
            t_critical = stats.t.ppf(0.975, len(group1_values) + len(group2_values) - 2)
            ci_lower = diff_mean - t_critical * se_diff
            ci_upper = diff_mean + t_critical * se_diff
            
            # Store results
            results.append({
                'variable': col,
                'group1_mean': group1_values.mean(),
                'group1_std': group1_values.std(),
                'group1_median': group1_values.median(),
                'group1_count': len(group1_values),
                'group2_mean': group2_values.mean(),
                'group2_std': group2_values.std(),
                'group2_median': group2_values.median(),
                'group2_count': len(group2_values),
                'mean_difference': diff_mean,
                'ci_lower': ci_lower,
                'ci_upper': ci_upper,
                't_statistic': t_stat,
                'p_value': p_value,
                'cohens_d': cohens_d,
                'equal_variances': equal_var,
                'levene_p_value': levene_p,
                'significant_005': p_value < 0.05,
                'significant_001': p_value < 0.01,
                'significant_0001': p_value < 0.001
            })
            
            # Create comparison plot
            create_comparison_plot(group_col, group1_values, group2_values, col, 
                                 group1_value, group2_value, 
                                 t_stat, p_value, cohens_d, output_path)
            
        except Exception as e:
            print(f"Error processing {col}: {str(e)}")
            continue
    
    # Create results DataFrame
    results_df = pd.DataFrame(results)
    
    if len(results_df) > 0:
        # Sort by p-value
        results_df = results_df.sort_values('p_value').reset_index(drop=True)
        
        # Add interpretation columns
        results_df['effect_size_interpretation'] = results_df['cohens_d'].abs().apply(interpret_cohens_d)
        results_df['significance_level'] = results_df.apply(get_significance_level, axis=1)
    
    return results_df

def interpret_cohens_d(d):
    """Interpret Cohen's d effect size"""
    d_abs = abs(d)
    if d_abs < 0.2:
        return "negligible"
    elif d_abs < 0.5:
        return "small"
    elif d_abs < 0.8:
        return "medium"
    else:
        return "large"

def get_significance_level(row):
    """Get significance level string"""
    if row['p_value'] < 0.001:
        return "*** (p < 0.001)"
    elif row['p_value'] < 0.01:
        return "** (p < 0.01)"
    elif row['p_value'] < 0.05:
        return "* (p < 0.05)"
    else:
        return "Not significant"

def create_comparison_plot(group_col, group1_values, group2_values, col_name, 
                          group1_name, group2_name, t_stat, p_value, cohens_d, output_path):
    """Create comprehensive comparison plot for two groups"""
    
    # Set up the figure with subplots
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle(f'Statistical Comparison: {col_name}', fontsize=16, fontweight='bold')
    
    # Color scheme
    color1 = '#ff7f7f'  # Light red for group 1
    color2 = '#7fbfff'  # Light blue for group 2
    
    # 1. Histogram comparison (top left)
    ax1 = axes[0, 0]
    
    # Calculate bins for better visualization
    all_values = np.concatenate([group1_values, group2_values])
    bins = np.histogram_bin_edges(all_values, bins=30)
    
    ax1.hist(group1_values, bins=bins, alpha=0.7, label=f'{group1_name} (n={len(group1_values)})', 
             color=color1, density=True, edgecolor='black', linewidth=0.5)
    ax1.hist(group2_values, bins=bins, alpha=0.7, label=f'{group2_name} (n={len(group2_values)})', 
             color=color2, density=True, edgecolor='black', linewidth=0.5)
    
    # Add mean lines
    ax1.axvline(group1_values.mean(), color='darkred', linestyle='--', linewidth=2, 
                label=f'{group1_name} Mean: {group1_values.mean():.2f}')
    ax1.axvline(group2_values.mean(), color='darkblue', linestyle='--', linewidth=2,
                label=f'{group2_name} Mean: {group2_values.mean():.2f}')
    
    ax1.set_title('Distribution Comparison (Histogram)')
    ax1.set_xlabel(col_name)
    ax1.set_ylabel('Density')
    ax1.legend(fontsize=9)
    ax1.grid(True, alpha=0.3)
    
    # 2. Box plot comparison (top right)
    ax2 = axes[0, 1]
    
    box_data = [group1_values, group2_values]
    box_labels = [f'{group1_name}\n(n={len(group1_values)})', f'{group2_name}\n(n={len(group2_values)})']
    
    bp = ax2.boxplot(box_data, labels=box_labels, patch_artist=True, 
                     boxprops=dict(alpha=0.7), showfliers=True)
    bp['boxes'][0].set_facecolor(color1)
    bp['boxes'][1].set_facecolor(color2)
    
    ax2.set_title('Box Plot Comparison')
    ax2.set_ylabel(col_name)
    ax2.grid(True, alpha=0.3)
    
    # 3. KDE plot (bottom left)
    ax3 = axes[1, 0]
    
    try:
        # Only plot KDE if we have enough data points
        if len(group1_values) > 3:
            sns.kdeplot(data=group1_values, label=f'{group1_name}', 
                       color='darkred', ax=ax3, linewidth=2)
        if len(group2_values) > 3:
            sns.kdeplot(data=group2_values, label=f'{group2_name}', 
                       color='darkblue', ax=ax3, linewidth=2)
    except:
        # Fallback to simple line plot if KDE fails
        ax3.hist(group1_values, alpha=0.5, density=True, color=color1, label=group1_name)
        ax3.hist(group2_values, alpha=0.5, density=True, color=color2, label=group2_name)
    
    ax3.set_title('Density Comparison (KDE)')
    ax3.set_xlabel(col_name)
    ax3.set_ylabel('Density')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    # 4. Statistics summary (bottom right)
    ax4 = axes[1, 1]
    
    # Calculate additional statistics
    effect_size_interp = interpret_cohens_d(cohens_d)
    sig_level = get_significance_level({'p_value': p_value})
    
    stats_text = f'T-Test Results:\n'
    stats_text += f'━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n'
    stats_text += f'T-statistic: {t_stat:.4f}\n'
    stats_text += f'P-value: {p_value:.2e}\n'
    stats_text += f'Cohen\'s d: {cohens_d:.4f} ({effect_size_interp})\n'
    stats_text += f'Significance: {sig_level}\n\n'
    
    stats_text += f'{group1_name}:\n'
    stats_text += f'  Mean ± SD: {group1_values.mean():.2f} ± {group1_values.std():.2f}\n'
    stats_text += f'  Median: {group1_values.median():.2f}\n'
    stats_text += f'  Range: [{group1_values.min():.2f}, {group1_values.max():.2f}]\n'
    stats_text += f'  Count: {len(group1_values)}\n\n'
    
    stats_text += f'{group2_name}:\n'
    stats_text += f'  Mean ± SD: {group2_values.mean():.2f} ± {group2_values.std():.2f}\n'
    stats_text += f'  Median: {group2_values.median():.2f}\n'
    stats_text += f'  Range: [{group2_values.min():.2f}, {group2_values.max():.2f}]\n'
    stats_text += f'  Count: {len(group2_values)}\n\n'
    
    # Difference
    mean_diff = group2_values.mean() - group1_values.mean()
    stats_text += f'Mean Difference: {mean_diff:.2f}\n'
    stats_text += f'({group2_name} - {group1_name})\n\n'
    
    # Effect size interpretation
    stats_text += f'Effect Size Interpretation:\n'
    if abs(cohens_d) < 0.2:
        stats_text += f'Negligible practical difference'
    elif abs(cohens_d) < 0.5:
        stats_text += f'Small practical difference'
    elif abs(cohens_d) < 0.8:
        stats_text += f'Medium practical difference'
    else:
        stats_text += f'Large practical difference'
    
    ax4.text(0.05, 0.95, stats_text,
             transform=ax4.transAxes,
             verticalalignment='top',
             horizontalalignment='left',
             bbox=dict(boxstyle='round,pad=0.5', facecolor='lightgray', alpha=0.8),
             fontsize=10, fontfamily='monospace')
    ax4.set_xlim(0, 1)
    ax4.set_ylim(0, 1)
    ax4.axis('off')
    
    plt.tight_layout()
    
    # Save with significance indicator in filename
    if p_value < 0.001:
        sig_suffix = "_highly_significant"
    elif p_value < 0.05:
        sig_suffix = "_significant"
    else:
        sig_suffix = "_not_significant"
    
    plt.savefig(f'{output_path}/{group_col}_{col_name}_ttest_{group1_name}_vs_{group2_name}_{sig_suffix}_{effect_size_interp}.png', 
                dpi=300, bbox_inches='tight', facecolor='white')
    plt.close()


# Perform the analysis
#ttest_results = perform_ttest_analysis(
#    df=fe_train_df, 
#    group_col='performs_gesture',
#    group1_value=True,
#    group2_value=False,
#    output_path=plot_output_path
#)

In [None]:
if train:
    print("TRIANING MODE")
    
    # Define feature columns
    imu_cols = ['acc_x', 
                'acc_y', 
                'acc_z', 
                'linear_acc_x', 
                'linear_acc_y',
                'linear_acc_z',
                'rot_x', 
                'rot_y', 
                'rot_z', 
                'rot_w',
                'acc_mag', 
                'linear_acc_mag', 
                'linear_acc_mag_jerk',
                'angular_vel_x', 
                'angular_vel_y', 
                'angular_vel_z',
               # added columns
                'acc_y_rolling_mean',
                'acc_y_rolling_std',
                'acc_z_rolling_mean',
                'acc_z_rolling_std',
                'rot_w_rolling_mean',
                'rot_w_rolling_std',
                'rot_z_rolling_mean',
                'rot_z_w_product',
                'euler_x',
                'euler_y',
                'euler_z',
               ]
    
    thm_cols = [
        'thm_1',
        'thm_2_delta_from_mean',
        'thm_2_normalized',
        'thm_2_rolling_std',
        'thm_2',
        'thm_3',
        'thm_4',
        'thm_5',
    ]
    tof_aggregated_cols = [
        'tof_5_band5_mean',
        'tof_5_band4_mean',
        'tof_5_band1_mean',
        'tof_4_band9_mean',
        'tof_4_band8_mean',
        'tof_4_band7_mean',
        'tof_4_band6_mean',
        'tof_4_band5_mean',
        'tof_4_band4_mean',
        'tof_4_band1_mean',
        'tof_3_band2_mean',
        'tof_3_band1_mean',
        'tof_2_band8_mean',
        'tof_2_band7_mean',
        'tof_2_band6_mean',
        'tof_2_band5_mean',
        'tof_2_band4_mean',
        'tof_2_band3_mean',
        'tof_2_band2_mean',
        'tof_2_band1_mean',
        'tof_1_band8_mean',
        'tof_1_band7_mean',
        'tof_1_band6_mean',
        'tof_1_band5_mean',
        'tof_1_band4_mean',
        'tof_1_band3_mean',
        'tof_1_band2_mean',
    ]
    for i in range(1, 6):
        tof_aggregated_cols.extend([
            f'tof_{i}_mean', f'tof_{i}_std', f'tof_{i}_min', f'tof_{i}_max'
        ])
    
    final_feature_cols = imu_cols + thm_cols + tof_aggregated_cols
    imu_dim = len(imu_cols)
    tof_thm_dim = len(thm_cols) + len(tof_aggregated_cols)
    
    print(f"IMU features: {imu_dim} | THM + ToF features: {tof_thm_dim}")
    np.save(output_dir / "feature_cols.npy", np.array(final_feature_cols))
    
    # Build sequences efficiently
    print("Building sequences...")
    seq_gp = fe_train_df.groupby('sequence_id')
    X_list_unscaled = []
    y_list_int = []
    groups_list = []
    lens = []
    
    for seq_id, seq_df in seq_gp:
        X_list_unscaled.append(
            seq_df[final_feature_cols].fillna(0).values.astype('float32')
        )
        y_list_int.append(seq_df['gesture_int'].iloc[0])
        groups_list.append(seq_df['subject'].iloc[0])
        lens.append(len(seq_df))
    
    # Scaling
    print("Fitting StandardScaler...")
    all_steps_concatenated = np.concatenate(X_list_unscaled, axis=0)
    scaler = StandardScaler().fit(all_steps_concatenated)
    joblib.dump(scaler, output_dir / "scaler.pkl")
    
    # Scale and pad
    X_scaled_list = [scaler.transform(x_seq) for x_seq in X_list_unscaled]
    pad_len = int(np.percentile(lens, pad_percentile))
    np.save(output_dir / "sequence_maxlen.npy", pad_len)
    
    X = pad_sequences(X_scaled_list, maxlen=pad_len, padding='post', 
                      truncating='post', dtype='float32')
    y_stratify = np.array(y_list_int)
    groups = np.array(groups_list)
    y = to_categorical(y_list_int, num_classes=len(le.classes_))
    
    print(f"Final data shape: X={X.shape}, y={y.shape}")
    
    # Cross-validation with reduced folds
    print(f"\nStarting training with {n_splits}-fold CV...")
    sgkf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42)
    
    for fold, (train_idx, val_idx) in enumerate(sgkf.split(X, y_stratify, groups)):
        print(f"\n{'='*20} FOLD {fold+1}/{n_splits} {'='*20}")
        
        X_tr, X_val = X[train_idx], X[val_idx]
        y_tr, y_val = y[train_idx], y[val_idx]
        
        # Build model
        model = build_two_branch_model_optimized(
            pad_len=pad_len, 
            imu_dim=imu_dim, 
            tof_dim=tof_thm_dim, 
            n_classes=len(le.classes_), 
            wd=wd
        )
        
        # Compile with mixed precision optimizer
        opt = Adam(learning_rate=lr_init)
        opt = mixed_precision.LossScaleOptimizer(opt)
        
        model.compile(
            optimizer=opt,
            loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1),
            metrics=['accuracy']
        )
        
        # Simplified callbacks
        callbacks = [
            EarlyStopping(
                monitor='val_loss', 
                patience=patience, 
                restore_best_weights=True, 
                verbose=1
            ),
            tf.keras.callbacks.ReduceLROnPlateau(
                factor=0.5, 
                patience=10, 
                min_lr=1e-6, 
                verbose=1
            ),
            tf.keras.callbacks.ModelCheckpoint(
                str(output_dir / f'model_fold_{fold}_best.h5'),
                save_best_only=True,
                monitor='val_loss'
            )
        ]
        
        # Train with larger batch size
        train_gen = FastMixupGenerator(
            X_tr, y_tr, 
            batch_size=batch_size, 
            alpha=mixup_alpha
        )
        
        history = model.fit(
            train_gen,
            epochs=epochs,
            validation_data=(X_val, y_val),
            callbacks=callbacks,
            verbose=1
        )
        
        # Save model
        model.save(output_dir / f"model_fold_{fold}_final.h5")
        
        # Clear session
        tf.keras.backend.clear_session()
    
    print("\nTraining completed!")
else:
    # INFERENCE MODE
    print("INFERENCE MODE – Loading artifacts from", pretrained_dir)
    
    # Load pretrained artifacts
    final_feature_cols = np.load(pretrained_dir / "feature_cols.npy", allow_pickle=True).tolist()
    pad_len = int(np.load(pretrained_dir / "sequence_maxlen.npy"))
    scaler = joblib.load(pretrained_dir / "scaler.pkl")
    gesture_classes = np.load(pretrained_dir / "gesture_classes.npy", allow_pickle=True)
    
    # Extract dimensions from feature columns
    imu_cols = ['acc_x', 'acc_y', 'acc_z', 
                'linear_acc_x', 'linear_acc_y', 'linear_acc_z',
                'rot_x', 'rot_y', 'rot_z', 'rot_w',
                'acc_mag', 'linear_acc_mag', 'linear_acc_mag_jerk',
                'angular_vel_x', 'angular_vel_y', 'angular_vel_z']
    imu_dim = len(imu_cols)
    tof_thm_dim = len(final_feature_cols) - imu_dim
    
    print(f"Loaded artifacts:")
    print(f"  - Feature columns: {len(final_feature_cols)}")
    print(f"  - Sequence max length: {pad_len}")
    print(f"  - Gesture classes: {len(gesture_classes)}")
    print(f"  - IMU dim: {imu_dim}, ToF+THM dim: {tof_thm_dim}")
    
    # Disable mixed precision for inference to avoid Cast layer issues
    from tensorflow.keras import mixed_precision
    mixed_precision.set_global_policy('float32')
    
    # Define Cast layer for loading models trained with mixed precision
    class Cast(tf.keras.layers.Layer):
        def __init__(self, dtype='float32', **kwargs):
            super().__init__(**kwargs)
            self.target_dtype = dtype
            
        def call(self, inputs):
            return tf.cast(inputs, self.target_dtype)
        
        def get_config(self):
            config = super().get_config()
            config.update({'dtype': self.target_dtype})
            return config
    
    # Custom objects for model loading
    custom_objs = {
        'time_sum': time_sum,
        'squeeze_last_axis': squeeze_last_axis,
        'expand_last_axis': expand_last_axis,
        'se_block': se_block,
        'residual_se_cnn_block': residual_se_cnn_block,
        'attention_layer': attention_layer,
        'Cast': Cast  # Add proper Cast layer support
    }
    
    # Load models
    models = []
    print(f"\n▶ Loading {n_splits} models for ensemble inference...")
    for fold in range(n_splits):
        model_path = pretrained_dir / f"model_fold_{fold}_final.h5"
        print(f"  Loading model: {model_path}")
        model = load_model(model_path, compile=False, custom_objects=custom_objs)
        models.append(model)
    print(f"[INFO] Successfully loaded {len(models)} models")

In [None]:
# Kaggle evaluation server setup
# import kaggle_evaluation.cmi_inference_server
# inference_server = kaggle_evaluation.cmi_inference_server.CMIInferenceServer(predict)

# if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
#    inference_server.serve()
#else:
#    print("\n▶ Running local gateway for testing...")
#    inference_server.run_local_gateway(
#        data_paths=(
#            '/kaggle/input/cmi-detect-behavior-with-sensor-data/test.csv',
#            '/kaggle/input/cmi-detect-behavior-with-sensor-data/test_demographics.csv',
#        )
#    )