In [None]:
# Random Forest Feature Extraction & ONNX Pipeline
import os
import sys
import random
import pickle as pkl
from pathlib import Path
import numpy as np
import pandas as pd
from scipy.stats import skew, kurtosis
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import balanced_accuracy_score, f1_score, confusion_matrix
from sklearn.preprocessing import LabelEncoder
from sklearn.utils import compute_class_weight
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
import onnx
import onnxruntime as ort
import matplotlib.pyplot as plt

# Configuration
PROJECT_ROOT = Path("../..")

DATA_DIR = PROJECT_ROOT / 'Data/Train_Data/3_MMExamples'
NORM_PATH = PROJECT_ROOT / f"Normalization_params/Normalization_params_pickle/normalization_params_Right_ver1.pkl"
ONNX_DIR = PROJECT_ROOT / 'Models/Onnx_model'
ONNX_DIR.mkdir(parents=True, exist_ok=True)

TRAIN_PIDS = [10, 100, 101, 102, 103]
TEST_PIDS = [3]
HAND = 'Right'
SR = '16000'
MODEL_VERSION = 1

# Feature configuration
SENSORS = ['Acc', 'Gyro', 'Rotvec']
AXES = ['x', 'y', 'z']
FEATURE_FUNCS = [np.mean, np.std, np.max, np.min, np.median, np.var, skew, kurtosis]

#  Utility Functions
def extract_features(window: np.ndarray) -> np.ndarray:
    """Extract statistical features for each axis in the window."""
    feats = []
    for i in range(window.shape[1]):
        col = window[:, i]
        feats.extend([func(col) for func in FEATURE_FUNCS])
    return np.array(feats)


def load_preprocess(pids, norm_params) -> (pd.DataFrame, pd.Series):
    """Load IMU pickle files, normalize, and extract features."""
    max_, min_, mean_, std_ = [norm_params[k] for k in ('max','min','mean','std')]
    # reshape parameters
    max_, min_, mean_, std_ = [p.reshape(1,1,-1) for p in (max_, min_, mean_, std_)]

    rows, labels = [], []

    for pid in pids:
        folder = DATA_DIR / str(pid) / HAND / SR
        for file in folder.glob('*.pkl'):
            pid_s, activity, trial = file.stem.split('---')
            data = pkl.load(open(file,'rb'))['IMU']
            if data.size == 0:
                continue
            # normalize
            norm = 1 + (data - max_) * 2 / (max_ - min_)
            norm = (norm - mean_) / std_
            # extract per-frame features
            for window in norm:
                feats = extract_features(window)
                rows.append(feats)
                labels.append(activity)

    cols = [f"{s}_{a}_{f.__name__}" for s in SENSORS for a in AXES for f in FEATURE_FUNCS]
    return pd.DataFrame(rows, columns=cols), pd.Series(labels, name='activity')

# Main
if __name__ == '__main__':
    # Load normalization parameters
    norm_params = pkl.load(open(NORM_PATH, 'rb'))

    # Load train and test
    X_train_df, y_train = load_preprocess(TRAIN_PIDS, norm_params)
    X_test_df, y_test = load_preprocess(TEST_PIDS, norm_params)

    # Encode labels
    le = LabelEncoder()
    y_train_enc = le.fit_transform(y_train)
    y_test_enc = le.transform(y_test)

    # Compute class weights
    cw = compute_class_weight('balanced', classes=np.unique(y_train_enc), y=y_train_enc)
    class_wt = dict(enumerate(cw))

    # Train Random Forest
    rf = RandomForestClassifier(n_estimators=100, max_depth=10, random_state=50,
                                class_weight=class_wt)
    rf.fit(X_train_df.values, y_train_enc)

    # Evaluate Scikit-learn model
    y_pred = rf.predict(X_test_df.values)
    ba = balanced_accuracy_score(y_test_enc, y_pred)
    f1w = f1_score(y_test_enc, y_pred, average='weighted')
    print(f"Test Balanced Accuracy: {ba:.4f}")
    print(f"Test Weighted F1 Score: {f1w:.4f}")

    # Plot confusion matrix
    cm = confusion_matrix(y_test_enc, y_pred)
    cm_pct = cm.astype(float) / cm.sum(axis=1)[:,None] * 100
    fig, ax = plt.subplots(figsize=(6,5))
    im = ax.imshow(cm_pct, cmap='Blues', vmin=0, vmax=100)
    fig.colorbar(im, ax=ax, label='Percent')
    ticks = np.arange(len(le.classes_))
    ax.set_xticks(ticks); ax.set_yticks(ticks)
    ax.set_xticklabels(le.classes_, rotation=45, ha='right')
    ax.set_yticklabels(le.classes_)
    thresh = cm_pct.max()/2
    for i in range(cm_pct.shape[0]):
        for j in range(cm_pct.shape[1]):
            color = 'white' if cm_pct[i,j]>thresh else 'black'
            ax.text(j, i, f"{cm_pct[i,j]:.1f}%", ha='center', va='center', color=color)
    plt.tight_layout()
    plt.show()

    # Convert and save ONNX
    initial_type = [('float_input', FloatTensorType([None, X_train_df.shape[1]]))]
    onnx_model = convert_sklearn(rf, initial_types=initial_type, target_opset=8)
    onnx_path = ONNX_DIR / f"random_forest_ver{MODEL_VERSION}.onnx"
    with open(onnx_path, 'wb') as f:
        f.write(onnx_model.SerializeToString())
    print(f"Saved ONNX model to {onnx_path}")

    # ONNX inference test
    sess = ort.InferenceSession(str(onnx_path))
    inp_name = sess.get_inputs()[0].name
    out_name = sess.get_outputs()[0].name
    preds_onnx = sess.run([out_name], {inp_name: X_test_df.values.astype(np.float32)})[0]
    ba2 = balanced_accuracy_score(y_test, preds_onnx)
    f1m = f1_score(y_test, preds_onnx, average='macro')
    print(f"ONNX Test Balanced Accuracy: {ba2:.4f}")
    print(f"ONNX Test Macro F1 Score:    {f1m:.4f}")