<a href="https://colab.research.google.com/github/mbernico/gcp_colab/blob/master/Keras_TPU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from tensorflow.python.lib.io import file_io
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
from tensorflow.keras.applications import InceptionV3
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Input
from tensorflow.keras.models import Model
from tensorflow.train import AdamOptimizer
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint

In [0]:
def input_fn(train_filename, val_filename, test_filename, IM_SIZE=299, BATCH=32):
  features = {
              'image': tf.FixedLenFeature(shape=[], dtype=tf.string),
              'label': tf.FixedLenFeature(shape=[], dtype=tf.int64),
             }
  
  def _parse_record(example_proto, clip=False):
    ex = tf.parse_single_example(example_proto, features)
    
    im = tf.decode_raw(ex['image'], tf.float32)
    im = tf.reshape(im, (IM_SIZE, IM_SIZE, 3))
    im = tf.image.random_flip_left_right(im)
    im = tf.image.random_flip_up_down(im)
    im = tf.image.random_brightness(im, max_delta=0.5)
    label = tf.cast(ex['label'], tf.float32)
    
    return im, label
  
  train = tf.data.TFRecordDataset(train_filename).map(_parse_record).repeat().batch(BATCH)
  val = tf.data.TFRecordDataset(val_filename).map(_parse_record).repeat().batch(BATCH)
  test = tf.data.TFRecordDataset(test_filename).map(_parse_record).batch(BATCH)
  return train, val, test


In [0]:
train_filename = "gs://sandbox226501-acme-grape-demo/data/train.tfr"
val_filename = "gs://sandbox226501-acme-grape-demo/data/val.tfr"
test_filename = "gs://sandbox226501-acme-grape-demo/data/test.tfr"
train_iterator, val_iterator, test_iterator = input_fn(train_filename, val_filename, test_filename)

In [0]:
def build_model():
  base_model = InceptionV3(weights='imagenet', include_top=False, input_shape=(299, 299, 3))
  x = base_model.output
  x = GlobalAveragePooling2D()(x)
  x = Dense(1024, activation='relu')(x)
  predictions = Dense(1, activation='sigmoid')(x)
  model = Model(inputs=base_model.inputs, outputs=predictions)
  for layer in base_model.layers:
      layer.trainable = False
  opt = AdamOptimizer()
  model.compile(optimizer=opt, loss='binary_crossentropy', metrics=['accuracy'])
  return model

In [0]:
def create_callbacks(name):
    tensorboard_callback = TensorBoard(log_dir=os.path.join(os.getcwd(), "tb_log", name), write_graph=True, write_grads=False)
    checkpoint_callback = ModelCheckpoint(filepath="./model-weights" + name + ".{epoch:02d}-{val_loss:.6f}.hdf5", monitor='val_loss',
                                          verbose=0, save_best_only=True)
    return [tensorboard_callback, checkpoint_callback]

In [0]:
model = build_model()
callbacks = create_callbacks("keras_colab")

In [47]:


# This address identifies the TPU we'll use when configuring TensorFlow.
# You need to grant the service-id of the TPU access to your bucket too...
TPU_WORKER = 'grpc://' + os.environ['COLAB_TPU_ADDR']
tf.logging.set_verbosity(tf.logging.INFO)

tpu_model = tf.contrib.tpu.keras_to_tpu_model(
    model,
    strategy=tf.contrib.tpu.TPUDistributionStrategy(
        tf.contrib.cluster_resolver.TPUClusterResolver(TPU_WORKER)))

tpu_model.summary()

INFO:tensorflow:Querying Tensorflow master (b'grpc://10.36.109.58:8470') for TPU system metadata.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 1849449424700706414)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 5540719510239057380)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_GPU:0, XLA_GPU, 17179869184, 8967291611663988348)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 10147755256607187518)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 15132723120231712926)
INFO:tensorflow:*** Available Device: _Device

In [0]:

model.fit(train_iterator, 
          epochs=100, 
          verbose=1,
          validation_data=val_iterator,
          steps_per_epoch=22,
          validation_steps=10,
          callbacks = callbacks
          )

Epoch 1/100