In [None]:
#https://github.com/zaidalyafeai/Notebooks/blob/master/GPUvsTPU.ipynb

In [None]:
import tensorflow as tf
import os
import numpy as np
from tensorflow.keras.utils import to_categorical

def get_data():
    #Load mnist data set
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

    x_train = x_train.astype('float32') / 255
    x_test = x_test.astype('float32') / 255

    x_train = np.expand_dims(x_train, 3)
    x_test = np.expand_dims(x_test, 3)

    y_train = to_categorical(y_train)
    y_test  = to_categorical(y_test)

    return x_train, y_train, x_test, y_test

In [None]:
from tensorflow.contrib.tpu.python.tpu import keras_support

def get_model(tpu = False):
    model = tf.keras.Sequential()

    #add layers to the model 
    model.add(tf.keras.layers.Conv2D(filters=64, kernel_size=2, padding='same', activation='relu', input_shape=(28,28,1))) 
    model.add(tf.keras.layers.MaxPooling2D(pool_size=2))
    model.add(tf.keras.layers.Dropout(0.3))

    model.add(tf.keras.layers.Conv2D(filters=32, kernel_size=2, padding='same', activation='relu'))
    model.add(tf.keras.layers.MaxPooling2D(pool_size=2))
    model.add(tf.keras.layers.Dropout(0.3))

    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(256, activation='relu'))
    model.add(tf.keras.layers.Dropout(0.5))
    model.add(tf.keras.layers.Dense(10, activation='softmax'))

    #compile the model 
    model.compile(loss='categorical_crossentropy',
               optimizer='adam',
               metrics=['accuracy'])

    #flag to run on tpu 
    if tpu:
        tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"]
    
        #connect the TPU cluster using the address 
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url)
    
        #run the model on different clusters 
        strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver)
    
        #convert the model to run on tpu 
        model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy)
    return model

In [None]:
x_train, y_train, x_test, y_test = get_data()

#set tpu = True if you want to run the model on TPU
model = get_model(tpu = False)

model.fit(x_train,
         y_train,
         batch_size=1024,
         epochs=10,
         validation_data=(x_test, y_test))

