In [1]:
import numpy as np
import pickle
import tensorflow as tf

from argparse import Namespace

In [2]:
sess = tf.InteractiveSession()

In [None]:
sess.close()

In [3]:
options = {'seqs_file': 4,
           'max_v': 6,
           'max_t': 5,
           'n_codes': 7,
           'n_labels': 4,
           'seqs_file': 'seqs',
           'labels_file': 'labs',
           'demo_file': 'demo',
           'out_file': 'zipped_TFR'}
args = Namespace(**options)

In [164]:
def load_data(args=args):
    """Replace later with dataset stuff."""
    seqs_file = args.seqs_file
    if args.labels_file is not None:
        labels_file = args.labels_file
    seqs = pickle.load(open(seqs_file, 'rb'))
    labs = None
    if args.labels_file is not None:
        labels_file = args.labels_file
        labs = pickle.load(open(labels_file, 'rb'))
    D_t = None
    demo_dim = 0
    if args.demo_file is not None:
        demo_file = args.demo_file
        D_t = pickle.load(open(demo_file, 'rb'))
        demo_dim = D_t.shape[-1]
        demo = tf.constant(D_t, dtype=tf.float32)
    return seqs, labs, demo, demo_dim


def fill_visit(visit, args=args):
    """Fill all deficit visits with -2.

    Ensure that all visits have the same number of ICDs for efficient
    tensor logic. If a visit has fewer ICDs, filler ICDs get one-hot
    encoded as the zero vector, so that they affect nothing.

    visit: a list of integer medical codes

    Note: No visit in training or testing should have more than max_v
          visits.
    """
    max_v = args.max_v
    if visit != [-1]:
        new_visit = []
        new_visit.extend(visit)
        n_icd = len(visit)
        deficit = max_v - n_icd
        new_visit.extend([-2] * deficit)
        return new_visit


def fill_patient(patient, mask_batch, args=args):
    """Ensure that all patients have max_t visits.

    Create visits full of -2s, which are one-hot encoded as zero
    vectors. This makes all patients commensurate for efficient tensor
    logic.

    patient: list of list of integer codes
    max_t: the number of visits all patients ought to have

    Note: No patient in training or test data should have more
          than max_t visits.
    """
    max_t = args.max_t
    max_v = args.max_v
    new_patient = []
    new_patient.extend(patient)
    new_mask_batch = mask_batch
    t = len(new_patient)
    deficit = (max_t - t)
    new_patient.extend([[-2] * max_v] * deficit)
    new_mask_batch.append([[0] * max_v] * deficit)
    return new_patient, new_mask_batch, t


def tensorize_seqs(seqs, args=args, true_seqs=True):
    """Convert med2vec to tensorflow data.

    seqs: list of list. cf  https://github.com/mp2893/med2vec
    true_seqs: bool. Are we tensorizing the true sequences? If false,
               we are tonsorizing labels.
    returns:
        patients: tensor with shape [patients, max_t, max_v, |C|]
                  or [patients, max_t, max_v, n_labels] if true_seqs is
                  False.
        row_masks: numpy array with shape [patients, max_t, max_v]
               Later, we will create a [patients, max_t, max_v, |C|]
               tensor where the [p, t, i, j] entry is p(c_j|c_i).
               Row_masks will drop the rows where c_i is the zero
               vector--that is, an NA ICD.

               A separate mask, col_mask, will be created from
               patients in order to mask, for each t, those j for
               which c_j did not appear in visit t, as well as
               p(c_i|c_i).

               The masks are to be applied in reverse order of creation.
               col_mask is applied with tf.multiply and row_masks
               with tf.boolean_mask to avoid needless reshaping.
        patients_ts: numpy array with shape [patients,] containing the
                     number of true visits for each patient.
    """
    patients = []
    new_patient = []
    row_masks = []
    mask_batch = []
    patients_ts = []
    for visit in seqs + [[-1]]:
        if visit != [-1]:
            visit = fill_visit(visit, args)
            new_patient.append(visit)
        else:
            new_patient, mask_batch, t = fill_patient(new_patient,
                                                      mask_batch,
                                                      args)
            patients.append(new_patient)
            if true_seqs:
                patients_ts.append(t)
                row_masks.append(mask_batch)
                mask_batch = []
            new_patient = []
    patients = tf.constant(patients)
    patients_ts = tf.constant(patients_ts)
    patients_ts = tf.expand_dims(patients_ts, -1)
    patients_ts = tf.expand_dims(patients_ts, -1)
    row_masks = tf.not_equal(patients, -2)
    return patients, row_masks, patients_ts

In [244]:
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def serialize_with_labels(patient, label, demo, row_mask, patient_t):
    """
    Creates a tf.Example message ready to be written to a file.
    """

    # Create a dictionary mapping the feature name to the tf.Example-compatible
    # data type.

    feature = {'patients': _bytes_feature(patient),
               'labels': _bytes_feature(label),
               'demo': _bytes_feature(demo),
               'row_mask': _bytes_feature(row_mask),
               'patient_t': _bytes_feature(patient_t)}

    # Create a Features message using tf.train.Example.

    example_proto = tf.train.Example(features=tf.train
                                                .Features(feature=feature))
    return example_proto.SerializeToString()



def tf_serialize_w_labels(patient, label, demo, row_mask, patient_t):
    """Map serialize_with_labels to tf.data.Dataset."""
    tf_string = tf.py_func(serialize_with_labels,
                           (patient, label, demo, row_mask, patient_t),
                           tf.string)
    return tf.reshape(tf_string, ())


def _parse_function(example_proto):
    # Parse the input tf.Example proto using the dictionary below.

    feature_description = {
    'patients': tf.FixedLenFeature([], tf.string, default_value=''),
    'labels': tf.FixedLenFeature([], tf.string, default_value=''),
    'demo': tf.FixedLenFeature([], tf.string, default_value=''),
    'row_masks': tf.FixedLenFeature([], tf.string, default_value=''),
    'patients_ts': tf.FixedLenFeature([], tf.string, default_value='')
    }
    return tf.parse_single_example(example_proto, feature_description)


In [245]:
seqs, labs, demo, demo_dim = load_data()

In [246]:
patients, row_masks, patients_ts = tensorize_seqs(seqs, true_seqs=True)
labels, _, _ = tensorize_seqs(labs, true_seqs=False)

In [247]:
output.output_types

(tf.int32, tf.int32, tf.float32, tf.bool, tf.float32)

In [248]:
patients_ds = tf.data.Dataset().from_tensor_slices(patients)
labels_ds = tf.data.Dataset().from_tensor_slices(labels)
demo_ds = tf.data.Dataset().from_tensor_slices(demo)
row_masks_ds = tf.data.Dataset().from_tensor_slices(row_masks)
patients_ts_ds = tf.data.Dataset().from_tensor_slices(patients_ts)

In [249]:
"""output = tf.data.Dataset().from_tensor_slices({
    'patients': patients,
    'labels': labels,
    'demo': demo,
    'row_masks': row_masks,
    'patients_ts': patients_ts
    })"""

"output = tf.data.Dataset().from_tensor_slices({\n    'patients': patients,\n    'labels': labels,\n    'demo': demo,\n    'row_masks': row_masks,\n    'patients_ts': patients_ts\n    })"

In [250]:
output = tf.data.Dataset().from_tensor_slices((patients,
    labels,
    demo,
    row_masks,
    patients_ts
    ))

In [251]:
output.output_types

(tf.int32, tf.int32, tf.float32, tf.bool, tf.float32)

In [252]:
output.output_shapes

(TensorShape([Dimension(5), Dimension(6)]),
 TensorShape([Dimension(5), Dimension(6)]),
 TensorShape([Dimension(5), Dimension(6)]),
 TensorShape([Dimension(5), Dimension(6)]),
 TensorShape([Dimension(1), Dimension(1)]))

In [253]:
output_it = output.make_one_shot_iterator()

In [254]:
print(sess.run(output_it.get_next()))

(array([[ 5,  4, -2, -2, -2, -2],
       [ 2,  3,  4, -2, -2, -2],
       [-2, -2, -2, -2, -2, -2],
       [-2, -2, -2, -2, -2, -2],
       [-2, -2, -2, -2, -2, -2]]), array([[ 3,  2, -2, -2, -2, -2],
       [ 0,  3,  2, -2, -2, -2],
       [-2, -2, -2, -2, -2, -2],
       [-2, -2, -2, -2, -2, -2],
       [-2, -2, -2, -2, -2, -2]]), array([[-1.2766646 , -8.258818  ,  1.0295945 , -9.932409  ,  3.4238288 ,
         4.7623577 ],
       [ 4.781394  , -2.7135177 ,  8.019953  ,  0.30808517,  0.31495225,
        -3.8640842 ],
       [ 0.3880088 , -6.512215  ,  3.4989486 ,  2.7263956 , -1.5292437 ,
        -0.5698157 ],
       [ 4.377598  , -1.8317424 ,  2.9549975 , -1.0864103 ,  4.6243834 ,
        -0.11334902],
       [-7.137146  ,  0.8752901 , -4.1808968 ,  3.2075074 ,  4.62191   ,
        -1.7289466 ]], dtype=float32), array([[ True,  True, False, False, False, False],
       [ True,  True,  True, False, False, False],
       [False, False, False, False, False, False],
       [False, False

In [255]:
serialized = output.map(lambda patients, labels, demo, row_masks, patients_ts: (tf.serialize_tensor(patients),
                                           tf.serialize_tensor(labels),
                                           tf.serialize_tensor(demo),
                                           tf.serialize_tensor(row_masks),
                                           tf.serialize_tensor(patients_ts)))

In [256]:
serialized_it = serialized.make_one_shot_iterator()

In [263]:
print(sess.run(serialized_it.get_next()))

(b'\x08\x03\x12\x08\x12\x02\x08\x05\x12\x02\x08\x06"x\x05\x00\x00\x00\x06\x00\x00\x00\x04\x00\x00\x00\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\x00\x00\x00\x00\x03\x00\x00\x00\x06\x00\x00\x00\x05\x00\x00\x00\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff', b'\x08\x03\x12\x08\x12\x02\x08\x05\x12\x02\x08\x06"x\x03\x00\x00\x00\x03\x00\x00\x00\x02\x00\x00\x00\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\x00\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x02\x00\x00\x00\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff

In [258]:
serialized_features_dataset = serialized.map(tf_serialize_w_labels)

In [259]:
serialized_features_dataset.output_types

tf.string

In [260]:
serialized_features_dataset.output_shapes

TensorShape([])

In [261]:
sfd_it = serialized_features_dataset.make_one_shot_iterator()

In [98]:
writer = tf.contrib.data.TFRecordWriter(args.out_file)
writeop = writer.write(serialized_features_dataset)

In [99]:
sess.run(writeop)

## Reading

In [152]:
filenames = ['zipped_TFR']
raw_dataset = tf.data.TFRecordDataset(filenames)

In [153]:
it_raw = raw_dataset.make_one_shot_iterator()

In [119]:
print(sess.run(it_raw.get_next()))

b'\n\xe7\x05\n\xea\x01\n\x08patients\x12\xdd\x01\n\xda\x01\n\xd7\x01\x08\x03\x12\x08\x12\x02\x08\x05\x12\x02\x08\n"\xc8\x01\x05\x00\x00\x00\x04\x00\x00\x00\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\x02\x00\x00\x00\x03\x00\x00\x00\x04\x00\x00\x00\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\n\xe8\x01\n\x06labels\x12\xdd\x01\n\xda\x01\n\xd7\x01\x08\x03\x12\x08\x12\x0

In [103]:
print(sess.run(_parse_function(it_raw.get_next())))

{'demo': b'\x08\x03\x12\x08\x12\x02\x08\x05\x12\x02\x08\x06"x\x03\x00\x00\x00\x03\x00\x00\x00\x02\x00\x00\x00\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\x00\x00\x00\x00\x03\x00\x00\x00\x03\x00\x00\x00\x02\x00\x00\x00\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff\xfe\xff\xff\xff', 'labels': b'\x08\x01\x12\x08\x12\x02\x08\x05\x12\x02\x08\x06"x\xdb\xaf&>\xec\x1e\xae?`\x89\x05\xbfCZ\x9e\xc0\nT\x10\xc0x\xd5\xac@\x14;z\xc0\xcc\xdd\xfa\xc0\xfc\xdb\x1bA\x08\x06\xbd\xc0t\xa1\x0c\xc1\xf8\xdc\xe0\xbff\xfc\xa2@\x0c\xda\x00?\xdf\xc0\n@\x11\x14@?Ip\x97\xbf|BX\xc0?P\xf1\xbf&\xbe\xaf\xc0\xe8a\x07\xbfk\x03\x9d@n\xd1\x9a@\xb9\xf78@A\xb8\x16A\xc6\x17U\xbf\xd2:j\xc0\xc1\xf8\xf0\xc0i9\x1d\xc1j\x002@', 'patient_ts': b'', 'patients': b'\x08\x03\x12\x

In [154]:
parsed = raw_dataset.map(_parse_function)

In [282]:
parsed_it = parsed.make_one_shot_iterator()

In [287]:
print(sess.run(tf.parse_tensor(parsed_it.get_next()['patients'], out_type=tf.int32)))

[[ 5  6  4 -2 -2 -2 -2 -2 -2 -2]
 [ 0  3  6  5 -2 -2 -2 -2 -2 -2]
 [-2 -2 -2 -2 -2 -2 -2 -2 -2 -2]
 [-2 -2 -2 -2 -2 -2 -2 -2 -2 -2]
 [-2 -2 -2 -2 -2 -2 -2 -2 -2 -2]]
