In [None]:
import os
import h5py
import numpy as np
import pandas as pd
import tensorflow as tf
import multiprocessing as mp
from ml4c3.tensormap.TensorMap import update_tmaps, PatientData
from ml4c3.datasets import infer_mrn_column

In [None]:
tmaps = {}
update_tmaps('ecg_2500_std_no_pacemaker_180_days_pre_echo', tmaps)
update_tmaps('as_significant_180_days_post_ecg', tmaps)
input_tmaps = [tmaps['ecg_2500_std_no_pacemaker_180_days_pre_echo']]
output_tmaps = [tmaps['as_significant_180_days_post_ecg']]

In [None]:
data_split = "train"
hd5_sources = ['/storage/shared/ecg/mgh']
csv_sources = [('/home/sn69/dropbox/ecgnet-as/data/echo.csv', 'echo')]
patient_ids = set(pd.read_csv('/home/sn69/dropbox/ecgnet-as/data/test.csv')['patientid'])
batch_size = 32
num_workers = 20
augment = False
validate = True
normalize = True

tmaps = input_tmaps + output_tmaps

In [None]:
csv_data = []
for csv_source, csv_name in csv_sources:
    df = pd.read_csv(csv_source, low_memory=False)
    mrn_col = infer_mrn_column(df, csv_source)
    df[mrn_col] = df[mrn_col].dropna().astype(int)
    csv_data.append((csv_name, df, mrn_col))

In [None]:
def get_patient_tensors(patient_id):
    open_hd5s = []
    bad_idxs = []
    tensors = [[] for tm in tmaps]
    try:
        data = PatientData(patient_id=patient_id)
        # Add top level groups in hd5s to patient dictionary
        for hd5_source in hd5_sources:
            hd5_path = os.path.join(hd5_source, f"{patient_id}.hd5")
            if not os.path.isfile(hd5_path):
                continue
            hd5 = h5py.File(hd5_path, "r")
            for key in hd5:
                data[key] = hd5[key]
            open_hd5s.append(hd5)

        # Add rows in csv with patient data accessible in patient dictionary
        for csv_name, df, mrn_col in csv_data:
            mask = df[mrn_col] == patient_id
            if not mask.any():
                continue
            data[csv_name] = df[mask]

        for i, tm in enumerate(tmaps):
            _tensor = tm.tensor_from_file(tm, data)
            if tm.time_series_limit is None:
                _tensor = _tensor[None, ...]

            for j in range(len(_tensor)):
                try:
                    _tensor[j] = tm.postprocess_tensor(
                        tensor=_tensor[j],
                        data=data,
                        augment=augment,
                        validate=validate,
                        normalize=normalize,
                    )
                except Exception as e:
                    bad_idxs.append(j)
            tensors[i] = _tensor

    except Exception as e:
        pass

    for hd5 in open_hd5s:
        hd5.close()
    for i in range(len(tensors)):
        tensors[i] = np.delete(tensors[i], bad_idxs, axis=0)
    return tensors

In [None]:
output_types = [tf.string if tm.is_language else tf.float32 for tm in tmaps]

def wrapped(patient_id):
    tensors = tf.py_function(
        func=get_patient_tensors,
        inp=[patient_id],
        Tout=output_types,
    )
    in_tensors = {tm.input_name: tensors[i] for i, tm in enumerate(input_tmaps)}
    out_tensors = {
        tm.output_name: tensors[i+len(input_tmaps)] for i, tm in enumerate(output_tmaps)
    }
    return tf.data.Dataset.from_tensor_slices((in_tensors, out_tensors))

In [None]:
dataset = tf.data.Dataset.from_tensor_slices(list(patient_ids))
dataset = dataset.flat_map(wrapped)

In [None]:
# %%timeit -n1 -r1
# out = list(dataset.as_numpy_iterator())
# print(len(out))

In [None]:
from ml4c3.datasets import make_dataset
d, stats, cleanup = make_dataset(
    data_split,
    hd5_sources,
    csv_sources,
    patient_ids,
    input_tmaps,
    output_tmaps,
    batch_size,
    num_workers,
    cache=False
)
d = d.unbatch()

In [None]:
%%timeit -n1 -r5
out = list(d.as_numpy_iterator())
print(len(out))

In [None]:
def run_dispatcher(port):
    config = tf.data.experimental.service.DispatcherConfig(port=port)
    d = tf.data.experimental.service.DispatchServer(config)
    d.join()

dispatcher = mp.Process(
    target=run_dispatcher,
    name='dispatcher',
    args=(5050,),
)
dispatcher.start()

In [None]:
def run_worker(dispatcher_address):
    config = tf.data.experimental.service.WorkerConfig(dispatcher_address=dispatcher_address)
    w = tf.data.experimental.service.WorkerServer(config)
    w.join()

workers = []
for i in range(num_workers):
    worker = mp.Process(
        target=run_worker,
        name=f'worker_{i}',
        args=("localhost:5050",)
    )
    worker.start()
    workers.append(worker)

In [None]:
dataset = dataset.apply(tf.data.experimental.service.distribute(
    processing_mode="distributed_epoch", service="grpc://localhost:5050",
))

In [None]:
%%timeit -n1 -r5
out = list(dataset.as_numpy_iterator())
print(len(out))