In [25]:
import itertools
import numpy as np
import pandas as pd
import tensorflow as tf

class FrameConverter:
    def __init__(self, X_transforms=[], y_transforms=[], repeat_count=1, n_parallel=1):
        self.filename_base = '/home/data/full/frame/{}{}.tfrecord'
        class_lookup_df = pd.read_csv('/home/data/label_names_2018.csv')
        self.index_to_label = dict(enumerate(class_lookup_df.label_id.values))
        self.label_to_index = {v: k for k,v in self.index_to_label.items()}
        self.repeat_count = repeat_count
        self.keys_to_features = {
            'rgb': tf.FixedLenSequenceFeature([], tf.string, allow_missing=True),
            'audio': tf.FixedLenSequenceFeature([], tf.string, allow_missing=True),
        }
        self.key_to_label = {
            'labels': tf.VarLenFeature(tf.int64)
        }

    def get_data_from_record(self, filename):
        y, X = tf.parse_single_sequence_example(filename,
                                                  self.key_to_label,
                                                  self.keys_to_features)
        # X is still bytes; convert to float
        X['audio'] = tf.cast(tf.decode_raw(X['audio'], tf.uint8), tf.float32)
        X['rgb'] = tf.cast(tf.decode_raw(X['rgb'], tf.uint8), tf.float32)
        y = tf.sparse_to_dense(y['labels'].values, [3862], 1)
        
        # now apply custom transformations
        for transform in X_transforms:
            X = transform(X)
        for transform in y_transforms:
            y = transform(y)
        return X, y
    
    def __call__(self, subset, record_indices):
        filenames = [self.filename_base.format(subset, index) for index in record_indices]
        
        dataset = tf.data.TFRecordDataset(filenames)
        dataset = dataset.map(self.get_data_from_record,
                              num_parallel_calls=n_parallel)
        dataset = dataset.repeat(self.repeat_count)
        dataset = dataset.shuffle(buffer_size=256)
        dataset = dataset.batch(1)
        dataset = dataset.prefetch(1)
        iterator = dataset.make_one_shot_iterator()
        return iterator
    
frame_converter = FrameConverter()
iterator = frame_converter('train', [2500])
sess = tf.Session()
next_sample = iterator.get_next()
sample = sess.run(next_sample)
print(sample[0])

{'audio': array([[[ 76.,  46., 175., ..., 186., 255., 118.],
        [ 34.,  45., 131., ..., 110.,   0., 138.],
        [ 26.,  76., 134., ..., 124.,  34.,  63.],
        ...,
        [ 85.,  55., 170., ..., 193.,  45., 255.],
        [ 58.,  87., 146., ..., 167., 195., 218.],
        [ 98.,  60., 192., ..., 160., 255., 169.]]], dtype=float32), 'rgb': array([[[130., 213., 111., ..., 157., 188., 210.],
        [124., 203., 100., ..., 146., 113., 167.],
        [ 89., 210.,  91., ..., 154., 174., 183.],
        ...,
        [ 48., 206.,  84., ..., 167., 168., 218.],
        [ 36., 203.,  88., ..., 170., 138., 255.],
        [ 15., 205.,  87., ..., 174., 157., 210.]]], dtype=float32)}
