[View in Colaboratory](https://colab.research.google.com/github/christianmerkwirth/colabs/blob/master/TPU_Resnet_Training_on_Imagenet_III.ipynb)

In [2]:
 # colab.research.google.com specific
import sys

if 'google.colab' in sys.modules:
  import json
  import os
  from google.colab import auth

  !git clone https://github.com/tensorflow/tpu.git
  !mv tpu/models/official .

Cloning into 'tpu'...
remote: Enumerating objects: 157, done.[K
remote: Counting objects: 100% (157/157), done.[K
remote: Compressing objects: 100% (117/117), done.[K
remote: Total 2218 (delta 55), reused 92 (delta 39), pack-reused 2061[K
Receiving objects: 100% (2218/2218), 1.40 MiB | 12.77 MiB/s, done.
Resolving deltas: 100% (1307/1307), done.


In [0]:
# Now run the main to set up flags and constant. The main function should not be executed.

from official.resnet import resnet_main

In [4]:
# This cell contains the Colab-specific setup and hacks.

import json
import os
import sys
import time

from absl import flags
import absl.logging as _logging

import tensorflow as tf
tf.logging.set_verbosity(tf.logging.INFO)

FLAGS = flags.FLAGS

if 'google.colab' in sys.modules:
  from google.colab import auth

  # Authenticate to access GCS bucket
  auth.authenticate_user()

  # Initiate fake FLAGS parsing.
  FLAGS(['resnet_trainer'])
  
  # When connected to the TPU runtime
  if 'COLAB_TPU_ADDR' in os.environ:
    tpu_grpc = 'grpc://{}'.format(os.environ['COLAB_TPU_ADDR'])

    FLAGS.tpu = tpu_grpc
    FLAGS.use_tpu = True

    # Upload credentials to the TPU
    with tf.Session(tpu_grpc) as sess:
      data = json.load(open('/content/adc.json'))
      tf.contrib.cloud.configure_gcs(sess, credentials=data)


W1023 09:11:59.163012 140486750078848 _default.py:280] No project ID could be determined. Consider running `gcloud config set project` or setting the GOOGLE_CLOUD_PROJECT environment variable


In [5]:
from official.resnet import imagenet_input
from official.resnet import lars_util
from official.resnet import resnet_model


from tensorflow.python.estimator import estimator
      
# Number of training and evaluation images in the standard ImageNet dataset
NUM_TRAIN_IMAGES = 1281167
NUM_EVAL_IMAGES = 50000    

# Change parameters according to your setup.
FLAGS.model_dir = 'gs://tpu-cmerk-2/imagenet/models/resnet/20181022_007'
FLAGS.data_dir = 'gs://tpu-cmerk-2/imagenet/train'

LR_SCHEDULE = [    # (multiplier, epoch to start) tuples for piecewise linear lr
      (0.001, 0),  (1.0, 4), (0.1, 21), (0.01, 35), (0.001, 43), (0.001, 1000)
]
def piecewise_linear_lr_schedule(current_epoch):
  """
  Piecewise linear learning rate.
  Args:
    current_epoch: `Tensor` for current epoch.
  Returns:
    A scaled `Tensor` for current learning rate.
  """
  scaled_lr = 1.00
  lr = 0.0
  for i in range(1, len(LR_SCHEDULE)):
    this_lr = scaled_lr * ((LR_SCHEDULE[i][0] - LR_SCHEDULE[i-1][0]) /
                           (LR_SCHEDULE[i][1] - LR_SCHEDULE[i-1][1]) *
                           (current_epoch - LR_SCHEDULE[i-1][1]) +
                           LR_SCHEDULE[i-1][0])
    lr = tf.where(current_epoch >= LR_SCHEDULE[i-1][1], this_lr, lr)
  return lr

resnet_main.learning_rate_schedule = piecewise_linear_lr_schedule


# Convert flags to dict of params so the model_fn can use TPU specific settings.
params = FLAGS.flag_values_dict()
    
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(FLAGS.tpu)

config = tf.contrib.tpu.RunConfig(
    cluster=tpu_cluster_resolver,
    model_dir=FLAGS.model_dir,
    save_checkpoints_steps=FLAGS.iterations_per_loop,
    keep_checkpoint_max=None,
    tpu_config=tf.contrib.tpu.TPUConfig(
        iterations_per_loop=FLAGS.iterations_per_loop,
        num_shards=FLAGS.num_cores,
        per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))

# Input pipelines are slightly different (with regards to shuffling and
# preprocessing) between training and evaluation.
imagenet_train = imagenet_input.ImageNetInput(
    is_training=True,
    data_dir=FLAGS.data_dir,
    use_bfloat16=True,
    transpose_input=FLAGS.transpose_input)
imagenet_eval = imagenet_input.ImageNetInput(
    is_training=False,
    data_dir=FLAGS.data_dir,
    use_bfloat16=True,
    transpose_input=FLAGS.transpose_input)

resnet_classifier = tf.contrib.tpu.TPUEstimator(
    use_tpu=FLAGS.use_tpu,
    model_fn=resnet_main.resnet_model_fn,
    config=config,
    train_batch_size=FLAGS.train_batch_size,
    eval_batch_size=FLAGS.eval_batch_size,
    params=params)

I1023 09:12:05.666531 140486750078848 tf_logging.py:115] Using config: {'_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
cluster_def {
  job {
    name: "worker"
    tasks {
      value: "10.30.223.98:8470"
    }
  }
}
, '_keep_checkpoint_max': None, '_task_type': 'worker', '_train_distribute': None, '_is_chief': True, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fc5674afad0>, '_model_dir': 'gs://tpu-cmerk-2/imagenet/models/resnet/20181022_007', '_protocol': None, '_save_checkpoints_steps': 1251, '_keep_checkpoint_every_n_hours': 10000, '_service': None, '_num_ps_replicas': 0, '_tpu_config': TPUConfig(iterations_per_loop=1251, num_shards=8, num_cores_per_replica=None, per_host_input_for_training=3, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None), '_tf_random_seed': None, '_save_summary_steps': 100, '_device_fn': None, '_cluster': <tensorflow.contrib.cluster_resolver.python.training.tpu_cluste

In [0]:
current_step = estimator._load_global_step_from_checkpoint_dir(FLAGS.model_dir)
batches_per_epoch = NUM_TRAIN_IMAGES / FLAGS.train_batch_size
print('Training for %d steps (%.2f epochs in total). Current step %d.' % 
      (FLAGS.train_steps, FLAGS.train_steps / batches_per_epoch, current_step))

resnet_classifier.train(input_fn=imagenet_train.input_fn, max_steps=FLAGS.train_steps)

In [6]:
import time

eval_steps = NUM_EVAL_IMAGES // FLAGS.eval_batch_size

print('Starting to evaluate.')
eval_start = time.time()  # This time will include compilation time
eval_results = resnet_classifier.evaluate(
    input_fn=imagenet_eval.input_fn,
    steps=eval_steps)
eval_time = int(time.time() - eval_start)
print('Eval results: %s. Elapsed seconds: %d' % (eval_results, eval_time))

I1023 09:12:20.592689 140486750078848 tf_logging.py:115] Querying Tensorflow master (grpc://10.30.223.98:8470) for TPU system metadata.
I1023 09:12:20.605627 140486750078848 tf_logging.py:115] Found TPU system:
I1023 09:12:20.607089 140486750078848 tf_logging.py:115] *** Num TPU Cores: 8
I1023 09:12:20.612056 140486750078848 tf_logging.py:115] *** Num TPU Workers: 1
I1023 09:12:20.614027 140486750078848 tf_logging.py:115] *** Num TPU Cores Per Worker: 8
I1023 09:12:20.615044 140486750078848 tf_logging.py:115] *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 4081254722591740278)
I1023 09:12:20.618617 140486750078848 tf_logging.py:115] *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 3870344600991421031)
I1023 09:12:20.619668 140486750078848 tf_logging.py:115] *** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_GPU:0, XLA_GPU, 17179869184, 1684632002989994667

Starting to evaluate.


I1023 09:12:22.125166 140486750078848 tf_logging.py:115] Calling model_fn.
W1023 09:12:22.257639 140486750078848 tf_logging.py:125] From official/resnet/imagenet_input.py:293: parallel_interleave (from tensorflow.contrib.data.python.ops.interleave_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.experimental.parallel_interleave(...)`.
W1023 09:12:22.302369 140486750078848 tf_logging.py:125] From official/resnet/imagenet_input.py:180: map_and_batch (from tensorflow.contrib.data.python.ops.batching) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.experimental.map_and_batch(...)`.
I1023 09:12:25.992311 140486750078848 tf_logging.py:115] Done calling model_fn.
I1023 09:12:26.016001 140486750078848 tf_logging.py:115] Starting evaluation at 2018-10-23-09:12:26
I1023 09:12:26.017507 140486750078848 tf_logging.py:115] TPU job name worker
I1023 09:12:26.663330 140486750078848 tf_logging.py:115] Gra

Eval results: {'loss': 1.3765411, 'top_1_accuracy': 0.7579956, 'global_step': 112590, 'top_5_accuracy': 0.9286499}. Elapsed seconds: 37
