## A hands-on guide to TFRecords

The dataset uses TFRecords to store the data. This is a walkthorugh of how to work with TFRecords.

The code is based on the following webpage https://towardsdatascience.com/a-practical-guide-to-tfrecords-584536bc786c

In [None]:
import numpy as np
import tensorflow as tf 
import tqdm
import glob

### Creating 100 random small images

In [None]:
image_small_shape = (250,250,3)
number_of_images_small = 100

images_small = np.random.randint(low=0, high=256, size=(number_of_images_small, *image_small_shape), dtype=np.int16)
print(images_small.shape)

### Creating some labels

In [None]:
labels_small = np.random.randint(low=0, high=5, size=(number_of_images_small, 1))
labels_small = [label[0] for label in labels_small]
print(labels_small[:10])

### Get these {image, label} pairs into the TFRecord file

In [None]:
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))): # if value ist tensor
        value = value.numpy() # get value of tensor
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a floast_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def serialize_array(array):
  array = tf.io.serialize_tensor(array)
  return array

In [None]:
def parse_single_image(image, label):
  
  #define the dictionary -- the structure -- of our single example
  data = {
        'height' : _int64_feature(image.shape[0]),
        'width' : _int64_feature(image.shape[1]),
        'depth' : _int64_feature(image.shape[2]),
        'raw_image' : _bytes_feature(serialize_array(image)),
        'label' : _int64_feature(label)
    }
  #create an Example, wrapping the single features
  out = tf.train.Example(features=tf.train.Features(feature=data))

  return out

### Write our complete dataset to a TFRecord file.write our complete dataset to a TFRecord file

In [None]:
def write_images_to_tfr_short(images, labels, filename:str="images"):
  filename= filename+".tfrecords"
  writer = tf.io.TFRecordWriter(filename) #create a writer that'll store our data to disk
  count = 0

  for index in range(len(images)):
    #get the data we want to write
    current_image = images[index] 
    current_label = labels[index]

    out = parse_single_image(image=current_image, label=current_label)
    writer.write(out.SerializeToString())
    count += 1

  writer.close()
  print(f"Wrote {count} elements to TFRecord")
  return count

In [None]:

count = write_images_to_tfr_short(images_small, labels_small, filename="../data/small_images")

### Read the TFRecord file we created

In [None]:
def parse_tfr_element(element):
  #use the same structure as above; it's kinda an outline of the structure we now want to create
  data = {
      'height': tf.io.FixedLenFeature([], tf.int64),
      'width':tf.io.FixedLenFeature([], tf.int64),
      'label':tf.io.FixedLenFeature([], tf.int64),
      'raw_image' : tf.io.FixedLenFeature([], tf.string),
      'depth':tf.io.FixedLenFeature([], tf.int64),
    }

    
  content = tf.io.parse_single_example(element, data)
  
  height = content['height']
  width = content['width']
  depth = content['depth']
  label = content['label']
  raw_image = content['raw_image']
  
  
  #get our 'feature'-- our image -- and reshape it appropriately
  feature = tf.io.parse_tensor(raw_image, out_type=tf.int16)
  feature = tf.reshape(feature, shape=[height,width,depth])
  return (feature, label)

In [None]:
def get_dataset_small(filename):
  #create the dataset
  dataset = tf.data.TFRecordDataset(filename)

  #pass every single feature through our mapping function
  dataset = dataset.map(
      parse_tfr_element
  )
    
  return dataset

In [None]:
dataset_small = get_dataset_small("../data/small_images.tfrecords")

for sample in dataset_small.take(1):
  print(sample[0].shape)
  print(sample[1].shape)

## Large Dataset and TFRecords

In [None]:
image_large_shape = (400,750,3)
number_of_images_large = 500 #constraining to 500 files here, to not outgrow RAM capacities

images_large = np.random.randint(low=0, high=256, size=(number_of_images_large, *image_large_shape), dtype=np.int16)

In [None]:
labels_large = np.random.randint(low=0, high=5, size=(number_of_images_large, 1))
labels_large = [label[0] for label in labels_large]

In [None]:
def write_images_to_tfr_long(images, labels, filename:str="large_image", max_files:int=10, out_dir:str="../data/large_dataset/"):
    #determine the number of shards (single TFRecord files) we need:
    splits = (len(images)//max_files) + 1 #determine how many tfr shards are needed
    if len(images)%max_files == 0:
        splits-=1
    print(f"\nUsing {splits} shard(s) for {len(images)} files, with up to {max_files} samples per shard")

    file_count = 0
    for i in tqdm.tqdm(range(splits)):
        current_shard_name = f"{out_dir}{i+1}_{splits}{filename}.tfrecords"
        writer = tf.io.TFRecordWriter(current_shard_name)

        current_shard_count = 0
        while current_shard_count < max_files: #as long as our shard is not full
            #get the index of the file that we want to parse now
            index = i*max_files+current_shard_count
            if index == len(images): #when we have consumed the whole data, preempt generation
                break

            current_image = images[index]
            current_label = labels[index]

            #create the required Example representation
            out = parse_single_image(image=current_image, label=current_label)
            
            writer.write(out.SerializeToString())
            current_shard_count+=1
            file_count += 1

        writer.close()
        
    print(f"\nWrote {file_count} elements to TFRecord")

In [None]:
write_images_to_tfr_long(images_large, labels_large, max_files=30)

In [None]:
def get_dataset_large(tfr_dir:str="../data/large_dataset/", pattern:str="*large_image.tfrecords"):
    files = glob.glob(tfr_dir+pattern, recursive=False)

    #create the dataset
    dataset = tf.data.TFRecordDataset(files)

    #pass every single feature through our mapping function
    dataset = dataset.map(
        parse_tfr_element
    )

    return dataset

In [None]:
dataset_large = get_dataset_large()

for sample in dataset_large.take(1):
  print(sample[0].shape)
  print(sample[1].shape)