In [1]:
import os
import tensorflow as tf
from datetime import datetime
from tensorflow import keras

In [2]:
BATCH_SIZE = 64
NUM_EPOCS = 4

In [3]:
def decode(serialized_example):
    """Parses an image and label from the given `serialized_example`."""
    features = tf.io.parse_single_example(
        serialized_example,
        features={
            'image_raw': tf.io.FixedLenFeature([], tf.string),
            'label': tf.io.FixedLenFeature([], tf.int64),
        })
    label = tf.cast(features['label'], tf.int32)
    image = tf.io.decode_raw(features['image_raw'], tf.uint8)
    image = tf.reshape(image, (28, 28, 1))
    return image, label

In [4]:
def normalize(image, label):
    """Convert `image` from [0, 255] -> [-0.5, 0.5] floats."""
    image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
    return image, label

In [5]:
def get_dataset(filename, batch_size=BATCH_SIZE, epochs=NUM_EPOCS):
    dataset = tf.data.TFRecordDataset(filename)
    dataset = dataset.map(decode)
    # dataset = dataset.shuffle(128)
    dataset = dataset.batch(batch_size)
    dataset = dataset.map(normalize)
    # dataset = dataset.repeat(epochs)
    return dataset

In [6]:
model = keras.Sequential([keras.layers.Conv2D(32, kernel_size=(3, 3),
                                             activation='relu',
                                             input_shape=(28, 28, 1)),
                          keras.layers.Conv2D(64, (3, 3), activation='relu'),
                          keras.layers.MaxPooling2D(pool_size=(2, 2)),
                          keras.layers.Dropout(0.25),
                          keras.layers.Flatten(),
                          keras.layers.Dense(128, activation='relu'),
                          keras.layers.Dropout(0.5),
                          keras.layers.Dense(10, activation='softmax')])

model.compile(optimizer=keras.optimizers.Adam(0.001),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

In [7]:
tb_callback = tf.keras.callbacks.TensorBoard(log_dir = os.path.join('cnn_logs', datetime.now().strftime("%d-%H%M")),
                                             histogram_freq = 1,
                                             profile_batch = '700,730')

In [8]:
fit = model.fit(get_dataset('../tfrecords/train.tfrecords'),
                validation_data=get_dataset('../tfrecords/test.tfrecords', epochs=1),
                epochs=NUM_EPOCS,
                callbacks=[tb_callback])

Epoch 1/4
Instructions for updating:
use `tf.profiler.experimental.stop` instead.
Epoch 2/4
Epoch 3/4
Epoch 4/4


In [9]:
%load_ext tensorboard

In [10]:
%tensorboard --logdir=cnn_logs