In [1]:
import os
import sys
import numpy as np
import tensorflow as tf
from data_reader import data as D
from terminaltables import AsciiTable
from edfpy.channel import lowpass_resample
from keras.utils.np_utils import to_categorical
from sklearn.model_selection import train_test_split

# Patch keras
import inspect
cat_arglist = inspect.getargspec(to_categorical).args
if 'nb_classes' in cat_arglist:
    to_cat_args = { 'nb_classes': 6 }
elif 'num_classes' in cat_arglist:
    to_cat_args = { 'num_classes': 6 }


kw_get_samples = {
    'dtype': np.float32,
    'resample': lowpass_resample
}

Using TensorFlow backend.
  del sys.path[0]


In [2]:
def get_data(I, channel, sr, dt):
    m = int(dt*sr)
    X = np.empty((len(I), m), dtype=np.float32)
    for i, a in enumerate(I):
        dset = D.get_dataset(a.file)
        X[i, :] = dset[channel].get_physical_samples(a.t, dt, sr,
                **kw_get_samples)
    return X

In [3]:
# load annotations index
I = np.load('./index.npy').view(np.recarray)

In [4]:
files = np.unique(I.file)
I_split = [
    I[I.file==f]
    for f in files
]

In [5]:
def floats_feature(arr):
    return tf.train.Feature(float_list=tf.train.FloatList(value=arr.flatten().tolist()))

def int_feature(i):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[i]))

In [7]:
for I in I_split:
    fname = I.file[0]
    print('converting', fname)
    output_fname = './tfr-database/'+os.path.splitext(os.path.basename(fname))[0]+'.tfrecords'
    
    if os.path.exists(output_fname):
        print(f'{output_fname} already exists')
        continue
    
    features = np.swapaxes(np.array([get_data(I, channel, D.sr, D.dt) for channel in D.channels]), 0, 1)
    target = D.encode(I.event)

    writer = tf.python_io .TFRecordWriter(output_fname)
    for i, (np_features, np_target) in enumerate(zip(features, target)):
        tf_features_map = {
            ch: floats_feature(x)
            for ch, x in zip(D.channels, np_features)
        }
        tf_features_map['target'] = int_feature(np_target)
        tf_features = tf.train.Features(feature=tf_features_map)
        tf_example = tf.train.Example(features=tf_features)
        
        writer.write(tf_example.SerializeToString())
    writer.close()

converting /home/jtschw2/Data/capslpdb/brux1.edf
converting /home/jtschw2/Data/capslpdb/brux2.edf
./tfr-database/brux2.tfrecords already exists
converting /home/jtschw2/Data/capslpdb/ins1.edf
./tfr-database/ins1.tfrecords already exists
converting /home/jtschw2/Data/capslpdb/ins2.edf
./tfr-database/ins2.tfrecords already exists
converting /home/jtschw2/Data/capslpdb/ins3.edf
./tfr-database/ins3.tfrecords already exists
converting /home/jtschw2/Data/capslpdb/ins4.edf
./tfr-database/ins4.tfrecords already exists
converting /home/jtschw2/Data/capslpdb/ins5.edf
./tfr-database/ins5.tfrecords already exists
converting /home/jtschw2/Data/capslpdb/ins6.edf
./tfr-database/ins6.tfrecords already exists
converting /home/jtschw2/Data/capslpdb/ins7.edf
./tfr-database/ins7.tfrecords already exists
converting /home/jtschw2/Data/capslpdb/ins8.edf
./tfr-database/ins8.tfrecords already exists
converting /home/jtschw2/Data/capslpdb/ins9.edf
./tfr-database/ins9.tfrecords already exists
converting /home/jt