# Train THOSnet 

In [None]:
import os
import sys
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
import random
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, confusion_matrix, classification_report

# ---- Add these two for plotting ----
import seaborn as sns
import matplotlib.pyplot as plt

###############################################################################
# Suppress TensorFlow logs (only custom prints go to stdout)
###############################################################################
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # Hide TF internal logs
tf.get_logger().setLevel('ERROR')         # Hide TF python warnings

###############################################################################
# GPU Selection
###############################################################################
GPU_INDICES = "0"  # e.g., "0" or "1" or "0,1"
os.environ["CUDA_VISIBLE_DEVICES"] = GPU_INDICES

###############################################################################
# Random Seeds
###############################################################################
def set_random_seeds(seed_value=42):
    np.random.seed(seed_value)
    random.seed(seed_value)
    tf.random.set_seed(seed_value)
    os.environ['PYTHONHASHSEED'] = str(seed_value)

set_random_seeds(42)

###############################################################################
# Enable memory growth for the selected GPU
###############################################################################
physical_gpus = tf.config.list_physical_devices('GPU')
if physical_gpus:
    try:
        for gpu in physical_gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"Memory growth enabled. Visible GPUs: {physical_gpus}")
    except RuntimeError as e:
        print(f"Failed to set memory growth: {e}")
else:
    print("No GPU devices available.")

###############################################################################
# Data Augmentation Utilities
###############################################################################

# The following data augmentation techniques are used in the THOSnet paper.
# You can experiment with different augmentation techniques here.
def add_noise(data, noise_factor=0.0185):
    noise = np.random.randn(*data.shape) * noise_factor
    return np.clip(data + noise, 0, 1)

def scale(data, scale_factor=0.1605):
    scale_val = 1 + np.random.uniform(-scale_factor, scale_factor)
    return np.clip(data * scale_val, 0, 1)

def generate_random_curves(data, sigma=0.078, knot=2):
    from scipy.interpolate import CubicSpline
    xx = np.linspace(0, data.shape[0] - 1, knot + 2)
    yy = np.random.normal(loc=1.0, scale=sigma, size=(knot + 2, data.shape[1]))
    x_range = np.arange(data.shape[0])
    augmented_data = np.zeros_like(data)

    for i in range(data.shape[1]):
        cs = CubicSpline(xx, yy[:, i])
        augmented_data[:, i] = data[:, i] * cs(x_range)

    return np.clip(augmented_data, 0, 1)

def augment_sample(sample):
    sample = add_noise(sample)
    sample = scale(sample)
    sample = generate_random_curves(sample)
    return sample

def augment_dataset(X, y, target_size=10000):
    """Augment the dataset up to 'target_size' samples."""
    from joblib import Parallel, delayed
    num_original = len(X)
    num_augmented = target_size - num_original
    indices = np.random.randint(0, num_original, size=num_augmented)

    augmented_samples = Parallel(n_jobs=-1)(
        delayed(lambda idx: (augment_sample(X[idx]), y[idx]))(idx) 
        for idx in indices
    )

    augmented_X = np.array([s[0] for s in augmented_samples])
    augmented_y = np.array([s[1] for s in augmented_samples])
    return np.vstack((X, augmented_X)), np.vstack((y, augmented_y))

###############################################################################
# Dataset Preparation Utilities
###############################################################################
def split_hand_data(X):
    """
    Splits the input array into left and right hands, each with 63 features.
    The shape of X is assumed to be (samples, frames, 126).
    """
    hand1_data = X[:, :, :63]
    hand2_data = X[:, :, 63:]
    return hand1_data, hand2_data

def load_or_create_subject_datasets(X, y, subject_number, split_indices, data_folder):
    """
    Loads or creates original and augmented data for a given subject.
    If you wish to use the above or other data augmentation techniques with different parameters, simply do not place or delete the processed_datasets folder in root and 
    this function will produce a new processed_datasets folder.
    """
    X_original_path = os.path.join(data_folder, f'X_subject{subject_number}_original.npy')
    y_original_path = os.path.join(data_folder, f'y_subject{subject_number}_original.npy')
    X_aug_path = os.path.join(data_folder, f'X_subject{subject_number}_aug.npy')
    y_aug_path = os.path.join(data_folder, f'y_subject{subject_number}_aug.npy')

    files_exist = all(os.path.exists(p) for p in [X_original_path, y_original_path, X_aug_path, y_aug_path])
    if files_exist:
        print(f"Loading datasets for Subject {subject_number} from disk.")
        X_subject_original = np.load(X_original_path)
        y_subject_original = np.load(y_original_path)
        X_subject_aug = np.load(X_aug_path)
        y_subject_aug = np.load(y_aug_path)
    else:
        print(f"Creating and saving datasets for Subject {subject_number}.")
        start_idx, end_idx = split_indices
        X_subject_original = X[start_idx:end_idx]
        y_subject_original = y[start_idx:end_idx]
        np.save(X_original_path, X_subject_original)
        np.save(y_original_path, y_subject_original)

        X_subject_aug, y_subject_aug = augment_dataset(X_subject_original, y_subject_original, 10000)
        np.save(X_aug_path, X_subject_aug)
        np.save(y_aug_path, y_subject_aug)

    return (X_subject_original, y_subject_original), (X_subject_aug, y_subject_aug)

###############################################################################
# Model Building
###############################################################################
def transformer_decoder_block(inputs, encoder_output, head_size, num_heads, ff_dim,
                             dropout=0.1, name_prefix=""):
    """
    A single Transformer decoder block:
      1) Self-attention
      2) Cross-attention
      3) Feed-forward
    """
    # (1) Self-attention
    x = layers.LayerNormalization(epsilon=1e-6, name=f"{name_prefix}_ln1")(inputs)
    x1 = layers.MultiHeadAttention(
        key_dim=head_size, num_heads=num_heads, dropout=dropout,
        name=f"{name_prefix}_mha1"
    )(x, x)
    x1 = layers.Dropout(dropout, name=f"{name_prefix}_dropout1")(x1)
    x = layers.Add(name=f"{name_prefix}_add1")([x1, inputs])

    # (2) Cross-attention
    x = layers.LayerNormalization(epsilon=1e-6, name=f"{name_prefix}_ln2")(x)
    x2 = layers.MultiHeadAttention(
        key_dim=head_size, num_heads=num_heads, dropout=dropout,
        name=f"{name_prefix}_mha2"
    )(x, encoder_output)
    x2 = layers.Dropout(dropout, name=f"{name_prefix}_dropout2")(x2)
    x = layers.Add(name=f"{name_prefix}_add2")([x2, x])

    # (3) Feed-forward
    x = layers.LayerNormalization(epsilon=1e-6, name=f"{name_prefix}_ln3")(x)
    x_ff = layers.Dense(ff_dim, activation='gelu', name=f"{name_prefix}_dense1_ff")(x)
    x_ff = layers.Dropout(dropout, name=f"{name_prefix}_dropout_ff")(x_ff)
    x_ff = layers.Dense(inputs.shape[-1], name=f"{name_prefix}_dense2_ff")(x_ff)
    x = layers.Add(name=f"{name_prefix}_add_ff")([x_ff, x])

    return x

def build_pretrained_structure(input_shape, head_size, num_heads, ff_dim,
                               dropout=0.1, num_decoders=1, bilstm_unit_size=256):
    """
    Creates a classification model using:
      - BiLSTMs for left/right inputs
      - Two decoders: (A) left->right, (B) right->left
      - Concatenate outputs
      - Flatten -> Dense(128) -> Dense(64) -> Dense(9, softmax)
    """
    left_input = tf.keras.Input(shape=input_shape, name="left_hand_input")
    right_input = tf.keras.Input(shape=input_shape, name="right_hand_input")

    # BiLSTMs
    left_bilstm_output = layers.Bidirectional(
        layers.LSTM(bilstm_unit_size, return_sequences=True),
        name="left_bilstm"
    )(left_input)

    right_bilstm_output = layers.Bidirectional(
        layers.LSTM(bilstm_unit_size, return_sequences=True),
        name="right_bilstm"
    )(right_input)

    # Decoder A: left as query, right as key
    xA = left_bilstm_output
    for i in range(num_decoders):
        xA = transformer_decoder_block(
            xA, right_bilstm_output,
            head_size=head_size, num_heads=num_heads,
            ff_dim=ff_dim, dropout=dropout,
            name_prefix=f"decoderA_block_{i}"
        )

    # Decoder B: right as query, left as key
    xB = right_bilstm_output
    for i in range(num_decoders):
        xB = transformer_decoder_block(
            xB, left_bilstm_output,
            head_size=head_size, num_heads=num_heads,
            ff_dim=ff_dim, dropout=dropout,
            name_prefix=f"decoderB_block_{i}"
        )

    # Concatenate outputs of both decoders
    combined = layers.Concatenate(axis=-1, name="concat_decoders")([xA, xB])
    flat = layers.Flatten(name="flatten")(combined)

    # MLP
    x = layers.Dense(128, activation='gelu', name="mlp_dense_128")(flat)
    x = layers.Dropout(dropout, name="mlp_dropout_128")(x)
    x = layers.Dense(64, activation='gelu', name="mlp_dense_64")(x)
    x = layers.Dropout(dropout, name="mlp_dropout_64")(x)
    output = layers.Dense(9, activation='softmax', name="classification_output")(x)

    model = tf.keras.Model(
        inputs=[left_input, right_input], outputs=output,
        name="THOSnet_with_two_decoders"
    )
    return model

###############################################################################
# Load Data. These are the non_augmented arrays.
###############################################################################
X = np.load('dataset_20240709_X_3096.npy')
y = np.load('dataset_20240709_y_3096.npy')

data_folder = 'processed_datasets'
os.makedirs(data_folder, exist_ok=True)

split_60 = int(0.60 * len(X))
split_80 = int(0.80 * len(X))

(X_subject1, y_subject1), (X_subject1_aug, y_subject1_aug) = load_or_create_subject_datasets(
    X, y, subject_number=1, split_indices=(0, split_60), data_folder=data_folder
)
(X_subject2, y_subject2), (X_subject2_aug, y_subject2_aug) = load_or_create_subject_datasets(
    X, y, subject_number=2, split_indices=(split_60, split_80), data_folder=data_folder
)
(X_subject3, y_subject3), (X_subject3_aug, y_subject3_aug) = load_or_create_subject_datasets(
    X, y, subject_number=3, split_indices=(split_80, len(X)), data_folder=data_folder
)

subject_datasets = {
    1: {"aug": (X_subject1_aug, y_subject1_aug), "original": (X_subject1, y_subject1)},
    2: {"aug": (X_subject2_aug, y_subject2_aug), "original": (X_subject2, y_subject2)},
    3: {"aug": (X_subject3_aug, y_subject3_aug), "original": (X_subject3, y_subject3)}
}

###############################################################################
# Hyperparameters and Subject Splits
# In the below arrays, you can try different hyperparameter settings to tune the THOSnet. Refer to the original paper for the best performing parameters. 
# The nested for loops below perform a full grid search over the hyperparameters.
###############################################################################
input_shape_hand = (30, 63)

bilstm_unit_sizes = [256]
head_sizes = [256,128]
num_heads_list = [8]
ff_dims = [64]
dropouts = [0.3]
learning_rates = np.linspace(2e-4, 5e-4, 3)
batch_sizes = [128]
epochs = 200
num_decoders_list = [1]

combinations_subjects = [
    {'train_subjects': [1, 2], 'test_subject': 3},
    {'train_subjects': [1, 3], 'test_subject': 2},
    {'train_subjects': [2, 3], 'test_subject': 1}
]

###############################################################################
# Prepare data with an option to use original-only or original+aug
###############################################################################
def prepare_datasets(train_subjects, test_subject, use_augmentation=True):
    """
    If use_augmentation is True, training includes original+augmented data.
    If use_augmentation is False, training includes only the original data.
    Validation & test always use original data of the test subject.
    """
    X_train_full = []
    y_train_full = []

    # For each subject in train_subjects:
    for subj in train_subjects:
        X_orig, y_orig = subject_datasets[subj]["original"]
        X_aug,  y_aug  = subject_datasets[subj]["aug"]

        # Decide whether to combine augmented or not
        if use_augmentation:
            X_both = np.vstack((X_orig, X_aug))
            y_both = np.vstack((y_orig, y_aug))
        else:
            X_both = X_orig
            y_both = y_orig

        X_train_full.append(X_both)
        y_train_full.append(y_both)

    # Concatenate everything from the chosen training subjects
    X_train_full = np.vstack(X_train_full)
    y_train_full = np.vstack(y_train_full)

    # For validation & test: ONLY the original data of the test subject
    X_test_full, y_test_full = subject_datasets[test_subject]["original"]
    X_val_full, X_test_full, y_val_full, y_test_full = train_test_split(
        X_test_full, y_test_full, test_size=0.5, random_state=42
    )

    # Split each set into left-hand & right-hand features
    X_train_hand1, X_train_hand2 = split_hand_data(X_train_full)
    X_val_hand1,   X_val_hand2   = split_hand_data(X_val_full)
    X_test_hand1,  X_test_hand2  = split_hand_data(X_test_full)

    return (X_train_hand1, X_train_hand2, y_train_full,
            X_val_hand1,   X_val_hand2,   y_val_full,
            X_test_hand1,  X_test_hand2,  y_test_full)

###############################################################################
# Run Training/Evaluation (single GPU)
###############################################################################
def main():
    for bilstm_unit_size in bilstm_unit_sizes:
        for head_size in head_sizes:
            for num_heads in num_heads_list:
                for ff_dim in ff_dims:
                    for dropout in dropouts:
                        for learning_rate in learning_rates:
                            for batch_size in batch_sizes:
                                for num_decoders in num_decoders_list:

                                    print("\n======================================")
                                    print(f"Testing hyperparameters: "
                                          f"bilstm_unit_size={bilstm_unit_size}, "
                                          f"head_size={head_size}, num_heads={num_heads}, "
                                          f"ff_dim={ff_dim}, dropout={dropout}, "
                                          f"lr={learning_rate:.5f}, batch_size={batch_size}, "
                                          f"epochs={epochs}, num_decoders={num_decoders}")
                                    print("======================================\n")

                                    test_losses = []
                                    test_accuracies = []
                                    f1_scores = []
                                    val_accuracies = []

                                    for idx, combo in enumerate(combinations_subjects):
                                        print(f"\nRunning combination {idx+1}: "
                                              f"Train on {combo['train_subjects']}, "
                                              f"Test on {combo['test_subject']}")

                                        # You can toggle `use_augmentation` to True/False here
                                        (X_train_hand1, X_train_hand2, y_train,
                                         X_val_hand1,   X_val_hand2,   y_val,
                                         X_test_hand1,  X_test_hand2,  y_test
                                        ) = prepare_datasets(
                                            combo['train_subjects'], 
                                            combo['test_subject'],
                                            use_augmentation=True  # <-- Toggle as needed
                                        )

                                        thosnet_model = build_pretrained_structure(
                                            input_shape_hand,
                                            head_size=head_size,
                                            num_heads=num_heads,
                                            ff_dim=ff_dim,
                                            dropout=dropout,
                                            num_decoders=num_decoders,
                                            bilstm_unit_size=bilstm_unit_size
                                        )

                                        for layer in thosnet_model.layers:
                                            layer.trainable = True

                                        optimizer = tf.keras.optimizers.Adam(learning_rate=0.0)
                                        thosnet_model.compile(
                                            loss="categorical_crossentropy",
                                            optimizer=optimizer,
                                            metrics=["accuracy"]
                                        )

                                        def lr_schedule(epoch):
                                            warmup_epochs = 20
                                            total_epochs = epochs
                                            init_lr = learning_rate
                                            min_lr = 1e-8

                                            if epoch < warmup_epochs:
                                                # linear warmup
                                                lr_current = init_lr * (epoch + 1) / warmup_epochs
                                            else:
                                                # linear decay
                                                decay_epochs = total_epochs - warmup_epochs
                                                lr_current = init_lr - (init_lr - min_lr) \
                                                            * (epoch - warmup_epochs + 1) / decay_epochs
                                                lr_current = max(lr_current, min_lr)
                                            return lr_current

                                        from tensorflow.keras.callbacks import LearningRateScheduler
                                        callbacks = [LearningRateScheduler(lr_schedule, verbose=0)]

                                        train_dataset = tf.data.Dataset.from_tensor_slices(
                                            ((X_train_hand1, X_train_hand2), y_train)
                                        ).shuffle(10000).batch(batch_size).prefetch(tf.data.AUTOTUNE)

                                        val_dataset = tf.data.Dataset.from_tensor_slices(
                                            ((X_val_hand1, X_val_hand2), y_val)
                                        ).batch(batch_size).prefetch(tf.data.AUTOTUNE)

                                        # Train
                                        history = thosnet_model.fit(
                                            train_dataset,
                                            epochs=epochs,
                                            validation_data=val_dataset,
                                            callbacks=callbacks,
                                            verbose=0
                                        )

                                        val_acc_final = history.history['val_accuracy'][-1]
                                        val_accuracies.append(val_acc_final)

                                        # Test
                                        test_dataset = tf.data.Dataset.from_tensor_slices(
                                            ((X_test_hand1, X_test_hand2), y_test)
                                        ).batch(batch_size).prefetch(tf.data.AUTOTUNE)

                                        y_pred_probs = thosnet_model.predict(test_dataset, verbose=0)
                                        y_pred = np.argmax(y_pred_probs, axis=1)
                                        y_true = np.argmax(y_test, axis=1)

                                        from sklearn.metrics import confusion_matrix, classification_report

                                        cm = confusion_matrix(y_true, y_pred)
                                        test_loss, test_accuracy = thosnet_model.evaluate(
                                            test_dataset, verbose=0
                                        )
                                        f1 = f1_score(y_true, y_pred, average='weighted')

                                        # ===== PLOT THE CONFUSION MATRIX =====
                                        plt.figure(figsize=(7, 5))
                                        sns.heatmap(cm, annot=True, cmap="Blues", fmt='g')
                                        plt.title(f'Confusion Matrix (Train {combo["train_subjects"]} -> Test {combo["test_subject"]})')
                                        plt.xlabel('Predicted')
                                        plt.ylabel('True')
                                        plt.show()

                                        # ===== PRINT A CLASSIFICATION REPORT (precision, recall, f1) =====
                                        print("\nClassification Report:")
                                        print(classification_report(
                                            y_true, y_pred,
                                            digits=4,
                                            target_names=[f"Gesture_{i}" for i in range(9)]
                                        ))

                                        print(f"Final Validation Accuracy: {val_acc_final:.4f}")
                                        print(f"Test Loss: {test_loss:.4f}, "
                                              f"Test Accuracy: {test_accuracy:.4f}, "
                                              f"F1 Score: {f1:.4f}")

                                        test_losses.append(test_loss)
                                        test_accuracies.append(test_accuracy)
                                        f1_scores.append(f1)

                                    # Summaries over the 3 subject combos
                                    avg_test_loss = np.mean(test_losses)
                                    avg_test_accuracy = np.mean(test_accuracies)
                                    avg_f1_score = np.mean(f1_scores)
                                    avg_val_acc = np.mean(val_accuracies)

                                    print("\n*** AVERAGED RESULTS FOR THIS HYPERPARAMETER SET ***")
                                    print(f"Average Test Loss: {avg_test_loss:.4f}")
                                    print(f"Average Test Accuracy: {avg_test_accuracy:.4f}")
                                    print(f"Average F1 Score: {avg_f1_score:.4f}")
                                    print(f"Average Validation Accuracy (over 3 combos): {avg_val_acc:.4f}")
                                    print("****************************************************\n")

if __name__ == "__main__":
    main()
