In [1]:
import tensorflow as tf
from time import time
import numpy
import pickle
import pandas
import os
import json
import datetime
import sys
from functools import partial, reduce

sys.path.append('../libs')
import initialize
import prepare_data
import data_pipeline
import conv_model

tf.debugging.set_log_device_placement(True)

In [2]:
from initialize import RESP_SCALE
from data_pipeline import get_windows, get_window_index_matrix, filter_datum

In [3]:
%%time

H = json.load(open('../hypes.json'))
%time metadata = initialize.load_metadata(H)
%time sig_data = initialize.load_sig_data(H, metadata)
partition = initialize.load_partition(H, metadata)
initialize.describe_data_size(metadata)

CPU times: user 12.1 s, sys: 1.82 s, total: 13.9 s
Wall time: 14 s
CPU times: user 32.8 s, sys: 7.76 s, total: 40.6 s
Wall time: 40.6 s
711 GB,  132 years,  182431 record segments
CPU times: user 46.2 s, sys: 10.2 s, total: 56.4 s
Wall time: 56.4 s


In [10]:
def calculate_chunks_per_record(H, rec_count):
    chunk_count = H['epochs'] * H['steps_per_epoch'] * H['batch_size'] / H['windows_per_chunk']
    chunk_count *= 5
    return round(chunk_count / rec_count)

def sample_segments(replace, n, data):
    m = data.shape[0]
    if not replace and n > m:
        data = data.iloc[[i for i in range(m) for j in range(n // m + 1)]]
    data = data.sample(n=n, replace=replace)
    return data

def get_chunk_paths(data):
    paths = data.reset_index()[['rec_id', 'segment', 'chunk_id']].values
    rec_ids = paths[:, 0].astype('a7')
    segs = numpy.char.zfill(paths[:, 1].astype('a4'), 4)
    chunk_ids = numpy.char.zfill(paths[:, 2].astype('a4'), 4)
    root = str.encode(prepare_data.ROOT_SERIAL)
    paths = [root, rec_ids, b'_', segs, b'_', chunk_ids, b'.tfrec']
    paths = reduce(numpy.char.add, paths).astype(str)
    return paths

def sample_data(H, data):
    data.at[:, 'chunk_count'] = [prepare_data.get_chunk_count(i) for i in data['sig_len']]
    rec_count = len(data.index.remove_unused_levels().levels[0])
    chunks_per_record = calculate_chunks_per_record(H, rec_count)
    sample_segs = partial(sample_segments, H['sample_with_replacement'], chunks_per_record)
    data = data.groupby(level=0).apply(sample_segs).droplevel(0)
    data.at[:, 'chunk_index'] = [i for j in range(rec_count) for i in range(chunks_per_record)]
    data = data.reset_index().set_index(['rec_id', 'segment', 'chunk_index'], verify_integrity=True)
    data.sort_index(inplace=True)
    if H['sample_with_replacement']:
        data.at[:, 'chunk_id'] = [numpy.random.randint(i) for i in data['chunk_count']]
    else:
        data.at[:, 'chunk_id'] = range(data.shape[0])
        data['chunk_id'] %= data['chunk_count']
    
    data.at[:, 'chunk_path'] = get_chunk_paths(data)
    I, J = range(data.shape[0]), range(H['windows_per_chunk'])
    data = data.iloc[[i for i in I for j in J]]
    data.at[:, 'window_index'] = [j for i in I for j in J]
    data.at[:, 'window_id'] = numpy.random.randint(
        low = H['window_size'] * RESP_SCALE,
        high = prepare_data.CHUNK_SIZE,
        size = data.shape[0]
    )
    data = data.reset_index().set_index(['rec_id', 'chunk_index', 'window_index'], verify_integrity=True)
    data.sort_index(inplace=True)
    return data

def dataframe_to_tensors(H, data):
    window_indices = data['window_id'].unstack(-1).values
    data = data.loc[(slice(None), slice(None), 0), :]
    S = H['input_sigs'] + H['output_sigs']
    n = data.shape[0]
    chunk_paths = data['chunk_path'].values
    sig_indices = data['sig_index'][S].values
    baselines   = data['baseline'][S].values
    gains       = data['adc_gain'][S].values
    I = numpy.random.permutation(n)
    tensors = (
        tf.constant(chunk_paths[I],    dtype='string',  shape=[n]),
        tf.constant(sig_indices[I],    dtype='int8',    shape=(n, len(S))),
        tf.constant(window_indices[I], dtype='int32',   shape=(n, H['windows_per_chunk'])),
        tf.constant(baselines[I],      dtype='int32',   shape=(n, len(S))),
        tf.constant(gains[I],          dtype='float32', shape=(n, len(S))),
    )
    return tensors

def build_pipeline(H, tensors):
    dataset = tf.data.Dataset.from_tensor_slices(tensors)
    window_index_matrix = get_window_index_matrix(H)
    dataset = dataset.interleave(
        partial(get_windows, H, window_index_matrix), 
        block_length=1, 
        cycle_length=H['batch_buffer_size'] * H['batch_size'],
        num_parallel_calls=tf.data.experimental.AUTOTUNE,
    )
    buffer_size = H['batch_buffer_size'] * H['batch_size'] * H['windows_per_chunk']
    if H['filter_data']:
        dataset = dataset.filter(filter_datum)
    dataset = dataset.shuffle(buffer_size).batch(H['batch_size'])
    return dataset

In [11]:
%%time

H_ = {**H, 'sample_with_replacement': False, 'windows_per_chunk': 2}
data = sample_data(H_, sig_data)

CPU times: user 50.7 s, sys: 4.52 s, total: 55.2 s
Wall time: 55.2 s


In [32]:
tensors = dataframe_to_tensors(H_, data)
dataset = build_pipeline(H_, tensors)

Executing op ParallelInterleaveDatasetV2 in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op FilterDataset in device /job:localhost/replica:0/task:0/device:CPU:0


In [26]:
data[:50]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,segment,sig_index,sig_index,sig_index,sig_index,sig_index,sig_index,baseline,baseline,baseline,...,adc_gain,adc_gain,adc_gain,adc_gain,adc_gain,sig_len,chunk_count,chunk_id,chunk_path,window_id
Unnamed: 0_level_1,Unnamed: 1_level_1,sig_name,Unnamed: 3_level_1,ABP,AVR,II,PLETH,RESP,V,ABP,AVR,II,...,AVR,II,PLETH,RESP,V,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1
rec_id,chunk_index,window_index,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2
3000003,0,0,9,3,0,1,0,0,2,-100,0,0,...,0.0,29.0,0.0,0.0,14.0,877500,13,8,/scr-ssd/mimic/waveforms/3000003_0009_0008.tfrec,44032
3000003,0,1,9,3,0,1,0,0,2,-100,0,0,...,0.0,29.0,0.0,0.0,14.0,877500,13,8,/scr-ssd/mimic/waveforms/3000003_0009_0008.tfrec,64512
3000003,0,2,9,3,0,1,0,0,2,-100,0,0,...,0.0,29.0,0.0,0.0,14.0,877500,13,8,/scr-ssd/mimic/waveforms/3000003_0009_0008.tfrec,62464
3000003,0,3,9,3,0,1,0,0,2,-100,0,0,...,0.0,29.0,0.0,0.0,14.0,877500,13,8,/scr-ssd/mimic/waveforms/3000003_0009_0008.tfrec,22528
3000003,0,4,9,3,0,1,0,0,2,-100,0,0,...,0.0,29.0,0.0,0.0,14.0,877500,13,8,/scr-ssd/mimic/waveforms/3000003_0009_0008.tfrec,19456
3000003,0,5,9,3,0,1,0,0,2,-100,0,0,...,0.0,29.0,0.0,0.0,14.0,877500,13,8,/scr-ssd/mimic/waveforms/3000003_0009_0008.tfrec,5120
3000003,0,6,9,3,0,1,0,0,2,-100,0,0,...,0.0,29.0,0.0,0.0,14.0,877500,13,8,/scr-ssd/mimic/waveforms/3000003_0009_0008.tfrec,24576
3000003,0,7,9,3,0,1,0,0,2,-100,0,0,...,0.0,29.0,0.0,0.0,14.0,877500,13,8,/scr-ssd/mimic/waveforms/3000003_0009_0008.tfrec,63488
3000003,0,8,9,3,0,1,0,0,2,-100,0,0,...,0.0,29.0,0.0,0.0,14.0,877500,13,8,/scr-ssd/mimic/waveforms/3000003_0009_0008.tfrec,23552
3000003,0,9,9,3,0,1,0,0,2,-100,0,0,...,0.0,29.0,0.0,0.0,14.0,877500,13,8,/scr-ssd/mimic/waveforms/3000003_0009_0008.tfrec,23552


In [52]:
val_counts = data.loc[(slice(None), slice(None), 0), 'chunk_path'].value_counts()
val_counts[val_counts == 21]
rec_ids_1 = sorted({i.split('/')[-1].split('_')[0] for i in val_counts.index if val_counts.loc[i] == 21})

In [66]:
set(sig_data.sort_index().loc[[int(i) for i in rec_ids_1]]['chunk_count'])

{1}

In [53]:
rec_ids_1 == rec_ids_2

True

In [30]:
data.loc[(slice(None), slice(None), 0), 'chunk_path'].value_counts().mean()

1.175630512679162

In [8]:
l1 = len(data)
l2 = len(data[['chunk_id', 'window_index']].reset_index().drop_duplicates())
l3 = len(data['chunk_id'].reset_index().drop_duplicates())
print(l1, l2, l3)

4094580 3585092 328993


In [23]:
%%time

H_ = {**H, 'sample_with_replacement': True, 'windows_per_chunk': 10}
data_wr = sample_data(H_, sig_data)

CPU times: user 1min 28s, sys: 3.27 s, total: 1min 31s
Wall time: 1min 31s


In [51]:
val_counts = data_wr.loc[(slice(None), slice(None), 0), 'chunk_path'].value_counts()
val_counts[val_counts == 21]
rec_ids_2 = sorted({i.split('/')[-1].split('_')[0] for i in val_counts.index if val_counts.loc[i] == 21})

In [31]:
data_wr.loc[(slice(None), slice(None), 0), 'chunk_path'].value_counts().mean()

1.3776352710804864

In [27]:
data_wr[:50]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,segment,sig_index,sig_index,sig_index,sig_index,sig_index,sig_index,baseline,baseline,baseline,...,adc_gain,adc_gain,adc_gain,adc_gain,adc_gain,sig_len,chunk_count,chunk_id,chunk_path,window_id
Unnamed: 0_level_1,Unnamed: 1_level_1,sig_name,Unnamed: 3_level_1,ABP,AVR,II,PLETH,RESP,V,ABP,AVR,II,...,AVR,II,PLETH,RESP,V,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1
rec_id,chunk_index,window_index,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2
3000003,0,0,8,3,0,1,0,0,2,-100,0,0,...,0.0,29.0,0.0,0.0,14.0,217500,3,2,/scr-ssd/mimic/waveforms/3000003_0008_0002.tfrec,22528
3000003,0,1,8,3,0,1,0,0,2,-100,0,0,...,0.0,29.0,0.0,0.0,14.0,217500,3,2,/scr-ssd/mimic/waveforms/3000003_0008_0002.tfrec,21504
3000003,0,2,8,3,0,1,0,0,2,-100,0,0,...,0.0,29.0,0.0,0.0,14.0,217500,3,2,/scr-ssd/mimic/waveforms/3000003_0008_0002.tfrec,41984
3000003,0,3,8,3,0,1,0,0,2,-100,0,0,...,0.0,29.0,0.0,0.0,14.0,217500,3,2,/scr-ssd/mimic/waveforms/3000003_0008_0002.tfrec,11264
3000003,0,4,8,3,0,1,0,0,2,-100,0,0,...,0.0,29.0,0.0,0.0,14.0,217500,3,2,/scr-ssd/mimic/waveforms/3000003_0008_0002.tfrec,53248
3000003,0,5,8,3,0,1,0,0,2,-100,0,0,...,0.0,29.0,0.0,0.0,14.0,217500,3,2,/scr-ssd/mimic/waveforms/3000003_0008_0002.tfrec,50176
3000003,0,6,8,3,0,1,0,0,2,-100,0,0,...,0.0,29.0,0.0,0.0,14.0,217500,3,2,/scr-ssd/mimic/waveforms/3000003_0008_0002.tfrec,29696
3000003,0,7,8,3,0,1,0,0,2,-100,0,0,...,0.0,29.0,0.0,0.0,14.0,217500,3,2,/scr-ssd/mimic/waveforms/3000003_0008_0002.tfrec,57344
3000003,0,8,8,3,0,1,0,0,2,-100,0,0,...,0.0,29.0,0.0,0.0,14.0,217500,3,2,/scr-ssd/mimic/waveforms/3000003_0008_0002.tfrec,44032
3000003,0,9,8,3,0,1,0,0,2,-100,0,0,...,0.0,29.0,0.0,0.0,14.0,217500,3,2,/scr-ssd/mimic/waveforms/3000003_0008_0002.tfrec,43008


In [6]:
l1 = len(data_wr)
l2 = len(data_wr[['chunk_id', 'window_index']].reset_index().drop_duplicates())
l3 = len(data_wr['chunk_id'].reset_index().drop_duplicates())
print(l1, l2, l3)

4094580 3505198 297218


In [225]:
index = ['rec_id', 'segment', 'chunk_id', 'window_index_']
data[:1000].reset_index().set_index(index).unstack(-1)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,sig_index,sig_index,sig_index,sig_index,sig_index,sig_index,sig_index,sig_index,sig_index,sig_index,...,window_index,window_index,window_index,window_index,window_index,window_index,window_index,window_index,window_index,window_index
Unnamed: 0_level_1,Unnamed: 1_level_1,sig_name,ABP,ABP,ABP,ABP,ABP,ABP,ABP,ABP,ABP,ABP,...,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1
Unnamed: 0_level_2,Unnamed: 1_level_2,window_index_,0,1,2,3,4,5,6,7,8,9,...,990,991,992,993,994,995,996,997,998,999
rec_id,segment,chunk_id,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3,Unnamed: 17_level_3,Unnamed: 18_level_3,Unnamed: 19_level_3,Unnamed: 20_level_3,Unnamed: 21_level_3,Unnamed: 22_level_3,Unnamed: 23_level_3
3000003,7,4,,,,,,,,,,,...,,,,,,,,,,
3000003,7,5,,,,,,,,,,,...,,,,,,,,,,
3000003,7,6,,,,,,,,,,,...,,,,,,,,,,
3000003,7,9,,,,,,,,,,,...,,,,,,,,,,
3000003,7,11,,,,,,,,,,,...,,,,,,,,,,
3000003,7,12,,,,,,,,,,,...,,,,,,,,,,
3000003,8,0,,,,,,,,,,,...,,,,,,,,,,
3000003,8,1,,,,,,,,,,,...,,,,,,,,,,
3000003,8,2,,,,,,,,,,,...,,,,,,,,,,
3000003,9,0,,,,,,,,,,,...,,,,,,,,,,


In [166]:
len(data['chunk_id'].reset_index().drop_duplicates(['rec_id', 'segment', 'chunk_id']))

1441292

In [164]:
len(data['chunk_id'].reset_index().drop_duplicates(['rec_id', 'segment', 'chunk_id']))

1463741

In [165]:
len(data)

7994180

In [154]:
%%time

H_ = {**H, 'sample_with_replacement': False}
data = sample_data(H_, sig_data)

CPU times: user 1min 22s, sys: 6.74 s, total: 1min 29s
Wall time: 1min 29s


In [141]:
%%time

H_ = {**H, 'sample_with_replacement': False}
data = sample_data(H_, sig_data)

CPU times: user 3min 18s, sys: 5.51 s, total: 3min 23s
Wall time: 3min 23s


In [21]:
def sample_data(H, sig_data):
    n = calculate_examples_per_record(H, sig_lens.index)
    if H['sample_with_replacement']:
        sample_segments = partial(sample_segments_with_replacement, n)
    else:
        sample_segments = partial(sample_segments_without_replacement, n)
        
    sig_data = sig_data.groupby(level=0).apply(sample_segments).droplevel(0)
    sig_data = sig_data.reindex(sig_lens.index)
    chunk_ids = sig_lens.groupby(level=[0, 1]).apply(sample_chunk_ids)
            
    chunk_paths = sig_lens.reset_index().apply(sample_chunk_path, axis=1)
    
    S = H['input_sigs'] + H['output_sigs']
    sig_indices = sig_data['sig_index'][S].values
    baselines   = sig_data['baseline'][S].values
    gains       = sig_data['adc_gain'][S].values
    
    
    chunk_counts = sig_lens.apply(prepare_data.get_chunk_count)
    chunk_indices = chunk_counts.apply(numpy.random.randint)
    chunk_coords = zip(shuffled['rec_id'], shuffled['segment'], chunk_indices)
    chunk_path = prepare_data.ROOT_SERIAL + '{}_{}_{}.tfrec'
    chunk_paths = [
        chunk_path.format(i, str(j).zfill(4), str(k).zfill(4)) 
        for i, j, k in chunk_coords
    ]
    
    window_indices = numpy.random.randint(
        low = RESP_SCALE * H['window_size'],
        high = prepare_data.CHUNK_SIZE,
        size = [len(chunk_indices), H['windows_per_chunk']]
    )
    
    epoch = {
        'chunk_paths': chunk_paths,
        'window_indices': window_indices,
        'sig_indices': sig_indices,
        'baselines': baselines,
        'gains': gains
    }
    
    data = {'train': epoch, 'validation': epoch}
    
    return data
                    
def build_pipeline(H, tensors):
    dataset = tf.data.Dataset.from_tensor_slices(tensors)
    window_index_matrix = get_window_index_matrix(H)
    dataset = dataset.interleave(
        partial(get_windows, H, window_index_matrix), 
        block_length=1, 
        cycle_length=H['batch_buffer_size'] * H['batch_size'],
        num_parallel_calls=tf.data.experimental.AUTOTUNE,
    )
    buffer_size = H['batch_buffer_size'] * H['batch_size'] * H['windows_per_chunk']
    if H['filter_data']:
        dataset = dataset.filter(filter_datum)
    dataset = dataset.shuffle(buffer_size).batch(H['batch_size'])
    return dataset

In [13]:
%%time

dataframes = sample_data(H, metadata, sig_data)

166665 5905
CPU times: user 1.39 s, sys: 20 ms, total: 1.41 s
Wall time: 1.4 s


In [22]:
%%time

dataset = {
    'train': build_pipeline(H, dataframes['train']),
    'validation': build_pipeline(H, dataframes['validation']),
}

Executing op ParallelInterleaveDatasetV2 in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op FilterDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op ParallelInterleaveDatasetV2 in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op FilterDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op DeleteRandomSeedGenerator in device /job:localhost/replica:0/task:0/device:CPU:0
CPU times: user 256 ms, sys: 12 ms, total: 268 ms
Wall time: 261 ms
