In [2]:
from generative_ecg.dataset.load_data import load_dataset
from generative_ecg.train.train_cnn_model import train_discriminator
from generative_ecg.train.train_vae_model import train_vae
from generative_ecg.generate.generate_ecg import generate_and_save_ecgs

from pathlib import Path

ptb_xl_path = r"C:\Users\Aaron Zhang\Desktop\ecg_project\mlscience-ekgs\mlscience_ekgs\Data\raw\ptb-xl"
result_path = r"C:\Users\Aaron Zhang\Desktop\ecg_project\mlscience-ekgs\mlscience_ekgs\Data\results"

# Generative ECG Example

This code notebook is meant to be a reference to using the functionalities of the Generative ECG repository. In general, this code notebook walks through 4 specific functions:
1. load_dataset - this function takes in the path of the dataset and other optional arguments to load the dataset into memory as well as save the data in a user specified directory. 

2. train_discriminator - this function trains a CNN discriminator model that is saved to a user specified path

3. train_vae - this function trains a generative VAE model that will be used to generate the ECGs

4. generate_and_save_ecgs - this function will plot the ECGs through the input VAE, and then save the ecgs to a user specified path

In [None]:
# FINISHED, in process_data
x_signal, y_signal = gen_ecg.load_signals(filepath)
# FINISHED, in segment_ecg
x_peaks = gen_ecg.get_peaks(x_signal, sampling_rate=500)
# FINISHED, in segment_ecg
x_beats, y_beats = gen_ecg.process(x_signal, x_peaks, y_signal, tmax=400)
# FINISHED, in segment_ecg
x_beats, y_beats = gen_ecg.filter_beats(x_beats, y_beats, drop_first=True, drop_last=True, range_min=0.5, sd_min=0.06, autocorr_min=0.75)
# FINISHED, in process_data
x_beats, y_beats = gen_ecg.project(x_beats, y_beats, tol=1e-6)

# Finally, create a custom save & load pair of functions if you want to skip the above work
gen_ecg.save_beat_dataset(x_beats, y_beats, dirpath)
x_beats, y_beats = gen_ecg.load_beat_dataset(dirpath)

load_dataset will automatically return the processed train and test features, but if load_dataset already ran once, processed flag can be used to automatically load the train and test files

In [None]:
X_beats, y_beats = gen_ecg.load_beat_dataset(dirpath)
model = gen_ecg.models.ECGConv(tmax=400, n_channels=12, n_layers_conv=2, n_layers_dense=2, n_outputs=4)
loss_fn = gen_ecg.models.rmse(n_outputs=4)
lr_schedule = optax.warmup_cosine_decay_schedule(
    init_value=1e-3,
    peak_value=1e-2,
    warmup_steps=50,
    decay_steps=50,
    end_value=1e-3
)

state = gen_ecg.train.train_cnn(X_beats, y_beats, model, loss_fn, lr_schedule)

Epoch      2 | RMSE:   25.88020: 100%|██████████| 2/2 [00:51<00:00, 25.83s/it]


Test loss: 29.4504


In [7]:
# del y_tr, X_te, y_te
result_path = Path(result_path)
result = train_vae(X_tr, ckpt_dir, result_path, beat_segment=True, processed=True, n_epochs=5)

Encoder params size: (249968,)
Decoder params size: (351600,)


Epoch 4 average loss: 5.527275562286377: 100%|██████████| 5/5 [11:21<00:00, 136.29s/it]


In [8]:
generate_and_save_ecgs(X_tr, result, result_path, processed=True)

100%|██████████| 5/5 [00:03<00:00,  1.66it/s]
