In [None]:
import pandas as pd
import seaborn as sns
from plotnine import aes, geom_line, geom_point, ggplot, labs, theme, theme_bw

from lstm_autoencoder.data.preprocessing import scale_data, train_test_val_split
from lstm_autoencoder.data.simulation import simulate_ecg_data
from lstm_autoencoder.data.windowed_dataset import get_windowed_datasets
from lstm_autoencoder.models.autoencoder import create_prediction, train_lstm_autoencoder

scale = True

## Data Simulation

In [None]:
df = simulate_ecg_data(n_beats=500, fs=50, peak_width_factor=10)
# taking only ecg_amplitude column for training
df = df[["ecg_amplitude"]]

train, val, test = train_test_val_split(df)

len(train)

In [None]:
sns.lineplot(data=train[:200], x=train[:200].index, y="ecg_amplitude")

## Data Preprocessing

In [None]:
if scale:
    scaler_filename = "../data/02_intermediate/scaler.pkl"
    train, val, test = scale_data(train, test, val, scaler_path=scaler_filename)

    sns.lineplot(data=train[:100], x=train[:100].index, y="ecg_amplitude")

In [None]:
prep_params = {
    "window_size": 60,
    "window_shift": 1,
    "split_model_method": "kendall",
    "split_model_th": 0.9,
    "split_model_th_aux": 0.9,
}
tf_train, tf_val, tf_test = get_windowed_datasets(train, val, test, prep_params)

## Model Training

In [None]:
train_params = {
    "batch_size": 256,
    "shuffle": False,
    "min_epochs": 10,
    "max_epochs": 100,
    "train_device": "cpu",
    "train_workers": 1,
    "load_workers": 0,
}

model = train_lstm_autoencoder(
    tf_train.data_windowed,
    tf_val.data_windowed,
    strategy="auto",
    window_size=prep_params["window_size"],
    train_params=train_params,
    save_path="../data/03_models",
    compression_factor=1.25,
)

## Inference & Plotting

In [None]:
df_pred = create_prediction(
    model, tf_test, save_name="test_prediction", save_fig=False, use_averaging=True
)
df_test_vs_pred = pd.concat(
    [
        df_pred[["ecg_amplitude"]].assign(type="actual"),
        df_pred[["ecg_amplitude_pred"]]
        .rename(columns={"ecg_amplitude_pred": "ecg_amplitude"})
        .assign(type="prediction"),
    ]
).sort_index()

# fully reset index for plotting
df_plt_ = df_test_vs_pred[:500].reset_index(drop=True).reset_index()
(
    ggplot(df_plt_, aes(x="index", y="ecg_amplitude", color="type"))
    + geom_line(size=1, alpha=0.8)
    + geom_point(size=1, alpha=0.8)
    + theme_bw()
    + theme(figure_size=(10, 6), legend_position="bottom")
    + labs(title="Actual vs Predicted (averaging)", x="Index", y="Value", color="")
)

In [None]:
df_pred = create_prediction(
    model, tf_test, save_name="test_prediction", save_fig=False, use_averaging=False
)
df_test_vs_pred = pd.concat(
    [
        df_pred[["ecg_amplitude"]].assign(type="actual"),
        df_pred[["ecg_amplitude_pred"]]
        .rename(columns={"ecg_amplitude_pred": "ecg_amplitude"})
        .assign(type="prediction"),
    ]
).sort_index()

# fully reset index for plotting
df_plt_ = df_test_vs_pred[:500].reset_index(drop=True).reset_index()
(
    ggplot(df_plt_, aes(x="index", y="ecg_amplitude", color="type"))
    + geom_line(size=1, alpha=0.8)
    + geom_point(size=1, alpha=0.8)
    + theme_bw()
    + theme(figure_size=(10, 6), legend_position="bottom")
    + labs(title="Actual vs Predicted (last observation)", x="Index", y="Value", color="")
)

In [None]:
figure = (
    ggplot(df_plt_, aes(x="index", y="ecg_amplitude", color="type"))
    + geom_line(size=1)
    + geom_point(size=1)
    + theme_bw()
    + theme(figure_size=(10, 7), legend_position="bottom")
    + labs(title="Actual vs Predicted", x="Index", y="Value", color="")
).draw(show=False)

# saving plot
figure.savefig("../figures/actual_vs_predicted.png", dpi=300)

In [None]:
sns.color_palette()