In [None]:
from emg_fatigue.utils.load_emg_data import load_all_participant_data
from emg_fatigue.utils.process_emg_data import process_all_participant_data
from emg_fatigue.utils.create_loocv_dataset import create_loocv_dataset
from emg_fatigue.modeling.train import train_model, fine_tune_model
from emg_fatigue.modeling.evaluate import evaluate_model
from emg_fatigue.modeling.build.rnn_model import build_lstm_model
from emg_fatigue.modeling.build.mlp_model import build_mlp_model
from emg_fatigue.modeling.build.transformer_model import build_transformer_model
from emg_fatigue.modeling.build.gru_model import build_gru_model
from emg_fatigue.plots.visualize_model_predictions import visualize_model_predictions
from emg_fatigue.config import BATCH_SIZE, PADDING_VALUE

In [None]:
raw_data = load_all_participant_data()
processed_data = process_all_participant_data(participant_data=raw_data)

# --- Define parameters ---
# If N > 1, the N-th recording becomes fine-tune validation
NUM_FINE_TUNING_RECS = 3 # e.g., Use 1 for FT-train, 1 for FT-val
INITIAL_EPOCHS = 1
FINE_TUNE_EPOCHS = 1 # Might need more epochs if FT val set is small
FINE_TUNE_LR = 1e-5


# --- Create LOO-CV Datasets ---
all_participant_ids = list(processed_data.keys())
train_ids = all_participant_ids[:-3]
val_ids = all_participant_ids[-3:-1]
test_ids = [all_participant_ids[-1]]

(train_ds,
 val_ds,
 fine_tune_train_ds, # Updated name
 fine_tune_val_ds,   # New dataset
 test_ds,
 input_shape,
 output_shape,
 norm_mean,
 norm_std,
 test_recording_indices) = create_loocv_dataset(
    processed_data=processed_data,
    train_participant_ids=train_ids,
    validation_participant_ids=val_ids,
    test_participant_ids=test_ids,
    batch_size=BATCH_SIZE,
    padding_value=PADDING_VALUE,
    normalize=True,
    augment=True, # Augmentation still applies only to train_ds
    num_fine_tuning_recordings_per_subject=NUM_FINE_TUNING_RECS
)

# Check if datasets were created successfully
if train_ds is None or val_ds is None or test_ds is None:
    raise RuntimeError("Failed to create one or more essential datasets.")

print(f"Fine-tuning train dataset created: {fine_tune_train_ds is not None}")
print(f"Fine-tuning validation dataset created: {fine_tune_val_ds is not None}")




In [None]:
# model, model_name = build_lstm_model(input_shape=input_shape, test_id=test_ids, padding_value=PADDING_VALUE)

model, model_name = build_mlp_model(input_shape=input_shape, test_id=test_ids, padding_value=PADDING_VALUE)

# model, model_name = build_transformer_model(input_shape=input_shape, test_id=test_ids, padding_value=PADDING_VALUE)

# model, model_name = build_gru_model(input_shape=input_shape, test_id=test_ids, padding_value=PADDING_VALUE)


# --- Initial Training ---
print(f"--- Starting Initial Training for {model_name} ---")
initial_history = train_model(
    model=model,
    train_ds=train_ds,
    val_ds=val_ds, # Use main validation set for initial training
    model_name=model_name,
    epochs=INITIAL_EPOCHS
)

if initial_history is None:
    raise RuntimeError(f"Initial training failed for {model_name}")

# --- Fine-tuning (Optional) ---
fine_tuned_model_name_for_eval = model_name
if fine_tune_train_ds is not None:
    print(f"--- Starting Fine-tuning for {model_name} ---")

    fine_tune_history = fine_tune_model(
        model=model,
        fine_tune_train_ds=fine_tune_train_ds, # Pass FT train set
        fine_tune_val_ds=fine_tune_val_ds,     # Pass FT validation set
        model_name=model_name,
        fine_tune_epochs=FINE_TUNE_EPOCHS,
        fine_tune_lr=FINE_TUNE_LR
    )
    if fine_tune_history is None:
        print(f"Warning: Fine-tuning failed or was skipped for {model_name}. Proceeding with initial model.")
    else:
        print("Fine-tuning complete. Using fine-tuned model.")
        if model_name.endswith((".h5", ".keras")):
             base_name = model_name.rsplit('.', 1)[0]
        else:
             base_name = model_name
        fine_tuned_model_name_for_eval = f"{base_name}_finetuned"

else:
    print(f"--- Skipping Fine-tuning (no fine_tune_train_ds available) for {model_name} ---")

# --- Evaluation ---
print(f"--- Evaluating final model: {fine_tuned_model_name_for_eval} ---")
final_metrics = evaluate_model(model=model, test_ds=test_ds, model_name=fine_tuned_model_name_for_eval)
print(f"Final Evaluation Metrics: {final_metrics}")

In [None]:
# --- Visualize Predictions ---
print(f"--- Visualizing predictions for model: {fine_tuned_model_name_for_eval} ---")
visualize_model_predictions(
    model=model,
    model_name=fine_tuned_model_name_for_eval,
    processed_data=processed_data,
    test_participant_ids=test_ids,
    input_shape=input_shape,
    norm_mean=norm_mean,
    norm_std=norm_std,
    test_recording_indices=test_recording_indices # Pass the indices
)