In [137]:
import numpy as np
import pickle
import tensorflow as tf

from argparse import Namespace

import random

In [20]:
sess = tf.InteractiveSession()



In [None]:
sess.close()

In [21]:
options = {'seqs_file': 4,
           'max_v': 6,
           'max_t': 5,
           'n_codes': 7,
           'n_labels': 4,
           'seqs_file': 'seqs',
           'labels_file': 'labs',
           'demo_file': 'demo',
           'out_file': 'zipped_TFR',
           'n_patients': 4}
args = Namespace(**options)

In [22]:
def load_data(args=args):
    """Replace later with dataset stuff."""
    seqs_file = args.seqs_file
    if args.labels_file is not None:
        labels_file = args.labels_file
    seqs = pickle.load(open(seqs_file, 'rb'))
    labs = None
    if args.labels_file is not None:
        labels_file = args.labels_file
        labs = pickle.load(open(labels_file, 'rb'))
    D_t = None
    demo_dim = 0
    if args.demo_file is not None:
        demo_file = args.demo_file
        D_t = pickle.load(open(demo_file, 'rb'))
        demo_dim = D_t.shape[-1]
        demo = tf.constant(D_t, dtype=tf.float32)
        demo = tf.reshape(demo, [args.n_patients, -1])
    return seqs, labs, demo, demo_dim


def fill_visit(visit, args=args):
    """Fill all deficit visits with -2.

    Ensure that all visits have the same number of ICDs for efficient
    tensor logic. If a visit has fewer ICDs, filler ICDs get one-hot
    encoded as the zero vector, so that they affect nothing.

    visit: a list of integer medical codes

    Note: No visit in training or testing should have more than max_v
          visits.
    """
    max_v = args.max_v
    if visit != [-1]:
        new_visit = []
        new_visit.extend(visit)
        n_icd = len(visit)
        deficit = max_v - n_icd
        new_visit.extend([-2] * deficit)
        return new_visit


def fill_patient(patient, mask_batch, args=args):
    """Ensure that all patients have max_t visits.

    Create visits full of -2s, which are one-hot encoded as zero
    vectors. This makes all patients commensurate for efficient tensor
    logic.

    patient: list of list of integer codes
    max_t: the number of visits all patients ought to have

    Note: No patient in training or test data should have more
          than max_t visits.
    """
    max_t = args.max_t
    max_v = args.max_v
    new_patient = []
    new_patient.extend(patient)
    new_mask_batch = mask_batch
    t = len(new_patient)
    deficit = (max_t - t)
    new_patient.extend([[-2] * max_v] * deficit)
    new_mask_batch.append([[0] * max_v] * deficit)
    return new_patient, new_mask_batch, t


def tensorize_seqs(seqs, args=args, true_seqs=True):
    """Convert med2vec to tensorflow data.

    seqs: list of list. cf  https://github.com/mp2893/med2vec
    true_seqs: bool. Are we tensorizing the true sequences? If false,
               we are tonsorizing labels.
    returns:
        patients: tensor with shape [patients, max_t, max_v, |C|]
                  or [patients, max_t, max_v, n_labels] if true_seqs is
                  False.
        row_masks: numpy array with shape [patients, max_t, max_v]
               Later, we will create a [patients, max_t, max_v, |C|]
               tensor where the [p, t, i, j] entry is p(c_j|c_i).
               Row_masks will drop the rows where c_i is the zero
               vector--that is, an NA ICD.

               A separate mask, col_mask, will be created from
               patients in order to mask, for each t, those j for
               which c_j did not appear in visit t, as well as
               p(c_i|c_i).

               The masks are to be applied in reverse order of creation.
               col_mask is applied with tf.multiply and row_masks
               with tf.boolean_mask to avoid needless reshaping.
        patients_ts: numpy array with shape [patients,] containing the
                     number of true visits for each patient.
    """
    patients = []
    new_patient = []
    row_masks = []
    mask_batch = []
    patients_ts = []
    for visit in seqs + [[-1]]:
        if visit != [-1]:
            visit = fill_visit(visit, args)
            new_patient.append(visit)
        else:
            new_patient, mask_batch, t = fill_patient(new_patient,
                                                      mask_batch,
                                                      args)
            patients.append(new_patient)
            if true_seqs:
                patients_ts.append(t)
                row_masks.append(mask_batch)
                mask_batch = []
            new_patient = []
    patients = tf.constant(patients)
    patients_ts = patients_ts
    #patients_ts = tf.expand_dims(patients_ts, -1)
    #patients_ts = tf.expand_dims(patients_ts, -1)
    row_masks = tf.not_equal(patients, -2)
    row_masks = tf.cast(row_masks, tf.int32)
    patients = tf.reshape(patients, [args.n_patients, -1])
    row_masks = tf.reshape(row_masks, [args.n_patients, -1])
    return patients, row_masks, patients_ts

In [23]:
seqs, labs, demo, demo_dim = load_data()

In [24]:
patients, row_masks, patients_ts = tensorize_seqs(seqs, true_seqs=True)
labels, _, _ = tensorize_seqs(labs, true_seqs=False)

In [25]:
output = tf.data.Dataset().from_tensor_slices((patients,
                                               labels,
                                               demo,
                                               row_masks,
                                               patients_ts
                                              ))

In [26]:
output_it = output.make_one_shot_iterator()

In [27]:
output_it.get_next()

(<tf.Tensor 'IteratorGetNext_2:0' shape=(30,) dtype=int32>,
 <tf.Tensor 'IteratorGetNext_2:1' shape=(30,) dtype=int32>,
 <tf.Tensor 'IteratorGetNext_2:2' shape=(30,) dtype=float32>,
 <tf.Tensor 'IteratorGetNext_2:3' shape=(30,) dtype=int32>,
 <tf.Tensor 'IteratorGetNext_2:4' shape=() dtype=int32>)

In [37]:
#serial = output.map(lambda v,w,x,y,z: (tf.serialize_tensor(v),tf.serialize_tensor(w),tf.serialize_tensor(x),tf.serialize_tensor(y),z))

In [28]:
def serialize_lab_dem(patient, label, demo, row_mask, patient_t, args=args):
    """Turn each row of zipped dataset to example protos for writing to TFR."""
    ex = tf.train.SequenceExample()
    # Non-sequential features of the Example
    ex.context.feature["patient_t"].int64_list.value.append(patient_t)
    ex.context.feature["max_t"].int64_list.value.append(args.max_t)
    ex.context.feature["max_v"].int64_list.value.append(args.max_v)
    # Feature lists for the "sequential" features of the Example
    fl_patients = ex.feature_lists.feature_list["patient"]
    fl_labels = ex.feature_lists.feature_list["label"]
    fl_demo = ex.feature_lists.feature_list["demo"]
    fl_row_masks = ex.feature_lists.feature_list["row_mask"]
    for visit, lab, dem, mask in zip(patient, label, demo, row_mask):
        fl_patients.feature.add().int64_list.value.append(visit)
        fl_labels.feature.add().int64_list.value.append(lab)
        fl_demo.feature.add().float_list.value.append(dem)
        fl_row_masks.feature.add().int64_list.value.append(mask)
    return ex.SerializeToString()

In [29]:
def tf_serialize_lab_dem(patient, label, demo, row_mask, patient_t):
    """Map serialize_with_labels to tf.data.Dataset."""
    tf_string = tf.py_func(serialize_lab_dem,
                           (patient, label, demo, row_mask, patient_t),
                           tf.string)
    return tf.reshape(tf_string, ())

In [30]:
serialized = output.map(tf_serialize_lab_dem)

In [31]:
serialized_it = serialized.make_one_shot_iterator()

In [32]:
ex = sess.run(serialized_it.get_next())

In [197]:
writer = tf.contrib.data.TFRecordWriter(args.out_file)
writeop = writer.write(serialized)

In [198]:
sess.run(writeop)

## Reading

In [128]:
raw = tf.data.TFRecordDataset(['testtesttest'])

In [129]:
raw_it = raw.make_one_shot_iterator()

In [130]:
ex = sess.run(raw_it.get_next())

In [103]:
test = vars(args)

In [104]:
test['fake'] = "fake"

In [105]:
args.fake

'fake'

In [136]:
parse_lab_dem(ex)['patient'].eval()

array([[ 5,  4, -2, -2, -2, -2],
       [ 2,  3,  4, -2, -2, -2],
       [-2, -2, -2, -2, -2, -2],
       [-2, -2, -2, -2, -2, -2],
       [-2, -2, -2, -2, -2, -2]], dtype=int64)

In [364]:
def parse_lab_dem(example_proto, args=args):
    """Prepare TFRecords for training."""
    ctxt_fts = {
        "patient_t": tf.FixedLenFeature([], dtype=tf.float32),
        "max_t": tf.FixedLenFeature([], dtype=tf.int64),
        "max_v": tf.FixedLenFeature([], dtype=tf.int64),
        "demo_dim": tf.FixedLenFeature([], dtype=tf.int64)
    }
    seq_fts = {
        "patient": tf.FixedLenSequenceFeature([], dtype=tf.int64),
        "label": tf.FixedLenSequenceFeature([], dtype=tf.int64),
        "demo": tf.FixedLenSequenceFeature([], dtype=tf.float32),
        "row_mask": tf.FixedLenSequenceFeature([], dtype=tf.int64)
    }
    ctxt_parsed, seq_parsed = tf.parse_single_sequence_example(
        serialized=example_proto,
        context_features=ctxt_fts,
        sequence_features=seq_fts
    )
    print(ctxt_parsed['demo_dim'].eval())
    output_shape = [ctxt_parsed['max_t'], ctxt_parsed['max_v']]
    demo_shape = [ctxt_parsed['max_t'], ctxt_parsed['demo_dim']]
    output_shape = tf.stack(output_shape)
    demo_shape = tf.stack(demo_shape)
    patient = tf.reshape(seq_parsed['patient'], output_shape)
    label = tf.reshape(seq_parsed['label'], output_shape)
    demo = tf.reshape(seq_parsed['demo'], demo_shape)
    row_mask = tf.reshape(seq_parsed['row_mask'], output_shape)
    patient_t = tf.reshape(ctxt_parsed['patient_t'], [1, 1])
    return {'patient': patient, 'label': label, 'demo': demo,
            'row_mask': row_mask, 'patient_t': patient_t}

In [121]:
parse = raw.map(parse_lab_dem)

In [122]:
parse_it = parse.make_one_shot_iterator()

In [123]:
sess.run(parse_it.get_next())

{'patient': array([[ 5,  4, -2, -2, -2, -2],
        [ 2,  3,  4, -2, -2, -2],
        [-2, -2, -2, -2, -2, -2],
        [-2, -2, -2, -2, -2, -2],
        [-2, -2, -2, -2, -2, -2]], dtype=int64),
 'label': array([[ 3,  2, -2, -2, -2, -2],
        [ 0,  3,  2, -2, -2, -2],
        [-2, -2, -2, -2, -2, -2],
        [-2, -2, -2, -2, -2, -2],
        [-2, -2, -2, -2, -2, -2]], dtype=int64),
 'demo': array([[-1.2766646 , -8.258818  ,  1.0295945 , -9.932409  ,  3.4238288 ,
          4.7623577 ],
        [ 4.781394  , -2.7135177 ,  8.019953  ,  0.30808517,  0.31495225,
         -3.8640842 ],
        [ 0.3880088 , -6.512215  ,  3.4989486 ,  2.7263956 , -1.5292437 ,
         -0.5698157 ],
        [ 4.377598  , -1.8317424 ,  2.9549975 , -1.0864103 ,  4.6243834 ,
         -0.11334902],
        [-7.137146  ,  0.8752901 , -4.1808968 ,  3.2075074 ,  4.62191   ,
         -1.7289466 ]], dtype=float32),
 'row_mask': array([[1, 1, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0],
    

In [124]:
parse = parse.batch(2)

In [125]:
parse_it = parse.make_one_shot_iterator()

In [126]:
sess.run(parse_it.get_next())

{'patient': array([[[ 5,  4, -2, -2, -2, -2],
         [ 2,  3,  4, -2, -2, -2],
         [-2, -2, -2, -2, -2, -2],
         [-2, -2, -2, -2, -2, -2],
         [-2, -2, -2, -2, -2, -2]],
 
        [[ 5,  6,  4, -2, -2, -2],
         [ 0,  3,  6,  5, -2, -2],
         [-2, -2, -2, -2, -2, -2],
         [-2, -2, -2, -2, -2, -2],
         [-2, -2, -2, -2, -2, -2]]], dtype=int64),
 'label': array([[[ 3,  2, -2, -2, -2, -2],
         [ 0,  3,  2, -2, -2, -2],
         [-2, -2, -2, -2, -2, -2],
         [-2, -2, -2, -2, -2, -2],
         [-2, -2, -2, -2, -2, -2]],
 
        [[ 3,  3,  2, -2, -2, -2],
         [ 0,  3,  3,  2, -2, -2],
         [-2, -2, -2, -2, -2, -2],
         [-2, -2, -2, -2, -2, -2],
         [-2, -2, -2, -2, -2, -2]]], dtype=int64),
 'demo': array([[[-1.2766646 , -8.258818  ,  1.0295945 , -9.932409  ,
           3.4238288 ,  4.7623577 ],
         [ 4.781394  , -2.7135177 ,  8.019953  ,  0.30808517,
           0.31495225, -3.8640842 ],
         [ 0.3880088 , -6.512215  , 

In [88]:
row = parse_it.get_next()

In [89]:
a = 2*row['patient']
b = 2 * row['label']

In [90]:
sess.run([a, b])

[array([[10,  8, -4, -4, -4, -4],
        [ 4,  6,  8, -4, -4, -4],
        [-4, -4, -4, -4, -4, -4],
        [-4, -4, -4, -4, -4, -4],
        [-4, -4, -4, -4, -4, -4]], dtype=int64),
 array([[ 6,  4, -4, -4, -4, -4],
        [ 0,  6,  4, -4, -4, -4],
        [-4, -4, -4, -4, -4, -4],
        [-4, -4, -4, -4, -4, -4],
        [-4, -4, -4, -4, -4, -4]], dtype=int64)]

In [92]:
parse.batch(2)

<BatchDataset shapes: {patient: (?, ?, ?), label: (?, ?, ?), demo: (?, ?, ?), row_mask: (?, ?, ?), patient_t: (?, 1, 1)}, types: {patient: tf.int64, label: tf.int64, demo: tf.float32, row_mask: tf.int64, patient_t: tf.int64}>

In [100]:
patients.shape.as_list()

[4, 30]

# Making more dummy data

In [259]:
def simulate_seqs(n_codes, max_v, max_t, n_iters, n_patients):
    """Making some test data with built-in relationshipps between codes."""
    simu_seqs = []
    codes = range(n_codes)
    dist = {}
    for _i, _j in [(i,j) for i in codes for j in codes]:
        i, j = str(_i), str(_j)
        if i == j:
            continue
        dist[",".join([i,j])] = random.random()
    for _i in codes:
        i = str(_i)
        dist[",".join([i,i])] = random.random()
        total = 0
        for _j in codes:
            j = str(j)
            total += dist[",".join([i,j])]
        for _j in codes:
            j = str(j)
            dist[",".join([i,j])] /= total
            dist[",".join([i,j])] *= dist[",".join([i,i])]
    for thing in dist:
        dist[thing] *= .01
    for patient in range(n_patients):
        for t in range(random.choice(range(2, max_t))):
            v = random.choice(range(2, max_v))
            visit = set()
            r = 0
            while r < n_iters or len(visit) < v:
                if len(visit) < max_v:
                    for code in codes:
                        if len(visit) == 0:
                            dart = random.random()
                            if dart <= dist[",".join([str(code),str(code)])]:
                                visit.add(code)
                    for code_1 in list(visit):
                        for code_2 in codes:
                            if len(visit) < max_v:
                                dart = random.random()
                                if dart <= dist[",".join([str(code_2),str(code_1)])]:
                                    visit.add(code_2)
                                    if len(visit) == max_v:
                                        break
                r += 1
            simu_seqs.append(list(visit))
        simu_seqs.append([-1])
    return dist, simu_seqs[:-1]

In [260]:
# n_codes, max_v, max_t, n_iters, n_patients
dist, simu_seqs = simulate_seqs(n_codes=12, max_v=5, max_t=8, n_iters=10, n_patients=1000)

In [272]:
len([thing for thing in simu_seqs if (11 in thing and 4 in thing)])

152

In [262]:
with open('simu_seqs', 'wb') as oFile:
    pickle.dump(simu_seqs, oFile, protocol=2)

In [263]:
for thing in simu_seqs:
    if len(thing) > 5:
        print(thing)

In [332]:
def divvy(seqs, n_patients=5):
    batches = []
    batch = []
    total_patients = len([x for x in seqs if x == [-1]]) + 1
    patients = 1
    for idx, visit in enumerate(seqs):
        if visit == [-1]:
            patients += 1
        if patients > n_patients  or idx == len(seqs) - 1:
            batches.append(batch)
            batch = []
            patients = 1
        else:
            batch.append(visit)
    return batches

In [289]:
test = divvy(simu_seqs, 3)

In [290]:
test[-1]

[[0, 4, 5, 7], [0, 1, 6, 8], [3, 5]]

array([[ 0.99096453,  5.162933  , -6.2529473 ,  3.6117225 ,  7.993999  ,
         5.283886  ],
       [ 7.7175665 ,  4.076336  , -1.2595476 ,  5.600626  ,  5.1977305 ,
         1.6988049 ],
       [-1.4223924 ,  3.7327557 , -5.697534  , -2.8188312 , -0.50290483,
        -2.4053864 ],
       [ 0.43594104, -2.3363402 ,  2.7335165 , -9.652917  , -1.457068  ,
        -1.7781298 ],
       [-5.504733  , -2.1465867 , -0.70124686,  1.7979473 ,  7.8137245 ,
         0.40736306]], dtype=float32)

In [300]:
simu_lab_map = {-1:-1, 0: 0, 1:1, 2:1, 3:2, 4: 3, 5:1, 6:4, 7:0, 8:5, 9:2, 10:6, 11:6}

In [302]:
simu_labs = [[simu_lab_map[code] for code in visit] for visit in simu_seqs]

In [303]:
simu_labs

[[0, 2, 0],
 [2, 3, 4],
 [0, 6],
 [1, 4],
 [3, 1, 4, 0],
 [-1],
 [5, 2, 4],
 [2, 1, 3],
 [2, 0],
 [-1],
 [1, 4],
 [0, 1, 1, 2],
 [5, 1],
 [0, 0],
 [-1],
 [2, 6, 1, 3],
 [5, 0],
 [5, 6, 1, 0],
 [2, 4, 0, 2, 6],
 [-1],
 [6, 1, 4, 0],
 [0, 1, 0],
 [1, 2, 3, 0],
 [6, 3, 1],
 [2, 6, 1],
 [5, 4, 0],
 [-1],
 [1, 1, 2],
 [1, 6, 1, 0],
 [1, 2],
 [5, 1],
 [-1],
 [5, 1],
 [1, 6, 4, 0],
 [5, 1],
 [-1],
 [0, 5, 0],
 [1, 6, 2, 0],
 [0, 6],
 [2, 1, 6, 3],
 [0, 3, 4],
 [-1],
 [2, 4, 1, 0],
 [2, 3, 1, 0],
 [3, 1],
 [1, 4],
 [-1],
 [0, 6, 2, 1],
 [2, 1, 4, 0],
 [2, 2],
 [0, 1, 6, 0],
 [0, 1, 3],
 [0, 5, 6, 0],
 [-1],
 [5, 3],
 [1, 1],
 [1, 4, 0],
 [6, 2, 2, 3],
 [6, 1],
 [2, 2],
 [-1],
 [2, 1],
 [2, 1, 2, 1],
 [6, 6],
 [-1],
 [6, 2],
 [2, 4, 0],
 [-1],
 [0, 3],
 [6, 2, 4, 0],
 [-1],
 [0, 6, 1, 4],
 [6, 0],
 [0, 1],
 [0, 5, 1, 0],
 [0, 1, 1, 2],
 [6, 1, 3],
 [-1],
 [2, 1, 0],
 [0, 5, 6, 0],
 [-1],
 [5, 1, 1],
 [6, 1, 2],
 [1, 6],
 [6, 6, 1],
 [-1],
 [3, 1],
 [1, 6, 0],
 [6, 4],
 [5, 6, 1, 3],
 [-1],
 [2,

In [304]:
simu_demo = []
demo_dim = 7
for visit in simu_seqs:
    if visit == [-1]:
        simu_demo.append(visit)
    else:
        new_vec = [random.random() for i in range(demo_dim)]
        simu_demo.append(new_vec)

In [305]:
simu_demo

[[0.8187589069626391,
  0.589141427353065,
  0.5090105558913021,
  0.005875008470987209,
  0.05679415799163834,
  0.038621842416045826,
  0.7896159386947912],
 [0.5476637730802864,
  0.05326825795130541,
  0.2965937609398399,
  0.07915699601091053,
  0.5253323141944548,
  0.7225475559570019,
  0.774022331798564],
 [0.6864043548235739,
  0.5129799130530761,
  0.5713403698102727,
  0.47506163080938835,
  0.15102044287516248,
  0.9238557385322455,
  0.6791717110808009],
 [0.5475711651749273,
  0.33328293848560486,
  0.29481848452254733,
  0.18773818344892956,
  0.2468326972775423,
  0.5408704076482236,
  0.36910404480954606],
 [0.07318519248073163,
  0.961029953908276,
  0.9502625693556245,
  0.49046824018890534,
  0.16975648998695525,
  0.011636009425135763,
  0.36717532176266166],
 [-1],
 [0.6235068923820783,
  0.36489871473960067,
  0.2448107711994495,
  0.9327035337483227,
  0.6036026759183437,
  0.3380854949818839,
  0.7243742920120072],
 [0.7391636347819479,
  0.42156990913103554,
 

In [306]:
with open('simu_labs', 'wb') as oFile:
    pickle.dump(simu_labs, oFile, protocol=2)
with open('simu_demo', 'wb') as oFile:
    pickle.dump(simu_demo, oFile, protocol=2)

In [338]:
raw_data = []

In [339]:
raw_data.append(divvy(simu_seqs)[:2])

In [340]:
raw_data.append(divvy(simu_labs)[:2])

In [341]:
raw_data.append(divvy(simu_demo)[:2])

In [344]:
divvy(simu_demo)[:2]

[[[0.8187589069626391,
   0.589141427353065,
   0.5090105558913021,
   0.005875008470987209,
   0.05679415799163834,
   0.038621842416045826,
   0.7896159386947912],
  [0.5476637730802864,
   0.05326825795130541,
   0.2965937609398399,
   0.07915699601091053,
   0.5253323141944548,
   0.7225475559570019,
   0.774022331798564],
  [0.6864043548235739,
   0.5129799130530761,
   0.5713403698102727,
   0.47506163080938835,
   0.15102044287516248,
   0.9238557385322455,
   0.6791717110808009],
  [0.5475711651749273,
   0.33328293848560486,
   0.29481848452254733,
   0.18773818344892956,
   0.2468326972775423,
   0.5408704076482236,
   0.36910404480954606],
  [0.07318519248073163,
   0.961029953908276,
   0.9502625693556245,
   0.49046824018890534,
   0.16975648998695525,
   0.011636009425135763,
   0.36717532176266166],
  [-1],
  [0.6235068923820783,
   0.36489871473960067,
   0.2448107711994495,
   0.9327035337483227,
   0.6036026759183437,
   0.3380854949818839,
   0.7243742920120072],
  [

In [347]:
for thing in zip(*raw_data):
    print(thing)
    break

([[0, 3, 7], [3, 4, 6], [0, 10], [1, 6], [4, 5, 6, 7], [-1], [8, 9, 6], [9, 2, 4], [3, 7], [-1], [5, 6], [0, 1, 5, 9], [8, 5], [0, 7], [-1], [3, 10, 2, 4], [8, 7], [8, 11, 5, 7], [3, 6, 7, 9, 10], [-1], [10, 5, 6, 7], [0, 5, 7], [1, 3, 4, 7], [11, 4, 5], [9, 10, 2], [8, 6, 7]], [[0, 2, 0], [2, 3, 4], [0, 6], [1, 4], [3, 1, 4, 0], [-1], [5, 2, 4], [2, 1, 3], [2, 0], [-1], [1, 4], [0, 1, 1, 2], [5, 1], [0, 0], [-1], [2, 6, 1, 3], [5, 0], [5, 6, 1, 0], [2, 4, 0, 2, 6], [-1], [6, 1, 4, 0], [0, 1, 0], [1, 2, 3, 0], [6, 3, 1], [2, 6, 1], [5, 4, 0]], [[0.8187589069626391, 0.589141427353065, 0.5090105558913021, 0.005875008470987209, 0.05679415799163834, 0.038621842416045826, 0.7896159386947912], [0.5476637730802864, 0.05326825795130541, 0.2965937609398399, 0.07915699601091053, 0.5253323141944548, 0.7225475559570019, 0.774022331798564], [0.6864043548235739, 0.5129799130530761, 0.5713403698102727, 0.47506163080938835, 0.15102044287516248, 0.9238557385322455, 0.6791717110808009], [0.5475711651749

In [330]:
raw_data

[[[[0, 3, 7],
   [3, 4, 6],
   [0, 10],
   [1, 6],
   [4, 5, 6, 7],
   [-1],
   [8, 9, 6],
   [9, 2, 4],
   [3, 7],
   [-1],
   [5, 6],
   [0, 1, 5, 9],
   [8, 5],
   [0, 7],
   [-1],
   [3, 10, 2, 4],
   [8, 7],
   [8, 11, 5, 7],
   [3, 6, 7, 9, 10],
   [-1],
   [10, 5, 6, 7],
   [0, 5, 7],
   [1, 3, 4, 7],
   [11, 4, 5],
   [9, 10, 2],
   [8, 6, 7],
   [-1],
   [1, 5, 9],
   [2, 10, 5, 7],
   [2, 3],
   [8, 5],
   [-1],
   [8, 5],
   [2, 11, 6, 7],
   [8, 5],
   [-1],
   [0, 8, 7],
   [1, 10, 3, 7],
   [0, 10],
   [9, 2, 10, 4],
   [0, 4, 6],
   [-1],
   [9, 6, 1, 7],
   [3, 4, 5, 7],
   [4, 5],
   [5, 6],
   [-1],
   [0, 10, 3, 5],
   [3, 5, 6, 7],
   [9, 3],
   [0, 1, 11, 7],
   [0, 2, 4],
   [0, 8, 10, 7],
   [-1],
   [8, 4],
   [1, 2],
   [5, 6, 7],
   [11, 9, 3, 4],
   [10, 2],
   [9, 3],
   [-1],
   [3, 5],
   [9, 2, 3, 5],
   [10, 11],
   [-1],
   [10, 3],
   [9, 6, 7],
   [-1],
   [0, 4],
   [10, 3, 6, 7],
   [-1],
   [0, 10, 5, 6],
   [10, 7],
   [0, 1],
   [0, 8, 5, 7],
   

In [346]:
zip(alist, blist, clist)

<zip at 0x165f37f1708>

In [365]:
data = tf.data.TFRecordDataset(['./data/simu_seqs_TFR0'])

In [366]:
data_it = data.make_one_shot_iterator()

In [367]:
ex = sess.run(data_it.get_next())

In [369]:
parse_lab_dem(ex)['demo'].eval()

7


array([[ 0.8187589 ,  0.5891414 ,  0.50901055,  0.00587501,  0.05679416,
         0.03862184,  0.7896159 ],
       [ 0.54766375,  0.05326826,  0.29659376,  0.07915699,  0.52533233,
         0.72254753,  0.77402234],
       [ 0.68640435,  0.5129799 ,  0.5713404 ,  0.47506163,  0.15102044,
         0.9238557 ,  0.6791717 ],
       [ 0.5475712 ,  0.33328295,  0.2948185 ,  0.18773818,  0.2468327 ,
         0.5408704 ,  0.36910406],
       [ 0.07318519,  0.96102995,  0.95026255,  0.49046823,  0.16975649,
         0.01163601,  0.3671753 ],
       [-2.        , -2.        , -2.        , -2.        , -2.        ,
        -2.        , -2.        ],
       [-2.        , -2.        , -2.        , -2.        , -2.        ,
        -2.        , -2.        ],
       [-2.        , -2.        , -2.        , -2.        , -2.        ,
        -2.        , -2.        ]], dtype=float32)