# Set the root directory.

In [None]:
import os

root_dir = '/content/'
os.chdir(root_dir)

!ls -al

# Import TensorFlow-2.x.

In [None]:
try:
  %tensorflow_version 2.x
except Exception:
  pass

import tensorflow as tf
from tensorflow.keras import backend as K

import tensorflow.keras.layers as layers
import tensorflow.keras.models as models

import numpy as np
np.random.seed(7)

import matplotlib.pyplot as plot

print(tf.__version__)

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

# Create the model.

In [None]:
def create_model():
  model = models.Sequential()
  model.add(layers.BatchNormalization(input_shape=x_train.shape[1:]))
  model.add(layers.Conv2D(64, (5, 5), padding='same', activation='elu'))
  model.add(layers.MaxPooling2D(pool_size=(2, 2), strides=(2,2)))
  model.add(layers.Dropout(0.25))

  model.add(layers.BatchNormalization(input_shape=x_train.shape[1:]))
  model.add(layers.Conv2D(128, (5, 5), padding='same', activation='elu'))
  model.add(layers.MaxPooling2D(pool_size=(2, 2)))
  model.add(layers.Dropout(0.25))

  model.add(layers.BatchNormalization(input_shape=x_train.shape[1:]))
  model.add(layers.Conv2D(256, (5, 5), padding='same', activation='elu'))
  model.add(layers.MaxPooling2D(pool_size=(2, 2), strides=(2,2)))
  model.add(layers.Dropout(0.25))

  model.add(layers.Flatten())
  model.add(layers.Dense(256))
  model.add(layers.Activation('elu'))
  model.add(layers.Dropout(0.5))
  model.add(layers.Dense(10))
  model.add(layers.Activation('softmax'))
  return( model )

In [None]:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
print('Running on TPU -', resolver.cluster_spec().as_dict()['worker'])
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)

In [None]:
with strategy.scope():
  model = create_model()

# Train the model.

### Compile the model.

In [None]:
with strategy.scope():  
  model.compile(
      optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
      loss='sparse_categorical_crossentropy',
      metrics=['sparse_categorical_accuracy'])

### Train the model.

In [None]:
epochs = 20
validation_freq = 5

In [None]:
model.fit(
    x_train.astype(np.float32), y_train.astype(np.float32),
    epochs=epochs,
    steps_per_epoch=60,
    validation_data=(x_test.astype(np.float32), y_test.astype(np.float32)),
    validation_freq=validation_freq
)