In [None]:
from generative_ecg.dataset import load_signals, get_peaks, segment, filter_beats, project, save_beat_dataset, load_beat_dataset
from generative_ecg.train import train_cnn
import tqdm
import jax

from pathlib import Path

ptb_xl_path = r"C:\Users\Aaron Zhang\Desktop\College\Senior Year\generative_ecg\examples\ptb_xl\data"
result_path = r"C:\Users\Aaron Zhang\Desktop\College\Senior Year\generative_ecg\examples\ptb_xl\results"

x_signals, y_signals = load_signals(filepath=ptb_xl_path, sampling_rate=500)
x_beats, y_beats = [], []

x_beats = jax.numpy.zeros((0, 9, 400))
y_beats = jax.numpy.zeros((0,))

for i, x_signal in enumerate(tqdm.tqdm(x_signals, desc="Processing and Filtering Beats")):
    try:
        y_signal = y_signals[i]
        x_peaks = get_peaks(x_signal, sampling_rate=500)
        x_seg, x_windows = segment(x_signal, x_peaks, tmax=400)
        x_filter, y_filter = filter_beats(x_seg, y_signal, x_windows, x_peaks, drop_first=True, drop_last=True, range_min=0.5, sd_min=0.06)

        x_proj, y_proj = project(x_filter, y_filter, tol=1e-6)
        if x_proj.shape[0] != 0:
            x_beats = jax.numpy.concatenate([x_beats, x_proj], axis=0)
            y_beats = jax.numpy.concatenate([y_beats, y_proj], axis=0)
        
    except:
        continue

print(x_beats.shape, y_beats.shape)
# Finally, create a custom save & load pair of functions if you want to skip the above work
save_beat_dataset(x_beats, y_beats, filepath=result_path)

Loading data from records: 100%|██████████| 21799/21799 [03:28<00:00, 104.64it/s]
  mrrs /= th2
  mrrs /= th2
  mrrs /= th2
  mrrs /= th2
  mrrs /= th2
  mrrs /= th2
  mrrs /= th2
  mrrs /= th2
  mrrs /= th2
  mrrs /= th2
  mrrs /= th2
  warn(
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  warn(
Processing and Filtering Beats: 100%|██████████| 21799/21799 [30:19<00:00, 11.98it/s]


(34064, 9, 400) (34064,)


# Generative ECG Example

In [1]:
from generative_ecg.models import ECGConv, rmse_loss
from generative_ecg.dataset import load_signals, get_peaks, segment, filter_beats, project, save_beat_dataset, load_beat_dataset
from generative_ecg.train import train_cnn
import optax
import sklearn

ptb_xl_path = r"C:\Users\Aaron Zhang\Desktop\College\Senior Year\generative_ecg\examples\ptb_xl\data"
result_path = r"C:\Users\Aaron Zhang\Desktop\College\Senior Year\generative_ecg\examples\ptb_xl\results"

x_beats, y_beats = load_beat_dataset(filepath=result_path)

model = ECGConv(output_dim=1)
loss_fn = rmse_loss
lr_schedule = optax.warmup_cosine_decay_schedule(
    init_value=1e-3,
    peak_value=1e-2,
    warmup_steps=50,
    decay_steps=1000,
    end_value=1e-3
)

X_tr, X_te, y_tr, y_te = sklearn.model_selection.train_test_split(x_beats, y_beats, random_state=42)

In [2]:
# state = train_cnn(X_tr, X_te, y_tr, y_te, model, loss_fn, lr_schedule, ckpt_dir=result_path + "/cnn_model_checkpoint/", batch_size=64, n_epochs=100)

# Load CNN, train VAE

In [None]:
from generative_ecg.train import create_cnn_train_state, train_vae
from generative_ecg.models import ECGConv
import orbax.checkpoint

model_params = {
    "beta1": 1.0,
    "beta2": 0.0,
    "z_dim": 512,
    "hidden_width": 100,
    "hidden_depth": 4,
    "lr_init": 1e-7,
    "lr_peak": 1e-4,
    "lr_end": 1e-7,
    "beta1_scheduler": "warmup_cosine",
    "target": "age",
    "n_channels": 12,
    "beat_segment": False,
    "processed": False,
    "seed": 0,
    "batch_size": 512,
    "n_epochs": 5,
    "encoder_type": "cnn",
    "use_bias": False
}

state_disc = create_cnn_train_state(X_tr)
ckptr = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler())
state_disc = ckptr.restore(
    result_path + "/cnn_model_checkpoint/", item=state_disc
)
model = ECGConv(output_dim=1)
vae_pred_fn = lambda x: model.apply(state_disc.params, x)

result = train_vae(X_tr, y_tr, vae_pred_fn, model_params, lr_schedule, ckpt_dir=result_path + "/vae_model_checkpoint/")

Encoder params size: (288368,)
Decoder params size: (441600,)


Epoch 4 average loss 9.4886: 100%|██████████| 5/5 [13:01<00:00, 156.20s/it] 


In [None]:
from generative_ecg.generate import generate_ecgs
from generative_ecg.train import load_vae_from_ckpt

gen_params = {
    "seed": 0,
    "n_ecgs": 10,
    "z_dim": 512,
    "processed": True,
    "n_channels": 12,
    "find_closest_real": False,
    "std": None,
    "title": "ECG",
    "ylim": None,
}

result = load_vae_from_ckpt(X_tr, model_params, ckpt_dir=result_path + "/vae_model_checkpoint/")
generate_ecgs(X_tr, result, gen_params, save_dir=result_path + "/generated_ecgs/")

[-0.03318268 -0.03163719 -0.02997451 ...  0.20374052  0.06817751
 -0.01388908]


  0%|          | 0/10 [00:00<?, ?it/s]


TypeError: 'NoneType' object is not callable