In [None]:
import os
import random
import math
from pathlib import Path

import numpy as np
import pickle as pkl
import tensorflow as tf
from tensorflow.keras.models import load_model, clone_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

# Configuration
DATA_ROOT = Path("../Data/Experiment_Data/3_PreprocessDataset_Oversample")
BASE_MODEL_PATH = Path("../Models/tensorflow_model/MultiModal/MultiModal_ver1/Right/MM_Scratch.h5")
LABEL_BINARIZER = Path("../LabelBinarizer/Label_binarizer_6_classes.pkl")
NORM_PARAMS = Path("../Normalization_params/Normalization_params_pickle/normalization_params_Right_ver1.pkl")
SAVE_BASE = Path("../Models/Finetuned_Model/Finetune_model_6class_ver3")

STRATEGIES = ["full", "some", "one"]
TRAIN_SEC_LIST = [10, 30] + list(range(60, 481, 60))
NOISE_SEC, VAL_SEC = 3, 10
EPOCHS = 50
BATCH_SIZE = 32
CLASSES = ["Shower","Tooth_brushing","Washing_hands","Vacuum_Cleaner","Wiping","Other"]
SEED = 4

# Utility Functions

def compute_num_frames(duration_sec: float, window_length: float = 2.0, hop_length: float = 0.2):
    """Compute number of overlapping frames for given duration."""
    if duration_sec < window_length:
        return 0
    return 1 + int(math.floor((duration_sec - window_length) / hop_length))

def split_train_val_idxs(labels: np.ndarray,
                         train_frames: int,
                         val_frames: int,
                         noise_frames: int):
    """
    For each class, split indices into training and validation segments.
    """
    train_idxs, val_idxs = [], []
    for cls in CLASSES:
        cls_positions = np.where(labels == cls)[0]
        if cls_positions.size == 0:
            continue
        # find contiguous segments
        diffs = np.diff(cls_positions)
        breaks = np.where(diffs != 1)[0]
        seg_starts = np.concatenate(([0], breaks + 1))
        seg_ends   = np.concatenate((breaks, [len(cls_positions) - 1]))

        collected = []
        for start, end in zip(seg_starts, seg_ends):
            segment = cls_positions[start:end+1]
            if segment.size > noise_frames:
                collected.append(segment[noise_frames:])
        if collected:
            combined = np.concatenate(collected)
        else:
            combined = np.array([], dtype=int)

        subset = combined[:train_frames + val_frames]
        train_idxs.append(subset[:train_frames])
        val_idxs.append(subset[train_frames:])

    train_idxs = np.sort(np.concatenate(train_idxs)) if train_idxs else np.array([], dtype=int)
    val_idxs   = np.sort(np.concatenate(val_idxs)) if val_idxs else np.array([], dtype=int)
    return train_idxs, val_idxs

def set_trainable_layers(model: tf.keras.Model, strategy: str):
    """Set layer.trainable based on fine-tuning strategy."""
    if strategy == 'full':
        for layer in model.layers:
            layer.trainable = True
    elif strategy == 'some':
        for layer in model.layers:
            layer.trainable = not any(key in layer.name for key in ['conv', 'batch_normalization', 'pool'])
    elif strategy == 'one':
        for layer in model.layers:
            layer.trainable = False
        model.layers[-1].trainable = True
    else:
        raise ValueError(f"Unknown strategy: {strategy}")

def normalize_imu_data(x: np.ndarray, norm: dict):
    """Apply min-max to [-1,1] then standardize IMU data."""
    pm, pn = norm['max'], norm['min']
    mean, std = norm['mean'], norm['std']
    scaled = 1 + (x - pm) * 2 / (pm - pn)
    return ((scaled - mean) / std).astype('float32')

# Main Fine-tuning Pipeline
if __name__ == '__main__':
    # Reproducibility
    random.seed(SEED)
    np.random.seed(SEED)
    tf.random.set_seed(SEED)

    # Load base model, label binarizer, normalization params
    base_model = load_model(BASE_MODEL_PATH)
    with open(LABEL_BINARIZER, 'rb') as f:
        lb = pkl.load(f)
    with open(NORM_PARAMS, 'rb') as f:
        norm = pkl.load(f)

    # Precompute frames for noise and validation
    noise_frames = compute_num_frames(NOISE_SEC)
    val_frames = compute_num_frames(VAL_SEC)

    for train_sec in TRAIN_SEC_LIST:
        train_frames = compute_num_frames(train_sec)
        for pid_dir in sorted(DATA_ROOT.iterdir()):
            if not pid_dir.is_dir():
                continue
            pid = pid_dir.name
            print(f"\n=== Fine-tuning {pid}, train_sec={train_sec} ===")

            # Load preprocessed data
            with open(pid_dir / f"{pid}_preprocessing.pkl", 'rb') as f:
                data = pkl.load(f)
            imu_data = data['IMU']
            audio_data = data['Audio']
            labels = data['Activity']

            # Split train/val indices
            train_idxs, val_idxs = split_train_val_idxs(labels, train_frames, val_frames, noise_frames)

            # Prepare datasets
            X_imu_train = normalize_imu_data(imu_data[train_idxs], norm)
            X_imu_val   = normalize_imu_data(imu_data[val_idxs], norm)
            X_audio_train = audio_data[train_idxs].astype('float32')
            X_audio_val   = audio_data[val_idxs].astype('float32')
            y_train = lb.transform(labels[train_idxs])
            y_val   = lb.transform(labels[val_idxs])

            # Shuffle training data
            perm = np.random.permutation(len(train_idxs))
            X_imu_train = X_imu_train[perm]
            X_audio_train = X_audio_train[perm]
            y_train = y_train[perm]

            # Iterate strategies
            for strategy in STRATEGIES:
                print(f"-- Strategy: {strategy}")
                model = clone_model(base_model)
                model.set_weights(base_model.get_weights())
                set_trainable_layers(model, strategy)

                # Dummy call for build
                _ = model([tf.zeros((1,)+X_imu_train.shape[1:]), tf.zeros((1,)+X_audio_train.shape[1:])])
                model.compile(optimizer=Adam(1e-4),
                              loss='categorical_crossentropy', metrics=['accuracy'])

                # Callbacks
                out_dir = SAVE_BASE / strategy / pid / f"train{train_sec}"
                out_dir.mkdir(parents=True, exist_ok=True)
                checkpoint = ModelCheckpoint(str(out_dir / 'best.h5'), save_best_only=True, monitor='val_loss')
                earlystop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

                # Fine-tune
                model.fit(
                    [X_imu_train, X_audio_train], y_train,
                    validation_data=([X_imu_val, X_audio_val], y_val),
                    epochs=EPOCHS, batch_size=BATCH_SIZE,
                    callbacks=[checkpoint, earlystop], verbose=2
                )

                # Convert and save TFLite
                converter = tf.lite.TFLiteConverter.from_keras_model(model)
                converter.optimizations = [tf.lite.Optimize.DEFAULT]
                tflite_model = converter.convert()
                (out_dir / 'model.tflite').write_bytes(tflite_model)

            print(f"Completed fine-tuning for {pid} (train_sec={train_sec})")