In [1]:
import os
import tensorflow as tf
from datetime import datetime
from tensorflow import keras
from tb_cscs import tensorboard

In [63]:
BATCH_SIZE = 128
NUM_EPOCS = 4

In [71]:
def decode(serialized_example):
    """Parses an image and label from the given `serialized_example`."""
    features = tf.io.parse_single_example(
        serialized_example,
        features={
            'image_raw': tf.io.FixedLenFeature([], tf.string),
            'label': tf.io.FixedLenFeature([], tf.int64),
        })
    label = tf.cast(features['label'], tf.int32)
    image = tf.io.decode_raw(features['image_raw'], tf.uint8)
    return image, label

In [72]:
def normalize(image, label):
    """Convert `image` from [0, 255] -> [-0.5, 0.5] floats."""
    image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
    return image, label

In [80]:
def get_dataset(filename, batch_size=BATCH_SIZE, epochs=NUM_EPOCS):
    dataset = tf.data.TFRecordDataset(filename)
    dataset = dataset.map(decode)
    # dataset = dataset.map(normalize)
    dataset = dataset.shuffle(6000)
    # dataset = dataset.batch(batch_size)
    dataset = dataset.map(normalize)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    dataset = dataset.repeat(epochs)
    return dataset

In [81]:
model = keras.Sequential([
    keras.layers.Flatten(input_shape=(784,)),
    keras.layers.Dense(128, activation=tf.nn.relu),
    keras.layers.Dense(10, activation=tf.nn.softmax)
])

model.compile(optimizer=keras.optimizers.Adam(0.001),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

In [82]:
tb_callback = tf.keras.callbacks.TensorBoard(log_dir = os.path.join('dense_logs', datetime.now().strftime("%d-%H%M")),
                                             histogram_freq = 1,
                                             profile_batch = '700,730')

In [83]:
fit = model.fit(get_dataset('../input_pipelines/tfrecords/train.tfrecords'),
                validation_data=get_dataset('../input_pipelines/tfrecords/test.tfrecords', epochs=1),
                epochs=NUM_EPOCS,
                callbacks=[tb_callback])

Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4


In [8]:
%load_ext tensorboard

In [84]:
%tensorboard --logdir=dense_logs

Reusing TensorBoard on port 6007 (pid 17972), started 0:53:56 ago. (Use '!kill 17972' to kill it.)