# Feature Pruning and Hyper Parameter Tuning

**`HYPOTHESIS`** By removing "dead-weight" features, the model will be faster and simpler. The removal of noise may lead to a slight increase in CV score and a potential increase in stability - measured by std. deviation

## Setup

In [1]:
#### TEMPORARY
import sys
sys.path.append('/home/bac/code/kaggle/kaggle-cmi-detect-behavior/')

In [2]:
import pandas as pd
import numpy as np
import os
import catboost as cat
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import f1_score, make_scorer
from sklearn.inspection import permutation_importance
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns
import optuna
import warnings

warnings.filterwarnings('ignore')
pd.set_option('display.max_columns', None)
sns.set_style('whitegrid')

# --- Pathing and Experiment Tracking Setup ---
from src.tracking import ExperimentTracker
from src.config import DATA_PATH, PROJECT_PATH, USE_WANDB, WANDB_PROJECT, WANDB_ENTITY

tracker = ExperimentTracker(
    project_path=os.path.expanduser(PROJECT_PATH),
    use_wandb=USE_WANDB,
    wandb_project_name=WANDB_PROJECT,
    wandb_entity=WANDB_ENTITY
)

SEED = 42

In [3]:
def display_all(df):
    with pd.option_context("display.max_rows", 500, "display.max_columns", 1000):
        display(df)

In [4]:
# Load the raw data and create helper maps
train_sensor = pd.read_csv(os.path.join(os.path.expanduser(DATA_PATH), 'train.csv'))
train_demos = pd.read_csv(os.path.join(os.path.expanduser(DATA_PATH), 'train_demographics.csv'))
train_df = pd.merge(train_sensor, train_demos, on='subject', how='left')
metadata = train_df[['gesture', 'sequence_type']].drop_duplicates()
gesture_to_seq_type_map = metadata.set_index('gesture')['sequence_type'].to_dict()
gesture_map = {label: i for i, label in enumerate(metadata['gesture'].unique())}
inv_gesture_map = {i: label for label, i in gesture_map.items()}

In [5]:
# Load features feather file and permutation importance csv
features_w4_df = pd.read_feather("outputs/features_4n.fea")
perm_importance_df = pd.read_csv("outputs/permutation_importance_W4n.csv")

features_w4_df.head()

Unnamed: 0,sequence_id,subject,gesture,adult_child,age,sex,handedness,height_cm,shoulder_to_wrist_cm,elbow_to_wrist_cm,acc_x_mean_Gesture,acc_x_mean_Transition,acc_x_std_Gesture,acc_x_std_Transition,acc_x_min_Gesture,acc_x_min_Transition,acc_x_max_Gesture,acc_x_max_Transition,acc_x_skew_Gesture,acc_x_skew_Transition,acc_y_mean_Gesture,acc_y_mean_Transition,acc_y_std_Gesture,acc_y_std_Transition,acc_y_min_Gesture,acc_y_min_Transition,acc_y_max_Gesture,acc_y_max_Transition,acc_y_skew_Gesture,acc_y_skew_Transition,acc_z_mean_Gesture,acc_z_mean_Transition,acc_z_std_Gesture,acc_z_std_Transition,acc_z_min_Gesture,acc_z_min_Transition,acc_z_max_Gesture,acc_z_max_Transition,acc_z_skew_Gesture,acc_z_skew_Transition,rot_w_mean_Gesture,rot_w_mean_Transition,rot_w_std_Gesture,rot_w_std_Transition,rot_w_min_Gesture,rot_w_min_Transition,rot_w_max_Gesture,rot_w_max_Transition,rot_w_skew_Gesture,rot_w_skew_Transition,rot_x_mean_Gesture,rot_x_mean_Transition,rot_x_std_Gesture,rot_x_std_Transition,rot_x_min_Gesture,rot_x_min_Transition,rot_x_max_Gesture,rot_x_max_Transition,rot_x_skew_Gesture,rot_x_skew_Transition,rot_y_mean_Gesture,rot_y_mean_Transition,rot_y_std_Gesture,rot_y_std_Transition,rot_y_min_Gesture,rot_y_min_Transition,rot_y_max_Gesture,rot_y_max_Transition,rot_y_skew_Gesture,rot_y_skew_Transition,rot_z_mean_Gesture,rot_z_mean_Transition,rot_z_std_Gesture,rot_z_std_Transition,rot_z_min_Gesture,rot_z_min_Transition,rot_z_max_Gesture,rot_z_max_Transition,rot_z_skew_Gesture,rot_z_skew_Transition,thm_1_mean_Gesture,thm_1_mean_Transition,thm_1_std_Gesture,thm_1_std_Transition,thm_1_min_Gesture,thm_1_min_Transition,thm_1_max_Gesture,thm_1_max_Transition,thm_1_skew_Gesture,thm_1_skew_Transition,thm_2_mean_Gesture,thm_2_mean_Transition,thm_2_std_Gesture,thm_2_std_Transition,thm_2_min_Gesture,thm_2_min_Transition,thm_2_max_Gesture,thm_2_max_Transition,thm_2_skew_Gesture,thm_2_skew_Transition,thm_3_mean_Gesture,thm_3_mean_Transition,thm_3_std_Gesture,thm_3_std_Transition,thm_3_min_Gesture,thm_3_min_Transition,thm_3_max_Gesture,thm_3_max_Transition,thm_3_skew_Gesture,thm_3_skew_Transition,thm_4_mean_Gesture,thm_4_mean_Transition,thm_4_std_Gesture,thm_4_std_Transition,thm_4_min_Gesture,thm_4_min_Transition,thm_4_max_Gesture,thm_4_max_Transition,thm_4_skew_Gesture,thm_4_skew_Transition,thm_5_mean_Gesture,thm_5_mean_Transition,thm_5_std_Gesture,thm_5_std_Transition,thm_5_min_Gesture,thm_5_min_Transition,thm_5_max_Gesture,thm_5_max_Transition,thm_5_skew_Gesture,thm_5_skew_Transition,acc_mag_mean_Gesture,acc_mag_mean_Transition,acc_mag_std_Gesture,acc_mag_std_Transition,acc_mag_min_Gesture,acc_mag_min_Transition,acc_mag_max_Gesture,acc_mag_max_Transition,acc_mag_skew_Gesture,acc_mag_skew_Transition,rot_mag_mean_Gesture,rot_mag_mean_Transition,rot_mag_std_Gesture,rot_mag_std_Transition,rot_mag_min_Gesture,rot_mag_min_Transition,rot_mag_max_Gesture,rot_mag_max_Transition,rot_mag_skew_Gesture,rot_mag_skew_Transition,jerk_mean_Gesture,jerk_mean_Transition,jerk_std_Gesture,jerk_std_Transition,jerk_min_Gesture,jerk_min_Transition,jerk_max_Gesture,jerk_max_Transition,jerk_skew_Gesture,jerk_skew_Transition,thm_grad_1_2_mean_Gesture,thm_grad_1_2_mean_Transition,thm_grad_1_2_std_Gesture,thm_grad_1_2_std_Transition,thm_grad_1_2_min_Gesture,thm_grad_1_2_min_Transition,thm_grad_1_2_max_Gesture,thm_grad_1_2_max_Transition,thm_grad_1_2_skew_Gesture,thm_grad_1_2_skew_Transition,thm_grad_2_3_mean_Gesture,thm_grad_2_3_mean_Transition,thm_grad_2_3_std_Gesture,thm_grad_2_3_std_Transition,thm_grad_2_3_min_Gesture,thm_grad_2_3_min_Transition,thm_grad_2_3_max_Gesture,thm_grad_2_3_max_Transition,thm_grad_2_3_skew_Gesture,thm_grad_2_3_skew_Transition,thm_grad_3_4_mean_Gesture,thm_grad_3_4_mean_Transition,thm_grad_3_4_std_Gesture,thm_grad_3_4_std_Transition,thm_grad_3_4_min_Gesture,thm_grad_3_4_min_Transition,thm_grad_3_4_max_Gesture,thm_grad_3_4_max_Transition,thm_grad_3_4_skew_Gesture,thm_grad_3_4_skew_Transition,thm_grad_4_5_mean_Gesture,thm_grad_4_5_mean_Transition,thm_grad_4_5_std_Gesture,thm_grad_4_5_std_Transition,thm_grad_4_5_min_Gesture,thm_grad_4_5_min_Transition,thm_grad_4_5_max_Gesture,thm_grad_4_5_max_Transition,thm_grad_4_5_skew_Gesture,thm_grad_4_5_skew_Transition,tof_invalid_pct_mean_Gesture,tof_invalid_pct_mean_Transition,tof_invalid_pct_std_Gesture,tof_invalid_pct_std_Transition,tof_invalid_pct_min_Gesture,tof_invalid_pct_min_Transition,tof_invalid_pct_max_Gesture,tof_invalid_pct_max_Transition,tof_invalid_pct_skew_Gesture,tof_invalid_pct_skew_Transition,tof_pca_0_mean_Gesture,tof_pca_0_mean_Transition,tof_pca_0_std_Gesture,tof_pca_0_std_Transition,tof_pca_0_min_Gesture,tof_pca_0_min_Transition,tof_pca_0_max_Gesture,tof_pca_0_max_Transition,tof_pca_0_skew_Gesture,tof_pca_0_skew_Transition,tof_pca_1_mean_Gesture,tof_pca_1_mean_Transition,tof_pca_1_std_Gesture,tof_pca_1_std_Transition,tof_pca_1_min_Gesture,tof_pca_1_min_Transition,tof_pca_1_max_Gesture,tof_pca_1_max_Transition,tof_pca_1_skew_Gesture,tof_pca_1_skew_Transition,tof_pca_2_mean_Gesture,tof_pca_2_mean_Transition,tof_pca_2_std_Gesture,tof_pca_2_std_Transition,tof_pca_2_min_Gesture,tof_pca_2_min_Transition,tof_pca_2_max_Gesture,tof_pca_2_max_Transition,tof_pca_2_skew_Gesture,tof_pca_2_skew_Transition,tof_pca_3_mean_Gesture,tof_pca_3_mean_Transition,tof_pca_3_std_Gesture,tof_pca_3_std_Transition,tof_pca_3_min_Gesture,tof_pca_3_min_Transition,tof_pca_3_max_Gesture,tof_pca_3_max_Transition,tof_pca_3_skew_Gesture,tof_pca_3_skew_Transition,tof_pca_4_mean_Gesture,tof_pca_4_mean_Transition,tof_pca_4_std_Gesture,tof_pca_4_std_Transition,tof_pca_4_min_Gesture,tof_pca_4_min_Transition,tof_pca_4_max_Gesture,tof_pca_4_max_Transition,tof_pca_4_skew_Gesture,tof_pca_4_skew_Transition,tof_pca_5_mean_Gesture,tof_pca_5_mean_Transition,tof_pca_5_std_Gesture,tof_pca_5_std_Transition,tof_pca_5_min_Gesture,tof_pca_5_min_Transition,tof_pca_5_max_Gesture,tof_pca_5_max_Transition,tof_pca_5_skew_Gesture,tof_pca_5_skew_Transition,tof_pca_6_mean_Gesture,tof_pca_6_mean_Transition,tof_pca_6_std_Gesture,tof_pca_6_std_Transition,tof_pca_6_min_Gesture,tof_pca_6_min_Transition,tof_pca_6_max_Gesture,tof_pca_6_max_Transition,tof_pca_6_skew_Gesture,tof_pca_6_skew_Transition,tof_pca_7_mean_Gesture,tof_pca_7_mean_Transition,tof_pca_7_std_Gesture,tof_pca_7_std_Transition,tof_pca_7_min_Gesture,tof_pca_7_min_Transition,tof_pca_7_max_Gesture,tof_pca_7_max_Transition,tof_pca_7_skew_Gesture,tof_pca_7_skew_Transition,tof_pca_8_mean_Gesture,tof_pca_8_mean_Transition,tof_pca_8_std_Gesture,tof_pca_8_std_Transition,tof_pca_8_min_Gesture,tof_pca_8_min_Transition,tof_pca_8_max_Gesture,tof_pca_8_max_Transition,tof_pca_8_skew_Gesture,tof_pca_8_skew_Transition,tof_pca_9_mean_Gesture,tof_pca_9_mean_Transition,tof_pca_9_std_Gesture,tof_pca_9_std_Transition,tof_pca_9_min_Gesture,tof_pca_9_min_Transition,tof_pca_9_max_Gesture,tof_pca_9_max_Transition,tof_pca_9_skew_Gesture,tof_pca_9_skew_Transition,acc_mag_mean_Gesture_div_age,acc_mag_mean_Gesture_mul_age,acc_mag_mean_Gesture_div_height_cm,acc_mag_mean_Gesture_mul_height_cm,acc_mag_mean_Gesture_div_shoulder_to_wrist_cm,acc_mag_mean_Gesture_mul_shoulder_to_wrist_cm,acc_mag_std_Gesture_div_age,acc_mag_std_Gesture_mul_age,acc_mag_std_Gesture_div_height_cm,acc_mag_std_Gesture_mul_height_cm,acc_mag_std_Gesture_div_shoulder_to_wrist_cm,acc_mag_std_Gesture_mul_shoulder_to_wrist_cm,jerk_mean_Gesture_div_age,jerk_mean_Gesture_mul_age,jerk_mean_Gesture_div_height_cm,jerk_mean_Gesture_mul_height_cm,jerk_mean_Gesture_div_shoulder_to_wrist_cm,jerk_mean_Gesture_mul_shoulder_to_wrist_cm,jerk_std_Gesture_div_age,jerk_std_Gesture_mul_age,jerk_std_Gesture_div_height_cm,jerk_std_Gesture_mul_height_cm,jerk_std_Gesture_div_shoulder_to_wrist_cm,jerk_std_Gesture_mul_shoulder_to_wrist_cm,tof_pca_0_mean_Gesture_div_age,tof_pca_0_mean_Gesture_mul_age,tof_pca_0_mean_Gesture_div_height_cm,tof_pca_0_mean_Gesture_mul_height_cm,tof_pca_0_mean_Gesture_div_shoulder_to_wrist_cm,tof_pca_0_mean_Gesture_mul_shoulder_to_wrist_cm,tof_invalid_pct_mean_Gesture_div_age,tof_invalid_pct_mean_Gesture_mul_age,tof_invalid_pct_mean_Gesture_div_height_cm,tof_invalid_pct_mean_Gesture_mul_height_cm,tof_invalid_pct_mean_Gesture_div_shoulder_to_wrist_cm,tof_invalid_pct_mean_Gesture_mul_shoulder_to_wrist_cm,gesture_encoded
0,SEQ_000007,SUBJ_059520,Cheek - pinch skin,0,12,1,1,163.0,52,24.0,6.875854,5.227969,0.779424,1.33437,5.492188,3.613281,9.015625,7.292969,0.671795,0.119323,5.640137,1.708125,0.582177,3.493182,4.683594,-2.019531,6.519531,6.214844,-0.177613,0.188431,4.259521,7.265156,1.247767,2.335664,1.09375,3.125,6.875,9.792969,0.123765,-0.430736,0.230537,0.305862,0.043031,0.073551,0.170898,0.134399,0.327942,0.379272,0.684978,-0.869169,-0.302607,-0.252927,0.058485,0.040159,-0.442871,-0.355164,-0.212952,-0.204163,-0.594216,-0.946573,-0.431194,-0.203823,0.033207,0.199606,-0.478027,-0.447327,-0.364746,0.005066,0.533235,-0.068043,-0.814034,-0.868665,0.023412,0.037632,-0.854187,-0.914856,-0.757935,-0.809753,0.157845,0.328752,28.572991,28.704366,0.381308,0.769848,27.782368,27.69651,29.215113,30.54373,-0.425742,0.732803,30.761249,28.049465,0.485551,3.299713,30.041212,24.558798,31.663404,32.010178,0.29183,0.058794,29.116922,27.885,0.197033,1.664297,28.557234,25.90749,29.396484,30.090014,-1.083624,-0.106085,29.078639,29.305039,0.148042,0.349623,28.819799,28.592863,29.346125,29.76148,0.169882,-0.345342,28.3593,27.443073,0.513923,0.982867,27.72184,26.047148,29.428299,28.815403,0.780503,-0.13629,9.977855,10.07773,0.320601,0.370631,9.170917,9.561136,11.067592,11.140053,0.993389,1.955505,1.000002,0.999998,1.7e-05,1.7e-05,0.999975,0.99996,1.000032,1.000028,0.164118,0.020982,0.000941,0.017588,0.498341,0.454179,-1.896675,-1.275742,0.775677,1.242384,-1.641391,0.070289,-2.188258,0.654901,0.282866,2.965536,-2.649433,-2.878344,-1.557152,4.04641,0.623726,-0.157089,1.644328,0.164465,0.380689,1.835484,1.034723,-2.266716,2.397999,2.614817,0.170443,0.189357,0.038283,-1.420039,0.137916,1.993306,-0.325338,-3.726547,0.280298,1.293926,-0.710368,-0.07975,0.719338,1.861966,0.612231,1.303043,-0.387695,0.18368,1.624285,3.70616,-0.521508,0.14187,0.37168,0.47975,0.047432,0.167239,0.28125,0.23125,0.453125,0.684375,-0.256594,0.091825,-22.607758,543.027036,241.834416,323.473134,-384.95863,53.941335,436.71827,1200.546304,0.560053,-0.046242,52.458076,162.377648,60.708098,211.132096,-79.900672,-153.878832,153.915369,430.489704,-0.450427,0.045137,-144.839784,315.982696,74.499527,461.571761,-249.215313,-217.515079,24.56842,861.657139,0.835621,0.05015,99.929143,-147.750075,49.515569,255.271966,10.248764,-460.565167,198.986008,145.085479,-0.135408,-0.046852,236.865309,200.322996,54.023843,70.797483,128.589022,12.469457,306.145225,283.695181,-0.805197,-0.780005,-98.249872,-150.78804,74.96971,217.958795,-234.949524,-471.2518,70.733944,264.743052,0.708392,0.373321,-32.102086,10.531423,64.981824,75.152808,-215.542898,-198.907051,82.282031,179.908747,-1.236574,-0.401348,-6.357957,-77.120343,46.66483,97.363443,-120.826869,-377.375667,74.373146,57.334558,-0.233751,-1.61597,-55.193642,-19.757805,97.921693,64.83769,-204.706496,-186.57534,117.340559,106.709203,0.450027,-0.171208,-47.608941,-121.93904,99.625825,50.821,-124.599026,-250.524568,219.235456,-30.747318,1.715353,-1.073033,0.831488,119.734261,0.061214,1626.390378,0.191882,518.848464,0.026717,3.847215,0.001967,52.258001,0.006165,16.671264,7.8e-05,0.01129,6e-06,0.153358,1.8e-05,0.048924,0.041528,5.980093,0.003057,81.229598,0.009583,25.913737,-1.88398,-271.293096,-0.138698,-3685.064553,-0.434765,-1175.603416,0.030973,4.460156,0.00228,60.583789,0.007148,19.327344,0
1,SEQ_000008,SUBJ_020948,Forehead - pull hairline,1,24,1,1,173.0,49,26.0,3.449929,3.353906,0.736728,1.346681,1.734375,1.886719,4.949219,5.90625,-0.292417,0.278393,7.261364,3.472433,0.824113,3.638059,4.71875,-0.222656,8.667969,7.976562,-0.761132,0.29531,5.647609,7.462277,1.740919,2.752709,1.722656,3.175781,9.0,11.074219,-0.22578,-0.278592,0.229918,0.256294,0.026109,0.084777,0.185181,0.157593,0.30542,0.34198,0.725468,-0.163603,-0.091453,-0.14137,0.028465,0.052901,-0.160339,-0.263306,-0.050537,-0.079468,-0.668575,-0.427363,-0.432821,-0.257005,0.079182,0.223424,-0.495483,-0.508606,-0.093811,-0.031555,3.356636,-0.156264,-0.862236,-0.887312,0.021701,0.053133,-0.933899,-0.937805,-0.837036,-0.814697,-2.201159,0.211861,32.273003,28.758968,1.175071,2.644869,26.886362,25.985313,32.870808,32.105263,-3.955113,0.294443,32.016294,27.473724,1.988673,3.963966,24.039984,23.907709,33.100945,32.403992,-3.533768,0.32212,30.335959,28.089808,1.261277,3.822748,25.774496,24.414917,31.516409,32.316135,-2.445173,0.154153,31.252343,29.793223,0.365182,0.830051,29.622156,28.755495,31.613327,30.865301,-3.210709,0.273863,26.816013,24.889104,0.895062,0.264488,24.378687,24.181562,28.054575,25.145569,-0.859372,-1.282659,9.993138,10.043379,0.900278,0.324914,7.160294,9.485942,11.575777,11.359967,-0.550851,2.131329,0.999997,0.99999,1.5e-05,1.3e-05,0.999974,0.999966,1.000025,1.000022,0.352029,0.424091,-0.091381,0.00379,1.14321,0.503068,-2.329603,-1.567471,1.677727,1.346165,-0.775943,0.05142,0.25671,1.285243,0.858768,1.383707,-0.368416,-0.417637,3.846796,3.829411,3.315974,-0.139404,1.680335,-0.616083,1.018859,0.966461,-1.734512,-4.542854,2.870394,0.281979,-2.065082,-2.777034,-0.916384,-1.703415,1.089324,3.108427,-4.048893,-4.839048,0.461081,2.746029,-1.252089,0.181023,4.43633,4.904119,0.630283,0.793452,3.4093,4.011812,5.667048,5.865318,0.313367,0.150111,0.269508,0.636518,0.095645,0.274283,0.203125,0.309375,0.76875,0.928125,4.723081,-0.192509,962.369093,183.828937,155.828401,510.54826,297.661368,-470.607221,1245.514052,802.297193,-2.245366,0.035907,-231.601257,-185.054577,135.856006,433.613544,-478.276002,-927.49554,25.564866,303.060156,-0.046413,-0.180489,237.804746,371.908914,81.717635,138.408061,159.760266,49.852018,648.391508,577.209804,4.116503,-0.160388,-292.886214,-238.515622,94.946252,137.854018,-593.972239,-583.730319,-98.357729,98.661902,-0.959556,-0.009018,-193.017296,15.207375,79.185715,145.420677,-307.384447,-182.255156,76.576567,189.779336,2.056715,-0.275894,112.318938,64.094566,68.660825,103.992361,-167.533609,-218.902711,217.304143,339.792856,-1.989233,-0.058425,-137.281659,6.725439,66.178269,45.827349,-244.018094,-147.436902,81.995597,130.567031,1.026228,-0.143645,13.40464,-8.424366,78.238608,99.656513,-282.18293,-359.071366,218.451954,113.687761,-1.10574,-2.500468,-104.822739,-5.642818,78.536712,180.001712,-366.934511,-362.977023,-28.428955,230.174779,-1.937548,-0.177006,-81.096235,-143.535153,65.276093,59.719988,-325.369799,-353.58314,38.881442,110.48803,-1.484493,1.052192,0.416381,239.835322,0.057764,1728.812944,0.203942,489.663782,0.037512,21.606667,0.005204,155.748057,0.018373,44.113611,-0.003808,-2.193132,-0.000528,-15.808827,-0.001865,-4.477645,0.047634,27.437037,0.006608,197.775309,0.023331,56.017284,40.098711,23096.858235,5.562827,166489.853114,19.640185,47156.085564,0.011229,6.468182,0.001558,46.624811,0.0055,13.205871,1
2,SEQ_000013,SUBJ_040282,Cheek - pinch skin,0,12,1,1,157.0,44,26.0,-7.591276,-6.36464,1.143683,1.162073,-9.25,-7.714844,-5.226562,-3.347656,0.395099,1.138648,3.751562,0.513077,0.571292,2.984959,2.65625,-3.273438,4.683594,3.957031,0.011108,-0.231933,-5.49362,-6.818444,0.968642,1.382213,-6.808594,-10.945312,-3.515625,-5.277344,0.279689,-1.467629,0.474384,0.285024,0.044857,0.172503,0.404968,0.061157,0.540771,0.480835,-0.014976,-0.239875,0.254718,0.453091,0.025604,0.232688,0.221008,0.140991,0.304016,0.726501,0.5397,-0.009959,0.836882,0.75306,0.02281,0.119621,0.797424,0.580505,0.875366,0.881653,-0.071165,-0.405698,0.078009,-0.095658,0.023455,0.212688,0.043457,-0.406799,0.105957,0.129761,-0.259239,-0.384956,24.311727,24.797481,0.090862,0.572397,24.181389,24.223915,24.484457,25.634346,0.20386,0.289273,24.06567,24.76044,0.061293,0.787152,23.933413,24.035677,24.194889,26.175961,0.211255,0.62245,24.744966,25.084761,0.207938,0.283411,24.406981,24.587378,25.136721,25.512794,0.051444,0.049188,24.664993,25.277597,0.150427,0.723736,24.419798,24.5168,24.930002,26.452927,0.32837,0.355179,24.62835,24.870243,0.114652,0.693859,24.437956,24.16798,24.858009,26.051331,0.03021,0.685011,10.189387,9.922447,0.750752,0.685781,8.309325,9.007764,11.418236,12.072645,-0.568246,1.604378,0.999995,0.999994,1.9e-05,1.7e-05,0.999965,0.999968,1.000032,1.000034,0.015226,0.552284,-0.037628,0.028173,1.238615,0.942134,-2.902877,-2.750848,2.458852,1.69366,-0.024235,-0.751645,0.246057,0.037041,0.101063,0.312214,-0.001427,-0.541615,0.447063,0.65233,-0.310067,-0.160022,-0.679296,-0.324321,0.191245,0.550552,-1.04072,-0.954321,-0.34141,0.743389,-0.006536,0.518358,0.079973,-0.192837,0.112028,0.545821,-0.140873,-1.216579,0.272602,0.49992,-0.2971,-0.109936,0.036643,0.407354,0.243243,0.287609,-0.330923,-0.020517,0.476324,1.246277,0.279415,1.506933,0.887292,0.551495,0.079218,0.33337,0.76875,0.15,0.975,0.915625,-0.328373,-0.071184,-386.201317,246.614082,309.114367,657.013661,-714.557738,-524.018765,70.079935,1426.695253,0.354448,0.30464,-526.798589,-302.76867,312.123331,506.363391,-967.185402,-1155.603005,-160.581616,442.584322,-0.178211,0.081631,224.246225,199.922423,168.782378,187.448466,14.959682,-234.387395,487.573283,562.834494,0.169694,-0.475579,364.715763,268.958663,244.660771,234.408394,46.501682,-327.304973,705.686603,521.115531,-0.021498,-1.023586,-170.624795,24.254217,61.778892,276.568305,-273.038502,-388.40407,-91.879904,384.009092,0.019431,0.267777,-177.697891,-37.070773,55.263405,182.298385,-275.05142,-292.181855,-112.228152,249.59464,-0.177806,0.461914,-189.03962,-122.635751,94.796652,187.40294,-376.110181,-360.211048,-48.224443,270.170687,-0.460681,0.673649,101.306468,157.300616,75.6102,137.620268,-84.982882,-290.506158,172.811201,368.919687,-1.005489,-1.64594,142.088005,72.737102,56.772362,164.109345,27.482126,-267.615522,295.615107,274.991692,0.338005,-0.830583,-75.708389,-109.362959,78.44354,111.799101,-201.581575,-340.622004,124.489261,203.674609,0.68423,1.299255,0.849116,122.272644,0.064901,1599.733758,0.231577,448.333028,0.062563,9.009029,0.004782,117.868131,0.017063,33.033107,-0.003136,-0.45153,-0.00024,-5.907524,-0.000855,-1.655612,0.103218,14.863381,0.007889,194.462564,0.02815,54.499062,-32.18344,-4634.415799,-2.459881,-60633.606706,-8.777302,-16992.85793,0.073941,10.6475,0.005652,139.304792,0.020166,39.040833,0
3,SEQ_000016,SUBJ_052342,Write name on leg,0,13,0,1,171.0,54,26.0,5.29159,5.818142,0.940102,1.174814,3.859375,3.4375,9.378906,7.925781,2.323306,-0.45196,-4.74483,-3.984954,0.441674,0.494274,-5.71875,-4.875,-3.113281,-2.960938,1.184362,0.267835,-6.674517,1.260995,1.94215,6.764263,-8.078125,-7.734375,4.0625,8.355469,5.405182,-0.48415,0.353268,0.370924,0.032622,0.049561,0.305847,0.277527,0.459045,0.449829,2.043232,-0.369664,-0.826032,-0.604795,0.119764,0.229897,-0.882019,-0.893677,-0.409119,-0.384827,3.084322,-0.394076,-0.303709,-0.122,0.08716,0.16708,-0.356934,-0.368713,0.013916,0.035889,3.02602,-0.390096,-0.219314,-0.545485,0.176431,0.328741,-0.803955,-0.817688,-0.094604,-0.082275,-2.87202,0.443606,34.167663,28.483457,2.089392,3.578696,26.586309,25.413513,36.053188,34.632439,-2.514382,0.576387,34.378147,28.104436,2.68195,3.837888,25.087893,25.018881,36.705894,34.943771,-2.526682,0.627349,31.298491,26.829364,1.665449,3.108468,24.644175,24.128819,33.617542,31.820757,-2.648939,0.546458,34.69445,30.393524,1.245293,3.427506,29.947037,27.227589,35.665222,35.334206,-2.829674,0.536214,32.257148,29.101899,2.67743,3.228711,27.337736,26.312038,35.801083,34.356197,-0.070206,0.535426,9.96768,9.829435,0.655484,0.58353,7.889443,8.166823,12.711292,11.176389,1.344158,-1.161512,1.000002,1.000005,1.8e-05,2e-05,0.999966,0.999964,1.000035,1.000036,-0.282396,-0.2113,-0.052604,0.009934,1.089846,0.608893,-4.821848,-1.580443,3.258967,1.766627,-1.792633,0.313984,-0.210485,0.379021,0.870409,0.450899,-1.005894,-0.311333,2.326494,1.828487,1.332841,1.625716,3.079657,1.275072,2.050599,0.840471,-1.3825,-0.35693,5.200777,3.123014,-0.864482,0.356835,-3.395959,-3.56416,0.90236,0.616531,-5.302862,-4.877464,-1.664322,-1.622339,0.007491,0.561658,2.437302,1.291625,2.793459,0.63414,-1.181065,0.440891,6.390553,3.293949,0.074523,1.538646,0.220496,0.625694,0.145764,0.362688,0.1,0.084375,0.8,0.934375,3.182656,-0.479829,-623.968067,-629.761604,48.879485,115.016883,-661.976619,-715.205258,-364.384672,-238.026687,4.789194,2.536521,-51.275098,-10.851367,82.692556,114.01907,-158.552044,-341.478744,297.256535,302.012546,2.633366,0.407993,-28.862416,41.004199,24.223706,123.244726,-78.969902,-215.499255,47.354899,401.74066,0.944819,0.535582,-52.47755,-68.911543,47.129406,98.214484,-122.688353,-131.618959,37.510928,386.850096,0.379817,4.179699,-43.633572,12.462743,50.792239,97.217736,-75.630351,-148.160156,164.065169,354.214902,3.773492,1.82195,-61.243723,34.834453,54.107752,124.278141,-342.528901,-197.538824,-30.110056,319.483373,-4.48623,-0.045719,33.63212,-16.014506,33.813308,67.892181,-61.105438,-130.815428,98.148785,179.976177,-0.305826,0.871564,77.198954,35.358729,50.127417,67.478862,-33.110951,-90.324049,163.034908,272.935952,-0.302733,1.712086,7.738039,21.00965,54.565516,54.142887,-80.613427,-82.49996,151.204978,229.358479,0.424938,2.103149,34.180665,29.104782,26.436484,75.238266,-21.581167,-223.381799,80.785423,179.564859,-0.067516,-1.246691,0.766745,129.579839,0.058291,1704.473272,0.184587,538.254718,0.050422,8.52129,0.003833,112.087733,0.012139,35.396126,-0.004046,-0.683848,-0.000308,-8.995236,-0.000974,-2.840601,0.083834,14.168002,0.006373,186.363715,0.020182,58.851699,-47.99754,-8111.584877,-3.648936,-106698.539531,-11.554964,-33694.275641,0.016961,2.866452,0.001289,37.704871,0.004083,11.906801,2
4,SEQ_000018,SUBJ_032165,Forehead - pull hairline,0,13,0,1,165.0,52,23.0,6.235639,3.881445,0.241211,1.901476,5.71875,1.964844,6.832031,6.332031,0.194196,0.300721,6.068704,0.779492,0.211266,4.005993,5.566406,-3.164062,6.449219,6.71875,-0.492244,0.461042,4.73989,7.972266,0.455216,2.25171,4.148438,4.261719,6.140625,9.933594,1.217216,-0.93485,0.841427,0.889304,0.010841,0.039327,0.828247,0.831604,0.865784,0.925049,0.666296,-0.697839,0.272425,0.01604,0.014179,0.211085,0.237183,-0.184204,0.305542,0.305481,-0.044626,0.472734,-0.431091,-0.218021,0.016419,0.17774,-0.457458,-0.435425,-0.398071,-0.022644,0.255425,-0.23594,-0.176968,-0.283978,0.007988,0.085287,-0.193359,-0.367676,-0.163818,-0.159668,0.057725,0.557964,29.581021,27.75201,0.38058,1.088661,28.659307,26.533083,30.267483,29.387503,-0.32283,0.184509,30.334161,27.916263,0.42664,1.866541,29.461178,25.795074,31.035217,31.007061,-0.364101,0.157125,27.559514,26.205623,0.381086,1.028559,26.780855,25.12772,28.381996,28.468761,0.268104,0.728573,28.062466,27.466412,0.185449,0.473282,27.696522,26.827133,28.400864,28.368055,0.080755,0.398008,31.859936,29.576891,0.12234,1.394618,31.598164,28.282324,32.180752,31.793083,0.008231,0.623453,9.92132,10.129568,0.220616,0.354229,9.585841,9.520368,10.866563,11.021747,2.581684,0.795032,1.000003,1.000008,1.8e-05,1.8e-05,0.999966,0.99997,1.000041,1.000036,0.322292,-0.389912,0.020273,0.032846,0.302133,0.351351,-0.424762,-0.653291,1.222857,0.677167,1.903539,0.175134,-0.75314,-0.164253,0.164019,0.855464,-1.1005,-1.661255,-0.52652,0.765812,-0.605604,-0.266751,2.774647,1.71064,0.207158,1.06177,2.523352,0.611027,3.235443,3.619762,0.84374,0.324464,-0.502952,-1.260789,0.217262,0.587334,-0.975103,-1.882231,0.006832,0.100706,0.3183,0.868645,-3.797471,-2.110478,0.178475,1.25932,-4.114485,-3.99962,-3.438377,-0.220291,0.558554,-0.428609,0.180147,0.49375,0.023723,0.265933,0.14375,0.171875,0.24375,0.79375,0.418988,-0.087291,990.020173,517.438202,92.703562,575.555875,739.003538,-165.433173,1105.490979,1632.12621,-1.116725,0.387316,553.200444,841.751227,99.393094,193.953525,344.437643,571.895682,695.115813,1121.369488,-1.003365,-0.115696,-30.436667,140.474868,36.484407,257.801236,-110.974911,-271.441783,28.989863,416.407085,-0.741782,-0.301716,-114.709447,158.019817,28.65947,187.701046,-199.498411,-115.672779,-58.957353,384.856007,-0.634319,-0.269588,0.494619,-31.077793,37.827817,179.767297,-166.03473,-410.953221,46.667918,214.96925,-2.636246,-0.992038,92.246991,51.972631,36.740444,118.346495,13.236341,-235.584705,209.490245,222.787131,0.468399,-0.724521,94.042621,106.421945,20.167244,92.958487,48.685766,-114.351806,129.430181,216.580579,-0.360698,-0.875963,41.809525,-127.553279,44.662556,189.58466,-40.760707,-385.962255,114.614866,132.774889,0.064958,-0.071539,-114.035334,-15.492355,21.394785,199.522887,-179.892425,-412.467169,-82.271356,204.24719,-1.005034,-0.539752,-79.17229,-28.196564,68.103104,136.651776,-166.508632,-300.047714,55.050714,142.576322,0.866289,-0.506372,0.763178,128.977162,0.060129,1637.017827,0.190795,515.908648,0.01697,2.86801,0.001337,36.40167,0.004243,11.472041,0.001559,0.263544,0.000123,3.344983,0.00039,1.054176,0.023241,3.927727,0.001831,49.851923,0.00581,15.710909,76.155392,12870.262243,6.000122,163353.328474,19.038849,51481.048974,0.013857,2.341912,0.001092,29.724265,0.003464,9.367647,1


In [6]:
perm_importance_df.head()

Unnamed: 0,feature,importance_mean
0,acc_z_std_Gesture,0.011634
1,acc_y_std_Gesture,0.008877
2,acc_x_std_Gesture,0.004547
3,thm_2_max_Gesture,0.003645
4,thm_2_mean_Gesture,0.003379


In [7]:
features_w4_df.shape, perm_importance_df.shape

((8151, 347), (343, 2))

## Pipeline

In [8]:
def average_f1_score(y_true_encoded, y_pred_proba):
    """
    Calculates the official competition F1 score.
    
    Args:
        y_true_encoded: True labels, integer encoded.
        y_pred_proba: Predicted probabilities from the model.
    """
    # Get predicted labels by finding the class with the highest probability
    y_pred_encoded = np.argmax(y_pred_proba, axis=1)
    
    # Map integer-encoded labels back to the string representations
    y_true_str = pd.Series(y_true_encoded).map(inv_gesture_map)
    y_pred_str = pd.Series(y_pred_encoded).map(inv_gesture_map)
    
    # Binary F1
    y_true_binary = y_true_str.map(gesture_to_seq_type_map)
    y_pred_binary = y_true_str.map(gesture_to_seq_type_map)
    binary_f1 = f1_score(y_true_binary, y_pred_binary, pos_label='Target', average='binary')
    
    # Macro F1 (collaped non-target class)
    def collapse_non_target(gesture):
        return 'non_target' if gesture_to_seq_type_map[gesture] == 'Non-Target' else gesture
    
    y_true_collapsed = y_true_str.apply(collapse_non_target)
    y_pred_collapsed = y_pred_str.apply(collapse_non_target)
    macro_f1 = f1_score(y_true_collapsed, y_pred_collapsed, average='macro')
    
    # Final score = average of the two components
    return (binary_f1 + macro_f1) / 2

## Pruning + Training

In [10]:
# Identify features with positive importance
keep_features = perm_importance_df[perm_importance_df['importance_mean'] > 0]['feature'].tolist()

# Target and grouping columns need to be added to the features list
keep_cols = ['sequence_id', 'gesture', 'gesture_encoded'] + keep_features

# Create new, pruned dataframe
pruned_features = features_w4_df[keep_cols]

print(f"Original number of features: {len(features_w4_df.columns) - 3}")
print(f"Number of features after pruning: {len(pruned_features.columns) - 3}")
print("Pruned features DataFrame created successfully.")

Original number of features: 344
Number of features after pruning: 76
Pruned features DataFrame created successfully.


In [11]:
# Configuration 
FEATURE_WAVE = "Wave 4 (Pruned)"
MODEL_NAME = "CatBoost"
EXPERIMENT_NAME = f"{FEATURE_WAVE}-{MODEL_NAME}-CPU"
N_SPLITS = 5
SEED = 42

# Model Parameters
params = {
    'iterations': 1000, 'learning_rate': 0.05, 'depth': 6,
    'loss_function': 'MultiClass', 'eval_metric': 'MultiClass',
    'random_seed': SEED, 'verbose': 0
}

## Prepare data for CV ---
X = pruned_features.drop(columns=['sequence_id', 'gesture', 'gesture_encoded'])
y = pruned_features['gesture_encoded']

# Get subject from the un-pruned dataframe for CV
groups = features_w4_df['subject']

In [12]:
# Cross-Validation Loop
print("\n" + "="*50)
print(f"Running CV for {MODEL_NAME} on {FEATURE_WAVE} Features")
print("="*50)

fold_scores = []
cv = StratifiedGroupKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
for fold, (train_idx, val_idx) in enumerate(cv.split(X, y, groups)):
    print(f"--- Fold {fold+1}/{N_SPLITS} ---")
    
    X_train, y_train = X.iloc[train_idx], y.iloc[train_idx]
    X_val, y_val = X.iloc[val_idx], y.iloc[val_idx]
    
    model = cat.CatBoostClassifier(**params)
    model.fit(X_train, y_train, eval_set=[(X_val, y_val)], early_stopping_rounds=100)
    
    val_preds_proba = model.predict_proba(X_val)
    fold_score = average_f1_score(y_val, val_preds_proba)
    fold_scores.append(fold_score)
    print(f"Fold {fold+1} Competition F1 Score: {fold_score:.5f}")

# --- Final Score and Logging ---
mean_cv_score = np.mean(fold_scores)
print(f"\n--- CV Summary for {MODEL_NAME} on {FEATURE_WAVE} Features ---")
print(f"Mean Competition F1 Score: {mean_cv_score:.5f}")
print(f"Std Dev: {np.std(fold_scores):.5f}\n")

tracker.log_experiment(
    experiment_name=EXPERIMENT_NAME, model_name=MODEL_NAME, feature_wave=FEATURE_WAVE,
    cv_score=mean_cv_score, params=params,
    notes="Re-trained Waved 4 (fixed) CatBoost on a pruned feature set (Permutation Importance > 0)."
)


Running CV for CatBoost on Wave 4 (Pruned) Features
--- Fold 1/5 ---
Fold 1 Competition F1 Score: 0.82885
--- Fold 2/5 ---
Fold 2 Competition F1 Score: 0.77407
--- Fold 3/5 ---
Fold 3 Competition F1 Score: 0.79894
--- Fold 4/5 ---
Fold 4 Competition F1 Score: 0.77499
--- Fold 5/5 ---
Fold 5 Competition F1 Score: 0.78693

--- CV Summary for CatBoost on Wave 4 (Pruned) Features ---
Mean Competition F1 Score: 0.79275
Std Dev: 0.02020

Experiment 'Wave 4 (Pruned)-CatBoost-CPU' logged to /home/bac/code/kaggle/kaggle-cmi-detect-behavior/experiment_log.csv


[34m[1mwandb[0m: Currently logged in as: [33mb-a-chaudhry[0m ([33mb-a-chaudhry-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


0,1
cv_score,▁

0,1
cv_score,0.79275
feature_wave,Wave 4 (Pruned)
model_name,CatBoost


Experiment 'Wave 4 (Pruned)-CatBoost-CPU' logged to W&B


**`ANALYSIS OF PRUNING`**

- The Full Set is Superior: Despite a 77% reduction in feature count, the pruned model's performance decreased by a significant -0.01097.

- The "Weak Features as Regularizers" Hypothesis:  The 268 features with low or zero permutation importance were not useless. In aggregate, they acted as a form of regularization, helping the model generalize better and preventing it from overfitting to the more dominant signals. Removing them made the model simpler but ultimately less accurate on this specific task.

- Final Verdict on Feature Engineering:  The evidence is conclusive: the full Wave 4 - Fixed feature set is our champion. We can now proceed with very high confidence that we have the best possible set of "ingredients" for our model.

## Hyper-Parameter Tuning

We will be using the [`Optuna`](https://optuna.org/#key_features) library to automate the search for the best model configuration.

In [17]:
# Define the Optuna objective function - single trial

def objective(trial, X, y, groups, scorer):
    """
    Defines a single trial for Optuna. It suggests hyperparameters,
    trains a CatBoost model using them, and returns the mean CV F1 score.
    """
    params = {
        'iterations': 2000,
        'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.08, log=True),
        'depth': trial.suggest_int('depth', 5, 8),
        'l2_leaf_reg': trial.suggest_float('l2_leaf_reg', 2.0, 15.0, log=True),
        'bagging_temperature': trial.suggest_float('bagging_temperature', 0.0, 1.0),
        'random_strength': trial.suggest_float('random_strength', 1.0, 10.0, log=True),
        'loss_function': 'MultiClass',
        'eval_metric': 'MultiClass',
        'random_seed': SEED,
        'verbose': 0
    }

    cv = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=SEED)
    fold_scores = []
    
    for fold, (train_idx, val_idx) in enumerate(cv.split(X, y, groups)):
        X_train, y_train = X.iloc[train_idx], y.iloc[train_idx]
        X_val, y_val = X.iloc[val_idx], y.iloc[val_idx]
        
        model = cat.CatBoostClassifier(**params)
        model.fit(X_train, y_train, eval_set=[(X_val, y_val)], early_stopping_rounds=150, verbose=0)
        
        val_preds_proba = model.predict_proba(X_val)
        
        fold_score = scorer(y_val, val_preds_proba)
        fold_scores.append(fold_score)
        
        trial.report(fold_score, fold)
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    return np.mean(fold_scores)

In [18]:
# --- Prepare data for tuning ---
X = features_w4_df.drop(columns=['sequence_id', 'subject', 'gesture', 'gesture_encoded'])
y = features_w4_df['gesture_encoded']
groups = features_w4_df['subject']

In [None]:
# Create the study object with direction set to 'maximize' for our F1 score
study = optuna.create_study(
    direction='maximize',
    sampler=optuna.samplers.TPESampler(seed=SEED),
    pruner=optuna.pruners.MedianPruner() # Review other options - FLAG
)

scorer = lambda y_true, y_pred: average_f1_score(y_true, y_pred)
objective_with_scorer = lambda trial: objective(trial, X, y, groups, scorer)

# Run the optimization with the corrected objective.
study.optimize(objective_with_scorer, n_trials=50, show_progress_bar=True)

# --- Print and Log the Best Results ---
print("\n" + "="*50)
print("Optuna Hyperparameter Tuning Complete")
print(f"Number of finished trials: {len(study.trials)}")

# Select the best trial
best_trial = study.best_trial
print(f"Best trial's F1 Score: {best_trial.value:.5f}")
print("Best trial's hyperparameters:")
for key, value in best_trial.params.items():
    print(f"    {key}: {value}")

# Log the best trial as our new champion experiment
tracker.log_experiment(
    experiment_name="CatBoost-Tuned-Wave4-Features",
    model_name="CatBoost (Tuned)",
    feature_wave="Wave 4 (Fixed)",
    cv_score=best_trial.value,
    params=best_trial.params,
    notes="Best parameters found after 50 Optuna trials on the Wave 4 feature set."
)

**`POORLY OPTIMIZED TUNING SETUP`**

- In its current form, the Optuna setup would take approximately 84+ hours to run trials on various hyper-parameter tests. This is totally unacceptable.
- This part of the modeling process has been moved to notebook `05A-hyper-parameter.ipynb` on Colab.
- Instead of re-splitting the data 5 times for every trial, we will perform one single split at the very beginning using a large training set (~80% of the data) and a single, fixed hold-out validation set (~20%).
- Inside each Optuna trial, the model will train once on the large training set and evaluate once on the hold-out set. This immediately gives us a 5x speedup.
- Running these optimizations on a CPU is extremely time consuming, so we will leverage Catboost's internal GPU support while leveraging GPU compute in Colab to achieve another 2-5x in speedups.

Based on Colab's hyper-parameter tuning, these are the results based on the fixed hold out approach:

```

Optuna Hyperparameter Tuning Complete
Number of finished trials: 75
Best trial's Hold-Out F1 Score: 0.84691
Best trial's hyperparameters:
    learning_rate: 0.0676968816733559
    depth: 5
    l2_leaf_reg: 2.115556787694446
    bagging_temperature: 0.195243584852945
    random_strength: 1.0061621160663667

```

## Hyper-Parameter Optimized (Fixed-Hold Out) Run 

In [10]:
FEATURE_WAVE = "Wave 4 (Fixed)"
MODEL_NAME = "CatBoost (Final-Tuned)"
EXPERIMENT_NAME = f"Final-Validation-{MODEL_NAME}-CPU"
N_SPLITS = 5
SEED = 42

# Using best performing parameters from the Optuna trials

best_params = {
    'iterations': 2000, # Increased iterations to allow early stopping to work
    'learning_rate': 0.0676968816733559,
    'depth': 5,
    'l2_leaf_reg': 2.115556787694446,
    'bagging_temperature': 0.195243584852945,
    'random_strength': 1.0061621160663667,
    'loss_function': 'MultiClass',
    'eval_metric': 'MultiClass',
    'random_seed': SEED,
    'task_type' : 'GPU',
    'verbose': 0
}

# Prepare data for CV 
X = features_w4_df.drop(columns=['sequence_id', 'subject', 'gesture', 'gesture_encoded'])
y = features_w4_df['gesture_encoded']
groups = features_w4_df['subject']

# Cross-Validation Loop 
print("\n" + "="*50)
print(f"Running Final Validation for {MODEL_NAME} on {FEATURE_WAVE} Features")
print("="*50)

fold_scores = []
cv = StratifiedGroupKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
for fold, (train_idx, val_idx) in enumerate(cv.split(X, y, groups)):
    print(f"--- Fold {fold+1}/{N_SPLITS} ---")
    
    X_train, y_train = X.iloc[train_idx], y.iloc[train_idx]
    X_val, y_val = X.iloc[val_idx], y.iloc[val_idx]
    
    # Instantiate the model with our best parameters
    model = cat.CatBoostClassifier(**best_params)
    
    # Train with early stopping
    model.fit(X_train, y_train, eval_set=[(X_val, y_val)], early_stopping_rounds=150)
    
    # Predict and score
    val_preds_proba = model.predict_proba(X_val)
    fold_score = average_f1_score(y_val, val_preds_proba)
    fold_scores.append(fold_score)
    print(f"Fold {fold+1} Competition F1 Score: {fold_score:.5f}")

# --- Final Score and Logging ---
mean_cv_score = np.mean(fold_scores)
print(f"\n--- Final CV Summary for Tuned {MODEL_NAME} ---")
print(f"Mean Competition F1 Score: {mean_cv_score:.5f}")
print(f"Std Dev: {np.std(fold_scores):.5f}\n")

tracker.log_experiment(
    experiment_name=EXPERIMENT_NAME, 
    model_name=MODEL_NAME, 
    feature_wave=FEATURE_WAVE,
    cv_score=mean_cv_score, 
    params=best_params,
    notes="Validation of best Optuna parameters using 5-fold StratifiedGroupKFold CV."
)



Running Final Validation for CatBoost (Final-Tuned) on Wave 4 (Fixed) Features
--- Fold 1/5 ---
Fold 1 Competition F1 Score: 0.84691
--- Fold 2/5 ---
Fold 2 Competition F1 Score: 0.77755
--- Fold 3/5 ---
Fold 3 Competition F1 Score: 0.80802
--- Fold 4/5 ---
Fold 4 Competition F1 Score: 0.79580
--- Fold 5/5 ---
Fold 5 Competition F1 Score: 0.80530

--- Final CV Summary for Tuned CatBoost (Final-Tuned) ---
Mean Competition F1 Score: 0.80671
Std Dev: 0.02276

Experiment 'Final-Validation-CatBoost (Final-Tuned)-CPU' logged to /home/bac/code/kaggle/kaggle-cmi-detect-behavior/experiment_log.csv


[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Network error (ConnectionError), entering retry loop.


0,1
cv_score,▁

0,1
cv_score,0.80671
feature_wave,Wave 4 (Fixed)
model_name,CatBoost (Final-Tune...


Experiment 'Final-Validation-CatBoost (Final-Tuned)-CPU' logged to W&B


## Train Final Model on 100% Training Data and Validate With the Test Set

The results from an intensive 50 trial Stratified K-Fold [Optuna tuning run on Colab](https://github.com/bachaudhry/kaggle-cmi-detect-behavior/blob/main/notebooks/05B-hyper-parameter-tuning.ipynb) indicate that the combination of hyper-parameters below result in an improvement of **+.00783** mean CV F1 score compared to the Wave 4 baseline:


    Final Robust Optuna Tuning Complete
    Number of finished trials: 50
    Best trial's Mean CV F1 Score: 0.81155
    Best trial's hyperparameters:
        
        learning_rate: 0.06442938671410496
        depth: 6
        l2_leaf_reg: 4.3544165839313385
        bagging_temperature: 0.13574803789441542
        random_strength: 7.435358758357957

We will now train the final model (of this leg of the development process) on the entirety of the training data while prepping to make our submission to Kaggle.

In [9]:
FEATURE_WAVE = "Wave 4 (Fixed)"
MODEL_NAME = "CatBoost Wave-4 Tuned (Final)"
EXPERIMENT_NAME = f"Final-Model-Training"

# final features are already loaded

# Best Hyperparameters from Final Optuna Study 
best_params = {
    'iterations': 2000, 
    'learning_rate': 0.06442938671410496,
    'depth': 6,
    'l2_leaf_reg': 4.3544165839313385,
    'bagging_temperature': 0.13574803789441542,
    'random_strength': 7.435358758357957,
    'loss_function': 'MultiClass',
    'eval_metric': 'MultiClass',
    'random_seed': SEED,
    'task_type' : 'GPU', # Speed up training using GPU
    'verbose': 100
}

In [20]:
X_train_full = features_w4_df.drop(columns=['sequence_id', 'subject', 'gesture', 'gesture_encoded'])
y_train_full = features_w4_df['gesture_encoded']

print("\n" + "="*50)
print(f"Training Final Model on ALL Training Data")
print(f"Features: {FEATURE_WAVE}")
print("="*50)

final_model = cat.CatBoostClassifier(**best_params)

final_model.fit(X_train_full, y_train_full, plot=True)

# Save the model artifact
MODELS_PATH = os.path.join(os.path.expanduser(PROJECT_PATH), 'models')
if not os.path.exists(MODELS_PATH):
    os.makedirs(MODELS_PATH)
    
model_save_path = os.path.join(MODELS_PATH, 'catboost_final_model.cbm')
final_model.save_model(model_save_path)
    
print(f"Model saved to: {model_save_path}")

# Log training event

tracker.log_experiment(
    experiment_name=EXPERIMENT_NAME, 
    model_name=MODEL_NAME, 
    feature_wave=FEATURE_WAVE,
    cv_score=0.81155, # Log our best validated CV score for this model
    params=best_params,
    notes="Final model trained on complete training data. Saved to file for submission."
)


Training Final Model on ALL Training Data
Features: Wave 4 (Fixed)


MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))

0:	learn: 2.7471977	total: 55.2ms	remaining: 1m 50s
100:	learn: 1.2250646	total: 1.34s	remaining: 25.2s
200:	learn: 0.9171906	total: 2.78s	remaining: 24.9s
300:	learn: 0.7106409	total: 4.61s	remaining: 26s
400:	learn: 0.5723588	total: 6.44s	remaining: 25.7s
500:	learn: 0.4767190	total: 8.27s	remaining: 24.7s
600:	learn: 0.4044465	total: 10.1s	remaining: 23.5s
700:	learn: 0.3465841	total: 11.9s	remaining: 22.1s
800:	learn: 0.2989821	total: 13.8s	remaining: 20.6s
900:	learn: 0.2602891	total: 15.6s	remaining: 19s
1000:	learn: 0.2288687	total: 17.4s	remaining: 17.3s
1100:	learn: 0.2023758	total: 19.2s	remaining: 15.6s
1200:	learn: 0.1810780	total: 20.9s	remaining: 13.9s
1300:	learn: 0.1615314	total: 22.8s	remaining: 12.2s
1400:	learn: 0.1447627	total: 24.6s	remaining: 10.5s
1500:	learn: 0.1304741	total: 26.3s	remaining: 8.76s
1600:	learn: 0.1182468	total: 28.1s	remaining: 7.01s
1700:	learn: 0.1073759	total: 29.9s	remaining: 5.26s
1800:	learn: 0.0981724	total: 31.8s	remaining: 3.51s
1900:	l

[34m[1mwandb[0m: Currently logged in as: [33mb-a-chaudhry[0m ([33mb-a-chaudhry-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


0,1
cv_score,▁

0,1
cv_score,0.81155
feature_wave,Wave 4 (Fixed)
model_name,CatBoost Wave-4 Tune...


Experiment 'Final-Model-Training' logged to W&B
