In [0]:
import numpy as np
import tensorflow as tf
import os, json, datetime, math

In [0]:
BATCH_SIZE = 512 #@param {type:"integer"}
MOMENTUM = 0.9 #@param {type:"number"}
LEARNING_RATE = 0.4 #@param {type:"number"}
WEIGHT_DECAY = 0.0005 #@param {type:"number"}
EPOCHS = 24 #@param {type:"integer"}
WARMUP = 5 #@param {type:"integer"}
BUCKET = 'gs://gs_colab' #@param {type:"string"}


In [0]:
from google.colab import auth
auth.authenticate_user()

In [0]:
with tf.Session('grpc://{}'.format(os.environ['COLAB_TPU_ADDR'])) as sess:
  with open('/content/adc.json', 'r') as f:
    auth_info = json.load(f)
  tf.contrib.cloud.configure_gcs(sess, credentials=auth_info)


For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.



In [0]:
TPU_ADDRESS = 'grpc://{}'.format(os.environ['COLAB_TPU_ADDR'])
print('Using TPU:', TPU_ADDRESS)

Using TPU: grpc://10.49.64.218:8470


In [0]:
def get_ds_from_tfrec(data_dir, training, batch_size, num_parallel_calls=12, prefetch=8, dtype=tf.float32):

  def _parser(serialized_example):
    features = tf.parse_single_example(
        serialized_example,
        features={
            "image": tf.FixedLenFeature([], tf.string),
            "label": tf.FixedLenFeature([], tf.int64),
        })
    image = tf.decode_raw(features["image"], tf.uint8)
    image = tf.reshape(image, [3, 32, 32])
    image = tf.transpose(image, [1, 2, 0])
    image = tf.cast(image, dtype)
    image = (image - [125.30691805, 122.95039414, 113.86538318]) / [62.99321928, 62.08870764, 66.70489964]
    
    label = features["label"]

    if training:
      image = tf.pad(image, [[4, 4], [4, 4], [0, 0]], mode='reflect')
      image = tf.random_crop(image, [32, 32, 3])
      image = tf.image.random_flip_left_right(image)

    return image, label

  split = 'train' if training else 'test'
  filename = os.path.join(data_dir, split + ".tfrecords")
  dataset = tf.data.TFRecordDataset(filename)
  dataset = dataset.repeat()
  dataset = dataset.map(_parser, num_parallel_calls=num_parallel_calls)

  if training:
    dataset = dataset.shuffle(50000, reshuffle_each_iteration=True)

  dataset = dataset.batch(batch_size, drop_remainder=True)
  dataset = dataset.prefetch(prefetch)

  return dataset

In [0]:
train_input_fn = lambda params: get_ds_from_tfrec(BUCKET + '/cifar10_tfrec', training=True, batch_size=params['batch_size'])
eval_input_fn = lambda params: get_ds_from_tfrec(BUCKET + '/cifar10_tfrec', training=False, batch_size=params['batch_size'])

In [0]:
def init_pytorch(shape, dtype=tf.float32, partition_info=None):
  fan = np.prod(shape[:-1])
  bound = 1 / math.sqrt(fan)
  return tf.random.uniform(shape, minval=-bound, maxval=bound, dtype=dtype)

class ConvBN(tf.keras.Model):
  def __init__(self, c_out):
    super().__init__()
    self.conv = tf.keras.layers.Conv2D(filters=c_out, kernel_size=3, padding="SAME", kernel_initializer=init_pytorch, use_bias=False)
    self.bn = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5)

  def call(self, inputs):
    return tf.nn.relu(self.bn(self.conv(inputs)))
  
class Blk(tf.keras.Model):
  def __init__(self, c_out, pool):
    super().__init__()
    self.conv_bn = ConvBN(c_out)
    self.pool = pool

  def call(self, inputs):
    return self.pool(self.conv_bn(inputs))
  
class ResBlk(tf.keras.Model):
  def __init__(self, c_out, pool):
    super().__init__()
    self.blk = Blk(c_out, pool)
    self.res1 = ConvBN(c_out)
    self.res2 = ConvBN(c_out)

  def call(self, inputs):
    h = self.blk(inputs)
    return h + self.res2(self.res1(h))
  
class DavidNet(tf.keras.Model):
  def __init__(self, c=64, weight=0.125):
    super().__init__()
    pool = tf.keras.layers.MaxPooling2D()
    self.init_conv_bn = ConvBN(c)
    self.blk1 = ResBlk(c*2, pool)
    self.blk2 = Blk(c*4, pool)
    self.blk3 = ResBlk(c*8, pool)
    self.pool = tf.keras.layers.GlobalMaxPool2D()
    self.linear = tf.keras.layers.Dense(10, kernel_initializer=init_pytorch, use_bias=False)
    self.weight = weight

  def call(self, x):
    h = self.pool(self.blk3(self.blk2(self.blk1(self.init_conv_bn(x)))))
    return self.linear(h) * self.weight
  
  def compute_grads(self, loss):
    var = self.trainable_variables
    grads = tf.gradients(loss, var)
    for g, v in zip(grads, self.trainable_variables):
      g += v * WEIGHT_DECAY * BATCH_SIZE
    return grads  

In [0]:
steps_per_epoch = 50000 // BATCH_SIZE

In [0]:
def model_fn(features, labels, mode, params):
  phase = 1 if mode == tf.estimator.ModeKeys.TRAIN else 0
  tf.keras.backend.set_learning_phase(phase)

  model = DavidNet()
  logits = model(features)
  
  step = tf.train.get_or_create_global_step()
  lr_schedule = lambda t: tf.cond(tf.less_equal(t, WARMUP), lambda: t * LEARNING_RATE / WARMUP, lambda: (EPOCHS-t) * LEARNING_RATE / (EPOCHS - WARMUP))
  lr_func = lambda: lr_schedule(tf.cast(step, tf.float32)/steps_per_epoch)/BATCH_SIZE

  opt = tf.train.MomentumOptimizer(lr_func, momentum=MOMENTUM, use_nesterov=True)
  opt = tf.contrib.tpu.CrossShardOptimizer(opt, reduction=tf.losses.Reduction.SUM)

  loss = tf.losses.sparse_softmax_cross_entropy(labels, logits, reduction=tf.losses.Reduction.SUM)

  grads = model.compute_grads(loss)
  with tf.control_dependencies(model.get_updates_for(features)):
    train_op = opt.apply_gradients(zip(grads, model.trainable_variables), global_step=step)

  classes = tf.math.argmax(logits, axis=-1)
  metric_fn = lambda classes, labels: {'accuracy': tf.metrics.accuracy(classes, labels)}
  tpu_metrics = (metric_fn, [classes, labels])
  
  return tf.contrib.tpu.TPUEstimatorSpec(
    mode=mode,
    loss=loss,
    train_op=train_op,
    eval_metrics = tpu_metrics
  )

In [0]:
now = datetime.datetime.now()
MODEL_DIR = BUCKET+"/cifar10jobs/job" + "-{}-{:02d}-{:02d}-{:02d}:{:02d}:{:02d}".format(now.year, now.month, now.day, now.hour, now.minute, now.second)

training_config = tf.contrib.tpu.RunConfig(
    cluster=tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS),
    model_dir=MODEL_DIR,
    tpu_config=tf.contrib.tpu.TPUConfig(
    iterations_per_loop=steps_per_epoch,
    per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))
   
estimator = tf.contrib.tpu.TPUEstimator(
    model_fn=model_fn,
    model_dir=MODEL_DIR,
    train_batch_size=BATCH_SIZE,
    eval_batch_size=10000,
    config=training_config)

estimator.train(train_input_fn, steps=steps_per_epoch*EPOCHS)

INFO:tensorflow:Using config: {'_model_dir': 'gs://gs_colab/cifar10jobs/job-2019-03-11-11:56:35', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
cluster_def {
  job {
    name: "worker"
    tasks {
      value: "10.49.64.218:8470"
    }
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fd2361a3b38>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.49.64.218:8470', '_evaluation_master': 'grpc://10.49.64.218:8470', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=97, num_shards=None, num_cores_

<tensorflow.contrib.tpu.python.tpu.tpu_estimator.TPUEstimator at 0x7fd23611d710>

In [0]:
estimator.evaluate(input_fn=eval_input_fn, steps=1)

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2019-03-11T11:57:48Z
INFO:tensorflow:TPU job name worker
INFO:tensorflow:Graph was finalized.
Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from gs://gs_colab/cifar10jobs/job-2019-03-11-11:56:35/model.ckpt-2328
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Init TPU system
INFO:tensorflow:Initialized TPU in 10 seconds
INFO:tensorflow:Starting infeed thread controller.
INFO:tensorflow:Starting outfeed thread controller.
INFO:tensorflow:Initialized dataset iterators in 0 seconds
INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
INFO:tensorflow:Evaluation [1/1]
INFO:tensorflow:Stop infeed thread controller
INFO:tensorflow:Shutting down InfeedController thread.
INFO:tensorflow

{'accuracy': 0.9302, 'global_step': 2328, 'loss': 262.35968}