[View in Colaboratory](https://colab.research.google.com/github/christianmerkwirth/colabs/blob/master/TPU_Resnet_Training_on_Imagenet_III.ipynb)

In [0]:
 # colab.research.google.com specific
import sys

if 'google.colab' in sys.modules:
  import json
  import os
  from google.colab import auth

  !git clone https://github.com/tensorflow/tpu.git
  !cp  tpu/models/official/resnet/resnet_main.py  .
  !mv tpu/models/official .

In [0]:
# Now run the main to set up flags and constant. The main function should not be executed.

import resnet_main

In [0]:
# This cell contains the Colab-specific setup and hacks.

import json
import os
import sys
import time

from absl import flags
import absl.logging as _logging

import tensorflow as tf

FLAGS = flags.FLAGS

if 'google.colab' in sys.modules:
  from google.colab import auth

  # Authenticate to access GCS bucket
  auth.authenticate_user()

  # Initiate fake FLAGS parsing.
  FLAGS(['resnet_trainer'])
  
  # When connected to the TPU runtime
  if 'COLAB_TPU_ADDR' in os.environ:
    tpu_grpc = 'grpc://{}'.format(os.environ['COLAB_TPU_ADDR'])

    FLAGS.tpu = tpu_grpc
    FLAGS.use_tpu = True

    # Upload credentials to the TPU
    with tf.Session(tpu_grpc) as sess:
      data = json.load(open('/content/adc.json'))
      tf.contrib.cloud.configure_gcs(sess, credentials=data)


In [0]:
from official.resnet import imagenet_input
from official.resnet import lars_util
from official.resnet import resnet_model
from official.resnet import resnet_main

from tensorflow.python.estimator import estimator
      
# Number of training and evaluation images in the standard ImageNet dataset
NUM_TRAIN_IMAGES = 1281167
NUM_EVAL_IMAGES = 50000    

# Change parameters according to your setup.
FLAGS.model_dir = 'gs://tpu-cmerk-2/imagenet/models/resnet/20181022_005'
FLAGS.data_dir = 'gs://tpu-cmerk-2/imagenet/train'
      

tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(FLAGS.tpu)

config = tf.contrib.tpu.RunConfig(
    cluster=tpu_cluster_resolver,
    model_dir=FLAGS.model_dir,
    save_checkpoints_steps=FLAGS.iterations_per_loop,
    keep_checkpoint_max=None,
    tpu_config=tf.contrib.tpu.TPUConfig(
        iterations_per_loop=FLAGS.iterations_per_loop,
        num_shards=FLAGS.num_cores,
        per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))

# Input pipelines are slightly different (with regards to shuffling and
# preprocessing) between training and evaluation.
imagenet_train = imagenet_input.ImageNetInput(
    is_training=True,
    data_dir=FLAGS.data_dir,
    use_bfloat16=True,
    transpose_input=FLAGS.transpose_input)
imagenet_eval = imagenet_input.ImageNetInput(
    is_training=False,
    data_dir=FLAGS.data_dir,
    use_bfloat16=True,
    transpose_input=FLAGS.transpose_input)

resnet_classifier = tf.contrib.tpu.TPUEstimator(
    use_tpu=FLAGS.use_tpu,
    model_fn=resnet_main.resnet_model_fn,
    config=config,
    train_batch_size=FLAGS.train_batch_size,
    eval_batch_size=FLAGS.eval_batch_size)

In [0]:
current_step = estimator._load_global_step_from_checkpoint_dir(FLAGS.model_dir)
batches_per_epoch = NUM_TRAIN_IMAGES / FLAGS.train_batch_size
tf.logging.info('Training for %d steps (%.2f epochs in total). Current'
                ' step %d.' % (FLAGS.train_steps,
                               FLAGS.train_steps / batches_per_epoch,
                               current_step))

resnet_classifier.train(input_fn=imagenet_train.input_fn, max_steps=FLAGS.train_steps)

In [0]:
import time

eval_steps = NUM_EVAL_IMAGES // FLAGS.eval_batch_size

tf.logging.info('Starting to evaluate.')
eval_start = time.time()  # This time will include compilation time
eval_results = resnet_classifier.evaluate(
    input_fn=imagenet_eval.input_fn,
    steps=eval_steps)
eval_time = int(time.time() - eval_start)
tf.logging.info('Eval results: %s. Elapsed seconds: %d' % (eval_results, eval_time))