<a href="https://colab.research.google.com/github/christianmerkwirth/colabs/blob/master/TPU_Resnet_Training_on_Imagenet_III.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# colab.research.google.com specific
import sys

if 'google.colab' in sys.modules:
  import json
  import os
  import pprint
  import tensorflow as tf

  assert 'COLAB_TPU_ADDR' in os.environ, 'ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!'
  TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']
  print('TPU address is', TPU_ADDRESS)
  from google.colab import auth
  auth.authenticate_user()
  with tf.Session(TPU_ADDRESS) as session:
    print('TPU devices:')
    pprint.pprint(session.list_devices())

    # Upload credentials to TPU.
    with open('/content/adc.json', 'r') as f:
      auth_info = json.load(f)
    tf.contrib.cloud.configure_gcs(session, credentials=auth_info)
    # Now credentials are set for all future sessions on this TPU.

    
# Clone the offical TPU model repo.    
!test -d tpu || git clone https://github.com/tensorflow/tpu.git
!test -d official || mv tpu/models/official .
!test -d common || mv tpu/models/common .

In [0]:
# Now run the main to set up flags and constant. The main function should not be executed.

from official.resnet import resnet_main

In [0]:
# This cell contains the Colab-specific setup and hacks.

import json
import os
import sys
import time

from absl import flags
import absl.logging as _logging

import tensorflow as tf
tf.logging.set_verbosity(tf.logging.INFO)

FLAGS = flags.FLAGS

if 'google.colab' in sys.modules:
  # Initiate fake FLAGS parsing.
  FLAGS(['resnet_trainer'])
  
  # When connected to the TPU runtime
  if 'COLAB_TPU_ADDR' in os.environ:
    tpu_grpc = 'grpc://{}'.format(os.environ['COLAB_TPU_ADDR'])

    FLAGS.tpu = tpu_grpc
    FLAGS.use_tpu = True

    # Upload credentials to the TPU
    with tf.Session(tpu_grpc) as sess:
      data = json.load(open('/content/adc.json'))
      tf.contrib.cloud.configure_gcs(sess, credentials=data)


In [0]:
from official.resnet import imagenet_input
from official.resnet import lars_util
from official.resnet import resnet_model

from tensorflow.python.estimator import estimator
      
# Number of training and evaluation images in the standard ImageNet dataset
NUM_TRAIN_IMAGES = 1281167
NUM_EVAL_IMAGES = 50000

# Change parameters according to your setup.
FLAGS.model_dir = 'gs://tpu-cmerk-2/imagenet/models/resnet/20181219_002'
FLAGS.data_dir = 'gs://tpu-cmerk-2/imagenet/train'

BATCHES_PER_EPOCH = NUM_TRAIN_IMAGES / FLAGS.train_batch_size

# Convert flags to dict of params so the model_fn can use TPU specific settings.
params = FLAGS.flag_values_dict()
    
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(FLAGS.tpu)

config = tf.contrib.tpu.RunConfig(
    cluster=tpu_cluster_resolver,
    model_dir=FLAGS.model_dir,
    save_checkpoints_steps=FLAGS.iterations_per_loop,
    keep_checkpoint_max=None,
    tpu_config=tf.contrib.tpu.TPUConfig(
        iterations_per_loop=FLAGS.iterations_per_loop,
        num_shards=FLAGS.num_cores,
        per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))

# Input pipelines are slightly different (with regards to shuffling and
# preprocessing) between training and evaluation.
imagenet_train = imagenet_input.ImageNetInput(
    is_training=True,
    data_dir=FLAGS.data_dir,
    use_bfloat16=True,
    transpose_input=FLAGS.transpose_input)
imagenet_eval = imagenet_input.ImageNetInput(
    is_training=False,
    data_dir=FLAGS.data_dir,
    use_bfloat16=True,
    transpose_input=FLAGS.transpose_input)

In [0]:
def learning_rate_schedule(unused_current_epoch):
  """Compute noisy cosine learningrate modulated by piecewise constant decay.
  Used tf.train.get_global_step() to determine current batch step. Learning
  rate is recomputed for each batch of the training.
  Args:
    unused_current_epoch: Unused argument.
  Returns:
    A scalar Tensor storing the current learning rate.
  """
  global_step = tf.train.get_global_step()
  scaled_lr = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0)
 
  lr1 = 0.6 + tf.math.maximum(0.00,
    tf.train.noisy_linear_cosine_decay(0.4,
                                       global_step,
                                       FLAGS.train_steps,
                                       num_periods=90.0,
                                       initial_variance=0.6,
                                       alpha=0.5))
  
  bpe = int(BATCHES_PER_EPOCH)
  
  lr2 = tf.train.piecewise_constant(
    global_step,
    boundaries = [int(0.1 * bpe), 
                  int(0.5 * bpe),
                  5 * bpe,
                  15 * bpe,
                  30 * bpe,
                  45 * bpe,
                  55 * bpe,
                  70 * bpe, 
                  80 * bpe],
    values = [0.05, 0.4, 1.0, 0.8, 0.5, 0.2, 0.1, 0.02, 0.01, 0.002])
  
  lr = scaled_lr * tf.math.maximum(0.001, lr1 * lr2)
  return lr

In [0]:
resnet_main.learning_rate_schedule = learning_rate_schedule

resnet_classifier = tf.contrib.tpu.TPUEstimator(
    use_tpu=FLAGS.use_tpu,
    model_fn=resnet_main.resnet_model_fn,
    config=config,
    train_batch_size=FLAGS.train_batch_size,
    eval_batch_size=FLAGS.eval_batch_size,
    params=params)

current_step = estimator._load_global_step_from_checkpoint_dir(FLAGS.model_dir)

print('Training for %d steps (%.2f epochs in total). Current step %d.' % 
      (FLAGS.train_steps, FLAGS.train_steps / BATCHES_PER_EPOCH, current_step))

resnet_classifier.train(input_fn=imagenet_train.input_fn, max_steps=FLAGS.train_steps)

In [0]:
import time

eval_steps = NUM_EVAL_IMAGES // FLAGS.eval_batch_size

print('Starting to evaluate.')
eval_start = time.time()  # This time will include compilation time
eval_results = resnet_classifier.evaluate(
    input_fn=imagenet_eval.input_fn,
    steps=eval_steps)
eval_time = int(time.time() - eval_start)
print('Eval results: %s. Elapsed seconds: %d' % (eval_results, eval_time))