In [None]:
# ====================================================================================================
# CMI BFRB Detection - Two-Stage Classification System (All-in-One)
# Score Target: 0.730+ (Binary F1: 0.94+, Macro F1: 0.52+)
# ====================================================================================================

import os
import sys
import json
import pickle
import joblib
import warnings
import gc
from pathlib import Path
from datetime import datetime
from typing import Tuple, Dict, List, Optional, Any
import numpy as np
import pandas as pd
import polars as pl
from scipy import stats, signal
from scipy.spatial.transform import Rotation as R
from scipy.fft import fft, fftfreq
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import f1_score, confusion_matrix, classification_report
import lightgbm as lgb
import xgboost as xgb

# Try to import SMOTE (optional for Kaggle)
try:
    from imblearn.over_sampling import SMOTE
    from imblearn.combine import SMOTETomek
    SMOTE_AVAILABLE = True
except ImportError:
    SMOTE_AVAILABLE = False
    print("SMOTE not available, will use class weights instead")

warnings.filterwarnings('ignore')

print('✓ All imports loaded successfully')

# ====================================================================================================
# CONFIGURATION
# ====================================================================================================

CONFIG = {
    'data_path': 'cmi-detect-behavior-with-sensor-data/',
    'n_folds': 5,
    'random_state': 42,
    'sample_rate': 20,  # Hz
    'use_smote': SMOTE_AVAILABLE,
    'two_stage': True,  # Enable two-stage classification
    
    # Stage 1: Binary classification (BFRB vs Non-BFRB)
    'binary_lgbm': {
        'objective': 'binary',
        'metric': 'binary_logloss',
        'boosting_type': 'gbdt',
        'num_leaves': 31,
        'learning_rate': 0.05,
        'feature_fraction': 0.8,
        'bagging_fraction': 0.8,
        'bagging_freq': 5,
        'n_estimators': 800,
        'max_depth': 8,
        'min_child_samples': 20,
        'reg_alpha': 0.1,
        'reg_lambda': 0.1,
        'random_state': 42,
        'n_jobs': -1,
        'verbosity': -1
    },
    
    # Stage 2A: BFRB multi-class (8 classes)
    'bfrb_lgbm': {
        'objective': 'multiclass',
        'num_class': 8,
        'metric': 'multi_logloss',
        'boosting_type': 'gbdt',
        'num_leaves': 25,
        'learning_rate': 0.03,
        'feature_fraction': 0.7,
        'bagging_fraction': 0.7,
        'bagging_freq': 5,
        'n_estimators': 1000,
        'max_depth': 6,
        'min_child_samples': 30,
        'reg_alpha': 0.2,
        'reg_lambda': 0.2,
        'class_weight': 'balanced',
        'random_state': 42,
        'n_jobs': -1,
        'verbosity': -1
    },
    
    # Stage 2B: Non-BFRB multi-class (10 classes)
    'non_bfrb_lgbm': {
        'objective': 'multiclass',
        'num_class': 10,
        'metric': 'multi_logloss',
        'boosting_type': 'gbdt',
        'num_leaves': 31,
        'learning_rate': 0.05,
        'feature_fraction': 0.8,
        'bagging_fraction': 0.8,
        'bagging_freq': 5,
        'n_estimators': 600,
        'max_depth': 7,
        'min_child_samples': 20,
        'random_state': 42,
        'n_jobs': -1,
        'verbosity': -1
    }
}

# Gesture mapping
GESTURE_MAPPER = {
    'Above ear - pull hair': 0, 'Cheek - pinch skin': 1, 'Eyebrow - pull hair': 2,
    'Eyelash - pull hair': 3, 'Forehead - pull hairline': 4, 'Forehead - scratch': 5,
    'Neck - pinch skin': 6, 'Neck - scratch': 7,
    'Drink from bottle/cup': 8, 'Feel around in tray and pull out an object': 9,
    'Glasses on/off': 10, 'Pinch knee/leg skin': 11, 'Pull air toward your face': 12,
    'Scratch knee/leg skin': 13, 'Text on phone': 14, 'Wave hello': 15,
    'Write name in air': 16, 'Write name on leg': 17
}
REVERSE_GESTURE_MAPPER = {v: k for k, v in GESTURE_MAPPER.items()}

print(f'✓ Configuration loaded ({len(GESTURE_MAPPER)} gesture classes)')

# ====================================================================================================
# WORLD ACCELERATION TRANSFORMATION
# ====================================================================================================

def compute_world_acceleration(acc: np.ndarray, rot: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """Convert device coordinates to world coordinates using quaternions."""
    # Convert quaternion format (w,x,y,z) to scipy format (x,y,z,w)
    rot_scipy = rot[:, [1, 2, 3, 0]]
    
    # Create rotation object and apply to acceleration
    r = R.from_quat(rot_scipy)
    world_acc = r.apply(acc)
    
    # Estimate gravity (low-pass filter)
    b, a = signal.butter(3, 0.3, 'low', fs=CONFIG['sample_rate'])
    gravity = np.zeros_like(world_acc)
    for i in range(3):
        gravity[:, i] = signal.filtfilt(b, a, world_acc[:, i])
    
    # Linear acceleration = total - gravity
    linear_acc = world_acc - gravity
    
    return world_acc, linear_acc

# ====================================================================================================
# FEATURE EXTRACTION
# ====================================================================================================

def extract_statistical_features(data: np.ndarray, prefix: str) -> Dict[str, float]:
    """Extract comprehensive statistical features from time series data."""
    features = {}
    
    # Basic statistics
    features[f'{prefix}_mean'] = np.mean(data)
    features[f'{prefix}_std'] = np.std(data)
    features[f'{prefix}_var'] = np.var(data)
    features[f'{prefix}_min'] = np.min(data)
    features[f'{prefix}_max'] = np.max(data)
    features[f'{prefix}_range'] = features[f'{prefix}_max'] - features[f'{prefix}_min']
    
    # Percentiles
    for p in [10, 25, 50, 75, 90]:
        features[f'{prefix}_p{p}'] = np.percentile(data, p)
    features[f'{prefix}_iqr'] = features[f'{prefix}_p75'] - features[f'{prefix}_p25']
    
    # Higher moments
    features[f'{prefix}_skew'] = stats.skew(data)
    features[f'{prefix}_kurtosis'] = stats.kurtosis(data)
    
    # Peak features
    peaks, _ = signal.find_peaks(data)
    features[f'{prefix}_n_peaks'] = len(peaks)
    features[f'{prefix}_peak_density'] = len(peaks) / len(data) if len(data) > 0 else 0
    
    # Temporal features
    features[f'{prefix}_first'] = data[0] if len(data) > 0 else 0
    features[f'{prefix}_last'] = data[-1] if len(data) > 0 else 0
    features[f'{prefix}_delta'] = features[f'{prefix}_last'] - features[f'{prefix}_first']
    
    # Zero crossing rate
    zero_crossings = np.where(np.diff(np.sign(data - np.mean(data))))[0]
    features[f'{prefix}_zero_crossing_rate'] = len(zero_crossings) / len(data) if len(data) > 0 else 0
    
    # Energy
    features[f'{prefix}_energy'] = np.sum(data ** 2) / len(data) if len(data) > 0 else 0
    
    # Segment features (divide into 3 parts)
    if len(data) >= 9:
        seg_size = len(data) // 3
        for i in range(3):
            start = i * seg_size
            end = (i + 1) * seg_size if i < 2 else len(data)
            seg = data[start:end]
            features[f'{prefix}_seg{i+1}_mean'] = np.mean(seg)
            features[f'{prefix}_seg{i+1}_std'] = np.std(seg)
            features[f'{prefix}_seg{i+1}_max'] = np.max(seg)
    
    return features

def extract_frequency_features(data: np.ndarray, prefix: str, sample_rate: int = 20) -> Dict[str, float]:
    """Extract frequency domain features using FFT and spectral analysis."""
    features = {}
    
    # FFT
    fft_vals = np.abs(fft(data))
    fft_freq = fftfreq(len(data), 1/sample_rate)
    
    # Only positive frequencies
    pos_mask = fft_freq > 0
    fft_vals = fft_vals[pos_mask]
    fft_freq = fft_freq[pos_mask]
    
    if len(fft_vals) > 0:
        # Dominant frequency
        features[f'{prefix}_dominant_freq'] = fft_freq[np.argmax(fft_vals)]
        features[f'{prefix}_dominant_freq_magnitude'] = np.max(fft_vals)
        
        # Spectral features
        features[f'{prefix}_spectral_energy'] = np.sum(fft_vals ** 2)
        features[f'{prefix}_spectral_entropy'] = stats.entropy(fft_vals / np.sum(fft_vals) if np.sum(fft_vals) > 0 else fft_vals)
        
        # Frequency bands (0-2Hz, 2-5Hz, 5-10Hz)
        bands = [(0, 2), (2, 5), (5, 10)]
        for i, (low, high) in enumerate(bands):
            band_mask = (fft_freq >= low) & (fft_freq < high)
            if np.any(band_mask):
                features[f'{prefix}_band{i+1}_energy'] = np.sum(fft_vals[band_mask] ** 2)
                features[f'{prefix}_band{i+1}_ratio'] = features[f'{prefix}_band{i+1}_energy'] / features[f'{prefix}_spectral_energy'] if features[f'{prefix}_spectral_energy'] > 0 else 0
    
    # Power spectral density
    try:
        freqs, psd = signal.welch(data, fs=sample_rate, nperseg=min(256, len(data)))
        features[f'{prefix}_psd_max'] = np.max(psd)
        features[f'{prefix}_psd_mean'] = np.mean(psd)
        features[f'{prefix}_psd_std'] = np.std(psd)
    except:
        features[f'{prefix}_psd_max'] = 0
        features[f'{prefix}_psd_mean'] = 0
        features[f'{prefix}_psd_std'] = 0
    
    return features

def extract_cross_correlation_features(data1: np.ndarray, data2: np.ndarray, prefix: str) -> Dict[str, float]:
    """Extract correlation features between two signals."""
    features = {}
    
    # Pearson correlation
    features[f'{prefix}_corr'] = np.corrcoef(data1, data2)[0, 1] if len(data1) > 1 else 0
    
    # Cross-correlation
    cross_corr = np.correlate(data1 - np.mean(data1), data2 - np.mean(data2), mode='same')
    features[f'{prefix}_cross_corr_max'] = np.max(cross_corr) if len(cross_corr) > 0 else 0
    features[f'{prefix}_cross_corr_argmax'] = np.argmax(cross_corr) if len(cross_corr) > 0 else 0
    
    return features

def extract_features_from_sequence(seq_df: pd.DataFrame, demo_df: pd.DataFrame = None) -> pd.DataFrame:
    """Extract all features from a sequence."""
    features = {}
    
    # Sequence metadata
    features['sequence_length'] = len(seq_df)
    features['duration_seconds'] = len(seq_df) / CONFIG['sample_rate']
    
    # Demographics
    if demo_df is not None and len(demo_df) > 0:
        demo = demo_df.iloc[0]
        for col in ['age', 'adult_child', 'sex', 'handedness', 'height_cm', 
                   'shoulder_to_wrist_cm', 'elbow_to_wrist_cm']:
            if col in demo.index:
                features[col] = demo[col]
    
    # Check if IMU columns exist
    acc_cols = ['acc_x', 'acc_y', 'acc_z']
    rot_cols = ['rot_w', 'rot_x', 'rot_y', 'rot_z']
    
    if all(col in seq_df.columns for col in acc_cols + rot_cols):
        # Get IMU data
        acc = seq_df[acc_cols].fillna(0).values
        
        # Handle rotation data more carefully
        rot_df = seq_df[rot_cols].copy()
        rot_df = rot_df.ffill().bfill()
        
        # Fill remaining NaNs with default quaternion [1, 0, 0, 0]
        default_quat = {'rot_w': 1, 'rot_x': 0, 'rot_y': 0, 'rot_z': 0}
        for col, default_val in default_quat.items():
            rot_df[col] = rot_df[col].fillna(default_val)
        
        rot = rot_df.values
        
        # World acceleration transformation
        try:
            world_acc, linear_acc = compute_world_acceleration(acc, rot)
            
            # Extract features for world acceleration
            for i, axis in enumerate(['x', 'y', 'z']):
                features.update(extract_statistical_features(world_acc[:, i], f'world_acc_{axis}'))
                features.update(extract_frequency_features(world_acc[:, i], f'world_acc_{axis}'))
            
            # Extract features for linear acceleration
            for i, axis in enumerate(['x', 'y', 'z']):
                features.update(extract_statistical_features(linear_acc[:, i], f'linear_acc_{axis}'))
                features.update(extract_frequency_features(linear_acc[:, i], f'linear_freq_{axis}'))
            
            # Magnitude features
            world_mag = np.linalg.norm(world_acc, axis=1)
            linear_mag = np.linalg.norm(linear_acc, axis=1)
            features.update(extract_statistical_features(world_mag, 'world_mag'))
            features.update(extract_statistical_features(linear_mag, 'linear_mag'))
            features.update(extract_frequency_features(world_mag, 'world_mag_freq'))
            features.update(extract_frequency_features(linear_mag, 'linear_mag_freq'))
            
        except Exception as e:
            print(f"World acceleration error: {e}")
            # Fallback to device coordinates
            for i, axis in enumerate(['x', 'y', 'z']):
                features.update(extract_statistical_features(acc[:, i], f'acc_{axis}'))
                features.update(extract_frequency_features(acc[:, i], f'acc_{axis}'))
        
        # Original acceleration features (device coordinates)
        for col in acc_cols:
            data = seq_df[col].fillna(0).values
            features.update(extract_statistical_features(data, f'device_{col}'))
        
        # Rotation features
        for i, col in enumerate(rot_cols):
            data = rot[:, i]
            features.update(extract_statistical_features(data, col))
        
        # Cross-correlation features
        for i, j in [(0, 1), (0, 2), (1, 2)]:
            features.update(extract_cross_correlation_features(
                acc[:, i], acc[:, j], f'acc_corr_{i}{j}'
            ))
    
    return pd.DataFrame([features])

print('✓ Feature extraction functions defined')

# ====================================================================================================
# TWO-STAGE CLASSIFIER
# ====================================================================================================

class TwoStageClassifier:
    """Two-stage classification: Binary (BFRB detection) -> Multi-class (specific gesture)."""
    
    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.binary_models = []  # Stage 1: BFRB vs Non-BFRB
        self.bfrb_models = []    # Stage 2A: BFRB 8-class
        self.non_bfrb_models = [] # Stage 2B: Non-BFRB 10-class
        self.feature_columns = None
        self.scaler = StandardScaler()
        
    def fit(self, X: np.ndarray, y: np.ndarray, groups: np.ndarray = None):
        """Train two-stage classifier with cross-validation."""
        
        # Store feature columns
        if isinstance(X, pd.DataFrame):
            self.feature_columns = X.columns.tolist()
            X = X.values
        
        # Scale features
        X = self.scaler.fit_transform(X)
        
        # Stage 1: Binary labels (BFRB: 0-7, Non-BFRB: 8-17)
        y_binary = (y < 8).astype(int)
        
        # Cross-validation
        skf = StratifiedGroupKFold(n_splits=self.config['n_folds'], shuffle=True, 
                                   random_state=self.config['random_state'])
        
        cv_scores = {'binary_f1': [], 'macro_f1': [], 'combined': []}
        
        for fold, (train_idx, val_idx) in enumerate(skf.split(X, y, groups)):
            print(f"\nFold {fold + 1}/{self.config['n_folds']}")
            
            X_train, X_val = X[train_idx], X[val_idx]
            y_train, y_val = y[train_idx], y[val_idx]
            y_binary_train, y_binary_val = y_binary[train_idx], y_binary[val_idx]
            
            # ========== Stage 1: Binary Classification ==========
            print("  Training Stage 1: Binary classifier...")
            binary_model = lgb.LGBMClassifier(**self.config['binary_lgbm'])
            binary_model.fit(X_train, y_binary_train, 
                           eval_set=[(X_val, y_binary_val)],
                           eval_metric='binary_logloss',
                           callbacks=[lgb.log_evaluation(period=100), 
                                    lgb.early_stopping(stopping_rounds=50)])
            self.binary_models.append(binary_model)
            
            # Binary predictions
            binary_pred = binary_model.predict(X_val)
            binary_proba = binary_model.predict_proba(X_val)
            
            # ========== Stage 2A: BFRB Multi-class ==========
            print("  Training Stage 2A: BFRB classifier...")
            bfrb_mask_train = y_train < 8
            bfrb_mask_val = y_val < 8
            
            if np.sum(bfrb_mask_train) > 0:
                X_bfrb_train = X_train[bfrb_mask_train]
                y_bfrb_train = y_train[bfrb_mask_train]
                
                # Apply SMOTE for class balancing
                if self.config['use_smote'] and len(np.unique(y_bfrb_train)) > 1:
                    try:
                        smote = SMOTE(random_state=self.config['random_state'], k_neighbors=3)
                        X_bfrb_train, y_bfrb_train = smote.fit_resample(X_bfrb_train, y_bfrb_train)
                        print(f"    SMOTE applied: {len(X_bfrb_train)} samples")
                    except:
                        print("    SMOTE failed, using original data")
                
                bfrb_model = lgb.LGBMClassifier(**self.config['bfrb_lgbm'])
                bfrb_model.fit(X_bfrb_train, y_bfrb_train)
                self.bfrb_models.append(bfrb_model)
            
            # ========== Stage 2B: Non-BFRB Multi-class ==========
            print("  Training Stage 2B: Non-BFRB classifier...")
            non_bfrb_mask_train = y_train >= 8
            non_bfrb_mask_val = y_val >= 8
            
            if np.sum(non_bfrb_mask_train) > 0:
                X_non_bfrb_train = X_train[non_bfrb_mask_train]
                y_non_bfrb_train = y_train[non_bfrb_mask_train] - 8  # Shift labels to 0-9
                
                non_bfrb_model = lgb.LGBMClassifier(**self.config['non_bfrb_lgbm'])
                non_bfrb_model.fit(X_non_bfrb_train, y_non_bfrb_train)
                self.non_bfrb_models.append(non_bfrb_model)
            
            # ========== Combined Predictions ==========
            y_pred_combined = self._predict_combined(X_val, fold)
            
            # Calculate metrics
            binary_f1 = f1_score(y_binary_val, binary_pred)
            
            # Macro F1 for BFRB classes only
            bfrb_true = y_val[y_val < 8]
            bfrb_pred = y_pred_combined[y_val < 8]
            if len(bfrb_true) > 0:
                macro_f1 = f1_score(bfrb_true, bfrb_pred, average='macro')
            else:
                macro_f1 = 0
            
            combined_score = (binary_f1 + macro_f1) / 2
            
            cv_scores['binary_f1'].append(binary_f1)
            cv_scores['macro_f1'].append(macro_f1)
            cv_scores['combined'].append(combined_score)
            
            print(f"  Binary F1: {binary_f1:.4f}, Macro F1: {macro_f1:.4f}, Combined: {combined_score:.4f}")
        
        # Print CV results
        print("\n" + "="*50)
        print("Cross-validation Results:")
        print(f"Binary F1: {np.mean(cv_scores['binary_f1']):.4f} ± {np.std(cv_scores['binary_f1']):.4f}")
        print(f"Macro F1:  {np.mean(cv_scores['macro_f1']):.4f} ± {np.std(cv_scores['macro_f1']):.4f}")
        print(f"Combined:  {np.mean(cv_scores['combined']):.4f} ± {np.std(cv_scores['combined']):.4f}")
        
        return cv_scores
    
    def _predict_combined(self, X: np.ndarray, fold: int) -> np.ndarray:
        """Combine predictions from two-stage models."""
        n_samples = len(X)
        y_pred = np.zeros(n_samples, dtype=int)
        
        # Stage 1: Binary prediction
        binary_proba = self.binary_models[fold].predict_proba(X)
        is_bfrb = binary_proba[:, 1] > 0.5
        
        # Stage 2: Conditional prediction
        for i in range(n_samples):
            if is_bfrb[i]:  # BFRB
                if fold < len(self.bfrb_models):
                    y_pred[i] = self.bfrb_models[fold].predict(X[i:i+1])[0]
                else:
                    y_pred[i] = 0  # Default BFRB class
            else:  # Non-BFRB
                if fold < len(self.non_bfrb_models):
                    y_pred[i] = self.non_bfrb_models[fold].predict(X[i:i+1])[0] + 8
                else:
                    y_pred[i] = 14  # Default Non-BFRB class (Text on phone)
        
        return y_pred
    
    def predict(self, X: np.ndarray) -> np.ndarray:
        """Predict using ensemble of models."""
        if isinstance(X, pd.DataFrame):
            X = X[self.feature_columns].values
        
        X = self.scaler.transform(X)
        n_samples = len(X)
        
        # Ensemble predictions
        all_predictions = []
        for fold in range(len(self.binary_models)):
            pred = self._predict_combined(X, fold)
            all_predictions.append(pred)
        
        # Majority voting
        all_predictions = np.array(all_predictions)
        final_predictions = np.zeros(n_samples, dtype=int)
        
        for i in range(n_samples):
            final_predictions[i] = np.bincount(all_predictions[:, i]).argmax()
        
        return final_predictions
    
    def save(self, path: str):
        """Save model to file."""
        model_data = {
            'binary_models': self.binary_models,
            'bfrb_models': self.bfrb_models,
            'non_bfrb_models': self.non_bfrb_models,
            'feature_columns': self.feature_columns,
            'scaler': self.scaler,
            'config': self.config
        }
        # Try both pickle and joblib
        try:
            with open(path, 'wb') as f:
                pickle.dump(model_data, f)
        except:
            joblib.dump(model_data, path)
    
    @classmethod
    def load(cls, path: str):
        """Load model from file."""
        try:
            with open(path, 'rb') as f:
                model_data = pickle.load(f)
        except:
            model_data = joblib.load(path)
        
        classifier = cls(model_data['config'])
        classifier.binary_models = model_data['binary_models']
        classifier.bfrb_models = model_data['bfrb_models']
        classifier.non_bfrb_models = model_data['non_bfrb_models']
        classifier.feature_columns = model_data['feature_columns']
        classifier.scaler = model_data['scaler']
        
        return classifier

print('✓ Two-stage classifier defined')

# ====================================================================================================
# DATA LOADING AND PREPROCESSING
# ====================================================================================================

def load_and_prepare_data():
    """Load training data and prepare for model training."""
    print("Loading data...")
    
    # Check if running in Kaggle
    is_kaggle = os.path.exists('/kaggle/input')
    
    # Load data with appropriate paths
    if is_kaggle:
        train_df = pl.read_csv(CONFIG['data_path'] + 'train.csv')
        demo_df = pl.read_csv(CONFIG['data_path'] + 'train_demographics.csv')
    else:
        # Local environment - try both possible names
        train_df = pl.read_csv(CONFIG['data_path'] + 'train.csv')
        try:
            demo_df = pl.read_csv(CONFIG['data_path'] + 'train_demographics.csv')
        except:
            demo_df = pl.read_csv(CONFIG['data_path'] + 'demographics.csv')
    
    # Convert to pandas for easier manipulation
    train_df = train_df.to_pandas()
    demo_df = demo_df.to_pandas()
    
    # Get unique sequences
    unique_sequences = train_df['sequence_id'].unique()
    print(f"Total sequences: {len(unique_sequences)}")
    
    # Filter for IMU columns
    imu_cols = ['acc_x', 'acc_y', 'acc_z', 'rot_w', 'rot_x', 'rot_y', 'rot_z']
    
    # Get sequences with IMU data
    imu_sequence_ids = unique_sequences
    
    print(f"Found {len(imu_sequence_ids)} IMU sequences")
    
    # Process sequences
    features_list = []
    labels = []
    groups = []
    
    for i, seq_id in enumerate(imu_sequence_ids):
        if i % 1000 == 0:
            print(f"Processing sequence {i+1}/{len(imu_sequence_ids)}")
        
        seq_data = train_df[train_df['sequence_id'] == seq_id]
        subject_id = seq_data['subject'].iloc[0]
        gesture = seq_data['gesture'].iloc[0]
        
        # Get demographics
        subject_demo = demo_df[demo_df['subject'] == subject_id]
        
        # Extract features
        features = extract_features_from_sequence(seq_data, subject_demo)
        features_list.append(features)
        labels.append(GESTURE_MAPPER[gesture])
        groups.append(subject_id)
    
    # Combine features
    X = pd.concat(features_list, ignore_index=True)
    y = np.array(labels)
    groups = np.array(groups)
    
    print(f"Feature matrix shape: {X.shape}")
    print(f"Class distribution:")
    for i in range(18):
        count = np.sum(y == i)
        if count > 0:
            print(f"  {REVERSE_GESTURE_MAPPER[i]}: {count}")
    
    return X, y, groups

# ====================================================================================================
# TRAINING PIPELINE
# ====================================================================================================

def train_model():
    """Main training pipeline."""
    print("="*70)
    print("CMI BFRB Detection - Two-Stage Classification Training")
    print("="*70)
    
    # Load data
    X, y, groups = load_and_prepare_data()
    
    # Train two-stage classifier
    print("\nTraining Two-Stage Classifier...")
    classifier = TwoStageClassifier(CONFIG)
    cv_scores = classifier.fit(X, y, groups)
    
    # Save model
    model_path = 'two_stage_model.pkl'
    classifier.save(model_path)
    print(f"\nModel saved to {model_path}")
    
    # Save results
    results = {
        'cv_scores': cv_scores,
        'mean_binary_f1': float(np.mean(cv_scores['binary_f1'])),
        'mean_macro_f1': float(np.mean(cv_scores['macro_f1'])),
        'mean_combined': float(np.mean(cv_scores['combined'])),
        'timestamp': datetime.now().isoformat()
    }
    
    with open('training_results.json', 'w') as f:
        json.dump(results, f, indent=2, default=lambda x: float(x) if isinstance(x, np.floating) else x)
    
    return classifier

# ====================================================================================================
# INFERENCE AND SUBMISSION
# ====================================================================================================

def predict_for_submission(sequence: pl.DataFrame, demographics: pl.DataFrame, model) -> str:
    """Prediction function for Kaggle submission."""
    try:
        # Convert to pandas
        seq_df = sequence.to_pandas() if isinstance(sequence, pl.DataFrame) else sequence
        demo_df = demographics.to_pandas() if isinstance(demographics, pl.DataFrame) else demographics
        
        # Extract features
        features = extract_features_from_sequence(seq_df, demo_df)
        
        # Make prediction
        pred = model.predict(features)[0]
        
        return REVERSE_GESTURE_MAPPER[pred]
        
    except Exception as e:
        print(f"Prediction error: {e}")
        return 'Text on phone'  # Default prediction

# ====================================================================================================
# MAIN EXECUTION
# ====================================================================================================

if __name__ == '__main__':
    # Check if running in Kaggle environment
    is_kaggle = os.path.exists('/kaggle/input')
    
    if is_kaggle:
        print("Running in Kaggle environment")
        # Update paths for Kaggle
        CONFIG['data_path'] = '/kaggle/input/cmi-detect-behavior-with-sensor-data/'
        
        # Try multiple possible model paths
        model_paths = [
            '/kaggle/input/cmi-two-stage-models/two_stage_model.pkl',
            '/kaggle/working/two_stage_model.pkl',
            'two_stage_model.pkl'
        ]
        
        model_loaded = False
        for model_path in model_paths:
            if os.path.exists(model_path):
                print(f"Loading model from {model_path}...")
                model = TwoStageClassifier.load(model_path)
                model_loaded = True
                break
        
        if not model_loaded:
            # Train new model if not found
            print("No pre-trained model found. Training new model...")
            model = train_model()
            
            # Save for inference
            model_path = '/kaggle/working/two_stage_model.pkl'
            model.save(model_path)
            print(f"Model saved to {model_path}")
        
        # Set up prediction function
        def predict(sequence: pl.DataFrame, demographics: pl.DataFrame) -> str:
            return predict_for_submission(sequence, demographics, model)
        
        # Test prediction function
        print("\nTesting prediction function...")
        test_seq = pl.DataFrame({
            'acc_x': np.random.randn(100),
            'acc_y': np.random.randn(100),
            'acc_z': np.random.randn(100),
            'rot_w': np.random.randn(100),
            'rot_x': np.random.randn(100),
            'rot_y': np.random.randn(100),
            'rot_z': np.random.randn(100)
        })
        test_demo = pl.DataFrame({
            'age': [25],
            'adult_child': [1],
            'sex': [0],
            'handedness': [1]
        })
        test_result = predict(test_seq, test_demo)
        print(f"Test result: {test_result}")
        assert isinstance(test_result, str) and test_result in GESTURE_MAPPER, "Invalid prediction"
        print("✓ Test passed!")
        
        # Initialize inference server
        sys.path.append('/kaggle/input/cmi-detect-behavior-with-sensor-data')
        try:
            import kaggle_evaluation.cmi_inference_server
            
            print("\nInitializing CMI inference server...")
            inference_server = kaggle_evaluation.cmi_inference_server.CMIInferenceServer(predict)
            print("✓ Inference server initialized")
            
            # Run inference
            print("\nStarting inference...")
            if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
                print("Competition environment - serving predictions...")
                inference_server.serve()
            else:
                print("Local testing mode...")
                try:
                    inference_server.run_local_gateway(
                        data_paths=(
                            CONFIG['data_path'] + 'test.csv',
                            CONFIG['data_path'] + 'test_demographics.csv',
                        )
                    )
                    print("\n✓ Inference complete!")
                    print("✓ submission.parquet has been generated")
                    
                    # Check if submission file was created
                    if os.path.exists('submission.parquet'):
                        submission_df = pd.read_parquet('submission.parquet')
                        print(f"\nSubmission shape: {submission_df.shape}")
                        print(f"Submission columns: {submission_df.columns.tolist()}")
                        print(f"\nFirst 5 predictions:")
                        print(submission_df.head())
                except Exception as e:
                    print(f"Inference error (may be normal in notebook): {e}")
        except ImportError as e:
            print(f"Could not import CMI inference server: {e}")
            print("This is normal if running locally without the CMI package.")
    else:
        # Local training and testing
        print("Running in local environment")
        
        # Check if we should train or just run inference
        if os.path.exists('two_stage_model.pkl'):
            print("Found existing model. Loading...")
            model = TwoStageClassifier.load('two_stage_model.pkl')
            print("✓ Model loaded successfully")
        else:
            print("Training new model...")
            model = train_model()
            print("✓ Training completed!")
        
        # Test inference
        print("\nTesting inference...")
        test_seq_df = pd.DataFrame({
            'acc_x': np.random.randn(100),
            'acc_y': np.random.randn(100),
            'acc_z': np.random.randn(100),
            'rot_w': np.ones(100),
            'rot_x': np.zeros(100),
            'rot_y': np.zeros(100),
            'rot_z': np.zeros(100)
        })
        test_demo_df = pd.DataFrame({
            'age': [30],
            'adult_child': [1],
            'sex': [0],
            'handedness': [1],
            'height_cm': [170],
            'shoulder_to_wrist_cm': [50],
            'elbow_to_wrist_cm': [30]
        })
        
        test_features = extract_features_from_sequence(test_seq_df, test_demo_df)
        test_pred = model.predict(test_features)[0]
        print(f"Test prediction: {REVERSE_GESTURE_MAPPER[test_pred]}")
        
        print("\n" + "="*70)
        print("To use this model for Kaggle submission:")
        print("1. Upload this notebook to Kaggle")
        print("2. Run it to train the model (or upload pre-trained model as dataset)")
        print("3. The notebook will automatically run inference and generate submission.parquet")
        print("="*70)