<a href="https://colab.research.google.com/github/josejailson/tensorflow_data_api/blob/main/proto_buffer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
import numpy as np

In [None]:
(X_train_full, y_train_full), (X_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
X_train, y_train = X_train_full[:5500], y_train_full[:5500]
X_valid, y_valid = X_train_full[5500:], y_train_full[5500:]

In [None]:
train_set = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(len(X_train))
test_set = tf.data.Dataset.from_tensor_slices((X_test, y_test))
valid_set = tf.data.Dataset.from_tensor_slices((X_valid, y_valid))

In [None]:
from tensorflow.train import Int64List, BytesList
from tensorflow.train import Feature, Features, Example


def create_example(image, label):
  image_data = tf.io.serialize_tensor(image)
  return Example(
      features=Features(
          feature={
              "image": Feature(bytes_list=BytesList(value=[image_data.numpy()])),
              "label": Feature(int64_list=Int64List(value=[label]))
          }
      )
  )

In [None]:
from contextlib import ExitStack

def write_tfrecords(name, dataset, n_shards=10):
  paths = ["{}.tfrecords-{:05d}-{:05d}".format(name, index, n_shards)
            for index in range(n_shards)]
  with ExitStack() as stack:
    writers = [stack.enter_context(tf.io.TFRecordWriter(path))
                for path in paths]
    for index, (image, label) in dataset.enumerate():
      shard = index % n_shards
      example = create_example(image, label)
      writers[shard].write(example.SerializeToString())
  return paths

In [None]:
train_filepaths = write_tfrecords("my_fashion_mnist.train", train_set)
test_filepaths = write_tfrecords("my_fashion_mnist.test", test_set)
valid_filepaths = write_tfrecords("my_fashion_mnist.valid", valid_set)

In [None]:
train_filepaths

['my_fashion_mnist.train.tfrecords-00000-00010',
 'my_fashion_mnist.train.tfrecords-00001-00010',
 'my_fashion_mnist.train.tfrecords-00002-00010',
 'my_fashion_mnist.train.tfrecords-00003-00010',
 'my_fashion_mnist.train.tfrecords-00004-00010',
 'my_fashion_mnist.train.tfrecords-00005-00010',
 'my_fashion_mnist.train.tfrecords-00006-00010',
 'my_fashion_mnist.train.tfrecords-00007-00010',
 'my_fashion_mnist.train.tfrecords-00008-00010',
 'my_fashion_mnist.train.tfrecords-00009-00010']

In [None]:
def preprocess(tfrecord):
  feature_descriptions = {
      "image": tf.io.FixedLenFeature([], tf.string, default_value=""),
      "label": tf.io.FixedLenFeature([], tf.int64, default_value=-1)
  }
  example = tf.io.parse_single_example(tfrecord, feature_descriptions)
  image = tf.io.parse_tensor(example["image"], out_type=tf.uint8)
  image = tf.reshape(image, shape=[28,28])
  return image, example["label"]

In [None]:
def mnist_dataset(filepaths, n_read_threads=5, shuffle_buffer_size=None, n_parse_threads=5, batch_size=32, cache=True):
  dataset = tf.data.TFRecordDataset(filepaths, num_parallel_reads=n_read_threads)
  if cache:
    dataset = dataset.cache()
    if shuffle_buffer_size:
      dataset = dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.map(preprocess, num_parallel_calls=n_parse_threads)
    dataset = dataset.batch(batch_size)
    return dataset.prefetch(1)

In [None]:
train_set = mnist_dataset(train_filepaths, shuffle_buffer_size=60000)
valid_set = mnist_dataset(valid_filepaths)
test_set = mnist_dataset(test_filepaths)

In [None]:
for X, y in train_set.take(1):
  for i in range(5):
    plt