In [None]:
import numpy as np
import tensorflow as tf
print('TF version:', tf.__version__)

# Training Classifier

## Parameters

In [None]:
LEARNING_RATE = 1e-2
BATCH_SIZE = 128
NUM_EPOCHS = 5

## Data Preparation

In [None]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Data normalization
x_train, x_test = x_train / 255.0, x_test / 255.0

# Adds channel dim for conv layer
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

print(x_train.shape)
print(x_test.shape)

## Model Definition

In [None]:
classifier = tf.keras.models.Sequential([
  tf.keras.layers.Input((28, 28, 1), name='input'),
  tf.keras.layers.Conv2D(8, 3, 2, activation='relu', name='conv1'),
  tf.keras.layers.Conv2D(16, 3, 2, activation='relu', name='conv2'),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(128, activation='relu', name='fc1'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(64, activation='relu', name='fc2'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax', name='fc3')
], name='classifier')

In [None]:
classifier.summary()

## Training

In [None]:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)

classifier.compile(optimizer=optimizer,
                   loss=loss_fn,
                   metrics=['accuracy'])

In [None]:
classifier.fit(x=x_train,
               y=y_train,
               batch_size=BATCH_SIZE,
               epochs=NUM_EPOCHS,
               validation_data=(x_test, y_test),
               validation_batch_size=BATCH_SIZE)

## Evaluation

In [None]:
eval_loss, eval_acc = classifier.evaluate(x_test, y_test, verbose=2)
print('eval_loss:', eval_loss)
print('eval_acc:', eval_acc)

# Visualization using Embedding Projector

In [None]:
import os
from PIL import Image
from tensorboard.plugins import projector

# Load the TensorBoard notebook extension
%load_ext tensorboard

## Dataset for Visualization

In [None]:
VISUALIZATION_COUNT = 900

x_test_ = x_test[:VISUALIZATION_COUNT]
y_test_ = y_test[:VISUALIZATION_COUNT]
print(x_test_.shape)
print(y_test_.shape)

## Feature Extractor Definition

In [None]:
feature_extractor = tf.keras.Model(inputs=[classifier.input], 
                                   outputs=[classifier.get_layer('fc2').output])
print(feature_extractor.output.shape)

## Embedding Projector Setup

In [None]:
# Sets up a logs directory for Tensorboard
log_dir='logs/mnist-embeddings'
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

In [None]:
# Saves labels to metadata.tsv
classes = ['Zero', 'One', 'Two', 'Three', 'Four', 'Five', 'Six', 'Seven',
           'Eight', 'Nine']
with open(os.path.join(log_dir, 'metadata.tsv'), "w") as f:
  for y in y_test_:
    f.write("{}\n".format(classes[y]))


images_pil = []
for x, y in zip(x_test_, y_test_):
  img_pil = Image.fromarray((x[..., 0] * 255).astype(np.uint8))
  images_pil.append(img_pil)

# Saves sprite image
one_square_size = int(np.ceil(np.sqrt(VISUALIZATION_COUNT)))
master_width = 28 * one_square_size
master_height = 28 * one_square_size
spriteimage = Image.new(
    mode='RGB',
    size=(master_width, master_height),
    color=(0,0,0)  # fully transparent
)

for count, image in enumerate(images_pil):
    div, mod = divmod(count, one_square_size)
    h_loc = 28 * div
    w_loc = 28 * mod
    spriteimage.paste(image, (w_loc, h_loc))

spriteimage.convert('RGB').save(os.path.join(log_dir, 'sprite.jpg'))
spriteimage

In [None]:
# Save the weights we want to analyze as a variable.
features = feature_extractor(x_test_, training=False)
features_var = tf.Variable(features)
print(features_var.shape)

# Create a checkpoint from embedding, the filename and key are the
# name of the tensor.
checkpoint = tf.train.Checkpoint(embedding=features_var)
checkpoint.save(os.path.join(log_dir, "embedding.ckpt"))

In [None]:
# Sets up config
config = projector.ProjectorConfig()
embedding = config.embeddings.add()

# The name of the tensor will be suffixed by `/.ATTRIBUTES/VARIABLE_VALUE`.
embedding.tensor_name = "embedding/.ATTRIBUTES/VARIABLE_VALUE"
embedding.metadata_path = 'metadata.tsv'
embedding.sprite.image_path = 'sprite.jpg'
embedding.sprite.single_image_dim.extend([28, 28])
projector.visualize_embeddings(log_dir, config)

## Visualization using Tensorboard

In [None]:
# Now run tensorboard against on log data we just saved.
%tensorboard --logdir {log_dir}