# **Model Training and Feature Engineering**

In [1]:
#### TEMPORARY
import sys
sys.path.append('/home/bac/code/kaggle/kaggle-cmi-detect-behavior/')

In [2]:
import os
import warnings
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import lightgbm as lgb
import xgboost as xgb
import catboost as cat
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import f1_score, balanced_accuracy_score

warnings.filterwarnings('ignore')
pd.set_option('display.max_columns', None)
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 8)

In [3]:
from src.config import PROJECT_PATH, DATA_PATH, USE_WANDB, WANDB_PROJECT, WANDB_ENTITY
from src.tracking import ExperimentTracker

In [4]:
def display_all(df):
    with pd.option_context("display.max_rows", 100, "display.max_columns", 1000):
        display(df)

In [5]:
# Initialize the experiment tracker
tracker = ExperimentTracker(
    project_path=PROJECT_PATH,
    use_wandb=USE_WANDB,
    wandb_project_name=WANDB_PROJECT,
    wandb_entity=WANDB_ENTITY
)

In [6]:
# Load data
train_sensor = pd.read_csv(os.path.join(os.path.expanduser(DATA_PATH), 'train.csv'))
train_demos = pd.read_csv(os.path.join(os.path.expanduser(DATA_PATH), 'train_demographics.csv'))

# Merge
train_df = pd.merge(train_sensor, train_demos, on='subject', how='left')
display_all(train_df.head())

Unnamed: 0,row_id,sequence_type,sequence_id,sequence_counter,subject,orientation,behavior,phase,gesture,acc_x,acc_y,acc_z,rot_w,rot_x,rot_y,rot_z,thm_1,thm_2,thm_3,thm_4,thm_5,tof_1_v0,tof_1_v1,tof_1_v2,tof_1_v3,tof_1_v4,tof_1_v5,tof_1_v6,tof_1_v7,tof_1_v8,tof_1_v9,tof_1_v10,tof_1_v11,tof_1_v12,tof_1_v13,tof_1_v14,tof_1_v15,tof_1_v16,tof_1_v17,tof_1_v18,tof_1_v19,tof_1_v20,tof_1_v21,tof_1_v22,tof_1_v23,tof_1_v24,tof_1_v25,tof_1_v26,tof_1_v27,tof_1_v28,tof_1_v29,tof_1_v30,tof_1_v31,tof_1_v32,tof_1_v33,tof_1_v34,tof_1_v35,tof_1_v36,tof_1_v37,tof_1_v38,tof_1_v39,tof_1_v40,tof_1_v41,tof_1_v42,tof_1_v43,tof_1_v44,tof_1_v45,tof_1_v46,tof_1_v47,tof_1_v48,tof_1_v49,tof_1_v50,tof_1_v51,tof_1_v52,tof_1_v53,tof_1_v54,tof_1_v55,tof_1_v56,tof_1_v57,tof_1_v58,tof_1_v59,tof_1_v60,tof_1_v61,tof_1_v62,tof_1_v63,tof_2_v0,tof_2_v1,tof_2_v2,tof_2_v3,tof_2_v4,tof_2_v5,tof_2_v6,tof_2_v7,tof_2_v8,tof_2_v9,tof_2_v10,tof_2_v11,tof_2_v12,tof_2_v13,tof_2_v14,tof_2_v15,tof_2_v16,tof_2_v17,tof_2_v18,tof_2_v19,tof_2_v20,tof_2_v21,tof_2_v22,tof_2_v23,tof_2_v24,tof_2_v25,tof_2_v26,tof_2_v27,tof_2_v28,tof_2_v29,tof_2_v30,tof_2_v31,tof_2_v32,tof_2_v33,tof_2_v34,tof_2_v35,tof_2_v36,tof_2_v37,tof_2_v38,tof_2_v39,tof_2_v40,tof_2_v41,tof_2_v42,tof_2_v43,tof_2_v44,tof_2_v45,tof_2_v46,tof_2_v47,tof_2_v48,tof_2_v49,tof_2_v50,tof_2_v51,tof_2_v52,tof_2_v53,tof_2_v54,tof_2_v55,tof_2_v56,tof_2_v57,tof_2_v58,tof_2_v59,tof_2_v60,tof_2_v61,tof_2_v62,tof_2_v63,tof_3_v0,tof_3_v1,tof_3_v2,tof_3_v3,tof_3_v4,tof_3_v5,tof_3_v6,tof_3_v7,tof_3_v8,tof_3_v9,tof_3_v10,tof_3_v11,tof_3_v12,tof_3_v13,tof_3_v14,tof_3_v15,tof_3_v16,tof_3_v17,tof_3_v18,tof_3_v19,tof_3_v20,tof_3_v21,tof_3_v22,tof_3_v23,tof_3_v24,tof_3_v25,tof_3_v26,tof_3_v27,tof_3_v28,tof_3_v29,tof_3_v30,tof_3_v31,tof_3_v32,tof_3_v33,tof_3_v34,tof_3_v35,tof_3_v36,tof_3_v37,tof_3_v38,tof_3_v39,tof_3_v40,tof_3_v41,tof_3_v42,tof_3_v43,tof_3_v44,tof_3_v45,tof_3_v46,tof_3_v47,tof_3_v48,tof_3_v49,tof_3_v50,tof_3_v51,tof_3_v52,tof_3_v53,tof_3_v54,tof_3_v55,tof_3_v56,tof_3_v57,tof_3_v58,tof_3_v59,tof_3_v60,tof_3_v61,tof_3_v62,tof_3_v63,tof_4_v0,tof_4_v1,tof_4_v2,tof_4_v3,tof_4_v4,tof_4_v5,tof_4_v6,tof_4_v7,tof_4_v8,tof_4_v9,tof_4_v10,tof_4_v11,tof_4_v12,tof_4_v13,tof_4_v14,tof_4_v15,tof_4_v16,tof_4_v17,tof_4_v18,tof_4_v19,tof_4_v20,tof_4_v21,tof_4_v22,tof_4_v23,tof_4_v24,tof_4_v25,tof_4_v26,tof_4_v27,tof_4_v28,tof_4_v29,tof_4_v30,tof_4_v31,tof_4_v32,tof_4_v33,tof_4_v34,tof_4_v35,tof_4_v36,tof_4_v37,tof_4_v38,tof_4_v39,tof_4_v40,tof_4_v41,tof_4_v42,tof_4_v43,tof_4_v44,tof_4_v45,tof_4_v46,tof_4_v47,tof_4_v48,tof_4_v49,tof_4_v50,tof_4_v51,tof_4_v52,tof_4_v53,tof_4_v54,tof_4_v55,tof_4_v56,tof_4_v57,tof_4_v58,tof_4_v59,tof_4_v60,tof_4_v61,tof_4_v62,tof_4_v63,tof_5_v0,tof_5_v1,tof_5_v2,tof_5_v3,tof_5_v4,tof_5_v5,tof_5_v6,tof_5_v7,tof_5_v8,tof_5_v9,tof_5_v10,tof_5_v11,tof_5_v12,tof_5_v13,tof_5_v14,tof_5_v15,tof_5_v16,tof_5_v17,tof_5_v18,tof_5_v19,tof_5_v20,tof_5_v21,tof_5_v22,tof_5_v23,tof_5_v24,tof_5_v25,tof_5_v26,tof_5_v27,tof_5_v28,tof_5_v29,tof_5_v30,tof_5_v31,tof_5_v32,tof_5_v33,tof_5_v34,tof_5_v35,tof_5_v36,tof_5_v37,tof_5_v38,tof_5_v39,tof_5_v40,tof_5_v41,tof_5_v42,tof_5_v43,tof_5_v44,tof_5_v45,tof_5_v46,tof_5_v47,tof_5_v48,tof_5_v49,tof_5_v50,tof_5_v51,tof_5_v52,tof_5_v53,tof_5_v54,tof_5_v55,tof_5_v56,tof_5_v57,tof_5_v58,tof_5_v59,tof_5_v60,tof_5_v61,tof_5_v62,tof_5_v63,adult_child,age,sex,handedness,height_cm,shoulder_to_wrist_cm,elbow_to_wrist_cm
0,SEQ_000007_000000,Target,SEQ_000007,0,SUBJ_059520,Seated Lean Non Dom - FACE DOWN,Relaxes and moves hand to target location,Transition,Cheek - pinch skin,6.683594,6.214844,3.355469,0.134399,-0.355164,-0.447327,-0.809753,28.943842,31.822186,29.553024,28.592863,28.310535,131.0,134.0,132.0,135.0,98.0,74.0,64.0,60.0,-1.0,-1.0,152.0,153.0,141.0,89.0,68.0,63.0,-1.0,-1.0,-1.0,-1.0,169.0,118.0,86.0,73.0,-1.0,-1.0,-1.0,-1.0,-1.0,147.0,110.0,87.0,126.0,-1.0,-1.0,-1.0,-1.0,-1.0,137.0,108.0,115.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,128.0,110.0,129.0,140.0,-1.0,126.0,131.0,-1.0,-1.0,-1.0,108.0,122.0,139.0,113.0,121.0,-1.0,118.0,96.0,-1.0,-1.0,-1.0,-1.0,165.0,124.0,100.0,102.0,119.0,-1.0,-1.0,115.0,130.0,-1.0,124.0,107.0,117.0,132.0,136.0,116.0,120.0,-1.0,141.0,118.0,115.0,122.0,145.0,128.0,130.0,137.0,131.0,-1.0,116.0,117.0,130.0,115.0,116.0,117.0,108.0,-1.0,-1.0,119.0,118.0,110.0,93.0,90.0,90.0,-1.0,-1.0,-1.0,116.0,103.0,87.0,82.0,81.0,-1.0,-1.0,-1.0,115.0,91.0,84.0,80.0,85.0,58.0,55.0,59.0,59.0,63.0,96.0,93.0,-1.0,57.0,59.0,58.0,64.0,72.0,103.0,98.0,-1.0,55.0,57.0,62.0,63.0,88.0,103.0,105.0,108.0,56.0,59.0,58.0,77.0,94.0,106.0,-1.0,113.0,57.0,58.0,66.0,78.0,93.0,-1.0,-1.0,-1.0,59.0,67.0,69.0,82.0,104.0,-1.0,-1.0,-1.0,63.0,70.0,79.0,96.0,-1.0,-1.0,-1.0,-1.0,79.0,83.0,-1.0,-1.0,-1.0,-1.0,102.0,100.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,74.0,130.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,69.0,134.0,137.0,136.0,145.0,131.0,126.0,83.0,60.0,-1.0,138.0,135.0,148.0,121.0,109.0,69.0,51.0,-1.0,143.0,139.0,148.0,113.0,91.0,67.0,52.0,-1.0,-1.0,-1.0,-1.0,101.0,81.0,62.0,54.0,-1.0,-1.0,-1.0,-1.0,124.0,78.0,68.0,55.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,66.0,60.0,128.0,130.0,147.0,165.0,-1.0,-1.0,-1.0,122.0,121.0,140.0,164.0,-1.0,-1.0,-1.0,140.0,119.0,135.0,156.0,166.0,-1.0,-1.0,155.0,137.0,112.0,148.0,163.0,164.0,153.0,133.0,131.0,121.0,118.0,134.0,134.0,128.0,121.0,119.0,121.0,129.0,-1.0,113.0,124.0,122.0,131.0,-1.0,-1.0,-1.0,-1.0,120.0,127.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,0,12,1,1,163.0,52,24.0
1,SEQ_000007_000001,Target,SEQ_000007,1,SUBJ_059520,Seated Lean Non Dom - FACE DOWN,Relaxes and moves hand to target location,Transition,Cheek - pinch skin,6.949219,6.214844,3.125,0.143494,-0.340271,-0.42865,-0.824524,29.340816,31.874645,29.79174,28.663383,28.406172,130.0,138.0,131.0,135.0,101.0,76.0,66.0,61.0,-1.0,-1.0,156.0,155.0,141.0,93.0,74.0,64.0,-1.0,-1.0,-1.0,-1.0,165.0,116.0,86.0,75.0,130.0,-1.0,-1.0,-1.0,-1.0,142.0,114.0,91.0,127.0,-1.0,-1.0,-1.0,-1.0,-1.0,145.0,114.0,114.0,-1.0,-1.0,-1.0,135.0,-1.0,-1.0,132.0,110.0,121.0,138.0,142.0,123.0,131.0,-1.0,-1.0,-1.0,106.0,120.0,139.0,119.0,124.0,131.0,117.0,109.0,-1.0,-1.0,-1.0,-1.0,165.0,134.0,108.0,106.0,123.0,-1.0,-1.0,121.0,147.0,-1.0,131.0,114.0,114.0,138.0,145.0,121.0,141.0,144.0,138.0,-1.0,120.0,124.0,147.0,115.0,141.0,135.0,125.0,-1.0,-1.0,122.0,122.0,117.0,103.0,108.0,108.0,-1.0,-1.0,-1.0,129.0,108.0,100.0,92.0,93.0,-1.0,-1.0,-1.0,116.0,99.0,93.0,90.0,91.0,-1.0,-1.0,-1.0,113.0,101.0,94.0,88.0,95.0,75.0,67.0,68.0,71.0,74.0,102.0,99.0,-1.0,64.0,68.0,67.0,72.0,88.0,112.0,103.0,-1.0,65.0,68.0,69.0,75.0,105.0,111.0,109.0,-1.0,66.0,71.0,72.0,81.0,109.0,116.0,121.0,118.0,61.0,67.0,75.0,93.0,116.0,128.0,130.0,121.0,62.0,72.0,80.0,92.0,115.0,-1.0,-1.0,-1.0,67.0,73.0,82.0,98.0,-1.0,-1.0,-1.0,-1.0,77.0,82.0,110.0,-1.0,-1.0,-1.0,112.0,105.0,134.0,-1.0,-1.0,-1.0,-1.0,-1.0,91.0,82.0,132.0,145.0,148.0,157.0,143.0,-1.0,117.0,66.0,142.0,142.0,149.0,147.0,136.0,109.0,80.0,60.0,142.0,142.0,143.0,135.0,126.0,92.0,73.0,61.0,-1.0,147.0,148.0,137.0,109.0,82.0,71.0,60.0,-1.0,-1.0,-1.0,-1.0,101.0,83.0,69.0,62.0,-1.0,-1.0,-1.0,-1.0,109.0,84.0,76.0,64.0,-1.0,-1.0,-1.0,-1.0,-1.0,93.0,72.0,74.0,126.0,137.0,157.0,174.0,-1.0,-1.0,140.0,130.0,124.0,143.0,168.0,-1.0,-1.0,-1.0,142.0,122.0,138.0,157.0,-1.0,-1.0,-1.0,155.0,133.0,117.0,145.0,170.0,163.0,157.0,139.0,127.0,126.0,121.0,136.0,142.0,133.0,127.0,123.0,127.0,134.0,-1.0,116.0,122.0,123.0,126.0,-1.0,-1.0,-1.0,-1.0,122.0,129.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,0,12,1,1,163.0,52,24.0
2,SEQ_000007_000002,Target,SEQ_000007,2,SUBJ_059520,Seated Lean Non Dom - FACE DOWN,Relaxes and moves hand to target location,Transition,Cheek - pinch skin,5.722656,5.410156,5.421875,0.219055,-0.274231,-0.356934,-0.865662,30.339359,30.935045,30.090014,28.796087,28.529778,137.0,136.0,147.0,109.0,90.0,81.0,74.0,74.0,-1.0,164.0,165.0,146.0,106.0,94.0,77.0,77.0,-1.0,-1.0,-1.0,180.0,140.0,118.0,103.0,92.0,-1.0,-1.0,-1.0,-1.0,-1.0,155.0,119.0,122.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,148.0,130.0,123.0,158.0,-1.0,141.0,147.0,-1.0,157.0,141.0,113.0,131.0,152.0,147.0,146.0,157.0,144.0,127.0,-1.0,115.0,127.0,129.0,119.0,112.0,117.0,120.0,119.0,-1.0,-1.0,125.0,154.0,165.0,145.0,122.0,111.0,126.0,143.0,157.0,143.0,166.0,149.0,137.0,116.0,123.0,149.0,148.0,136.0,141.0,143.0,139.0,-1.0,134.0,133.0,142.0,123.0,118.0,116.0,128.0,-1.0,-1.0,-1.0,143.0,123.0,109.0,112.0,117.0,-1.0,-1.0,-1.0,-1.0,119.0,127.0,114.0,128.0,-1.0,-1.0,-1.0,138.0,136.0,129.0,125.0,132.0,-1.0,-1.0,-1.0,-1.0,-1.0,159.0,150.0,-1.0,73.0,94.0,109.0,117.0,126.0,-1.0,-1.0,-1.0,94.0,119.0,101.0,123.0,139.0,-1.0,-1.0,-1.0,98.0,97.0,101.0,128.0,144.0,-1.0,-1.0,-1.0,88.0,107.0,101.0,154.0,141.0,-1.0,-1.0,-1.0,84.0,89.0,107.0,140.0,146.0,149.0,147.0,131.0,76.0,100.0,108.0,141.0,-1.0,158.0,143.0,117.0,77.0,89.0,105.0,133.0,-1.0,-1.0,141.0,108.0,79.0,99.0,114.0,-1.0,-1.0,-1.0,130.0,118.0,139.0,149.0,167.0,162.0,-1.0,107.0,92.0,81.0,135.0,140.0,148.0,151.0,155.0,111.0,82.0,94.0,132.0,139.0,147.0,138.0,120.0,97.0,78.0,85.0,140.0,146.0,136.0,131.0,98.0,86.0,75.0,80.0,149.0,156.0,147.0,113.0,97.0,84.0,81.0,71.0,-1.0,-1.0,174.0,117.0,96.0,89.0,80.0,78.0,-1.0,-1.0,-1.0,145.0,104.0,92.0,88.0,76.0,-1.0,-1.0,-1.0,-1.0,-1.0,117.0,98.0,105.0,92.0,110.0,157.0,180.0,-1.0,128.0,123.0,126.0,142.0,165.0,185.0,-1.0,-1.0,-1.0,145.0,139.0,138.0,164.0,-1.0,-1.0,-1.0,-1.0,145.0,120.0,151.0,165.0,-1.0,-1.0,-1.0,151.0,138.0,127.0,151.0,187.0,-1.0,156.0,136.0,135.0,134.0,-1.0,133.0,142.0,131.0,130.0,132.0,136.0,-1.0,-1.0,112.0,121.0,123.0,125.0,-1.0,-1.0,-1.0,-1.0,112.0,119.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,0,12,1,1,163.0,52,24.0
3,SEQ_000007_000003,Target,SEQ_000007,3,SUBJ_059520,Seated Lean Non Dom - FACE DOWN,Relaxes and moves hand to target location,Transition,Cheek - pinch skin,6.601562,3.53125,6.457031,0.297546,-0.26416,-0.238159,-0.885986,30.54373,27.044001,29.310717,29.018711,27.40201,143.0,147.0,170.0,127.0,109.0,98.0,95.0,95.0,-1.0,177.0,189.0,177.0,136.0,121.0,107.0,104.0,-1.0,-1.0,-1.0,202.0,171.0,160.0,141.0,135.0,-1.0,-1.0,-1.0,-1.0,-1.0,197.0,168.0,150.0,131.0,-1.0,-1.0,-1.0,170.0,179.0,174.0,164.0,125.0,140.0,161.0,175.0,154.0,174.0,160.0,159.0,-1.0,126.0,143.0,167.0,149.0,137.0,130.0,131.0,-1.0,-1.0,-1.0,141.0,137.0,129.0,115.0,124.0,108.0,123.0,146.0,166.0,152.0,168.0,158.0,161.0,123.0,133.0,138.0,155.0,163.0,151.0,132.0,151.0,-1.0,216.0,-1.0,-1.0,175.0,157.0,146.0,140.0,-1.0,-1.0,-1.0,-1.0,-1.0,173.0,153.0,164.0,-1.0,-1.0,-1.0,-1.0,246.0,189.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,225.0,-1.0,-1.0,-1.0,-1.0,-1.0,243.0,-1.0,220.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,119.0,132.0,164.0,-1.0,-1.0,-1.0,-1.0,-1.0,117.0,156.0,176.0,-1.0,-1.0,-1.0,-1.0,-1.0,144.0,151.0,187.0,-1.0,-1.0,-1.0,-1.0,-1.0,126.0,162.0,184.0,-1.0,-1.0,-1.0,-1.0,-1.0,128.0,158.0,180.0,189.0,202.0,171.0,168.0,164.0,117.0,153.0,183.0,197.0,192.0,164.0,156.0,160.0,113.0,142.0,192.0,197.0,192.0,157.0,149.0,146.0,114.0,158.0,-1.0,173.0,159.0,156.0,147.0,-1.0,139.0,146.0,153.0,181.0,161.0,113.0,112.0,100.0,131.0,153.0,148.0,162.0,138.0,120.0,106.0,116.0,141.0,157.0,159.0,153.0,128.0,111.0,111.0,115.0,149.0,173.0,162.0,141.0,120.0,117.0,108.0,120.0,181.0,178.0,210.0,137.0,143.0,112.0,126.0,112.0,-1.0,-1.0,209.0,202.0,144.0,163.0,133.0,155.0,-1.0,-1.0,-1.0,-1.0,-1.0,168.0,179.0,155.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,164.0,175.0,105.0,132.0,171.0,-1.0,157.0,167.0,149.0,131.0,149.0,189.0,203.0,-1.0,-1.0,164.0,133.0,-1.0,162.0,181.0,-1.0,-1.0,-1.0,152.0,134.0,-1.0,148.0,187.0,-1.0,-1.0,149.0,142.0,135.0,-1.0,159.0,181.0,150.0,135.0,129.0,139.0,-1.0,-1.0,141.0,136.0,120.0,122.0,132.0,-1.0,-1.0,-1.0,107.0,112.0,115.0,140.0,-1.0,-1.0,-1.0,-1.0,101.0,111.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,0,12,1,1,163.0,52,24.0
4,SEQ_000007_000004,Target,SEQ_000007,4,SUBJ_059520,Seated Lean Non Dom - FACE DOWN,Relaxes and moves hand to target location,Transition,Cheek - pinch skin,5.566406,0.277344,9.632812,0.333557,-0.218628,-0.063538,-0.914856,29.317265,25.270855,26.808746,29.408604,27.357603,178.0,191.0,183.0,157.0,146.0,139.0,143.0,148.0,-1.0,-1.0,236.0,238.0,208.0,200.0,185.0,190.0,-1.0,-1.0,-1.0,210.0,246.0,225.0,228.0,202.0,149.0,206.0,219.0,219.0,225.0,218.0,214.0,-1.0,162.0,177.0,206.0,219.0,207.0,182.0,225.0,-1.0,-1.0,-1.0,-1.0,233.0,195.0,204.0,190.0,-1.0,-1.0,-1.0,-1.0,-1.0,209.0,210.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,138.0,155.0,173.0,188.0,180.0,176.0,211.0,235.0,-1.0,-1.0,-1.0,-1.0,210.0,210.0,223.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,161.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,165.0,207.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,178.0,221.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,184.0,216.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,162.0,212.0,237.0,-1.0,-1.0,-1.0,-1.0,-1.0,152.0,198.0,221.0,213.0,204.0,211.0,235.0,-1.0,143.0,205.0,213.0,189.0,191.0,194.0,198.0,-1.0,139.0,138.0,159.0,145.0,120.0,121.0,118.0,116.0,149.0,143.0,152.0,136.0,127.0,138.0,125.0,125.0,163.0,161.0,148.0,135.0,127.0,137.0,153.0,129.0,184.0,197.0,155.0,146.0,140.0,149.0,154.0,164.0,-1.0,229.0,200.0,176.0,169.0,166.0,169.0,171.0,-1.0,-1.0,-1.0,-1.0,219.0,208.0,202.0,-1.0,-1.0,-1.0,202.0,-1.0,224.0,211.0,-1.0,-1.0,146.0,179.0,-1.0,191.0,192.0,194.0,-1.0,-1.0,127.0,185.0,-1.0,199.0,187.0,186.0,-1.0,-1.0,143.0,-1.0,-1.0,216.0,205.0,-1.0,-1.0,-1.0,197.0,-1.0,-1.0,219.0,192.0,-1.0,-1.0,-1.0,204.0,-1.0,-1.0,212.0,181.0,-1.0,-1.0,-1.0,184.0,-1.0,179.0,162.0,-1.0,-1.0,-1.0,-1.0,169.0,171.0,145.0,140.0,-1.0,-1.0,-1.0,-1.0,132.0,125.0,131.0,-1.0,-1.0,-1.0,-1.0,-1.0,101.0,109.0,125.0,-1.0,-1.0,-1.0,-1.0,-1.0,0,12,1,1,163.0,52,24.0


In [7]:
train_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 574945 entries, 0 to 574944
Columns: 348 entries, row_id to elbow_to_wrist_cm
dtypes: float64(334), int64(6), object(8)
memory usage: 1.5+ GB


## Feature Engineering and Experiment Checklist

**Wave 0: Baseline Features (Sequence-Level Aggregates)**
- **Goal**: Establish a rock-solid baseline score.

- Simple statistical aggregations (mean, std, min, max, median, skew) for the core sensor groups (acc_*, rot_*, thm_*) across each entire sequence_id.
- **ToF Handling**: For now, we will simply take the mean of all 320 tof columns as a single feature.

**Wave 1: Phase-Specific Features**

- **Goal**: Capture the distinct behavior within each phase (Transition, Gesture).
- All Wave 0 statistics, but now calculated on a per-phase basis (e.g., mean_acc_x_Transition, mean_acc_x_Gesture). This will significantly expand the feature set.

**Wave 2**: Advanced IMU & Thermopile Features

- **Goal**: Model the physics of the movements.
-   **IMU**: Calculate vector magnitudes (acc_mag, rot_mag), jerk (rate of change of acceleration), and potentially FFT features on the magnitude.  
-   Thermopile: Calculate gradients between adjacent sensors (thm_1-thm_2, etc.) and their statistics per phase.

**Wave 3: Advanced ToF & Demographic Interaction Features**

- **Goal: Extract spatial information from ToF sensors and leverage demographic context.**
- **ToF:** Calculate the percentage of invalid (-1) readings per sensor group. Use PCA or an Autoencoder on the valid ToF readings to create 5-10 latent features.
- **Demographics**: Create interaction features (e.g., mean_acc_mag_Gesture * age, mean_acc_mag_Gesture * height_cm).

**Wave 4: Temporal & Lag Features**

- **Goal**: Capture time-dependent patterns.
- Create rolling window statistics (e.g., 5-step rolling mean) and lag features (e.g., difference between current acc_x and acc_x from 3 steps prior). This requires careful handling of the sequence_counter

## Feature Engineering and Training - Wave 0 - Baseline

In [8]:
def create_baseline_features(df):
    """
    Create Wave 0 features: Simple sequence level aggregations.
    """
    aggs = {
        'acc_x': ['mean', 'std', 'min', 'max', 'median', 'skew'],
        'acc_y': ['mean', 'std', 'min', 'max', 'median', 'skew'],
        'acc_z': ['mean', 'std', 'min', 'max', 'median', 'skew'],
        'rot_w': ['mean', 'std', 'min', 'max', 'median', 'skew'],
        'rot_x': ['mean', 'std', 'min', 'max', 'median', 'skew'],
        'rot_y': ['mean', 'std', 'min', 'max', 'median', 'skew'],
        'rot_z': ['mean', 'std', 'min', 'max', 'median', 'skew'],
    }
    
    # Add Thermopile aggregations
    for i in range(1, 6):
        aggs[f'thm_{i}'] = ['mean', 'std', 'min', 'max']
        
    # Add time of flight aggregations (simple mean and std for baseline)
    tof_cols = [f'tof_{s}_v{p}' for s in range(1, 6) for p in range(64)]
    # Replace -1 with NaN to correctly calculate stats
    df[tof_cols] = df[tof_cols].replace(-1, np.nan)
    aggs['tof_mean'] = ['mean', 'std']
    df['tof_mean'] = df[tof_cols].mean(axis=1)
    
    # Group by sequence and aggregate and flatten multi index cols
    agg_df = df.groupby('sequence_id').agg(aggs)
    agg_df.columns = ['_'.join(col).strip() for col in agg_df.columns.values]
    
    # Get sequence level metadata (target, subject, etc.)
    meta_df = df.groupby('sequence_id').first()
    
    # Combine aggregated features with metadata
    final_df = pd.concat([meta_df[['subject', 'gesture']], agg_df], axis=1)
    
    # Encode gesture target
    final_df['gesture_encoded'] = final_df['gesture'].astype('category').cat.codes
    
    print(f"Feature engineering complete. Shape of features: {final_df.shape}")
    return final_df

In [9]:
# Create features
features_df = create_baseline_features(train_df)
display_all(features_df.head())

Feature engineering complete. Shape of features: (8151, 67)


Unnamed: 0_level_0,subject,gesture,acc_x_mean,acc_x_std,acc_x_min,acc_x_max,acc_x_median,acc_x_skew,acc_y_mean,acc_y_std,acc_y_min,acc_y_max,acc_y_median,acc_y_skew,acc_z_mean,acc_z_std,acc_z_min,acc_z_max,acc_z_median,acc_z_skew,rot_w_mean,rot_w_std,rot_w_min,rot_w_max,rot_w_median,rot_w_skew,rot_x_mean,rot_x_std,rot_x_min,rot_x_max,rot_x_median,rot_x_skew,rot_y_mean,rot_y_std,rot_y_min,rot_y_max,rot_y_median,rot_y_skew,rot_z_mean,rot_z_std,rot_z_min,rot_z_max,rot_z_median,rot_z_skew,thm_1_mean,thm_1_std,thm_1_min,thm_1_max,thm_2_mean,thm_2_std,thm_2_min,thm_2_max,thm_3_mean,thm_3_std,thm_3_min,thm_3_max,thm_4_mean,thm_4_std,thm_4_min,thm_4_max,thm_5_mean,thm_5_std,thm_5_min,thm_5_max,tof_mean_mean,tof_mean_std,gesture_encoded
sequence_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1,Unnamed: 58_level_1,Unnamed: 59_level_1,Unnamed: 60_level_1,Unnamed: 61_level_1,Unnamed: 62_level_1,Unnamed: 63_level_1,Unnamed: 64_level_1,Unnamed: 65_level_1,Unnamed: 66_level_1,Unnamed: 67_level_1
SEQ_000007,SUBJ_059520,Cheek - pinch skin,6.153098,1.334155,3.613281,9.015625,6.488281,-0.545319,3.91557,3.048287,-2.019531,6.519531,5.488281,-1.18487,5.577782,2.337517,1.09375,9.792969,4.964844,0.586111,0.263574,0.069033,0.134399,0.379272,0.254578,0.307637,-0.280817,0.056597,-0.442871,-0.204163,-0.275757,-0.846226,-0.33147,0.17505,-0.478027,0.005066,-0.414978,1.258707,-0.837994,0.040723,-0.914856,-0.757935,-0.825012,-0.477119,28.630612,0.582076,27.69651,30.54373,29.57187,2.576799,24.558798,32.010178,28.576605,1.260533,25.90749,30.090014,29.177937,0.278147,28.592863,29.76148,27.957446,0.877846,26.047148,29.428299,105.199184,53.223485,1
SEQ_000008,SUBJ_020948,Forehead - pull hairline,3.400506,1.087142,1.734375,5.90625,3.4375,0.146452,5.311179,3.268073,-0.222656,8.667969,7.0,-0.830467,6.581629,2.475402,1.722656,11.074219,5.839844,0.186346,0.243493,0.064414,0.157593,0.34198,0.226562,0.32673,-0.117145,0.049384,-0.263306,-0.050537,-0.097382,-0.937947,-0.342327,0.190164,-0.508606,-0.031555,-0.442169,0.830673,-0.875143,0.042626,-0.937805,-0.814697,-0.860046,-0.510329,30.464309,2.709212,25.985313,32.870808,29.678206,3.88508,23.907709,33.100945,29.179852,3.074828,24.414917,32.316135,30.501325,0.976249,28.755495,31.613327,25.824221,1.16594,24.181562,28.054575,150.73884,46.11008,6
SEQ_000013,SUBJ_040282,Cheek - pinch skin,-7.058962,1.295184,-9.25,-3.347656,-7.144531,0.518519,2.346182,2.564639,-3.273438,4.683594,3.382812,-1.445762,-6.068544,1.330784,-10.945312,-3.515625,-5.851562,-1.039566,0.392208,0.150629,0.061157,0.540771,0.439514,-1.359968,0.340804,0.182002,0.140991,0.726501,0.258362,1.350803,0.800506,0.090017,0.580505,0.881653,0.838135,-1.593112,0.002644,0.164305,-0.406799,0.129761,0.066101,-1.682267,24.522526,0.449773,24.181389,25.634346,24.367174,0.620555,23.933413,26.175961,24.892424,0.294962,24.406981,25.512794,24.93084,0.572871,24.419798,26.452927,24.733322,0.475044,24.16798,26.051331,195.963626,45.069032,1
SEQ_000016,SUBJ_052342,Write name on leg,5.524654,1.074108,3.4375,9.378906,5.390625,0.747648,-4.408491,0.598318,-5.71875,-2.960938,-4.492188,0.505319,-3.162077,6.139752,-8.078125,8.355469,-6.667969,0.964846,0.361083,0.041568,0.277527,0.459045,0.352234,0.525728,-0.728107,0.207529,-0.893677,-0.384827,-0.857361,0.843064,-0.223281,0.156706,-0.368713,0.035889,-0.315857,0.845564,-0.363684,0.301057,-0.817688,-0.082275,-0.190979,-0.794437,31.651703,4.006846,25.413513,36.053188,31.601259,4.495657,25.018881,36.705894,29.320353,3.274493,24.128819,33.617542,32.790761,3.253195,27.227589,35.665222,30.860562,3.310154,26.312038,35.801083,40.090805,46.833388,17
SEQ_000018,SUBJ_032165,Forehead - pull hairline,5.363715,1.627637,1.964844,6.832031,6.101562,-1.397824,4.109737,3.525304,-3.164062,6.71875,6.007812,-1.347944,5.937066,2.104544,4.148438,9.933594,4.761719,1.040372,0.859159,0.034238,0.828247,0.925049,0.846283,1.03008,0.177468,0.178091,-0.184204,0.305542,0.270111,-1.342766,-0.352176,0.149264,-0.457458,-0.022644,-0.422699,1.463709,-0.216601,0.073268,-0.367676,-0.159668,-0.18103,-1.283574,28.90361,1.144503,26.533083,30.267483,29.438643,1.658719,25.795074,31.035217,27.058073,0.951421,25.12772,28.468761,27.841705,0.431424,26.827133,28.400864,31.014364,1.394629,28.282324,32.180752,136.302421,33.106553,6


In [10]:
#Config
EXPERIMENT_NAME = "Baseline-Wave0-LGBM-0"
MODEL_NAME = "LightGBM"
FEATURE_WAVE = "Wave-0"
N_SPLITS = 5
SEED = 42

# Model Params
params = {
    'objective': 'multiclass',
    'num_class': features_df['gesture_encoded'].nunique(),
    'metric': 'multi_logloss',
    'n_estimators': 1000,
    'learning_rate': 0.05,
    'feature_fraction': 0.8,
    'bagging_fraction': 0.8,
    'bagging_freq': 1,
    'lambda_l1': 0.1,
    'lambda_l2': 0.1,
    'num_leaves': 31,
    'verbose': -1,
    'n_jobs': -1,
    'seed': SEED,
    'boosting_type': 'gbdt',
}

# Prep data for CV
X = features_df.drop(columns=['subject', 'gesture', 'gesture_encoded'])
y = features_df['gesture_encoded']
groups = features_df['subject']

In [11]:
# Cross Validation Setup
cv = StratifiedGroupKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)

oof_preds = np.zeros((len(features_df), y.nunique()))
oof_true = np.zeros(len(features_df))
fold_scores = []

# Training loop
for fold, (train_idx, val_idx) in enumerate(cv.split(X, y, groups)):
    print(f"--- Fold {fold+1}/{N_SPLITS} ---")
    
    # Split data
    X_train, y_train = X.iloc[train_idx], y.iloc[train_idx]
    X_val, y_val = X.iloc[val_idx], y.iloc[val_idx]
    
    # Init and train model
    model = lgb.LGBMClassifier(**params)
    model.fit(X_train, y_train,
              eval_set=[(X_val, y_val)],
              eval_metric='multi_logloss',
              callbacks=[lgb.early_stopping(100, verbose=False)])
    
    # Predict on validation set
    val_preds = model.predict_proba(X_val)
    oof_preds[val_idx] = val_preds
    oof_true[val_idx] = y_val
    
    # Eval fold performance using balanced_accuracy_score since it is robust to class imbalance
    fold_score = balanced_accuracy_score(y_val, np.argmax(val_preds, axis=1))
    fold_scores.append(fold_score)
    print(f"Fold {fold+1} Balanced Accuracy: {fold_score:.5f}")
    
# Final score and logging
mean_cv_score = np.mean(fold_scores)
print("\n--- CV Summary ---")
print(f"Mean Balanced Accuracy: {mean_cv_score:.5f}")
print(f"Std Dev.: {np.std(fold_scores):.5f}")

# Log the experiment
tracker.log_experiment(
    experiment_name=EXPERIMENT_NAME,
    model_name=MODEL_NAME,
    feature_wave=FEATURE_WAVE,
    cv_score=mean_cv_score,
    params=params,
    notes="Inital baseline model with sequence-level statistical aggregates."
)    

--- Fold 1/5 ---
Fold 1 Balanced Accuracy: 0.60245
--- Fold 2/5 ---
Fold 2 Balanced Accuracy: 0.52807
--- Fold 3/5 ---
Fold 3 Balanced Accuracy: 0.54332
--- Fold 4/5 ---
Fold 4 Balanced Accuracy: 0.48495
--- Fold 5/5 ---
Fold 5 Balanced Accuracy: 0.52315

--- CV Summary ---
Mean Balanced Accuracy: 0.53639
Std Dev.: 0.03821
Experiment 'Baseline-Wave0-LGBM-0' logged to ~/code/kaggle/kaggle-cmi-detect-behavior/experiment_log.csv


[34m[1mwandb[0m: Currently logged in as: [33mb-a-chaudhry[0m ([33mb-a-chaudhry-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


0,1
cv_score,▁

0,1
cv_score,0.53639
feature_wave,Wave-0
model_name,LightGBM


Experiment 'Baseline-Wave0-LGBM-0' logged to W&B


#### Analysis:

1. With 18 classes, a random classifier would achieve a balanced accuracy of approximately 1/18 (~5.6%). The current score of ~54% is nearly 10 times better and is a strong confirmation that even the simplest `Wave-0` features contain a strong predictive signal.
2. The standard deviation of 0.0381 across the folds is relatively low, indicating that the model's performance is consistent and not wildly dependent on a specific subset of subjects. 
3. The benchmark score ofo 0.53639 is now the number to beat! 

### Multi-Model Training and CV

Retaining the configurations from `Wave-0`

In [12]:
model_configs = {
    "XGBoost": {
        "model": xgb.XGBClassifier,
        "params": {
            'objective': 'multi:softprob', 'eval_metric': 'mlogloss',
            'n_estimators': 1000, 'learning_rate': 0.05, 'max_depth': 6,
            'subsample': 0.8, 'colsample_bytree': 0.8, 'gamma': 0.1,
            'seed': SEED, 'n_jobs': -1, 'use_label_encoder': False,
        }
    },
    "CatBoost": {
        "model": cat.CatBoostClassifier,
        "params": {
            'iterations': 1000, 'learning_rate': 0.05, 'depth': 6,
            'loss_function': 'MultiClass', 'eval_metric': 'MultiClass',
            'random_seed': SEED, 'verbose': 0,
        }
    },
    "RandomForest": {
        "model": RandomForestClassifier,
        "params": {
            'n_estimators': 200, 'max_depth': 10, 'min_samples_leaf': 5,
            'random_state': SEED, 'n_jobs': -1, 
        }
    }
}

print("="*50)
print("Starting Multi-Model Training on Wave 0 features")
print("="*50)

for model_name, config in model_configs.items():
    EXPERIMENT_NAME = f"Baseline-{FEATURE_WAVE}-{model_name}"
    
    oof_preds = np.zeros((len(features_df), y.nunique()))
    oof_true = np.zeros(len(features_df))
    fold_scores = []
    
    print(f"\n---> Training Model: {model_name} <---")
    cv = StratifiedGroupKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
    for fold, (train_idx, val_idx) in enumerate(cv.split(X, y, groups)):
        print(f"--- Fold {fold+1}/{N_SPLITS} ---")
        
        # Split data
        X_train, y_train = X.iloc[train_idx], y.iloc[train_idx]
        X_val, y_val = X.iloc[val_idx], y.iloc[val_idx]
        
        # Model Initialization and training
        model = config['model'](**config['params'])
        
        # Handle different early stopping mechanisms
        if model_name == 'XGBoost':
            model.fit(X_train, y_train,
                      verbose=False)
        elif model_name == 'CatBoost':
            model.fit(X_train, y_train,
                      eval_set=[(X_val, y_val)],
                      early_stopping_rounds=100,
                      verbose=0)
        else: #Random Forest
            model.fit(X_train, y_train)
        
        # Store preds and ground truth
        val_preds = model.predict_proba(X_val)
        oof_preds[val_idx] = val_preds
        oof_true[val_idx] = y_val
        
        fold_score = balanced_accuracy_score(y_val, np.argmax(val_preds, axis=1))
        fold_scores.append(fold_score)
        print(f"Fold {fold+1} Balanced Accuracy: {fold_score:.5f}")
    
    # Final score and logging
    mean_cv_score = np.mean(fold_scores)
    print(f"\n--- CV Summary for {model_name} ---")
    print(f"Mean Balanced Accuracy: {mean_cv_score:.5f}")
    print(f"Std Dev: {np.std(fold_scores):.5f}\n")
    
    # Log the experiment
    tracker.log_experiment(
        experiment_name=EXPERIMENT_NAME,
        model_name=model_name,
        feature_wave=FEATURE_WAVE,
        cv_score=mean_cv_score,
        params=config['params'],
        notes=f"Multi model baseline run on Wave 0 features using standard CPU training."
    )    

Starting Multi-Model Training on Wave 0 features

---> Training Model: XGBoost <---
--- Fold 1/5 ---
Fold 1 Balanced Accuracy: 0.62609
--- Fold 2/5 ---
Fold 2 Balanced Accuracy: 0.53573
--- Fold 3/5 ---
Fold 3 Balanced Accuracy: 0.55524
--- Fold 4/5 ---
Fold 4 Balanced Accuracy: 0.50275
--- Fold 5/5 ---
Fold 5 Balanced Accuracy: 0.53145

--- CV Summary for XGBoost ---
Mean Balanced Accuracy: 0.55025
Std Dev: 0.04147

Experiment 'Baseline-Wave-0-XGBoost' logged to ~/code/kaggle/kaggle-cmi-detect-behavior/experiment_log.csv


0,1
cv_score,▁

0,1
cv_score,0.55025
feature_wave,Wave-0
model_name,XGBoost


Experiment 'Baseline-Wave-0-XGBoost' logged to W&B

---> Training Model: CatBoost <---
--- Fold 1/5 ---
Fold 1 Balanced Accuracy: 0.61826
--- Fold 2/5 ---
Fold 2 Balanced Accuracy: 0.53993
--- Fold 3/5 ---
Fold 3 Balanced Accuracy: 0.55241
--- Fold 4/5 ---
Fold 4 Balanced Accuracy: 0.50637
--- Fold 5/5 ---
Fold 5 Balanced Accuracy: 0.54997

--- CV Summary for CatBoost ---
Mean Balanced Accuracy: 0.55339
Std Dev: 0.03637

Experiment 'Baseline-Wave-0-CatBoost' logged to ~/code/kaggle/kaggle-cmi-detect-behavior/experiment_log.csv


0,1
cv_score,▁

0,1
cv_score,0.55339
feature_wave,Wave-0
model_name,CatBoost


Experiment 'Baseline-Wave-0-CatBoost' logged to W&B

---> Training Model: RandomForest <---
--- Fold 1/5 ---
Fold 1 Balanced Accuracy: 0.55992
--- Fold 2/5 ---
Fold 2 Balanced Accuracy: 0.46962
--- Fold 3/5 ---
Fold 3 Balanced Accuracy: 0.48397
--- Fold 4/5 ---
Fold 4 Balanced Accuracy: 0.43562
--- Fold 5/5 ---
Fold 5 Balanced Accuracy: 0.47835

--- CV Summary for RandomForest ---
Mean Balanced Accuracy: 0.48549
Std Dev: 0.04082

Experiment 'Baseline-Wave-0-RandomForest' logged to ~/code/kaggle/kaggle-cmi-detect-behavior/experiment_log.csv


0,1
cv_score,▁

0,1
cv_score,0.48549
feature_wave,Wave-0
model_name,RandomForest


Experiment 'Baseline-Wave-0-RandomForest' logged to W&B


#### Analysis

1. Boosting models are out performing the standard Random Forest. Moving forward, we will be focusing on boosting algos.

## Feature Engineering - Wave 1 - Phase Specific Aggregates

`Note` Tree based models can handle NaNs which will naturally appear for sequences missing a phase like 'Pause'.

In [16]:
def create_wave1_features(df):
    """
    Creates Wave-1 features i.e. phase specific aggregates
    """
    imu_cols = ['acc_x', 'acc_y', 'acc_z', 'rot_w', 'rot_x', 'rot_y', 'rot_z']
    thm_cols = [f'thm_{i}' for i in range(1, 6)]
    tof_cols = [f'tof_{s}_v{p}' for s in range(1, 6) for p in range(64)]
    
    # Replace ToF -1 with NaN for correct stats
    df_feat = df.copy()
    df_feat[tof_cols] = df_feat[tof_cols].replace(-1, np.nan)
    
    # Single mean ToF column for all sensors
    df_feat['tof_mean_all_pixels'] = df_feat[tof_cols].mean(axis=1)
    
    # Define aggregations
    aggs = {}
    for col in imu_cols:
        aggs[col] = ['mean', 'std', 'min', 'max', 'skew']
    for col in thm_cols:
        aggs[col] = ['mean', 'std', 'min', 'max']
    aggs['tof_mean_all_pixels'] = ['mean', 'std', 'min', 'max']
    
    # Phase specific aggregations
    phase_agg_df = df_feat.groupby(['sequence_id', 'phase']).agg(aggs)
    
    # Flatten the first level of column names (e.g acc_x --> acc_x_mean)
    phase_agg_df.columns = ['_'.join(col).strip() for col in phase_agg_df.columns.values]
    
    # Unstack the phase level to pivot phases into columns and flatten
    phase_agg_df_unstacked = phase_agg_df.unstack(level='phase')
    phase_agg_df_unstacked.columns = ['_'.join(col).strip() for col in phase_agg_df_unstacked.columns.values] 
    
    # Sequence level meta data (target, subject etc.) to combine with features
    meta_df = df.groupby('sequence_id').first()
    final_df = pd.concat([meta_df[['subject', 'gesture']], phase_agg_df_unstacked], axis=1)
    
    # Encode gesture target
    final_df['gesture_encoded'] = final_df['gesture'].astype('category').cat.codes
    
    print(f"Feature engineering complete. Shape of features: {final_df.shape}")
    return final_df
    

In [17]:
features_df = create_wave1_features(train_df)
display_all(features_df)

Feature engineering complete. Shape of features: (8151, 121)


Unnamed: 0_level_0,subject,gesture,acc_x_mean_Gesture,acc_x_mean_Transition,acc_x_std_Gesture,acc_x_std_Transition,acc_x_min_Gesture,acc_x_min_Transition,acc_x_max_Gesture,acc_x_max_Transition,acc_x_skew_Gesture,acc_x_skew_Transition,acc_y_mean_Gesture,acc_y_mean_Transition,acc_y_std_Gesture,acc_y_std_Transition,acc_y_min_Gesture,acc_y_min_Transition,acc_y_max_Gesture,acc_y_max_Transition,acc_y_skew_Gesture,acc_y_skew_Transition,acc_z_mean_Gesture,acc_z_mean_Transition,acc_z_std_Gesture,acc_z_std_Transition,acc_z_min_Gesture,acc_z_min_Transition,acc_z_max_Gesture,acc_z_max_Transition,acc_z_skew_Gesture,acc_z_skew_Transition,rot_w_mean_Gesture,rot_w_mean_Transition,rot_w_std_Gesture,rot_w_std_Transition,rot_w_min_Gesture,rot_w_min_Transition,rot_w_max_Gesture,rot_w_max_Transition,rot_w_skew_Gesture,rot_w_skew_Transition,rot_x_mean_Gesture,rot_x_mean_Transition,rot_x_std_Gesture,rot_x_std_Transition,rot_x_min_Gesture,rot_x_min_Transition,rot_x_max_Gesture,rot_x_max_Transition,rot_x_skew_Gesture,rot_x_skew_Transition,rot_y_mean_Gesture,rot_y_mean_Transition,rot_y_std_Gesture,rot_y_std_Transition,rot_y_min_Gesture,rot_y_min_Transition,rot_y_max_Gesture,rot_y_max_Transition,rot_y_skew_Gesture,rot_y_skew_Transition,rot_z_mean_Gesture,rot_z_mean_Transition,rot_z_std_Gesture,rot_z_std_Transition,rot_z_min_Gesture,rot_z_min_Transition,rot_z_max_Gesture,rot_z_max_Transition,rot_z_skew_Gesture,rot_z_skew_Transition,thm_1_mean_Gesture,thm_1_mean_Transition,thm_1_std_Gesture,thm_1_std_Transition,thm_1_min_Gesture,thm_1_min_Transition,thm_1_max_Gesture,thm_1_max_Transition,thm_2_mean_Gesture,thm_2_mean_Transition,thm_2_std_Gesture,thm_2_std_Transition,thm_2_min_Gesture,thm_2_min_Transition,thm_2_max_Gesture,thm_2_max_Transition,thm_3_mean_Gesture,thm_3_mean_Transition,thm_3_std_Gesture,thm_3_std_Transition,thm_3_min_Gesture,thm_3_min_Transition,thm_3_max_Gesture,thm_3_max_Transition,thm_4_mean_Gesture,thm_4_mean_Transition,thm_4_std_Gesture,thm_4_std_Transition,thm_4_min_Gesture,thm_4_min_Transition,thm_4_max_Gesture,thm_4_max_Transition,thm_5_mean_Gesture,thm_5_mean_Transition,thm_5_std_Gesture,thm_5_std_Transition,thm_5_min_Gesture,thm_5_min_Transition,thm_5_max_Gesture,thm_5_max_Transition,tof_mean_all_pixels_mean_Gesture,tof_mean_all_pixels_mean_Transition,tof_mean_all_pixels_std_Gesture,tof_mean_all_pixels_std_Transition,tof_mean_all_pixels_min_Gesture,tof_mean_all_pixels_min_Transition,tof_mean_all_pixels_max_Gesture,tof_mean_all_pixels_max_Transition,gesture_encoded
sequence_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1,Unnamed: 58_level_1,Unnamed: 59_level_1,Unnamed: 60_level_1,Unnamed: 61_level_1,Unnamed: 62_level_1,Unnamed: 63_level_1,Unnamed: 64_level_1,Unnamed: 65_level_1,Unnamed: 66_level_1,Unnamed: 67_level_1,Unnamed: 68_level_1,Unnamed: 69_level_1,Unnamed: 70_level_1,Unnamed: 71_level_1,Unnamed: 72_level_1,Unnamed: 73_level_1,Unnamed: 74_level_1,Unnamed: 75_level_1,Unnamed: 76_level_1,Unnamed: 77_level_1,Unnamed: 78_level_1,Unnamed: 79_level_1,Unnamed: 80_level_1,Unnamed: 81_level_1,Unnamed: 82_level_1,Unnamed: 83_level_1,Unnamed: 84_level_1,Unnamed: 85_level_1,Unnamed: 86_level_1,Unnamed: 87_level_1,Unnamed: 88_level_1,Unnamed: 89_level_1,Unnamed: 90_level_1,Unnamed: 91_level_1,Unnamed: 92_level_1,Unnamed: 93_level_1,Unnamed: 94_level_1,Unnamed: 95_level_1,Unnamed: 96_level_1,Unnamed: 97_level_1,Unnamed: 98_level_1,Unnamed: 99_level_1,Unnamed: 100_level_1,Unnamed: 101_level_1,Unnamed: 102_level_1,Unnamed: 103_level_1,Unnamed: 104_level_1,Unnamed: 105_level_1,Unnamed: 106_level_1,Unnamed: 107_level_1,Unnamed: 108_level_1,Unnamed: 109_level_1,Unnamed: 110_level_1,Unnamed: 111_level_1,Unnamed: 112_level_1,Unnamed: 113_level_1,Unnamed: 114_level_1,Unnamed: 115_level_1,Unnamed: 116_level_1,Unnamed: 117_level_1,Unnamed: 118_level_1,Unnamed: 119_level_1,Unnamed: 120_level_1,Unnamed: 121_level_1
SEQ_000007,SUBJ_059520,Cheek - pinch skin,6.875854,5.227969,0.779424,1.334370,5.492188,3.613281,9.015625,7.292969,0.671795,0.119323,5.640137,1.708125,0.582177,3.493182,4.683594,-2.019531,6.519531,6.214844,-0.177613,0.188431,4.259521,7.265156,1.247767,2.335664,1.093750,3.125000,6.875000,9.792969,0.123765,-0.430736,0.230537,0.305862,0.043031,0.073551,0.170898,0.134399,0.327942,0.379272,0.684978,-0.869169,-0.302607,-0.252927,0.058485,0.040159,-0.442871,-0.355164,-0.212952,-0.204163,-0.594216,-0.946573,-0.431194,-0.203823,0.033207,0.199606,-0.478027,-0.447327,-0.364746,0.005066,0.533235,-0.068043,-0.814034,-0.868665,0.023412,0.037632,-0.854187,-0.914856,-0.757935,-0.809753,0.157845,0.328752,28.572991,28.704366,0.381308,0.769848,27.782368,27.696510,29.215113,30.543730,30.761249,28.049465,0.485551,3.299713,30.041212,24.558798,31.663404,32.010178,29.116922,27.885000,0.197033,1.664297,28.557234,25.907490,29.396484,30.090014,29.078639,29.305039,0.148042,0.349623,28.819799,28.592863,29.346125,29.761480,28.359300,27.443073,0.513923,0.982867,27.721840,26.047148,29.428299,28.815403,71.437350,148.414330,16.201963,52.966776,43.805714,76.702830,104.703883,203.551402,1
SEQ_000008,SUBJ_020948,Forehead - pull hairline,3.449929,3.353906,0.736728,1.346681,1.734375,1.886719,4.949219,5.906250,-0.292417,0.278393,7.261364,3.472433,0.824113,3.638059,4.718750,-0.222656,8.667969,7.976562,-0.761132,0.295310,5.647609,7.462277,1.740919,2.752709,1.722656,3.175781,9.000000,11.074219,-0.225780,-0.278592,0.229918,0.256294,0.026109,0.084777,0.185181,0.157593,0.305420,0.341980,0.725468,-0.163603,-0.091453,-0.141370,0.028465,0.052901,-0.160339,-0.263306,-0.050537,-0.079468,-0.668575,-0.427363,-0.432821,-0.257005,0.079182,0.223424,-0.495483,-0.508606,-0.093811,-0.031555,3.356636,-0.156264,-0.862236,-0.887312,0.021701,0.053133,-0.933899,-0.937805,-0.837036,-0.814697,-2.201159,0.211861,32.273003,28.758968,1.175071,2.644869,26.886362,25.985313,32.870808,32.105263,32.016294,27.473724,1.988673,3.963966,24.039984,23.907709,33.100945,32.403992,30.335959,28.089808,1.261277,3.822748,25.774496,24.414917,31.516409,32.316135,31.252343,29.793223,0.365182,0.830051,29.622156,28.755495,31.613327,30.865301,26.816013,24.889104,0.895062,0.264488,24.378687,24.181562,28.054575,25.145569,125.667271,174.377748,19.353654,51.496462,109.512000,110.835681,214.094595,222.333333,6
SEQ_000013,SUBJ_040282,Cheek - pinch skin,-7.591276,-6.364640,1.143683,1.162073,-9.250000,-7.714844,-5.226562,-3.347656,0.395099,1.138648,3.751562,0.513077,0.571292,2.984959,2.656250,-3.273438,4.683594,3.957031,0.011108,-0.231933,-5.493620,-6.818444,0.968642,1.382213,-6.808594,-10.945312,-3.515625,-5.277344,0.279689,-1.467629,0.474384,0.285024,0.044857,0.172503,0.404968,0.061157,0.540771,0.480835,-0.014976,-0.239875,0.254718,0.453091,0.025604,0.232688,0.221008,0.140991,0.304016,0.726501,0.539700,-0.009959,0.836882,0.753060,0.022810,0.119621,0.797424,0.580505,0.875366,0.881653,-0.071165,-0.405698,0.078009,-0.095658,0.023455,0.212688,0.043457,-0.406799,0.105957,0.129761,-0.259239,-0.384956,24.311727,24.797481,0.090862,0.572397,24.181389,24.223915,24.484457,25.634346,24.065670,24.760440,0.061293,0.787152,23.933413,24.035677,24.194889,26.175961,24.744966,25.084761,0.207938,0.283411,24.406981,24.587378,25.136721,25.512794,24.664993,25.277597,0.150427,0.723736,24.419798,24.516800,24.930002,26.452927,24.628350,24.870243,0.114652,0.693859,24.437956,24.167980,24.858009,26.051331,217.693390,167.620456,8.296971,56.802998,200.689189,84.317269,234.555556,225.155556,1
SEQ_000016,SUBJ_052342,Write name on leg,5.291590,5.818142,0.940102,1.174814,3.859375,3.437500,9.378906,7.925781,2.323306,-0.451960,-4.744830,-3.984954,0.441674,0.494274,-5.718750,-4.875000,-3.113281,-2.960938,1.184362,0.267835,-6.674517,1.260995,1.942150,6.764263,-8.078125,-7.734375,4.062500,8.355469,5.405182,-0.484150,0.353268,0.370924,0.032622,0.049561,0.305847,0.277527,0.459045,0.449829,2.043232,-0.369664,-0.826032,-0.604795,0.119764,0.229897,-0.882019,-0.893677,-0.409119,-0.384827,3.084322,-0.394076,-0.303709,-0.122000,0.087160,0.167080,-0.356934,-0.368713,0.013916,0.035889,3.026020,-0.390096,-0.219314,-0.545485,0.176431,0.328741,-0.803955,-0.817688,-0.094604,-0.082275,-2.872020,0.443606,34.167663,28.483457,2.089392,3.578696,26.586309,25.413513,36.053188,34.632439,34.378147,28.104436,2.681950,3.837888,25.087893,25.018881,36.705894,34.943771,31.298491,26.829364,1.665449,3.108468,24.644175,24.128819,33.617542,31.820757,34.694450,30.393524,1.245293,3.427506,29.947037,27.227589,35.665222,35.334206,32.257148,29.101899,2.677430,3.228711,27.337736,26.312038,35.801083,34.356197,14.492667,72.325497,13.559525,53.780371,8.534615,5.501859,80.681319,173.119048,17
SEQ_000018,SUBJ_032165,Forehead - pull hairline,6.235639,3.881445,0.241211,1.901476,5.718750,1.964844,6.832031,6.332031,0.194196,0.300721,6.068704,0.779492,0.211266,4.005993,5.566406,-3.164062,6.449219,6.718750,-0.492244,0.461042,4.739890,7.972266,0.455216,2.251710,4.148438,4.261719,6.140625,9.933594,1.217216,-0.934850,0.841427,0.889304,0.010841,0.039327,0.828247,0.831604,0.865784,0.925049,0.666296,-0.697839,0.272425,0.016040,0.014179,0.211085,0.237183,-0.184204,0.305542,0.305481,-0.044626,0.472734,-0.431091,-0.218021,0.016419,0.177740,-0.457458,-0.435425,-0.398071,-0.022644,0.255425,-0.235940,-0.176968,-0.283978,0.007988,0.085287,-0.193359,-0.367676,-0.163818,-0.159668,0.057725,0.557964,29.581021,27.752010,0.380580,1.088661,28.659307,26.533083,30.267483,29.387503,30.334161,27.916263,0.426640,1.866541,29.461178,25.795074,31.035217,31.007061,27.559514,26.205623,0.381086,1.028559,26.780855,25.127720,28.381996,28.468761,28.062466,27.466412,0.185449,0.473282,27.696522,26.827133,28.400864,28.368055,31.859936,29.576891,0.122340,1.394618,31.598164,28.282324,32.180752,31.793083,118.690740,166.242279,4.706987,38.990861,108.169291,106.621514,126.880952,204.458333,6
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
SEQ_065508,SUBJ_027682,Text on phone,-7.286932,3.387445,0.393596,7.357197,-8.058594,-9.359375,-6.410156,12.167969,0.251036,-0.565475,6.254498,1.831577,0.255871,3.381171,5.542969,-2.042969,6.808594,7.425781,-0.516027,0.717290,-2.798177,-2.354288,0.333119,3.774008,-3.640625,-10.808594,-2.222656,3.179688,-0.582944,-0.904098,0.123006,0.367488,0.017124,0.269607,0.074341,0.009705,0.148315,0.687256,-1.093809,0.151336,-0.483036,-0.354645,0.012533,0.139434,-0.510071,-0.530457,-0.457642,0.174744,0.069883,1.206415,0.636162,-0.273304,0.006912,0.626471,0.624023,-0.893433,0.647217,0.644104,-0.127790,0.618518,0.588429,0.103543,0.010567,0.429799,0.565186,-0.556152,0.612488,0.796509,0.027960,0.558277,28.618700,27.659832,0.174713,2.418055,28.266153,24.017521,28.933847,30.199110,30.241228,27.995612,0.250120,2.707673,29.701616,24.144169,30.598179,30.412640,27.666960,27.730368,0.392424,1.774875,27.047375,24.401672,28.319439,29.677988,27.869123,27.923252,0.160213,2.165651,27.614656,24.345264,28.260014,30.226200,24.568769,26.219772,0.114216,2.327463,24.301918,23.851444,24.770197,29.097397,122.831703,65.681948,13.819683,52.748440,105.198113,16.914773,149.058824,188.774194,14
SEQ_065519,SUBJ_050642,Neck - scratch,0.811458,3.822917,0.862646,3.676374,-0.628906,-1.433594,1.937500,9.632812,-0.364536,-0.477535,9.518099,1.651403,0.245501,6.125286,8.769531,-3.800781,10.035156,9.683594,-0.605971,0.469965,-3.321224,2.824219,0.503772,5.168124,-4.156250,-4.156250,-2.472656,8.554688,0.184275,-0.523815,0.068650,0.323758,0.033377,0.153423,0.002869,0.133362,0.123291,0.459167,-0.126919,-0.309932,0.017643,-0.168518,0.056886,0.268920,-0.052307,-0.405579,0.146179,0.166016,0.595832,0.335591,-0.156085,0.297954,0.806781,0.433539,-0.816528,-0.591125,0.820557,0.814026,0.430013,0.232410,-0.126654,-0.245738,0.576403,0.668840,-0.598328,-0.839111,0.592346,0.624084,0.430244,0.388220,28.916185,25.363529,0.218017,2.035027,27.998026,23.151098,29.111376,28.002087,29.116303,25.877557,0.207229,1.689515,28.216278,23.118992,29.345762,28.235514,29.107903,25.434269,0.126985,2.631336,28.747730,23.136972,29.309811,28.761459,28.231606,25.042248,0.228395,2.039401,27.508478,23.197491,28.513880,27.501440,27.450978,24.623817,0.545437,0.680856,25.788750,23.379164,28.023863,25.831879,59.195932,165.164429,5.906686,66.545070,47.666667,71.388128,72.681416,235.750000,10
SEQ_065522,SUBJ_040282,Above ear - pull hair,-9.020573,1.589518,0.501440,5.849199,-10.023438,-9.339844,-7.847656,7.437500,0.315421,-1.283318,-3.511719,-0.959391,0.813275,1.719252,-4.785156,-3.441406,-2.332031,7.203125,0.298439,2.947546,-1.737630,6.378174,0.814991,4.556230,-3.742188,-3.820312,-0.257812,11.660156,-0.422970,-1.181143,0.498043,0.737235,0.020919,0.229527,0.457153,0.053833,0.531372,0.898743,-0.518649,-1.209716,-0.662750,-0.246306,0.019296,0.296925,-0.708923,-0.713257,-0.632629,0.641052,-0.646267,0.037016,0.385569,-0.087603,0.025810,0.255602,0.324707,-0.549500,0.449890,0.427673,-0.025129,0.756552,0.402863,-0.186490,0.018648,0.393318,0.369812,-0.533081,0.433838,0.502380,-0.113519,1.170457,28.241747,25.259864,0.271362,1.543620,27.698671,24.172684,28.850843,28.814703,31.219762,25.352062,0.786523,2.260060,27.654655,24.009996,32.134312,31.506523,30.562498,25.560998,0.150289,2.770233,30.319349,23.788067,30.888124,30.425909,28.158508,25.617107,0.275256,1.687317,27.623669,24.423225,29.049471,29.694206,25.819240,26.744059,0.112647,0.454887,25.561703,25.591505,26.016832,27.349192,102.745984,133.606241,6.654648,11.330897,88.834532,101.055944,116.441379,165.308824,0
SEQ_065526,SUBJ_063447,Cheek - pinch skin,7.894531,5.653646,0.252330,2.180823,7.324219,2.414062,8.390625,9.343750,-0.491839,0.135434,6.123958,1.706814,0.201382,5.072068,5.636719,-7.464844,6.671875,6.441406,-0.180862,-0.600652,0.311068,4.820747,0.422256,4.083398,-0.769531,-0.273438,1.449219,9.839844,0.096727,-0.130372,0.267440,0.349592,0.011516,0.070285,0.247009,0.248779,0.290100,0.431396,0.132236,-0.209322,-0.348621,-0.277395,0.015033,0.095053,-0.382263,-0.399963,-0.322083,-0.077759,-0.490353,0.948633,-0.606787,-0.302365,0.005854,0.325336,-0.620911,-0.633240,-0.597351,0.159424,-0.367587,0.334026,-0.662065,-0.766208,0.006043,0.098610,-0.671326,-0.894653,-0.646606,-0.646301,0.708584,0.066892,29.117884,27.831675,0.097559,1.721285,28.929562,25.056185,29.359003,29.842354,30.896496,28.327643,0.175823,2.737208,30.592413,24.754093,31.295233,31.209423,29.768847,27.314962,0.099378,2.558627,29.593616,24.062937,29.951387,29.851467,28.614043,28.278936,0.128938,0.728813,28.319683,26.993855,28.908787,29.438768,29.521346,28.327831,0.070518,1.805568,29.348232,25.141193,29.648335,29.871538,132.500703,165.610287,3.953532,41.854151,126.161716,121.307407,139.636364,233.750000,1


In [18]:
FEATURE_WAVE = "Wave 1"
N_SPLITS = 5
SEED = 42

In [19]:
# --- Prepare data for CV ---
X = features_df.drop(columns=['subject', 'gesture', 'gesture_encoded'])
y = features_df['gesture_encoded']
groups = features_df['subject']

In [22]:
# --- Updated Model Configurations (Top 3) ---
model_configs = {
    "LightGBM": {
        "model": lgb.LGBMClassifier,
        "params": { 'objective': 'multiclass', 'num_class': features_df['gesture_encoded'].nunique(), 'metric': 'multi_logloss', 'n_estimators': 1000, 'learning_rate': 0.05, 'feature_fraction': 0.8, 'bagging_fraction': 0.8, 'bagging_freq': 1, 'lambda_l1': 0.1, 'lambda_l2': 0.1, 'num_leaves': 31, 'verbose': -1, 'n_jobs': -1, 'seed': SEED, 'boosting_type': 'gbdt' }
    },
    "XGBoost": {
        "model": xgb.XGBClassifier,
        "params": { 'objective': 'multi:softprob', 'eval_metric': 'mlogloss', 'n_estimators': 1000, 'learning_rate': 0.05, 'max_depth': 6, 'subsample': 0.8, 'colsample_bytree': 0.8, 'gamma': 0.1, 'seed': SEED, 'n_jobs': -1, 'use_label_encoder': False }
    },
    "CatBoost": {
        "model": cat.CatBoostClassifier,
        "params": { 'iterations': 1000, 'learning_rate': 0.05, 'depth': 6, 'loss_function': 'MultiClass', 'eval_metric': 'MultiClass', 'random_seed': SEED, 'verbose': 0 }
    }
}

In [23]:
# --- Cross-Validation Loop for Each Model ---
print("="*50)
print(f"Starting Multi-Model CPU-Based Training on {FEATURE_WAVE} Features")
print("="*50)

for model_name, config in model_configs.items():
    
    EXPERIMENT_NAME = f"{FEATURE_WAVE}-{model_name}-CPU"
    fold_scores = []

    print(f"\n---> Training Model: {model_name} <---")
    
    cv = StratifiedGroupKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
    for fold, (train_idx, val_idx) in enumerate(cv.split(X, y, groups)):
        print(f"--- Fold {fold+1}/{N_SPLITS} ---")
        
        X_train, y_train = X.iloc[train_idx], y.iloc[train_idx]
        X_val, y_val = X.iloc[val_idx], y.iloc[val_idx]
        
        model = config['model'](**config['params'])
        
        if model_name == 'LightGBM':
            model.fit(X_train, y_train, eval_set=[(X_val, y_val)], callbacks=[lgb.early_stopping(100, verbose=False)])
        elif model_name == 'CatBoost':
            model.fit(X_train, y_train, eval_set=[(X_val, y_val)], early_stopping_rounds=100)
        elif model_name == 'XGBoost':
            model.fit(X_train, y_train, verbose=False)
            
        val_preds = model.predict_proba(X_val)
        fold_score = balanced_accuracy_score(y_val, np.argmax(val_preds, axis=1))
        fold_scores.append(fold_score)
        print(f"Fold {fold+1} Balanced Accuracy: {fold_score:.5f}")

    mean_cv_score = np.mean(fold_scores)
    print(f"\n--- CV Summary for {model_name} ---")
    print(f"Mean Balanced Accuracy: {mean_cv_score:.5f}")
    print(f"Std Dev: {np.std(fold_scores):.5f}\n")
    
    tracker.log_experiment(
        experiment_name=EXPERIMENT_NAME,
        model_name=model_name,
        feature_wave=FEATURE_WAVE,
        cv_score=mean_cv_score,
        params=config['params'],
        notes="Phase-specific statistical aggregates."
    )

Starting Multi-Model CPU-Based Training on Wave 1 Features

---> Training Model: LightGBM <---
--- Fold 1/5 ---
Fold 1 Balanced Accuracy: 0.68979
--- Fold 2/5 ---
Fold 2 Balanced Accuracy: 0.63108
--- Fold 3/5 ---
Fold 3 Balanced Accuracy: 0.65937
--- Fold 4/5 ---
Fold 4 Balanced Accuracy: 0.58927
--- Fold 5/5 ---
Fold 5 Balanced Accuracy: 0.59668

--- CV Summary for LightGBM ---
Mean Balanced Accuracy: 0.63324
Std Dev: 0.03783

Experiment 'Wave 1-LightGBM-CPU' logged to ~/code/kaggle/kaggle-cmi-detect-behavior/experiment_log.csv


0,1
cv_score,▁

0,1
cv_score,0.63324
feature_wave,Wave 1
model_name,LightGBM


Experiment 'Wave 1-LightGBM-CPU' logged to W&B

---> Training Model: XGBoost <---
--- Fold 1/5 ---
Fold 1 Balanced Accuracy: 0.70552
--- Fold 2/5 ---
Fold 2 Balanced Accuracy: 0.62616
--- Fold 3/5 ---
Fold 3 Balanced Accuracy: 0.65261
--- Fold 4/5 ---
Fold 4 Balanced Accuracy: 0.60229
--- Fold 5/5 ---
Fold 5 Balanced Accuracy: 0.61356

--- CV Summary for XGBoost ---
Mean Balanced Accuracy: 0.64003
Std Dev: 0.03678

Experiment 'Wave 1-XGBoost-CPU' logged to ~/code/kaggle/kaggle-cmi-detect-behavior/experiment_log.csv


0,1
cv_score,▁

0,1
cv_score,0.64003
feature_wave,Wave 1
model_name,XGBoost


Experiment 'Wave 1-XGBoost-CPU' logged to W&B

---> Training Model: CatBoost <---
--- Fold 1/5 ---
Fold 1 Balanced Accuracy: 0.71385
--- Fold 2/5 ---
Fold 2 Balanced Accuracy: 0.65075
--- Fold 3/5 ---
Fold 3 Balanced Accuracy: 0.64162
--- Fold 4/5 ---
Fold 4 Balanced Accuracy: 0.60518
--- Fold 5/5 ---
Fold 5 Balanced Accuracy: 0.61343

--- CV Summary for CatBoost ---
Mean Balanced Accuracy: 0.64497
Std Dev: 0.03839

Experiment 'Wave 1-CatBoost-CPU' logged to ~/code/kaggle/kaggle-cmi-detect-behavior/experiment_log.csv


0,1
cv_score,▁

0,1
cv_score,0.64497
feature_wave,Wave 1
model_name,CatBoost


Experiment 'Wave 1-CatBoost-CPU' logged to W&B


#### Analysis

1. Massive Performance Improvement: This was not an incremental gain; it was a huge leap across the board. The phase-specific features provided a much richer, more granular signal for the models to learn from.

2. Hypothesis Confirmed: The core hypothesis for Wave 1 was that the statistical properties of the sensors are fundamentally different and more informative during the Transition and Gesture phases. A ~9.2% absolute improvement in our best model's accuracy validates this.

3. **CatBoost Solidifies its Lead**: CatBoost continues to be the top-performing algorithm on this dataset, achieving a new best score of **0.64497**. This will be our new benchmark to beat. The stability remains excellent, with a standard deviation of ~0.038, indicating the performance is consistent across all folds.

## Feature Engineering - Wave 2 - Physically Relevant Features From IMU and Thermopile Sensors. 

IMU: Vector magnitudes (acc_mag, rot_mag) and jerk (the derivative of acceleration magnitude).

Thermopile: Temperature gradients between adjacent sensors.

In [24]:
def create_wave2_features(df):
    """ 
    Creates wave 2 features: Adds advanced IMU (magnitude, jerk) and Thermopile (gradients)
    features before performing phase specifc aggregation.
    """
    df_feat = df.copy()
    
    # Create sensor derived features
    # IMU magnitudes
    df_feat['acc_mag'] = np.sqrt(df_feat['acc_x']**2 + df_feat['acc_y']**2 + df_feat['acc_z']**2)
    df_feat['rot_mag'] = np.sqrt(df_feat['rot_w']**2 + df_feat['rot_x']**2 + df_feat['rot_y']**2 + df_feat['rot_z']**2)
    
    # IMU Jerk (Rate of change of acceleration magnitude)
    # Group by sequence_id to prevent calculating diff across different sequences
    df_feat['jerk'] = df_feat.groupby('sequence_id')['acc_mag'].diff().fillna(0)
    
    # Thermopile gradients
    for i in range(1, 5):
        df_feat[f'thm_grad_{i}_{i+1}'] = df_feat[f'thm_{i}'] - df_feat[f'thm_{i+1}']
    
    # Repeat Wave 1 Phase-Specific Aggregations
    # Define columns to aggregate
    imu_derived_cols = ['acc_mag', 'rot_mag', 'jerk']
    thm_grad_cols = [f'thm_grad_{i}_{i+1}' for i in range(1, 5)]
    original_sensor_cols = [c for c in df.columns if 'acc_' in c or 'rot_' in c or 'thm_' in c]
    tof_cols = [f'tof_{s}_v{p}' for s in range(1, 6) for p in range(64)]

    df_feat[tof_cols] = df_feat[tof_cols].replace(-1, np.nan)
    df_feat['tof_mean_all_pixels'] = df_feat[tof_cols].mean(axis=1)

    aggs = {}
    for col in original_sensor_cols + imu_derived_cols + thm_grad_cols:
        aggs[col] = ['mean', 'std', 'min', 'max', 'skew']
    aggs['tof_mean_all_pixels'] = ['mean', 'std', 'min', 'max']

    phase_agg_df = df_feat.groupby(['sequence_id', 'phase']).agg(aggs)
    phase_agg_df.columns = ['_'.join(col).strip() for col in phase_agg_df.columns.values]
    phase_agg_df_unstacked = phase_agg_df.unstack(level='phase')
    phase_agg_df_unstacked.columns = ['_'.join(col).strip() for col in phase_agg_df_unstacked.columns.values]
    
    meta_df = df.groupby('sequence_id').first()
    final_df = pd.concat([meta_df[['subject', 'gesture']], phase_agg_df_unstacked], axis=1)
    final_df['gesture_encoded'] = final_df['gesture'].astype('category').cat.codes
    
    print(f"Feature engineering complete. Shape of features: {final_df.shape}")
    return final_df

In [26]:
features_df = create_wave2_features(train_df)
display_all(features_df.head())

Feature engineering complete. Shape of features: (8151, 201)


Unnamed: 0_level_0,subject,gesture,acc_x_mean_Gesture,acc_x_mean_Transition,acc_x_std_Gesture,acc_x_std_Transition,acc_x_min_Gesture,acc_x_min_Transition,acc_x_max_Gesture,acc_x_max_Transition,acc_x_skew_Gesture,acc_x_skew_Transition,acc_y_mean_Gesture,acc_y_mean_Transition,acc_y_std_Gesture,acc_y_std_Transition,acc_y_min_Gesture,acc_y_min_Transition,acc_y_max_Gesture,acc_y_max_Transition,acc_y_skew_Gesture,acc_y_skew_Transition,acc_z_mean_Gesture,acc_z_mean_Transition,acc_z_std_Gesture,acc_z_std_Transition,acc_z_min_Gesture,acc_z_min_Transition,acc_z_max_Gesture,acc_z_max_Transition,acc_z_skew_Gesture,acc_z_skew_Transition,rot_w_mean_Gesture,rot_w_mean_Transition,rot_w_std_Gesture,rot_w_std_Transition,rot_w_min_Gesture,rot_w_min_Transition,rot_w_max_Gesture,rot_w_max_Transition,rot_w_skew_Gesture,rot_w_skew_Transition,rot_x_mean_Gesture,rot_x_mean_Transition,rot_x_std_Gesture,rot_x_std_Transition,rot_x_min_Gesture,rot_x_min_Transition,rot_x_max_Gesture,rot_x_max_Transition,rot_x_skew_Gesture,rot_x_skew_Transition,rot_y_mean_Gesture,rot_y_mean_Transition,rot_y_std_Gesture,rot_y_std_Transition,rot_y_min_Gesture,rot_y_min_Transition,rot_y_max_Gesture,rot_y_max_Transition,rot_y_skew_Gesture,rot_y_skew_Transition,rot_z_mean_Gesture,rot_z_mean_Transition,rot_z_std_Gesture,rot_z_std_Transition,rot_z_min_Gesture,rot_z_min_Transition,rot_z_max_Gesture,rot_z_max_Transition,rot_z_skew_Gesture,rot_z_skew_Transition,thm_1_mean_Gesture,thm_1_mean_Transition,thm_1_std_Gesture,thm_1_std_Transition,thm_1_min_Gesture,thm_1_min_Transition,thm_1_max_Gesture,thm_1_max_Transition,thm_1_skew_Gesture,thm_1_skew_Transition,thm_2_mean_Gesture,thm_2_mean_Transition,thm_2_std_Gesture,thm_2_std_Transition,thm_2_min_Gesture,thm_2_min_Transition,thm_2_max_Gesture,thm_2_max_Transition,thm_2_skew_Gesture,thm_2_skew_Transition,thm_3_mean_Gesture,thm_3_mean_Transition,thm_3_std_Gesture,thm_3_std_Transition,thm_3_min_Gesture,thm_3_min_Transition,thm_3_max_Gesture,thm_3_max_Transition,thm_3_skew_Gesture,thm_3_skew_Transition,thm_4_mean_Gesture,thm_4_mean_Transition,thm_4_std_Gesture,thm_4_std_Transition,thm_4_min_Gesture,thm_4_min_Transition,thm_4_max_Gesture,thm_4_max_Transition,thm_4_skew_Gesture,thm_4_skew_Transition,thm_5_mean_Gesture,thm_5_mean_Transition,thm_5_std_Gesture,thm_5_std_Transition,thm_5_min_Gesture,thm_5_min_Transition,thm_5_max_Gesture,thm_5_max_Transition,thm_5_skew_Gesture,thm_5_skew_Transition,acc_mag_mean_Gesture,acc_mag_mean_Transition,acc_mag_std_Gesture,acc_mag_std_Transition,acc_mag_min_Gesture,acc_mag_min_Transition,acc_mag_max_Gesture,acc_mag_max_Transition,acc_mag_skew_Gesture,acc_mag_skew_Transition,rot_mag_mean_Gesture,rot_mag_mean_Transition,rot_mag_std_Gesture,rot_mag_std_Transition,rot_mag_min_Gesture,rot_mag_min_Transition,rot_mag_max_Gesture,rot_mag_max_Transition,rot_mag_skew_Gesture,rot_mag_skew_Transition,jerk_mean_Gesture,jerk_mean_Transition,jerk_std_Gesture,jerk_std_Transition,jerk_min_Gesture,jerk_min_Transition,jerk_max_Gesture,jerk_max_Transition,jerk_skew_Gesture,jerk_skew_Transition,thm_grad_1_2_mean_Gesture,thm_grad_1_2_mean_Transition,thm_grad_1_2_std_Gesture,thm_grad_1_2_std_Transition,thm_grad_1_2_min_Gesture,thm_grad_1_2_min_Transition,thm_grad_1_2_max_Gesture,thm_grad_1_2_max_Transition,thm_grad_1_2_skew_Gesture,thm_grad_1_2_skew_Transition,thm_grad_2_3_mean_Gesture,thm_grad_2_3_mean_Transition,thm_grad_2_3_std_Gesture,thm_grad_2_3_std_Transition,thm_grad_2_3_min_Gesture,thm_grad_2_3_min_Transition,thm_grad_2_3_max_Gesture,thm_grad_2_3_max_Transition,thm_grad_2_3_skew_Gesture,thm_grad_2_3_skew_Transition,thm_grad_3_4_mean_Gesture,thm_grad_3_4_mean_Transition,thm_grad_3_4_std_Gesture,thm_grad_3_4_std_Transition,thm_grad_3_4_min_Gesture,thm_grad_3_4_min_Transition,thm_grad_3_4_max_Gesture,thm_grad_3_4_max_Transition,thm_grad_3_4_skew_Gesture,thm_grad_3_4_skew_Transition,thm_grad_4_5_mean_Gesture,thm_grad_4_5_mean_Transition,thm_grad_4_5_std_Gesture,thm_grad_4_5_std_Transition,thm_grad_4_5_min_Gesture,thm_grad_4_5_min_Transition,thm_grad_4_5_max_Gesture,thm_grad_4_5_max_Transition,thm_grad_4_5_skew_Gesture,thm_grad_4_5_skew_Transition,tof_mean_all_pixels_mean_Gesture,tof_mean_all_pixels_mean_Transition,tof_mean_all_pixels_std_Gesture,tof_mean_all_pixels_std_Transition,tof_mean_all_pixels_min_Gesture,tof_mean_all_pixels_min_Transition,tof_mean_all_pixels_max_Gesture,tof_mean_all_pixels_max_Transition,gesture_encoded
sequence_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1,Unnamed: 58_level_1,Unnamed: 59_level_1,Unnamed: 60_level_1,Unnamed: 61_level_1,Unnamed: 62_level_1,Unnamed: 63_level_1,Unnamed: 64_level_1,Unnamed: 65_level_1,Unnamed: 66_level_1,Unnamed: 67_level_1,Unnamed: 68_level_1,Unnamed: 69_level_1,Unnamed: 70_level_1,Unnamed: 71_level_1,Unnamed: 72_level_1,Unnamed: 73_level_1,Unnamed: 74_level_1,Unnamed: 75_level_1,Unnamed: 76_level_1,Unnamed: 77_level_1,Unnamed: 78_level_1,Unnamed: 79_level_1,Unnamed: 80_level_1,Unnamed: 81_level_1,Unnamed: 82_level_1,Unnamed: 83_level_1,Unnamed: 84_level_1,Unnamed: 85_level_1,Unnamed: 86_level_1,Unnamed: 87_level_1,Unnamed: 88_level_1,Unnamed: 89_level_1,Unnamed: 90_level_1,Unnamed: 91_level_1,Unnamed: 92_level_1,Unnamed: 93_level_1,Unnamed: 94_level_1,Unnamed: 95_level_1,Unnamed: 96_level_1,Unnamed: 97_level_1,Unnamed: 98_level_1,Unnamed: 99_level_1,Unnamed: 100_level_1,Unnamed: 101_level_1,Unnamed: 102_level_1,Unnamed: 103_level_1,Unnamed: 104_level_1,Unnamed: 105_level_1,Unnamed: 106_level_1,Unnamed: 107_level_1,Unnamed: 108_level_1,Unnamed: 109_level_1,Unnamed: 110_level_1,Unnamed: 111_level_1,Unnamed: 112_level_1,Unnamed: 113_level_1,Unnamed: 114_level_1,Unnamed: 115_level_1,Unnamed: 116_level_1,Unnamed: 117_level_1,Unnamed: 118_level_1,Unnamed: 119_level_1,Unnamed: 120_level_1,Unnamed: 121_level_1,Unnamed: 122_level_1,Unnamed: 123_level_1,Unnamed: 124_level_1,Unnamed: 125_level_1,Unnamed: 126_level_1,Unnamed: 127_level_1,Unnamed: 128_level_1,Unnamed: 129_level_1,Unnamed: 130_level_1,Unnamed: 131_level_1,Unnamed: 132_level_1,Unnamed: 133_level_1,Unnamed: 134_level_1,Unnamed: 135_level_1,Unnamed: 136_level_1,Unnamed: 137_level_1,Unnamed: 138_level_1,Unnamed: 139_level_1,Unnamed: 140_level_1,Unnamed: 141_level_1,Unnamed: 142_level_1,Unnamed: 143_level_1,Unnamed: 144_level_1,Unnamed: 145_level_1,Unnamed: 146_level_1,Unnamed: 147_level_1,Unnamed: 148_level_1,Unnamed: 149_level_1,Unnamed: 150_level_1,Unnamed: 151_level_1,Unnamed: 152_level_1,Unnamed: 153_level_1,Unnamed: 154_level_1,Unnamed: 155_level_1,Unnamed: 156_level_1,Unnamed: 157_level_1,Unnamed: 158_level_1,Unnamed: 159_level_1,Unnamed: 160_level_1,Unnamed: 161_level_1,Unnamed: 162_level_1,Unnamed: 163_level_1,Unnamed: 164_level_1,Unnamed: 165_level_1,Unnamed: 166_level_1,Unnamed: 167_level_1,Unnamed: 168_level_1,Unnamed: 169_level_1,Unnamed: 170_level_1,Unnamed: 171_level_1,Unnamed: 172_level_1,Unnamed: 173_level_1,Unnamed: 174_level_1,Unnamed: 175_level_1,Unnamed: 176_level_1,Unnamed: 177_level_1,Unnamed: 178_level_1,Unnamed: 179_level_1,Unnamed: 180_level_1,Unnamed: 181_level_1,Unnamed: 182_level_1,Unnamed: 183_level_1,Unnamed: 184_level_1,Unnamed: 185_level_1,Unnamed: 186_level_1,Unnamed: 187_level_1,Unnamed: 188_level_1,Unnamed: 189_level_1,Unnamed: 190_level_1,Unnamed: 191_level_1,Unnamed: 192_level_1,Unnamed: 193_level_1,Unnamed: 194_level_1,Unnamed: 195_level_1,Unnamed: 196_level_1,Unnamed: 197_level_1,Unnamed: 198_level_1,Unnamed: 199_level_1,Unnamed: 200_level_1,Unnamed: 201_level_1
SEQ_000007,SUBJ_059520,Cheek - pinch skin,6.875854,5.227969,0.779424,1.33437,5.492188,3.613281,9.015625,7.292969,0.671795,0.119323,5.640137,1.708125,0.582177,3.493182,4.683594,-2.019531,6.519531,6.214844,-0.177613,0.188431,4.259521,7.265156,1.247767,2.335664,1.09375,3.125,6.875,9.792969,0.123765,-0.430736,0.230537,0.305862,0.043031,0.073551,0.170898,0.134399,0.327942,0.379272,0.684978,-0.869169,-0.302607,-0.252927,0.058485,0.040159,-0.442871,-0.355164,-0.212952,-0.204163,-0.594216,-0.946573,-0.431194,-0.203823,0.033207,0.199606,-0.478027,-0.447327,-0.364746,0.005066,0.533235,-0.068043,-0.814034,-0.868665,0.023412,0.037632,-0.854187,-0.914856,-0.757935,-0.809753,0.157845,0.328752,28.572991,28.704366,0.381308,0.769848,27.782368,27.69651,29.215113,30.54373,-0.425742,0.732803,30.761249,28.049465,0.485551,3.299713,30.041212,24.558798,31.663404,32.010178,0.29183,0.058794,29.116922,27.885,0.197033,1.664297,28.557234,25.90749,29.396484,30.090014,-1.083624,-0.106085,29.078639,29.305039,0.148042,0.349623,28.819799,28.592863,29.346125,29.76148,0.169882,-0.345342,28.3593,27.443073,0.513923,0.982867,27.72184,26.047148,29.428299,28.815403,0.780503,-0.13629,9.977855,10.07773,0.320601,0.370631,9.170917,9.561136,11.067592,11.140053,0.993389,1.955505,1.000002,0.999998,1.7e-05,1.7e-05,0.999975,0.99996,1.000032,1.000028,0.164118,0.020982,0.000941,0.017588,0.498341,0.454179,-1.896675,-1.275742,0.775677,1.242384,-1.641391,0.070289,-2.188258,0.654901,0.282866,2.965536,-2.649433,-2.878344,-1.557152,4.04641,0.623726,-0.157089,1.644328,0.164465,0.380689,1.835484,1.034723,-2.266716,2.397999,2.614817,0.170443,0.189357,0.038283,-1.420039,0.137916,1.993306,-0.325338,-3.726547,0.280298,1.293926,-0.710368,-0.07975,0.719338,1.861966,0.612231,1.303043,-0.387695,0.18368,1.624285,3.70616,-0.521508,0.14187,71.43735,148.41433,16.201963,52.966776,43.805714,76.70283,104.703883,203.551402,1
SEQ_000008,SUBJ_020948,Forehead - pull hairline,3.449929,3.353906,0.736728,1.346681,1.734375,1.886719,4.949219,5.90625,-0.292417,0.278393,7.261364,3.472433,0.824113,3.638059,4.71875,-0.222656,8.667969,7.976562,-0.761132,0.29531,5.647609,7.462277,1.740919,2.752709,1.722656,3.175781,9.0,11.074219,-0.22578,-0.278592,0.229918,0.256294,0.026109,0.084777,0.185181,0.157593,0.30542,0.34198,0.725468,-0.163603,-0.091453,-0.14137,0.028465,0.052901,-0.160339,-0.263306,-0.050537,-0.079468,-0.668575,-0.427363,-0.432821,-0.257005,0.079182,0.223424,-0.495483,-0.508606,-0.093811,-0.031555,3.356636,-0.156264,-0.862236,-0.887312,0.021701,0.053133,-0.933899,-0.937805,-0.837036,-0.814697,-2.201159,0.211861,32.273003,28.758968,1.175071,2.644869,26.886362,25.985313,32.870808,32.105263,-3.955113,0.294443,32.016294,27.473724,1.988673,3.963966,24.039984,23.907709,33.100945,32.403992,-3.533768,0.32212,30.335959,28.089808,1.261277,3.822748,25.774496,24.414917,31.516409,32.316135,-2.445173,0.154153,31.252343,29.793223,0.365182,0.830051,29.622156,28.755495,31.613327,30.865301,-3.210709,0.273863,26.816013,24.889104,0.895062,0.264488,24.378687,24.181562,28.054575,25.145569,-0.859372,-1.282659,9.993138,10.043379,0.900278,0.324914,7.160294,9.485942,11.575777,11.359967,-0.550851,2.131329,0.999997,0.99999,1.5e-05,1.3e-05,0.999974,0.999966,1.000025,1.000022,0.352029,0.424091,-0.091381,0.00379,1.14321,0.503068,-2.329603,-1.567471,1.677727,1.346165,-0.775943,0.05142,0.25671,1.285243,0.858768,1.383707,-0.368416,-0.417637,3.846796,3.829411,3.315974,-0.139404,1.680335,-0.616083,1.018859,0.966461,-1.734512,-4.542854,2.870394,0.281979,-2.065082,-2.777034,-0.916384,-1.703415,1.089324,3.108427,-4.048893,-4.839048,0.461081,2.746029,-1.252089,0.181023,4.43633,4.904119,0.630283,0.793452,3.4093,4.011812,5.667048,5.865318,0.313367,0.150111,125.667271,174.377748,19.353654,51.496462,109.512,110.835681,214.094595,222.333333,6
SEQ_000013,SUBJ_040282,Cheek - pinch skin,-7.591276,-6.36464,1.143683,1.162073,-9.25,-7.714844,-5.226562,-3.347656,0.395099,1.138648,3.751562,0.513077,0.571292,2.984959,2.65625,-3.273438,4.683594,3.957031,0.011108,-0.231933,-5.49362,-6.818444,0.968642,1.382213,-6.808594,-10.945312,-3.515625,-5.277344,0.279689,-1.467629,0.474384,0.285024,0.044857,0.172503,0.404968,0.061157,0.540771,0.480835,-0.014976,-0.239875,0.254718,0.453091,0.025604,0.232688,0.221008,0.140991,0.304016,0.726501,0.5397,-0.009959,0.836882,0.75306,0.02281,0.119621,0.797424,0.580505,0.875366,0.881653,-0.071165,-0.405698,0.078009,-0.095658,0.023455,0.212688,0.043457,-0.406799,0.105957,0.129761,-0.259239,-0.384956,24.311727,24.797481,0.090862,0.572397,24.181389,24.223915,24.484457,25.634346,0.20386,0.289273,24.06567,24.76044,0.061293,0.787152,23.933413,24.035677,24.194889,26.175961,0.211255,0.62245,24.744966,25.084761,0.207938,0.283411,24.406981,24.587378,25.136721,25.512794,0.051444,0.049188,24.664993,25.277597,0.150427,0.723736,24.419798,24.5168,24.930002,26.452927,0.32837,0.355179,24.62835,24.870243,0.114652,0.693859,24.437956,24.16798,24.858009,26.051331,0.03021,0.685011,10.189387,9.922447,0.750752,0.685781,8.309325,9.007764,11.418236,12.072645,-0.568246,1.604378,0.999995,0.999994,1.9e-05,1.7e-05,0.999965,0.999968,1.000032,1.000034,0.015226,0.552284,-0.037628,0.028173,1.238615,0.942134,-2.902877,-2.750848,2.458852,1.69366,-0.024235,-0.751645,0.246057,0.037041,0.101063,0.312214,-0.001427,-0.541615,0.447063,0.65233,-0.310067,-0.160022,-0.679296,-0.324321,0.191245,0.550552,-1.04072,-0.954321,-0.34141,0.743389,-0.006536,0.518358,0.079973,-0.192837,0.112028,0.545821,-0.140873,-1.216579,0.272602,0.49992,-0.2971,-0.109936,0.036643,0.407354,0.243243,0.287609,-0.330923,-0.020517,0.476324,1.246277,0.279415,1.506933,217.69339,167.620456,8.296971,56.802998,200.689189,84.317269,234.555556,225.155556,1
SEQ_000016,SUBJ_052342,Write name on leg,5.29159,5.818142,0.940102,1.174814,3.859375,3.4375,9.378906,7.925781,2.323306,-0.45196,-4.74483,-3.984954,0.441674,0.494274,-5.71875,-4.875,-3.113281,-2.960938,1.184362,0.267835,-6.674517,1.260995,1.94215,6.764263,-8.078125,-7.734375,4.0625,8.355469,5.405182,-0.48415,0.353268,0.370924,0.032622,0.049561,0.305847,0.277527,0.459045,0.449829,2.043232,-0.369664,-0.826032,-0.604795,0.119764,0.229897,-0.882019,-0.893677,-0.409119,-0.384827,3.084322,-0.394076,-0.303709,-0.122,0.08716,0.16708,-0.356934,-0.368713,0.013916,0.035889,3.02602,-0.390096,-0.219314,-0.545485,0.176431,0.328741,-0.803955,-0.817688,-0.094604,-0.082275,-2.87202,0.443606,34.167663,28.483457,2.089392,3.578696,26.586309,25.413513,36.053188,34.632439,-2.514382,0.576387,34.378147,28.104436,2.68195,3.837888,25.087893,25.018881,36.705894,34.943771,-2.526682,0.627349,31.298491,26.829364,1.665449,3.108468,24.644175,24.128819,33.617542,31.820757,-2.648939,0.546458,34.69445,30.393524,1.245293,3.427506,29.947037,27.227589,35.665222,35.334206,-2.829674,0.536214,32.257148,29.101899,2.67743,3.228711,27.337736,26.312038,35.801083,34.356197,-0.070206,0.535426,9.96768,9.829435,0.655484,0.58353,7.889443,8.166823,12.711292,11.176389,1.344158,-1.161512,1.000002,1.000005,1.8e-05,2e-05,0.999966,0.999964,1.000035,1.000036,-0.282396,-0.2113,-0.052604,0.009934,1.089846,0.608893,-4.821848,-1.580443,3.258967,1.766627,-1.792633,0.313984,-0.210485,0.379021,0.870409,0.450899,-1.005894,-0.311333,2.326494,1.828487,1.332841,1.625716,3.079657,1.275072,2.050599,0.840471,-1.3825,-0.35693,5.200777,3.123014,-0.864482,0.356835,-3.395959,-3.56416,0.90236,0.616531,-5.302862,-4.877464,-1.664322,-1.622339,0.007491,0.561658,2.437302,1.291625,2.793459,0.63414,-1.181065,0.440891,6.390553,3.293949,0.074523,1.538646,14.492667,72.325497,13.559525,53.780371,8.534615,5.501859,80.681319,173.119048,17
SEQ_000018,SUBJ_032165,Forehead - pull hairline,6.235639,3.881445,0.241211,1.901476,5.71875,1.964844,6.832031,6.332031,0.194196,0.300721,6.068704,0.779492,0.211266,4.005993,5.566406,-3.164062,6.449219,6.71875,-0.492244,0.461042,4.73989,7.972266,0.455216,2.25171,4.148438,4.261719,6.140625,9.933594,1.217216,-0.93485,0.841427,0.889304,0.010841,0.039327,0.828247,0.831604,0.865784,0.925049,0.666296,-0.697839,0.272425,0.01604,0.014179,0.211085,0.237183,-0.184204,0.305542,0.305481,-0.044626,0.472734,-0.431091,-0.218021,0.016419,0.17774,-0.457458,-0.435425,-0.398071,-0.022644,0.255425,-0.23594,-0.176968,-0.283978,0.007988,0.085287,-0.193359,-0.367676,-0.163818,-0.159668,0.057725,0.557964,29.581021,27.75201,0.38058,1.088661,28.659307,26.533083,30.267483,29.387503,-0.32283,0.184509,30.334161,27.916263,0.42664,1.866541,29.461178,25.795074,31.035217,31.007061,-0.364101,0.157125,27.559514,26.205623,0.381086,1.028559,26.780855,25.12772,28.381996,28.468761,0.268104,0.728573,28.062466,27.466412,0.185449,0.473282,27.696522,26.827133,28.400864,28.368055,0.080755,0.398008,31.859936,29.576891,0.12234,1.394618,31.598164,28.282324,32.180752,31.793083,0.008231,0.623453,9.92132,10.129568,0.220616,0.354229,9.585841,9.520368,10.866563,11.021747,2.581684,0.795032,1.000003,1.000008,1.8e-05,1.8e-05,0.999966,0.99997,1.000041,1.000036,0.322292,-0.389912,0.020273,0.032846,0.302133,0.351351,-0.424762,-0.653291,1.222857,0.677167,1.903539,0.175134,-0.75314,-0.164253,0.164019,0.855464,-1.1005,-1.661255,-0.52652,0.765812,-0.605604,-0.266751,2.774647,1.71064,0.207158,1.06177,2.523352,0.611027,3.235443,3.619762,0.84374,0.324464,-0.502952,-1.260789,0.217262,0.587334,-0.975103,-1.882231,0.006832,0.100706,0.3183,0.868645,-3.797471,-2.110478,0.178475,1.25932,-4.114485,-3.99962,-3.438377,-0.220291,0.558554,-0.428609,118.69074,166.242279,4.706987,38.990861,108.169291,106.621514,126.880952,204.458333,6


In [27]:
FEATURE_WAVE = "Wave 2"
N_SPLITS = 5
SEED = 42

# --- Prepare data for CV ---
X = features_df.drop(columns=['subject', 'gesture', 'gesture_encoded'])
y = features_df['gesture_encoded']
groups = features_df['subject']

In [28]:
# --- Cross-Validation Loop ---
print("="*50)
print(f"Starting Multi-Model CPU-Based Training on {FEATURE_WAVE} Features")
print("="*50)

for model_name, config in model_configs.items():
    EXPERIMENT_NAME = f"{FEATURE_WAVE}-{model_name}-CPU"
    fold_scores = []
    
    print(f"\n---> Training Model: {model_name} <---")
    cv = StratifiedGroupKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
    
    for fold, (train_idx, val_idx) in enumerate(cv.split(X, y, groups)):
        print(f"--- Fold {fold+1}/{N_SPLITS} ---")
        X_train, y_train = X.iloc[train_idx], y.iloc[train_idx]
        X_val, y_val = X.iloc[val_idx], y.iloc[val_idx]
        model = config['model'](**config['params'])
        if model_name == 'LightGBM':
            model.fit(X_train, y_train, eval_set=[(X_val, y_val)], callbacks=[lgb.early_stopping(100, verbose=False)])
        elif model_name == 'CatBoost':
            model.fit(X_train, y_train, eval_set=[(X_val, y_val)], early_stopping_rounds=100, verbose=0)
        elif model_name == 'XGBoost':
            model.fit(X_train, y_train, verbose=False)
        val_preds = model.predict_proba(X_val)
        fold_score = balanced_accuracy_score(y_val, np.argmax(val_preds, axis=1))
        fold_scores.append(fold_score)
        print(f"Fold {fold+1} Balanced Accuracy: {fold_score:.5f}")
    
    mean_cv_score = np.mean(fold_scores)
    print(f"\n--- CV Summary for {model_name} ---")
    print(f"Mean Balanced Accuracy: {mean_cv_score:.5f}")
    print(f"Std Dev: {np.std(fold_scores):.5f}\n")
    
    tracker.log_experiment(
        experiment_name=EXPERIMENT_NAME, model_name=model_name, feature_wave=FEATURE_WAVE,
        cv_score=mean_cv_score, params=config['params'],
        notes="Added IMU (mag, jerk) and Thermopile (gradient) features before phase aggregation."
    )

Starting Multi-Model CPU-Based Training on Wave 2 Features

---> Training Model: LightGBM <---
--- Fold 1/5 ---
Fold 1 Balanced Accuracy: 0.69788
--- Fold 2/5 ---
Fold 2 Balanced Accuracy: 0.63267
--- Fold 3/5 ---
Fold 3 Balanced Accuracy: 0.66461
--- Fold 4/5 ---
Fold 4 Balanced Accuracy: 0.59172
--- Fold 5/5 ---
Fold 5 Balanced Accuracy: 0.60961

--- CV Summary for LightGBM ---
Mean Balanced Accuracy: 0.63930
Std Dev: 0.03811

Experiment 'Wave 2-LightGBM-CPU' logged to ~/code/kaggle/kaggle-cmi-detect-behavior/experiment_log.csv


0,1
cv_score,▁

0,1
cv_score,0.6393
feature_wave,Wave 2
model_name,LightGBM


Experiment 'Wave 2-LightGBM-CPU' logged to W&B

---> Training Model: XGBoost <---
--- Fold 1/5 ---
Fold 1 Balanced Accuracy: 0.70817
--- Fold 2/5 ---
Fold 2 Balanced Accuracy: 0.64598
--- Fold 3/5 ---
Fold 3 Balanced Accuracy: 0.66839
--- Fold 4/5 ---
Fold 4 Balanced Accuracy: 0.61762
--- Fold 5/5 ---
Fold 5 Balanced Accuracy: 0.62309

--- CV Summary for XGBoost ---
Mean Balanced Accuracy: 0.65265
Std Dev: 0.03309

Experiment 'Wave 2-XGBoost-CPU' logged to ~/code/kaggle/kaggle-cmi-detect-behavior/experiment_log.csv


0,1
cv_score,▁

0,1
cv_score,0.65265
feature_wave,Wave 2
model_name,XGBoost


Experiment 'Wave 2-XGBoost-CPU' logged to W&B

---> Training Model: CatBoost <---
--- Fold 1/5 ---
Fold 1 Balanced Accuracy: 0.73970
--- Fold 2/5 ---
Fold 2 Balanced Accuracy: 0.65958
--- Fold 3/5 ---
Fold 3 Balanced Accuracy: 0.65665
--- Fold 4/5 ---
Fold 4 Balanced Accuracy: 0.62008
--- Fold 5/5 ---
Fold 5 Balanced Accuracy: 0.61642

--- CV Summary for CatBoost ---
Mean Balanced Accuracy: 0.65849
Std Dev: 0.04437

Experiment 'Wave 2-CatBoost-CPU' logged to ~/code/kaggle/kaggle-cmi-detect-behavior/experiment_log.csv


0,1
cv_score,▁

0,1
cv_score,0.65849
feature_wave,Wave 2
model_name,CatBoost


Experiment 'Wave 2-CatBoost-CPU' logged to W&B
