<a href="https://colab.research.google.com/github/hernandosalas/Code/blob/master/Machine%2520Learning/Fashion_MNIST_with_Keras_and_TPUs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:


# %pip install tensorflow==2.2
%tensorflow_version 2.x
import tensorflow as tf
print("Tensorflow version " + tf.__version__)
import numpy as np
import os
import time
from google.colab import files
from matplotlib import pyplot

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()

# add empty color dimension
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

def create_model():
  model = tf.keras.models.Sequential()

  model.add(tf.keras.layers.Conv2D(128, (3, 3), input_shape=x_train.shape[1:]))
  model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2,2)))
  model.add(tf.keras.layers.Activation('elu'))

  model.add(tf.keras.layers.Flatten())
  model.add(tf.keras.layers.Dense(10))
  model.add(tf.keras.layers.Activation('softmax'))
  
  return model


#  Train on the TPU
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])

tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)

strategy = tf.distribute.experimental.TPUStrategy(tpu)
print("REPLICAS: ", strategy.num_replicas_in_sync)

#
# Create model with TPU strategy
#
with strategy.scope():
  model = create_model()
  model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
    loss=tf.keras.losses.sparse_categorical_crossentropy,
    metrics=[tf.keras.metrics.sparse_categorical_accuracy])  
  
start_time = time.time()

model.fit(
  x_train.astype(np.float32), y_train.astype(np.float32),
  epochs=5,
  steps_per_epoch=60,
  validation_data=(x_test.astype(np.float32), y_test.astype(np.float32)),
  validation_freq=5
)

model.save_weights('./fashion_mnist.h5', overwrite=True)

print("--- %s seconds ---" % (time.time() - start_time))

# Inference model
inferencing_model = create_model()
inferencing_model.load_weights('./fashion_mnist.h5')
inferencing_model.summary()

# Download inference model
files.download('./fashion_mnist.h5')

# Check Results
LABEL_NAMES = ['t_shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle_boots']

%matplotlib inline

def plot_predictions(images, predictions):
  n = images.shape[0]
  nc = int(np.ceil(n / 4))
  f, axes = pyplot.subplots(nc, 4)
  for i in range(nc * 4):
    y = i // 4
    x = i % 4
    axes[x, y].axis('off')
    
    label = LABEL_NAMES[np.argmax(predictions[i])]
    confidence = np.max(predictions[i])
    if i > n:
      continue
    axes[x, y].imshow(images[i])
    axes[x, y].text(0.5, 0.5, label + '\n%.3f' % confidence, fontsize=14)

  pyplot.gcf().set_size_inches(8, 8)  

plot_predictions(np.squeeze(x_test[:16]), 
                 inferencing_model.predict(x_test[:16]))