# TFRecord

## What is a TFRecord

TFRecord is an individual aggregated compact file summing up all the data (present in any format) required during training/testing of a model. This particular file can be transported across multiple systems and is also independent of the model on which it is going to be trained on. The TFRecord file may also contain additional overhead data required to reconstruct the original data which may not have been needed had we trained without TFRecord. Also, in case the dataset is extremely large, we may have to create multiple similar types of TFRecord files.


## How to build a TFRecord

Any data in TFRecord has to be stored as either list of bytes or list of float or list of int64 only. Each of these data list entity created has to be wrapped by a Feature class. Next, each of the feature is stored in a key value pair with key corresponding to the title being allotted to each feature. These titles are going to be used later when extracting the data from TFRecord. The dictionary created is passed as input to Features class. Lastly, the features object is passed as input to Example class. Then this example class object is appended into the TFRecord. The above procedure is repeated for every type of data which has to be stored in TFRecord. The code to create TFRecord using simple data is given next.

In [1]:
import tensorflow as tf

data_arr = [
    {
        'int_data': 108,
        'float_data': 2.45,
        'str_data': 'String 100',
        'float_list_data': [256.78, 13.9]
    },
    {
        'int_data': 37,
        'float_data': 84.3,
        'str_data': 'String 200',
        'float_list_data': [1.34, 843.9, 65.22]
    }
]

def get_example_object(data_record):
    # Convert individual data into a list of int64 or float or bytes
    int_list1 = tf.train.Int64List(value = [data_record['int_data']])
    float_list1 = tf.train.FloatList(value = [data_record['float_data']])
    # Convert string data into list of bytes
    str_list1 = tf.train.BytesList(value = [data_record['str_data'].encode('utf-8')])
    float_list2 = tf.train.FloatList(value = data_record['float_list_data'])

    # Create a dictionary with above lists individually wrapped in Feature
    feature_key_value_pair = {
        'int_list1': tf.train.Feature(int64_list = int_list1),
        'float_list1': tf.train.Feature(float_list = float_list1),
        'str_list1': tf.train.Feature(bytes_list = str_list1),
        'float_list2': tf.train.Feature(float_list = float_list2)
    }

    # Create Features object with above feature dictionary
    features = tf.train.Features(feature = feature_key_value_pair)

    # Create Example object with features
    example = tf.train.Example(features = features)
    return example

with tf.python_io.TFRecordWriter('example.tfrecord') as tfwriter:
    # Iterate through all records
    for data_record in data_arr:
        example = get_example_object(data_record)

        # Append each example into tfrecord
        tfwriter.write(example.SerializeToString())

## Create TFRecord for Images

Now that we have basic understanding on how to create a TFRecord for text type of data comprising of dictionaries and lists, let us proceed into adding images. Our [toy dataset](./TFRecordImages) comprises of totally 10 images and two types of classes i.e cats and dogs. The dataset is a mixture of PNG and JPEG type of images.

In [9]:
import tensorflow as tf
import os
import matplotlib.image as mpimg

class GenerateTFRecord:
    def __init__(self, labels):
        self.labels = labels

    def convert_image_folder(self, img_folder, tfrecord_file_name):
        # Get all file names of images present in folder
        img_paths = os.listdir(img_folder)
        img_paths = [os.path.abspath(os.path.join(img_folder, i)) for i in img_paths]

        with tf.python_io.TFRecordWriter(tfrecord_file_name) as writer:
            for img_path in img_paths:
                example = self._convert_image(img_path)
                writer.write(example.SerializeToString())

    # convert the images to Numpy string, the TFRecord file size is huge
    # def _convert_image(self, img_path):
    #     label = self._get_label_with_filename(img_path)
    #     image_data = mpimg.imread(img_path)
    #     # Convert image to string data
    #     image_str = image_data.tostring()
    #     # Store shape of image for reconstruction purposes
    #     img_shape = image_data.shape
    #     # Get filename
    #     filename = os.path.basename(img_path)
        
    #     example = tf.train.Example(features = tf.train.Features(feature = {
    #         'filename': tf.train.Feature(bytes_list = tf.train.BytesList(value = [filename.encode('utf-8')])),
    #         'rows': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[0]])),
    #         'cols': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[1]])),
    #         'channels': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[2]])),
    #         'image': tf.train.Feature(bytes_list = tf.train.BytesList(value = [image_str])),
    #         'label': tf.train.Feature(int64_list = tf.train.Int64List(value = [label]))
    #     }))
    #     return example

    def _convert_image(self, img_path):
        label = self._get_label_with_filename(img_path)
        img_shape = mpimg.imread(img_path).shape
        filename = os.path.basename(img_path)

        # Read image data in terms of bytes
        with tf.gfile.GFile(img_path, 'rb') as fid:
            image_data = fid.read()

        example = tf.train.Example(features = tf.train.Features(feature = {
            'filename': tf.train.Feature(bytes_list = tf.train.BytesList(value = [filename.encode('utf-8')])),
            'rows': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[0]])),
            'cols': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[1]])),
            'channels': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[2]])),
            'image': tf.train.Feature(bytes_list = tf.train.BytesList(value = [image_data])),
            'label': tf.train.Feature(int64_list = tf.train.Int64List(value = [label])),
        }))
        return example        

    def _get_label_with_filename(self, filename):
        basename = os.path.basename(filename).split('.')[0]
        basename = basename.split('_')[0]        
        return self.labels[basename]

if __name__ == '__main__':
    labels = {'cat': 0, 'dog': 1}
    t = GenerateTFRecord(labels)
    t.convert_image_folder('TFRecordImages', 'images.tfrecord')

Now the size of our images.tfrecord file is 1.2 MB which is almost the same size of individual images summed up.

## Reduce TFRecord size further

Now, let us try to bring the TFRecord size further down. PNG images tend to capture more information with sharper edge details. This comes at a cost of increased storage size of image. Converting to JPEG images will infinitesimally blur your image but it will reward you with measurable amount of storage size reduction. Tensorflow also provides you with amount of quality you wish to retain when you are doing the conversion.

In [13]:
import tensorflow as tf
import os
import matplotlib.image as mpimg

class GenerateTFRecord:
    def __init__(self, labels):
        self.labels = labels
        self._create_graph()

    def convert_image_folder(self, img_folder, tfrecord_file_name):
        # Get all file names of images present in folder
        img_paths = os.listdir(img_folder)
        img_paths = [os.path.abspath(os.path.join(img_folder, i)) for i in img_paths]

        with tf.python_io.TFRecordWriter(tfrecord_file_name) as writer:
            for img_path in img_paths:
                example = self._convert_image(img_path)
                writer.write(example.SerializeToString())        

    # Create graph to convert PNG image data to JPEG data
    def _create_graph(self):
        tf.reset_default_graph()
        self.png_img_pl = tf.placeholder(tf.string)
        png_enc = tf.image.decode_png(self.png_img_pl, channels = 3)
        # Set how much quality of image you would like to retain while conversion
        self.png_to_jpeg = tf.image.encode_jpeg(png_enc, format = 'rgb', quality = 100)

    def _is_png_image(self, filename):
        ext = os.path.splitext(filename)[1].lower()
        return ext == '.png'

    # Run graph to convert PNG image data to JPEG data
    def _convert_png_to_jpeg(self, img):
        sess = tf.Session()
        return sess.run(self.png_to_jpeg, feed_dict = {self.png_img_pl: img})

    def _convert_image(self, img_path):
        label = self._get_label_with_filename(img_path)
        img_shape = mpimg.imread(img_path).shape
        filename = os.path.basename(img_path).split('.')[0]

        # Read image data in terms of bytes
        with tf.gfile.GFile(img_path, 'rb') as fid:
            image_data = fid.read()

            # Encode PNG data to JPEG data
            if self._is_png_image(img_path):
                image_data = self._convert_png_to_jpeg(image_data)

        example = tf.train.Example(features = tf.train.Features(feature = {
            'filename': tf.train.Feature(bytes_list = tf.train.BytesList(value = [filename.encode('utf-8')])),
            'rows': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[0]])),
            'cols': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[1]])),
            'channels': tf.train.Feature(int64_list = tf.train.Int64List(value = [3])),
            'image': tf.train.Feature(bytes_list = tf.train.BytesList(value = [image_data])),
            'label': tf.train.Feature(int64_list = tf.train.Int64List(value = [label])),
        }))
        return example

    def _get_label_with_filename(self, filename):
        basename = os.path.basename(filename).split('.')[0]
        basename = basename.split('_')[0]        
        return self.labels[basename]

if __name__ == '__main__':
    labels = {'cat': 0, 'dog': 1}
    t = GenerateTFRecord(labels)
    t.convert_image_folder('TFRecordImages', 'images.tfrecord')        

Now, maintaining 100 percent of encoding quality, we have reduced the earlier TFRecord file of 1.2 MB to 579.5 KB.

## Extracting data from TFRecord

Now that our TFRecords are ready, it is time to send them into training pipeline. The first step is to initialize TFRecordDataset with all the TFRecord file paths. After that, we have to extract the various features present in the TFRecords. We specify the various keys used during TFRecord formation earlier in this step. If we know beforehand what is the number of items present in the list of bytes or float or int64 for each data record, we can make use of FixedLenFeature, or else, we make use of VarLenFeature class. Next, the API parse_single_example extracts a dictionary object of each data record. Let us look into the extraction procedure of the TFRecord created earlier with simple text dictionary data.

In [14]:
import tensorflow as tf

def extract_fn(data_record):
    features = {
        # Extract features using the keys set during creation
        'int_list1': tf.FixedLenFeature([], tf.int64),
        'float_list1': tf.FixedLenFeature([], tf.float32),
        'str_list1': tf.FixedLenFeature([], tf.string),
        # If size is different of different records, use VarLenFeature 
        'float_list2': tf.VarLenFeature(tf.float32)
    }
    sample = tf.parse_single_example(data_record, features)
    return sample

# Initialize all tfrecord paths
dataset = tf.data.TFRecordDataset(['example.tfrecord'])
dataset = dataset.map(extract_fn)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    try:
        while True:
            data_record = sess.run(next_element)
            print(data_record)
    except:
        pass

Instructions for updating:
Use `for ... in dataset:` to iterate over a dataset. If using `tf.estimator`, return the `Dataset` object directly from your input function. As a last resort, you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)`.
{'float_list2': SparseTensorValue(indices=array([[0],
       [1]]), values=array([256.78,  13.9 ], dtype=float32), dense_shape=array([2])), 'float_list1': 2.45, 'int_list1': 108, 'str_list1': b'String 100'}
{'float_list2': SparseTensorValue(indices=array([[0],
       [1],
       [2]]), values=array([  1.34, 843.9 ,  65.22], dtype=float32), dense_shape=array([3])), 'float_list1': 84.3, 'int_list1': 37, 'str_list1': b'String 200'}


## Extract Images from TFRecord

We extend the same concept of extraction of simple TFRecord files to extract images from it as well. With the help of tf.image.decode_image API, we can decode the image present in any format. As a precautionary measure, we verify whether the shape of the decoded image matches with the stored overhead data of rows, cols and channels in TFRecord. Let us dive into the code of extraction of images from TFRecord.

In [15]:
import tensorflow as tf
import os
import shutil
import matplotlib.image as mpimg
import numpy as np

class TFRecordExtractor:
    def __init__(self, tfrecord_file):
        self.tfrecord_file = os.path.abspath(tfrecord_file)

    def _extract_fn(self, tfrecord):
        # Extract features using the keys set during creation
        features = {
            'filename': tf.FixedLenFeature([], tf.string),
            'rows': tf.FixedLenFeature([], tf.int64),
            'cols': tf.FixedLenFeature([], tf.int64),
            'channels': tf.FixedLenFeature([], tf.int64),
            'image': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.int64)
        }

        # Extract the data record
        sample = tf.parse_single_example(tfrecord, features)

        image = tf.image.decode_image(sample['image'])        
        img_shape = tf.stack([sample['rows'], sample['cols'], sample['channels']])
        label = sample['label']
        filename = sample['filename']
        return [image, label, filename, img_shape]        

    def extract_image(self):
        # Create folder to store extracted images
        folder_path = './ExtractedImages'
        shutil.rmtree(folder_path, ignore_errors = True)
        os.mkdir(folder_path)

        # Pipeline of dataset and iterator 
        dataset = tf.data.TFRecordDataset([self.tfrecord_file])
        dataset = dataset.map(self._extract_fn)
        iterator = dataset.make_one_shot_iterator()
        next_image_data = iterator.get_next()

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())

            try:
                # Keep extracting data till TFRecord is exhausted
                while True:
                    image_data = sess.run(next_image_data)

                    # Check if image shape is same after decoding
                    if not np.array_equal(image_data[0].shape, image_data[3]):
                        print('Image {} not decoded properly'.format(image_data[2]))
                        continue
                        
                    save_path = os.path.abspath(os.path.join(folder_path, image_data[2].decode('utf-8')))
                    mpimg.imsave(save_path, image_data[0])
                    print('Save path = ', save_path, ', Label = ', image_data[1])
            except:
                pass

if __name__ == '__main__':
    t = TFRecordExtractor('./images.tfrecord')
    t.extract_image()

Save path =  /data/Projects/MachineLearning-ComputerVision-DataScience/framework/ExtractedImages/dog_2 , Label =  1
Save path =  /data/Projects/MachineLearning-ComputerVision-DataScience/framework/ExtractedImages/cat_3 , Label =  0
Save path =  /data/Projects/MachineLearning-ComputerVision-DataScience/framework/ExtractedImages/dog_0 , Label =  1
Save path =  /data/Projects/MachineLearning-ComputerVision-DataScience/framework/ExtractedImages/dog_4 , Label =  1
Save path =  /data/Projects/MachineLearning-ComputerVision-DataScience/framework/ExtractedImages/cat_4 , Label =  0
Save path =  /data/Projects/MachineLearning-ComputerVision-DataScience/framework/ExtractedImages/dog_3 , Label =  1
Save path =  /data/Projects/MachineLearning-ComputerVision-DataScience/framework/ExtractedImages/cat_2 , Label =  0
Save path =  /data/Projects/MachineLearning-ComputerVision-DataScience/framework/ExtractedImages/dog_1 , Label =  1
Save path =  /data/Projects/MachineLearning-ComputerVision-DataScience/f