[View in Colaboratory](https://colab.research.google.com/github/christianmerkwirth/colabs/blob/master/TPU_Resnet_Training_on_Imagenet_III.ipynb)

In [0]:
 # colab.research.google.com specific
import sys

if 'google.colab' in sys.modules:
  import json
  import os
  from google.colab import auth

  !git clone https://github.com/tensorflow/tpu.git
  !cp  tpu/models/official/resnet/resnet_main.py  .
  !mv tpu/models/official .

In [0]:
# Now run the main to set up flags and constant. The main function should not be executed.

import resnet_main

In [0]:
import os
import time

from official.resnet import imagenet_input
from official.resnet import lars_util
from official.resnet import resnet_model
from tensorflow.contrib import summary
from tensorflow.contrib.tpu.python.tpu import async_checkpoint
from tensorflow.contrib.training.python.training import evaluation
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.estimator import estimator


def main(FLAGS):
    tf.logging.set_verbosity(tf.logging.INFO)
  # pass the args as params so the model_fn can use
  # the TPU specific args

  params = FLAGS.flag_values_dict()
    
  tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
      FLAGS.tpu if (FLAGS.tpu or FLAGS.use_tpu) else '')

  if FLAGS.use_async_checkpointing:
    save_checkpoints_steps = None
  else:
    save_checkpoints_steps = max(100, FLAGS.iterations_per_loop)
  config = tf.contrib.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      model_dir=FLAGS.model_dir,
      save_checkpoints_steps=save_checkpoints_steps,
      log_step_count_steps=FLAGS.log_step_count_steps,
      session_config=tf.ConfigProto(
          graph_options=tf.GraphOptions(
              rewrite_options=rewriter_config_pb2.RewriterConfig(
                  disable_meta_optimizer=True))),
      tpu_config=tf.contrib.tpu.TPUConfig(
          iterations_per_loop=FLAGS.iterations_per_loop,
          num_shards=FLAGS.num_cores,
          per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))

  resnet_classifier = tf.contrib.tpu.TPUEstimator(
      use_tpu=FLAGS.use_tpu,
      model_fn=resnet_main.resnet_model_fn,
      config=config,
      train_batch_size=FLAGS.train_batch_size,
      eval_batch_size=FLAGS.eval_batch_size,
      export_to_tpu=False)
  assert FLAGS.precision == 'bfloat16' or FLAGS.precision == 'float32', (
      'Invalid value for --precision flag; must be bfloat16 or float32.')
  tf.logging.info('Precision: %s', FLAGS.precision)
  use_bfloat16 = FLAGS.precision == 'bfloat16'

  tf.logging.info('Using dataset: %s', FLAGS.data_dir)
  imagenet_train, imagenet_eval = [
      imagenet_input.ImageNetInput(
          is_training=is_training,
          data_dir=FLAGS.data_dir,
          transpose_input=FLAGS.transpose_input,
          cache=FLAGS.use_cache and is_training,
          num_parallel_calls=FLAGS.num_parallel_calls,
          use_bfloat16=use_bfloat16) for is_training in [True, False]
  ]

  steps_per_epoch = FLAGS.num_train_images // FLAGS.train_batch_size
  eval_steps = FLAGS.num_eval_images // FLAGS.eval_batch_size

  if FLAGS.mode == 'eval':
    # Run evaluation when there's a new checkpoint
    for ckpt in evaluation.checkpoints_iterator(
        FLAGS.model_dir, timeout=FLAGS.eval_timeout):
      tf.logging.info('Starting to evaluate.')
      try:
        start_timestamp = time.time()  # This time will include compilation time
        eval_results = resnet_classifier.evaluate(
            input_fn=imagenet_eval.input_fn,
            steps=eval_steps,
            checkpoint_path=ckpt)
        elapsed_time = int(time.time() - start_timestamp)
        tf.logging.info('Eval results: %s. Elapsed seconds: %d',
                        eval_results, elapsed_time)

        # Terminate eval job when final checkpoint is reached
        current_step = int(os.path.basename(ckpt).split('-')[1])
        if current_step >= FLAGS.train_steps:
          tf.logging.info(
              'Evaluation finished after training step %d', current_step)
          break

      except tf.errors.NotFoundError:
        # Since the coordinator is on a different job than the TPU worker,
        # sometimes the TPU worker does not finish initializing until long after
        # the CPU job tells it to start evaluating. In this case, the checkpoint
        # file could have been deleted already.
        tf.logging.info('Checkpoint %s no longer exists, skipping checkpoint', ckpt)

  else:   # FLAGS.mode == 'train' or FLAGS.mode == 'train_and_eval'
    current_step = estimator._load_global_step_from_checkpoint_dir(FLAGS.model_dir)  # pylint: disable=protected-access,line-too-long
    steps_per_epoch = FLAGS.num_train_images // FLAGS.train_batch_size

    tf.logging.info('Training for %d steps (%.2f epochs in total). Current'
                    ' step %d.',
                    FLAGS.train_steps,
                    FLAGS.train_steps / steps_per_epoch,
                    current_step)

    start_timestamp = time.time()  # This time will include compilation time

    if FLAGS.mode == 'train':
      hooks = []
      if FLAGS.use_async_checkpointing:
        hooks.append(
            async_checkpoint.AsyncCheckpointSaverHook(
                checkpoint_dir=FLAGS.model_dir,
                save_steps=max(100, FLAGS.iterations_per_loop)))
      resnet_classifier.train(
          input_fn=imagenet_train.input_fn,
          max_steps=FLAGS.train_steps,
          hooks=hooks)

    else:
      assert FLAGS.mode == 'train_and_eval'
      while current_step < FLAGS.train_steps:
        # Train for up to steps_per_eval number of steps.
        # At the end of training, a checkpoint will be written to --model_dir.
        next_checkpoint = min(current_step + FLAGS.steps_per_eval,
                              FLAGS.train_steps)
        resnet_classifier.train(
            input_fn=imagenet_train.input_fn, max_steps=next_checkpoint)
        current_step = next_checkpoint

        tf.logging.info('Finished training up to step %d. Elapsed seconds %d.',
                        next_checkpoint, int(time.time() - start_timestamp))

        # Evaluate the model on the most recent model in --model_dir.
        # Since evaluation happens in batches of --eval_batch_size, some images
        # may be excluded modulo the batch size. As long as the batch size is
        # consistent, the evaluated images are also consistent.
        tf.logging.info('Starting to evaluate.')
        eval_results = resnet_classifier.evaluate(
            input_fn=imagenet_eval.input_fn,
            steps=FLAGS.num_eval_images // FLAGS.eval_batch_size)
        tf.logging.info('Eval results at step %d: %s',
                        next_checkpoint, eval_results)

      elapsed_time = int(time.time() - start_timestamp)
      tf.logging.info('Finished training up to step %d. Elapsed seconds %d.',
                      FLAGS.train_steps, elapsed_time)

    if FLAGS.export_dir is not None:
      tf.logging.info('Starting to export model.')
      resnet_classifier.export_savedmodel(
          export_dir_base=FLAGS.export_dir,
          serving_input_receiver_fn=image_serving_input_fn)


In [0]:
# colab.research.google.com specific
import sys
from absl import flags
import absl.logging as _logging
import tensorflow as tf


FLAGS = flags.FLAGS

if 'google.colab' in sys.modules:
  import json
  import os
  from google.colab import auth

  # Authenticate to access GCS bucket
  auth.authenticate_user()

  # Parse FLAGS parsing.
  FLAGS(['resnet_trainer'])
  
  # Change parameters according to your setup.
  FLAGS.train_steps = 1200000
  FLAGS.model_dir = 'gs://tpu-cmerk-2/imagenet/models/resnet/20181022_003'
  FLAGS.data_dir = 'gs://tpu-cmerk-2/imagenet/train'

  # When connected to the TPU runtime
  if 'COLAB_TPU_ADDR' in os.environ:
    tpu_grpc = 'grpc://{}'.format(os.environ['COLAB_TPU_ADDR'])

    FLAGS.tpu = tpu_grpc
    FLAGS.use_tpu = True

    # Upload credentials to the TPU
    with tf.Session(tpu_grpc) as sess:
      data = json.load(open('/content/adc.json'))
      tf.contrib.cloud.configure_gcs(sess, credentials=data)

    
main(FLAGS)