In [None]:
# Hybrid RF + TFLite Inference Pipeline
import pickle
from pathlib import Path

import numpy as np
import tensorflow as tf
import onnxruntime as ort
from scipy.stats import skew, kurtosis
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import balanced_accuracy_score, f1_score, confusion_matrix
import matplotlib.pyplot as plt
import itertools

# Configure threading
TF_THREADS = 14
tf.config.threading.set_intra_op_parallelism_threads(TF_THREADS)
tf.config.threading.set_inter_op_parallelism_threads(TF_THREADS)

# Paths & models
DATA_ROOT = Path('../../Participant')
RF6_ONNX = Path('../../random_forest_6class_ver36.onnx')
TFLITE_MODEL = Path('../../quan_multimodal_cnn_ver36.tflite')
NORM_PATH = Path('../../normalization_params_Right_ver36.pkl')

# Label encoders
ENC6 = LabelEncoder().fit([
    'Other','Shower','Tooth_brushing',
    'Washing_hands','Wiping','Vacuum_Cleaner'
])
ENC2 = LabelEncoder().fit(['Other','Event'])

# Load ONNX sessions
ses6 = ort.InferenceSession(str(RF6_ONNX))
in6 = ses6.get_inputs()[0].name; out6 = ses6.get_outputs()[0].name

# Load TFLite interpreter
interp = tf.lite.Interpreter(model_path=str(TFLITE_MODEL), num_threads=4)
interp.allocate_tensors()
inputs = interp.get_input_details(); outputs = interp.get_output_details()
idx_imu = inputs[0]['index']; idx_audio = inputs[1]['index']; idx_out = outputs[0]['index']

# Load normalization
norm = pickle.load(open(NORM_PATH,'rb'))
pm = norm['max'].reshape(1,1,-1).astype(np.float32)
pn = norm['min'].reshape(1,1,-1).astype(np.float32)
mu = norm['mean'].reshape(1,1,-1).astype(np.float32)
sd = norm['std'].reshape(1,1,-1).astype(np.float32)

# Feature extraction for RF
def extract_stats(window: np.ndarray) -> np.ndarray:
    feats = []
    for col in range(window.shape[1]):
        v = window[:,col]
        feats += [v.mean(), v.std(), v.max(), v.min(),
                  np.median(v), v.var(), skew(v), kurtosis(v)]
    return np.array(feats, dtype=np.float32).reshape(1,-1)

# Single-sample inference
def infer_sample(imu_win: np.ndarray, audio_win: np.ndarray) -> str:
    # RF decision
    f = extract_stats(imu_win)
    if rf6:= ses6.run([out6], {in6: f})[0][0][0] != 0:
        rf_lbl = ENC6.inverse_transform([int(rf6)])[0]
        is_other = False
    else:
        rf_lbl = 'Other'; is_other = True
    # if other, skip TFLite
    if is_other:
        return 'Other'
    # normalize imu
    imu_n = ((1 + (imu_win - pm)*2/(pm-pn)) - mu)/sd
    imu_n = imu_n.astype(np.float32)[None,...]
    audio_n = audio_win.astype(np.float32)[None,...]
    # resize
    interp.resize_tensor_input(idx_imu, imu_n.shape)
    interp.resize_tensor_input(idx_audio, audio_n.shape)
    interp.allocate_tensors()
    interp.set_tensor(idx_imu, imu_n)
    interp.set_tensor(idx_audio, audio_n)
    interp.invoke()
    out = interp.get_tensor(idx_out)
    t_idx = int(np.argmax(out, axis=1)[0])
    return ENC6.inverse_transform([t_idx])[0]

# Batch over dataset
all_true, all_pred = [], []
for pid_dir in sorted(DATA_ROOT.iterdir()):
    if not pid_dir.is_dir(): continue
    pid = pid_dir.name
    data = pickle.load(open(pid_dir/f'{pid}_preprocessing.pkl','rb'))
    IMU = data['IMU']; Audio = data['Audio']
    if Audio.ndim==3: Audio = Audio[...,None]
    for imu_w, aud_w, true_lbl in zip(IMU, Audio, data['Activity']):
        pred_lbl = infer_sample(imu_w, aud_w)
        all_true.append(true_lbl); all_pred.append(pred_lbl)

# Evaluate
labels = list(ENC6.classes_)
ba = balanced_accuracy_score(all_true, all_pred)
f1 = f1_score(all_true, all_pred, average='weighted')
print(f'Overall BA: {ba:.4f}, F1: {f1:.4f}')

# Confusion
cm = confusion_matrix(all_true, all_pred, labels=labels)
cm_pct = cm.astype(float)/cm.sum(axis=1)[:,None]*100
fig, ax = plt.subplots(figsize=(8,6))
im = ax.imshow(cm_pct, cmap='Blues', vmin=0, vmax=100)
for i,j in itertools.product(range(len(labels)), repeat=2):
    color='white' if cm_pct[i,j]>50 else 'black'
    ax.text(j,i,f'{cm_pct[i,j]:.1f}%', ha='center', va='center', color=color)
ax.set_xticks(range(len(labels))); ax.set_xticklabels(labels, rotation=45)
ax.set_yticks(range(len(labels))); ax.set_yticklabels(labels)
ax.set_xlabel('Pred'); ax.set_ylabel('True')
plt.title('Confusion Matrix (%)'); plt.colorbar(im, ax=ax)
plt.tight_layout(); plt.show()
