In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
gestures = [
    'Above ear - pull hair',
    'Cheek - pinch skin',
    'Drink from bottle/cup',
    'Eyebrow - pull hair',
    'Eyelash - pull hair',
    'Feel around in tray and pull out an object',
    'Forehead - pull hairline',
    'Forehead - scratch',
    'Glasses on/off',
    'Neck - pinch skin',
    'Neck - scratch',
    'Pinch knee/leg skin',
    'Pull air toward your face',
    'Scratch knee/leg skin',
    'Text on phone',
    'Wave hello',
    'Write name in air',
    'Write name on leg'
]

gesture_dict = {i: g for i, g in enumerate(gestures)}

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
from scipy.spatial.transform import Rotation as R
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from sklearn.metrics import f1_score
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.preprocessing import StandardScaler
import json
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import random

# ==============================
# Feature Engineering Functions
# ==============================

def remove_gravity_from_acc(acc_data, rot_data):
    """Remove gravity component from accelerometer data"""
    if isinstance(acc_data, pd.DataFrame):
        acc_values = acc_data[['acc_x', 'acc_y', 'acc_z']].values
    else:
        acc_values = acc_data

    if isinstance(rot_data, pd.DataFrame):
        quat_values = rot_data[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values
    else:
        quat_values = rot_data

    num_samples = acc_values.shape[0]
    linear_accel = np.zeros_like(acc_values)
    gravity_world = np.array([0, 0, 9.802])

    for i in range(num_samples):
        if np.all(np.isnan(quat_values[i])) or np.all(np.isclose(quat_values[i], 0)):
            linear_accel[i, :] = acc_values[i, :]
            continue

        try:
            rotation = R.from_quat(quat_values[i])
            gravity_sensor_frame = rotation.apply(gravity_world, inverse=True)
            linear_accel[i, :] = acc_values[i, :] - gravity_sensor_frame
        except ValueError:
            linear_accel[i, :] = acc_values[i, :]

    return linear_accel

def calculate_angular_velocity_from_quat(rot_data, time_delta=1/200):
    """Calculate angular velocity from quaternion derivatives"""
    if isinstance(rot_data, pd.DataFrame):
        quat_values = rot_data[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values
    else:
        quat_values = rot_data

    num_samples = quat_values.shape[0]
    angular_vel = np.zeros((num_samples, 3))

    for i in range(num_samples - 1):
        q_t = quat_values[i]
        q_t_plus_dt = quat_values[i+1]

        if np.all(np.isnan(q_t)) or np.all(np.isclose(q_t, 0)) or \
           np.all(np.isnan(q_t_plus_dt)) or np.all(np.isclose(q_t_plus_dt, 0)):
            continue

        try:
            rot_t = R.from_quat(q_t)
            rot_t_plus_dt = R.from_quat(q_t_plus_dt)
            delta_rot = rot_t.inv() * rot_t_plus_dt
            angular_vel[i, :] = delta_rot.as_rotvec() / time_delta
        except ValueError:
            pass

    return angular_vel

def calculate_angular_distance(rot_data):
    """Calculate angular distance between successive quaternions"""
    if isinstance(rot_data, pd.DataFrame):
        quat_values = rot_data[['rot_x', 'rot_y', 'rot_z', 'rot_w']].values
    else:
        quat_values = rot_data

    num_samples = quat_values.shape[0]
    angular_dist = np.zeros(num_samples)

    for i in range(num_samples - 1):
        q1 = quat_values[i]
        q2 = quat_values[i+1]

        if np.all(np.isnan(q1)) or np.all(np.isclose(q1, 0)) or \
           np.all(np.isnan(q2)) or np.all(np.isclose(q2, 0)):
            angular_dist[i] = 0
            continue

        try:
            r1 = R.from_quat(q1)
            r2 = R.from_quat(q2)
            relative_rotation = r1.inv() * r2
            angle = np.linalg.norm(relative_rotation.as_rotvec())
            angular_dist[i] = angle
        except ValueError:
            angular_dist[i] = 0

    return angular_dist

def apply_imu_feature_engineering(df):
    """Apply IMU feature engineering to the DataFrame"""
    df = df.copy()

    # Replace NaN values with 0 in rotation columns
    rotation_cols = ['rot_x', 'rot_y', 'rot_z', 'rot_w']
    df[rotation_cols] = df[rotation_cols].fillna(0)

    # Base features
    df['acc_mag'] = np.sqrt(df['acc_x']**2 + df['acc_y']**2 + df['acc_z']**2)
    df['rot_angle'] = 2 * np.arccos(df['rot_w'].clip(-1, 1))

    # Derivatives
    df['acc_mag_jerk'] = df.groupby('sequence_id')['acc_mag'].diff().fillna(0)
    df['rot_angle_vel'] = df.groupby('sequence_id')['rot_angle'].diff().fillna(0)

    # Gravity-removed acceleration
    linear_accel_list = []
    for seq_id, group in df.groupby('sequence_id'):
        acc_data = group[['acc_x', 'acc_y', 'acc_z']]
        rot_data = group[['rot_x', 'rot_y', 'rot_z', 'rot_w']]
        linear_accel = remove_gravity_from_acc(acc_data, rot_data)
        linear_accel_list.append(
            pd.DataFrame(linear_accel, columns=['linear_acc_x', 'linear_acc_y', 'linear_acc_z'], index=group.index)
        )
    df = pd.concat([df, pd.concat(linear_accel_list)], axis=1)

    # Linear acceleration features
    df['linear_acc_mag'] = np.sqrt(df['linear_acc_x']**2 + df['linear_acc_y']**2 + df['linear_acc_z']**2)
    df['linear_acc_mag_jerk'] = df.groupby('sequence_id')['linear_acc_mag'].diff().fillna(0)

    # Angular velocity
    angular_vel_list = []
    for seq_id, group in df.groupby('sequence_id'):
        rot_data = group[['rot_x', 'rot_y', 'rot_z', 'rot_w']]
        angular_vel = calculate_angular_velocity_from_quat(rot_data)
        angular_vel_list.append(
            pd.DataFrame(angular_vel, columns=['angular_vel_x', 'angular_vel_y', 'angular_vel_z'], index=group.index)
        )
    df = pd.concat([df, pd.concat(angular_vel_list)], axis=1)

    # Angular distance
    angular_dist_list = []
    for seq_id, group in df.groupby('sequence_id'):
        rot_data = group[['rot_x', 'rot_y', 'rot_z', 'rot_w']]
        angular_dist = calculate_angular_distance(rot_data)
        angular_dist_list.append(
            pd.DataFrame(angular_dist, columns=['angular_distance'], index=group.index)
        )
    df = pd.concat([df, pd.concat(angular_dist_list)], axis=1)

    # Define all features
    imu_cols = ['acc_x', 'acc_y', 'acc_z', 'rot_x', 'rot_y', 'rot_z', 'rot_w'] + [
        # Engineered features
        'acc_mag', 'rot_angle',
        'acc_mag_jerk', 'rot_angle_vel',
        'linear_acc_x', 'linear_acc_y', 'linear_acc_z',
        'linear_acc_mag', 'linear_acc_mag_jerk',
        'angular_vel_x', 'angular_vel_y', 'angular_vel_z',
        'angular_distance'
    ]

    return df, imu_cols

# ==============================
# Model Architecture Components
# ==============================

class ResidualBlock1D(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv1d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm1d(channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv1d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm1d(channels)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        out = self.relu(out)
        return out

class FPN1D(nn.Module):
    def __init__(self, in_channels=256, out_channels=256):
        super().__init__()
        self.lateral_conv0 = nn.Conv1d(in_channels, out_channels, kernel_size=1)
        self.lateral_conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1)
        self.lateral_conv2 = nn.Conv1d(in_channels, out_channels, kernel_size=1)

        self.smooth_conv0 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)
        self.smooth_conv1 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)
        self.smooth_conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)

    def forward(self, c0):
        c1 = F.max_pool1d(c0, kernel_size=2, stride=2, ceil_mode=True)
        c2 = F.max_pool1d(c1, kernel_size=2, stride=2, ceil_mode=True)

        p0 = self.lateral_conv0(c0)
        p1 = self.lateral_conv1(c1)
        p2 = self.lateral_conv2(c2)

        p2_up = F.interpolate(p2, size=p1.size(2), mode='linear', align_corners=True)
        p1_combined = p1 + p2_up
        p1_smoothed = self.smooth_conv1(p1_combined)

        p1_up = F.interpolate(p1_smoothed, size=p0.size(2), mode='linear', align_corners=True)
        p0_combined = p0 + p1_up
        p0_smoothed = self.smooth_conv0(p0_combined)

        return p0_smoothed

class SensorBranch(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        # 1×1 projection
        self.input_proj = nn.Conv1d(input_dim, 256, kernel_size=1)

        # Residual blocks
        self.res_cnn = nn.Sequential(
            ResidualBlock1D(256),
            ResidualBlock1D(256),
            ResidualBlock1D(256),
        )

        # Feature Pyramid Network
        self.fpn = FPN1D(256, 256)

        # GRU layers
        self.gru1 = nn.GRU(256, hidden_dim, num_layers=3, batch_first=True, bidirectional=True)
        self.gru2 = nn.GRU(hidden_dim*2, hidden_dim, num_layers=2, batch_first=True, bidirectional=True)
        self.gru3 = nn.GRU(hidden_dim*2, hidden_dim, num_layers=1, batch_first=True, bidirectional=True)

        # Skip connection
        self.skip_proj = nn.Linear(256, hidden_dim*2)

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        for name, param in self.named_parameters():
            if 'weight' in name and 'gru' in name:
                nn.init.xavier_uniform_(param)
            elif 'bias' in name and 'gru' in name:
                nn.init.constant_(param, 0.0)
            elif isinstance(param, nn.Linear):
                nn.init.xavier_uniform_(param)
                if param.bias is not None:
                    nn.init.zeros_(param.bias)

    def forward(self, x, lengths):
        # Input transformation
        x = x.transpose(1, 2)  # (B, C, T)
        x = self.input_proj(x)
        x = self.res_cnn(x)
        x = self.fpn(x)
        x = x.transpose(1, 2)  # (B, T, C)

        # GRU processing
        packed1 = pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
        out1, _ = self.gru1(packed1)
        pad1, _ = pad_packed_sequence(out1, batch_first=True)
        skip1 = self.skip_proj(x)
        prev = pad1 + skip1

        packed2 = pack_padded_sequence(prev, lengths.cpu(), batch_first=True, enforce_sorted=False)
        out2, _ = self.gru2(packed2)
        pad2, _ = pad_packed_sequence(out2, batch_first=True)
        prev = pad2 + prev

        packed3 = pack_padded_sequence(prev, lengths.cpu(), batch_first=True, enforce_sorted=False)
        out3, _ = self.gru3(packed3)
        pad3, _ = pad_packed_sequence(out3, batch_first=True)
        prev = pad3 + prev

        # Last timestep
        idx = (lengths - 1).unsqueeze(1).unsqueeze(2).expand(-1, 1, prev.size(2))
        last = prev.gather(1, idx).squeeze(1)

        return last

class MultiSensorClassifier(nn.Module):
    def __init__(self,
                 imu_input_dim=20,  # Updated for engineered features
                 thm_input_dim=5,
                 tof_input_dim=64,
                 hidden_dim=256,
                 num_heads=8,
                 ffn_hidden=256,
                 num_classes=18,
                 p_branch_mask=0.1,
                 p_feat_dropout=0.2):
        super().__init__()
        # Sensor branches
        self.imu_branch = SensorBranch(imu_input_dim, hidden_dim)
        self.thm_branch = SensorBranch(thm_input_dim, hidden_dim)
        self.tof_branches = nn.ModuleList([
            SensorBranch(tof_input_dim, hidden_dim) for _ in range(5)
        ])

        # Feature fusion
        self.pre_norm = nn.LayerNorm(hidden_dim*2)
        self.attention_layers = nn.ModuleList([
            nn.ModuleDict({
                'attention': nn.MultiheadAttention(
                    embed_dim=hidden_dim*2,
                    num_heads=num_heads,
                    batch_first=True
                ),
                'norm': nn.LayerNorm(hidden_dim*2)
            }) for _ in range(3)
        ])
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim*2, ffn_hidden),
            nn.ReLU(),
            nn.Linear(ffn_hidden, hidden_dim*2)
        )
        self.post_norm = nn.LayerNorm(hidden_dim*2)

        # Classifier
        self.classifier = nn.Linear(hidden_dim*2, num_classes)
        nn.init.xavier_uniform_(self.classifier.weight)
        nn.init.zeros_(self.classifier.bias)

        # Regularization
        self.p_branch_mask = p_branch_mask
        self.feat_dropout = nn.Dropout(p_feat_dropout)

    def forward(self,
                imu_seq, imu_len,
                thm_seq, thm_len,
                tof_inputs, tof_len,
                tof_attention_masks,
                return_features=False):
        # Process IMU
        imu_feat = self.imu_branch(imu_seq, imu_len)

        # Process THM
        thm_feat = self.thm_branch(thm_seq, thm_len)

        # Process ToF sensors
        tof_feats = []
        for i, branch in enumerate(self.tof_branches):
            masked_tof = tof_inputs[i].clone()
            masked_tof[~tof_attention_masks[i]] = 0
            tof_feat = branch(masked_tof, tof_len)
            tof_feats.append(tof_feat)

        # Combine features
        tokens = [imu_feat, thm_feat] + tof_feats
        x = torch.stack(tokens, dim=1)

        # Branch masking
        attention_mask = None
        if self.training and self.p_branch_mask > 0:
            batch_mask = torch.bernoulli(
                torch.full((len(tokens),), 1 - self.p_branch_mask, device=x.device)
            )
            attention_mask = (~batch_mask.bool()).expand(x.size(0), -1)

        # Feature fusion
        x = self.pre_norm(x)
        x = self.feat_dropout(x)

        for attn_layer in self.attention_layers:
            attn_out, _ = attn_layer['attention'](
                x, x, x,
                key_padding_mask=attention_mask
            )
            x = x + attn_out
            x = attn_layer['norm'](x)

        if attention_mask is not None:
            mask_expanded = (~attention_mask).float().unsqueeze(-1)
            fused = (x * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1)
        else:
            fused = x.mean(dim=1)

        out = self.ffn(fused)
        out = self.post_norm(out)
        logits = self.classifier(out)

        if return_features:
            return logits, out
        return logits

# ==============================
# Data Loading & Augmentation
# ==============================

class MultiSensorDataset(Dataset):
    def __init__(self, df, gesture2idx, augment=False, imu_cols=None):
        self.samples = []
        self.gesture2idx = gesture2idx
        self.augment = augment
        self.imu_cols = imu_cols or ['acc_x','acc_y','acc_z','rot_w','rot_x','rot_y','rot_z']

        # Group by sequence_id
        grouped = df.groupby('sequence_id')
        self.sequence_ids = list(grouped.groups.keys())

        # Store all sequences
        for sid, group in grouped:
            group = group.sort_values('sequence_counter')

            # IMU data - use engineered features
            imu = np.nan_to_num(group[self.imu_cols].values.astype(np.float32))

            # THM data
            thm_arr = np.nan_to_num(group[[f'thm_{i}' for i in range(1,6)]].values.astype(np.float32))

            # ToF data
            tofs = []
            tof_masks = []
            for i in range(1,6):
                arr = group[[f'tof_{i}_v{j}' for j in range(64)]].values.astype(np.float32)
                is_valid = ~(np.isnan(arr).all(axis=1) | (arr == -1).all(axis=1))
                tof_masks.append(is_valid)
                arr[arr == -1] = 512
                arr = np.nan_to_num(arr, nan=1000)
                tofs.append(arr)

            # Gesture label
            gesture = group['gesture'].iloc[0]
            label = self.gesture2idx[gesture]
            subject = group['subject'].iloc[0]

            self.samples.append({
                'imu': imu,
                'thm': thm_arr,
                'tofs': tofs,
                'tof_masks': tof_masks,
                'label': label,
                'gesture': gesture,
                'sequence_id': sid,
                'subject': subject
            })

        # Build gesture map for augmentation
        if augment:
            self.gesture_map = {}
            for sample in self.samples:
                gesture = sample['gesture']
                if gesture not in self.gesture_map:
                    self.gesture_map[gesture] = []
                self.gesture_map[gesture].append(sample)

    def __len__(self):
        return len(self.samples) * 3 if self.augment else len(self.samples)

    def __getitem__(self, idx):
        n = len(self.samples)

        # Determine base sample index and sample type
        base_idx = idx % n
        sample_type = idx // n

        # Original sample (first copy)
        if sample_type == 0:
            sample = self.samples[base_idx]
            return (
                sample['imu'],
                sample['thm'],
                sample['tofs'],
                sample['tof_masks'],
                sample['label']
            )

        # Mixup samples (remaining copies)
        else:
            base_sample = self.samples[base_idx]
            gesture = base_sample['gesture']
            candidates = self.gesture_map.get(gesture, [])

            # Fallback to original if no suitable partner
            if len(candidates) < 2:
                return (
                    base_sample['imu'],
                    base_sample['thm'],
                    base_sample['tofs'],
                    base_sample['tof_masks'],
                    base_sample['label']
                )

            # Select different sequence
            partner_sample = base_sample
            while partner_sample['sequence_id'] == base_sample['sequence_id']:
                partner_sample = random.choice(candidates)

            # Mixup parameters
            lam = np.random.beta(0.4, 0.4)

            # Mix IMU data
            min_imu = min(base_sample['imu'].shape[0], partner_sample['imu'].shape[0])
            mixed_imu = lam * base_sample['imu'][:min_imu] + (1 - lam) * partner_sample['imu'][:min_imu]

            # Mix THM data
            min_thm = min(base_sample['thm'].shape[0], partner_sample['thm'].shape[0])
            mixed_thm = lam * base_sample['thm'][:min_thm] + (1 - lam) * partner_sample['thm'][:min_thm]

            # Mix ToF data
            mixed_tofs = []
            for i in range(5):
                min_tof = min(base_sample['tofs'][i].shape[0], partner_sample['tofs'][i].shape[0])
                mixed_tof = lam * base_sample['tofs'][i][:min_tof] + (1 - lam) * partner_sample['tofs'][i][:min_tof]
                mixed_tofs.append(mixed_tof)

            # Mix masks (OR operation)
            mixed_masks = []
            for i in range(5):
                min_len = min(base_sample['tof_masks'][i].shape[0], partner_sample['tof_masks'][i].shape[0])
                mask1 = base_sample['tof_masks'][i][:min_len]
                mask2 = partner_sample['tof_masks'][i][:min_len]
                mixed_masks.append(mask1 | mask2)

            return (
                mixed_imu,
                mixed_thm,
                mixed_tofs,
                mixed_masks,
                base_sample['label']  # Same label
            )

# ==============================
# Collate Function (No Normalization)
# ==============================

def collate_fn(batch):
    imu_seqs, thm_seqs, tof_seqs_list, tof_masks_list, labels = zip(*batch)

    # Sequence lengths
    imu_len = torch.tensor([s.shape[0] for s in imu_seqs], dtype=torch.long)
    thm_len = torch.tensor([s.shape[0] for s in thm_seqs], dtype=torch.long)
    tof_len = torch.tensor([tofs[0].shape[0] for tofs in tof_seqs_list], dtype=torch.long)

    # Pad IMU
    imu_padded = nn.utils.rnn.pad_sequence(
        [torch.tensor(s, dtype=torch.float32) for s in imu_seqs],
        batch_first=True
    )

    # Pad THM
    thm_padded = nn.utils.rnn.pad_sequence(
        [torch.tensor(s, dtype=torch.float32) for s in thm_seqs],
        batch_first=True
    )

    # Pad ToF sensors
    tof_padded = []
    for i in range(5):
        sensor_seqs = [s[i] for s in tof_seqs_list]
        padded = nn.utils.rnn.pad_sequence(
            [torch.tensor(s, dtype=torch.float32) for s in sensor_seqs],
            batch_first=True
        )
        tof_padded.append(padded)

    # Pad ToF masks
    tof_masks_padded = []
    max_t_tof = max(tof_len)
    for i in range(5):
        sensor_masks = []
        for mask in tof_masks_list:
            m = mask[i]
            pad = max_t_tof - m.shape[0]
            padded_mask = torch.tensor(
                np.pad(m, (0, pad), constant_values=False),
                dtype=torch.bool
            )
            sensor_masks.append(padded_mask)
        tof_masks_padded.append(torch.stack(sensor_masks, dim=0))

    # Labels
    labels = torch.tensor(labels, dtype=torch.long)

    return (
        imu_padded,
        imu_len,
        thm_padded,
        thm_len,
        tof_padded,
        tof_len,
        tof_masks_padded,
        labels
    )

# ==============================
# Training & Evaluation
# ==============================

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    criterion = nn.CrossEntropyLoss()

    for batch in loader:
        imu, imu_len, thm, thm_len, tof_inputs, tof_len, tof_masks, labels = batch

        # Move to device
        imu = imu.to(device)
        thm = thm.to(device)
        tof_inputs = [t.to(device) for t in tof_inputs]
        tof_masks = [m.to(device) for m in tof_masks]
        labels = labels.to(device)
        imu_len = imu_len.to(device)
        thm_len = thm_len.to(device)
        tof_len = tof_len.to(device)

        # Forward pass
        optimizer.zero_grad()
        logits = model(imu, imu_len, thm, thm_len, tof_inputs, tof_len, tof_masks)
        loss = criterion(logits, labels)

        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item() * imu.size(0)

    return total_loss / len(loader.dataset)

def evaluate(model, loader, device, non_target_idxs):
    model.eval()
    preds, trues = [], []

    with torch.no_grad():
        for batch in loader:
            imu, imu_len, thm, thm_len, tof_inputs, tof_len, tof_masks, labels = batch

            # Move to device
            imu = imu.to(device)
            thm = thm.to(device)
            tof_inputs = [t.to(device) for t in tof_inputs]
            tof_masks = [m.to(device) for m in tof_masks]
            imu_len = imu_len.to(device)
            thm_len = thm_len.to(device)
            tof_len = tof_len.to(device)

            # Forward pass
            logits = model(imu, imu_len, thm, thm_len, tof_inputs, tof_len, tof_masks)
            p = logits.argmax(1).cpu().numpy()

            preds.append(p)
            trues.append(labels.numpy())

    preds = np.concatenate(preds)
    trues = np.concatenate(trues)

    # Binary F1 (target vs non-target)
    true_bin = ~np.isin(trues, non_target_idxs)
    pred_bin = ~np.isin(preds, non_target_idxs)
    f1_binary = f1_score(true_bin, pred_bin)

    # Macro F1
    f1_macro = f1_score(trues, preds, average='macro')

    # Combined score
    score = (f1_binary + f1_macro) / 2

    return f1_binary, f1_macro, score

# ==============================
# Main Training Pipeline
# ==============================

def main():
    # Configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    num_epochs = 50
    batch_size = 64
    val_batch_size = 128
    hidden_dim = 256
    lr = 1e-4

    # Load data
    df = pd.read_csv('/content/drive/MyDrive/cmi-detect-behavior-with-sensor-data/train.csv')

    # Create gesture mapping
    gestures = sorted(df['gesture'].unique())
    gesture2idx = {g: i for i, g in enumerate(gestures)}

    # Define non-target gestures
    non_target_list = [
        'Drink from bottle/cup','Glasses on/off','Pull air toward your face',
        'Pinch knee/leg skin','Scratch knee/leg skin','Write name on leg',
        'Text on phone','Feel around in tray and pull out an object',
        'Write name in air','Wave hello'
    ]
    non_target_idxs = [gesture2idx[g] for g in non_target_list if g in gesture2idx]

    # Apply IMU feature engineering to the entire dataset
    print("Applying IMU feature engineering...")
    df, imu_cols = apply_imu_feature_engineering(df)
    print(f"IMU features expanded to {len(imu_cols)} dimensions")

    # # Create sequence info for stratified grouping
    seq_info = df.groupby('sequence_id').agg(
        gesture=('gesture', 'first'),
        subject=('subject', 'first')
    ).reset_index()

    # Cross-validation setup
    splitter = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42)

    for fold, (tr_idx, vl_idx) in enumerate(
        splitter.split(np.zeros(len(seq_info)), seq_info.gesture, groups=seq_info.subject),
        start=1
    ):
        print(f'\n=== Fold {fold}/5 ===')

        # Split data
        train_ids = seq_info.sequence_id.iloc[tr_idx]
        val_ids = seq_info.sequence_id.iloc[vl_idx]
        train_df = df[df.sequence_id.isin(train_ids)].copy()
        val_df = df[df.sequence_id.isin(val_ids)].copy()

        # Build datasets
        train_ds = MultiSensorDataset(train_df, gesture2idx, augment=True, imu_cols=imu_cols)
        val_ds = MultiSensorDataset(val_df, gesture2idx, augment=False, imu_cols=imu_cols)

        print(f"Training samples: {len(train_ds)}")
        print(f"Validation samples: {len(val_ds)}")

        # Create data loaders
        train_loader = DataLoader(
            train_ds, batch_size=batch_size, shuffle=True,
            collate_fn=collate_fn, num_workers=4, pin_memory=True
        )
        val_loader = DataLoader(
            val_ds, batch_size=val_batch_size, shuffle=False,
            collate_fn=collate_fn, num_workers=4, pin_memory=True
        )

        # Initialize model
        model = MultiSensorClassifier(
            imu_input_dim=len(imu_cols),  # Updated for engineered features
            thm_input_dim=5,
            tof_input_dim=64,
            hidden_dim=hidden_dim,
            num_classes=len(gesture2idx)
        ).to(device)

        # Optimizer
        optimizer = optim.Adam(model.parameters(), lr=lr)

        # Training loop
        best_score = 0.0
        for epoch in range(1, num_epochs + 1):
            loss = train_one_epoch(model, train_loader, optimizer, device)
            f1b, f1m, score = evaluate(model, val_loader, device, non_target_idxs)

            print(f'E{epoch}: Loss={loss:.4f} | F1-bin={f1b:.4f} | F1-mac={f1m:.4f} | S={score:.4f}')

            if score > best_score:
                best_score = score
                torch.save(model.state_dict(), f'/content/drive/MyDrive/cmi-detect-behavior-with-sensor-data/best_model_fold{fold}.pth')
                print(f'Saved new best model with score: {score:.4f}')

        print(f'Best score fold {fold}: {best_score:.4f}')

if __name__ == '__main__':
    main()

Applying IMU feature engineering...
IMU features expanded to 20 dimensions

=== Fold 1/5 ===
Training samples: 19869
Validation samples: 1528
E1: Loss=1.7340 | F1-bin=0.9692 | F1-mac=0.5894 | S=0.7793
Saved new best model with score: 0.7793
E2: Loss=1.1764 | F1-bin=0.9774 | F1-mac=0.6682 | S=0.8228
Saved new best model with score: 0.8228
E3: Loss=0.9584 | F1-bin=0.9881 | F1-mac=0.7044 | S=0.8462
Saved new best model with score: 0.8462
E4: Loss=0.8249 | F1-bin=0.9830 | F1-mac=0.7106 | S=0.8468
Saved new best model with score: 0.8468
E5: Loss=0.7318 | F1-bin=0.9922 | F1-mac=0.7273 | S=0.8597
Saved new best model with score: 0.8597
E6: Loss=0.6462 | F1-bin=0.9870 | F1-mac=0.7198 | S=0.8534
E7: Loss=0.6148 | F1-bin=0.9932 | F1-mac=0.7146 | S=0.8539
E8: Loss=0.4910 | F1-bin=0.9922 | F1-mac=0.7484 | S=0.8703
Saved new best model with score: 0.8703
E9: Loss=0.4778 | F1-bin=0.9860 | F1-mac=0.7209 | S=0.8534
E10: Loss=0.4290 | F1-bin=0.9927 | F1-mac=0.7379 | S=0.8653
E11: Loss=0.4351 | F1-bin=0