In [1]:
import numpy as np
import pickle
import tensorflow as tf

from argparse import Namespace

In [2]:
sess = tf.InteractiveSession()

In [None]:
sess.close()

In [130]:
options = {'seqs_file': 4,
           'max_v': 6,
           'max_t': 5,
           'n_codes': 7,
           'n_labels': 4,
           'seqs_file': 'seqs',
           'labels_file': 'labs',
           'demo_file': 'demo',
           'out_file': 'zipped_TFR',
           'n_patients': 4}
args = Namespace(**options)

In [140]:
def load_data(args=args):
    """Replace later with dataset stuff."""
    seqs_file = args.seqs_file
    if args.labels_file is not None:
        labels_file = args.labels_file
    seqs = pickle.load(open(seqs_file, 'rb'))
    labs = None
    if args.labels_file is not None:
        labels_file = args.labels_file
        labs = pickle.load(open(labels_file, 'rb'))
    D_t = None
    demo_dim = 0
    if args.demo_file is not None:
        demo_file = args.demo_file
        D_t = pickle.load(open(demo_file, 'rb'))
        demo_dim = D_t.shape[-1]
        demo = tf.constant(D_t, dtype=tf.float32)
        demo = tf.reshape(demo, [args.n_patients, -1])
    return seqs, labs, demo, demo_dim


def fill_visit(visit, args=args):
    """Fill all deficit visits with -2.

    Ensure that all visits have the same number of ICDs for efficient
    tensor logic. If a visit has fewer ICDs, filler ICDs get one-hot
    encoded as the zero vector, so that they affect nothing.

    visit: a list of integer medical codes

    Note: No visit in training or testing should have more than max_v
          visits.
    """
    max_v = args.max_v
    if visit != [-1]:
        new_visit = []
        new_visit.extend(visit)
        n_icd = len(visit)
        deficit = max_v - n_icd
        new_visit.extend([-2] * deficit)
        return new_visit


def fill_patient(patient, mask_batch, args=args):
    """Ensure that all patients have max_t visits.

    Create visits full of -2s, which are one-hot encoded as zero
    vectors. This makes all patients commensurate for efficient tensor
    logic.

    patient: list of list of integer codes
    max_t: the number of visits all patients ought to have

    Note: No patient in training or test data should have more
          than max_t visits.
    """
    max_t = args.max_t
    max_v = args.max_v
    new_patient = []
    new_patient.extend(patient)
    new_mask_batch = mask_batch
    t = len(new_patient)
    deficit = (max_t - t)
    new_patient.extend([[-2] * max_v] * deficit)
    new_mask_batch.append([[0] * max_v] * deficit)
    return new_patient, new_mask_batch, t


def tensorize_seqs(seqs, args=args, true_seqs=True):
    """Convert med2vec to tensorflow data.

    seqs: list of list. cf  https://github.com/mp2893/med2vec
    true_seqs: bool. Are we tensorizing the true sequences? If false,
               we are tonsorizing labels.
    returns:
        patients: tensor with shape [patients, max_t, max_v, |C|]
                  or [patients, max_t, max_v, n_labels] if true_seqs is
                  False.
        row_masks: numpy array with shape [patients, max_t, max_v]
               Later, we will create a [patients, max_t, max_v, |C|]
               tensor where the [p, t, i, j] entry is p(c_j|c_i).
               Row_masks will drop the rows where c_i is the zero
               vector--that is, an NA ICD.

               A separate mask, col_mask, will be created from
               patients in order to mask, for each t, those j for
               which c_j did not appear in visit t, as well as
               p(c_i|c_i).

               The masks are to be applied in reverse order of creation.
               col_mask is applied with tf.multiply and row_masks
               with tf.boolean_mask to avoid needless reshaping.
        patients_ts: numpy array with shape [patients,] containing the
                     number of true visits for each patient.
    """
    patients = []
    new_patient = []
    row_masks = []
    mask_batch = []
    patients_ts = []
    for visit in seqs + [[-1]]:
        if visit != [-1]:
            visit = fill_visit(visit, args)
            new_patient.append(visit)
        else:
            new_patient, mask_batch, t = fill_patient(new_patient,
                                                      mask_batch,
                                                      args)
            patients.append(new_patient)
            if true_seqs:
                patients_ts.append(t)
                row_masks.append(mask_batch)
                mask_batch = []
            new_patient = []
    patients = tf.constant(patients)
    patients_ts = patients_ts
    #patients_ts = tf.expand_dims(patients_ts, -1)
    #patients_ts = tf.expand_dims(patients_ts, -1)
    row_masks = tf.not_equal(patients, -2)
    row_masks = tf.cast(row_masks, tf.int32)
    patients = tf.reshape(patients, [args.n_patients, -1])
    row_masks = tf.reshape(row_masks, [args.n_patients, -1])
    return patients, row_masks, patients_ts

In [141]:
seqs, labs, demo, demo_dim = load_data()

In [142]:
patients, row_masks, patients_ts = tensorize_seqs(seqs, true_seqs=True)
labels, _, _ = tensorize_seqs(labs, true_seqs=False)

In [145]:
output = tf.data.Dataset().from_tensor_slices((patients,
    labels,
    demo,
    row_masks,
    patients_ts
    ))

In [146]:
output_it = output.make_one_shot_iterator()

In [37]:
#serial = output.map(lambda v,w,x,y,z: (tf.serialize_tensor(v),tf.serialize_tensor(w),tf.serialize_tensor(x),tf.serialize_tensor(y),z))

In [163]:
def make_example(patient, label, demo, row_mask, patient_t, args=args):
    # The object we return
    ex = tf.train.SequenceExample()
    # A non-sequential feature of our example
    ex.context.feature["patient_t"].int64_list.value.append(patient_t)
    ex.context.feature["max_t"].int64_list.value.append(args.max_t)
    ex.context.feature["max_v"].int64_list.value.append(args.max_v)
    # Feature lists for the two sequential features of our example
    fl_patients = ex.feature_lists.feature_list["patient"]
    fl_labels = ex.feature_lists.feature_list["label"]
    fl_demo = ex.feature_lists.feature_list["demo"]
    fl_row_masks = ex.feature_lists.feature_list["row_mask"]
    for visit, lab, dem, mask in zip(patient, label, demo, row_mask):
        fl_patients.feature.add().int64_list.value.append(visit)
        fl_labels.feature.add().int64_list.value.append(lab)
        fl_demo.feature.add().float_list.value.append(dem)
        fl_row_masks.feature.add().int64_list.value.append(mask)
    return ex.SerializeToString()

In [164]:
def tf_serialize_w_labels(patient, label, demo, row_mask, patient_t):
    """Map serialize_with_labels to tf.data.Dataset."""
    tf_string = tf.py_func(make_example,
                           (patient, label, demo, row_mask, patient_t),
                           tf.string)
    return tf.reshape(tf_string, ())

In [193]:
serialized = output.map(tf_serialize_w_labels)

In [194]:
serialized_it = serialized.make_one_shot_iterator()

In [195]:
ex = sess.run(serialized_it.get_next())

In [197]:
writer = tf.contrib.data.TFRecordWriter(args.out_file)
writeop = writer.write(serialized)

In [198]:
sess.run(writeop)

## Reading

In [289]:
raw = tf.data.TFRecordDataset([args.out_file])

In [290]:
raw_it = raw.make_one_shot_iterator()

In [291]:
ex = sess.run(raw_it.get_next())

In [292]:
def parse_sequence_examples(example_proto, args=args):
    ctxt_fts = {
        "patient_t": tf.FixedLenFeature([], dtype=tf.int64),
        "max_t": tf.FixedLenFeature([], dtype=tf.int64),
        "max_v": tf.FixedLenFeature([], dtype=tf.int64),
    }
    seq_fts = {
        "patient": tf.FixedLenSequenceFeature([], dtype=tf.int64),
        "label": tf.FixedLenSequenceFeature([], dtype=tf.int64),
        "demo": tf.FixedLenSequenceFeature([], dtype=tf.float32),
        "row_mask": tf.FixedLenSequenceFeature([], dtype=tf.int64)
    }
    ctxt_parsed, seq_parsed = tf.parse_single_sequence_example(
        serialized=example_proto,
        context_features=ctxt_fts,
        sequence_features=seq_fts
    )
    output_shape = [ctxt_parsed['max_t'], ctxt_parsed['max_v']]
    output_shape = tf.stack(output_shape)
    patient = tf.reshape(seq_parsed['patient'], output_shape)
    label = tf.reshape(seq_parsed['label'], output_shape)
    demo = tf.reshape(seq_parsed['demo'], output_shape)
    row_mask = tf.reshape(seq_parsed['row_mask'], output_shape)
    patient_t = tf.reshape(ctxt_parsed['patient_t'], [1,1])
    return (patient, label, demo, row_mask, patient_t)

In [293]:
parse = raw.map(parse_sequence_examples)

In [301]:
parse_batch = parse.batch(2)

In [302]:
parse_batch_it = parse_batch.make_one_shot_iterator()

In [303]:
sess.run(parse_batch_it.get_next())

(array([[[ 5,  4, -2, -2, -2, -2],
         [ 2,  3,  4, -2, -2, -2],
         [-2, -2, -2, -2, -2, -2],
         [-2, -2, -2, -2, -2, -2],
         [-2, -2, -2, -2, -2, -2]],
 
        [[ 5,  6,  4, -2, -2, -2],
         [ 0,  3,  6,  5, -2, -2],
         [-2, -2, -2, -2, -2, -2],
         [-2, -2, -2, -2, -2, -2],
         [-2, -2, -2, -2, -2, -2]]], dtype=int64),
 array([[[ 3,  2, -2, -2, -2, -2],
         [ 0,  3,  2, -2, -2, -2],
         [-2, -2, -2, -2, -2, -2],
         [-2, -2, -2, -2, -2, -2],
         [-2, -2, -2, -2, -2, -2]],
 
        [[ 3,  3,  2, -2, -2, -2],
         [ 0,  3,  3,  2, -2, -2],
         [-2, -2, -2, -2, -2, -2],
         [-2, -2, -2, -2, -2, -2],
         [-2, -2, -2, -2, -2, -2]]], dtype=int64),
 array([[[-1.2766646 , -8.258818  ,  1.0295945 , -9.932409  ,
           3.4238288 ,  4.7623577 ],
         [ 4.781394  , -2.7135177 ,  8.019953  ,  0.30808517,
           0.31495225, -3.8640842 ],
         [ 0.3880088 , -6.512215  ,  3.4989486 ,  2.7263956 ,
  