In [None]:
from emg_fatigue.utils.load_emg_data import load_all_participant_data
from emg_fatigue.utils.process_emg_data import process_all_participant_data
from emg_fatigue.utils.create_loocv_dataset import create_loocv_dataset
from emg_fatigue.modeling.train import train_model
from emg_fatigue.modeling.evaluate import evaluate_model
from emg_fatigue.modeling.build.rnn_model import build_lstm_model
from emg_fatigue.plots.visualize_model_predictions import visualize_model_predictions
from emg_fatigue.config import BATCH_SIZE, PADDING_VALUE

In [None]:
raw_data = load_all_participant_data()
processed_data = process_all_participant_data(raw_data)


# --- Create LOO-CV Datasets ---
all_participant_ids = list(processed_data.keys())
train_ids = all_participant_ids[:-3] 
val_ids = all_participant_ids[-3:-1]
test_ids = [all_participant_ids[-1]]

train_ds, val_ds, test_ds, input_shape, output_shape, norm_mean, norm_std = create_loocv_dataset(
    processed_data,
    train_participant_ids=train_ids,
    validation_participant_ids=val_ids,
    test_participant_ids=test_ids,
    batch_size=BATCH_SIZE,
    padding_value=PADDING_VALUE,
    normalize=True
)



In [None]:
lstm_model = build_lstm_model(input_shape, output_shape, padding_value=PADDING_VALUE)

model_name = "lstm_fatigue_model"
train_model(lstm_model, train_ds, val_ds, model_name, epochs=100)
evaluate_model(lstm_model, test_ds, model_name)

In [None]:
visualize_model_predictions(
    model=lstm_model,
    processed_data=processed_data,
    test_participant_ids=test_ids,
    input_shape=input_shape,
    norm_mean=norm_mean,
    norm_std=norm_std
)