In [1]:
import os
import tensorflow as tf
from datetime import datetime
from tb_cscs import tensorboard

In [2]:
def decode(serialized_example):
    features = tf.io.parse_single_example(
        serialized_example,
        features={
            'image/encoded': tf.io.FixedLenFeature([], tf.string),
            'image/class/label': tf.io.FixedLenFeature([], tf.int64),
        })
    image = tf.image.decode_jpeg(features['image/encoded'], channels=3)
    image = tf.image.resize(image, (224, 224))
    label = tf.cast(features['image/class/label'], tf.int64)
    label = tf.one_hot(label, 1001)
    return image, label

In [3]:
data_dir = '/scratch/snx3000/stud50/imagenet/'
list_of_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir)]

dataset = tf.data.Dataset.list_files(list_of_files)
dataset = dataset.interleave(tf.data.TFRecordDataset,
                             cycle_length=160,
                             block_length=2)
dataset = dataset.map(decode)
dataset = dataset.batch(128)

In [4]:
model = tf.keras.applications.InceptionV3(weights=None,
                                          input_shape=(224, 224, 3),
                                          classes=1001)

optimizer = tf.keras.optimizers.SGD(lr=0.01, momentum=0.9)

model.compile(optimizer=optimizer,
              loss='categorical_crossentropy',
              metrics=['accuracy'])

In [6]:
tb_callback = tf.keras.callbacks.TensorBoard(log_dir=os.path.join('inceptionv3_logs',
                                                                  datetime.now().strftime("%d-%H%M")),
                                             histogram_freq=1,
                                             profile_batch='85,95')

In [7]:
fit = model.fit(dataset.take(100),
                epochs=1,
                callbacks=[tb_callback])

Instructions for updating:
use `tf.profiler.experimental.stop` instead.


In [2]:
%load_ext tensorboard

In [3]:
%tensorboard --logdir=inceptionv3_logs