In [None]:
import os
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns


# ml4cvd imports
from ml4cvd.tensor_generators import test_train_valid_tensor_generators
from ml4cvd.tensor_maps_partners_ecg import TMAPS
%matplotlib inline

In [None]:
# Constants
DATA_DIR = '/data/apollo/hd5_waveforms/'  # Data directory
BATCH_SIZE = 32
TRAIN_RATIO = .7  # train valid test split
VALID_RATIO = .1
TEST_RATIO = .2
NUM_WORKERS = 0  # number of multiprocessing workers. Set to 0 for no multiprocessing
CACHE_SIZE = 1e9  # size in bytes of cache that keeps data in memory after loading and transforming


# These are keys in the TMAPS dictionary. More can be found in ml4cvd.tensor_maps_partners_ecg
tensor_maps_in_keys = 'partners_ecg_5000',
tensor_maps_out_keys = 'pressure_hr', 'pressure_RA_mean_pressure',

tensor_maps_in = [TMAPS[k] for k in tensor_maps_in_keys]
tensor_maps_out = [TMAPS[k] for k in tensor_maps_out_keys]
for tm in tensor_maps_in:
    print(f'TensorMap {tm.name} has shape {tm.shape}')
for tm in tensor_maps_out:
    print(f'TensorMap {tm.name} has shape {tm.shape}')

In [None]:
# Make sure the data exists
len(os.listdir(DATA_DIR))

In [None]:
generate_train, generate_valid, generate_test = test_train_valid_tensor_generators(
    tensor_maps_in=tensor_maps_in,
    tensor_maps_out=tensor_maps_out,
    tensors=DATA_DIR,    
    batch_size=BATCH_SIZE,
    valid_ratio=VALID_RATIO,
    test_ratio=TEST_RATIO,
    test_modulo=0,  # this adds extra data to the test set. Should be left at 0
    num_workers=NUM_WORKERS,
    cache_size=CACHE_SIZE,
    balance_csvs=[],  # this can be used to resample unbalanced data
)

In [None]:
# get an example batch from generate train
input_batch, output_batch, _ = next(generate_train)  # generator yields (input batch, output batch, sample_weights)
print(f'Input data has {input_batch.keys()}.')
print(f'Output data has {output_batch.keys()}.')
print()
for k, v in input_batch.items():
    print(f'{k} has shape {v.shape}')
for k, v in output_batch.items():
    print(f'{k} has shape {v.shape}')

In [None]:
# Plot some data
def plot_ecg(ecg, title):
    rows = ecg.shape[1]
    fig, axes = plt.subplots(rows, 1, figsize=(12, 3 * rows))
    axes[0].set_title(title)
    for i in range(rows):
        axes[i].plot(ecg[:, i])
    plt.show()


for k, v in input_batch.items():
    if len(v.shape) == 2:
        plt.title(k)
        sns.distplot(v)
        plt.show()
    elif len(v.shape) == 3:
        plot_ecg(v[0], k)  # only plot first in batch   

for k, v in output_batch.items():
    if len(v.shape) == 2:
        plt.title(k)
        sns.distplot(v)
        plt.show()