In [0]:
import tensorflow as tf

In [0]:
if tf.gfile.Exists('./fenwicks'):
  tf.gfile.DeleteRecursively('./fenwicks')
!git clone https://github.com/fenwickslab/fenwicks.git

Cloning into 'fenwicks'...
remote: Enumerating objects: 109, done.[K
remote: Counting objects:   0% (1/109)   [Kremote: Counting objects:   1% (2/109)   [Kremote: Counting objects:   2% (3/109)   [Kremote: Counting objects:   3% (4/109)   [Kremote: Counting objects:   4% (5/109)   [Kremote: Counting objects:   5% (6/109)   [Kremote: Counting objects:   6% (7/109)   [Kremote: Counting objects:   7% (8/109)   [Kremote: Counting objects:   8% (9/109)   [Kremote: Counting objects:   9% (10/109)   [Kremote: Counting objects:  10% (11/109)   [Kremote: Counting objects:  11% (12/109)   [Kremote: Counting objects:  12% (14/109)   [Kremote: Counting objects:  13% (15/109)   [Kremote: Counting objects:  14% (16/109)   [Kremote: Counting objects:  15% (17/109)   [Kremote: Counting objects:  16% (18/109)   [Kremote: Counting objects:  17% (19/109)   [Kremote: Counting objects:  18% (20/109)   [Kremote: Counting objects:  19% (21/109)   [Kremote: Counting ob

In [0]:
%load_ext autoreload
%autoreload 2

from fenwicks.datasets import *
from fenwicks.io import *
from fenwicks.vision.models.keras_models import *
from fenwicks.utils.colab_tpu import *
from fenwicks.tpu_estimator import *

In [0]:
BUCKET = 'gs://gs_colab' #@param {type:"string"}
MODEL_DIR = 'model/InceptionResNetV2' #@param {type:"string"}
DATA_DIR = 'data/dvc' #@param {type:"string"}
WORK_DIR = 'work/dvc' #@param {type:"string"}

BATCH_SIZE = 128 #@param {type: "integer"}
EPOCHS = 35 #@param {type:"integer"}
IMG_SIZE = 299 #@param {type: "integer"}

LEARNING_RATE = 0.001 #@param {type:"number"}


In [0]:
setup_gcs()


For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.



In [0]:
model_dir = os.path.join(BUCKET, MODEL_DIR)
data_dir = os.path.join(BUCKET, DATA_DIR)
work_dir = os.path.join(BUCKET, WORK_DIR)

In [0]:
ws_dir, ws_ckpt_fn = InceptionResNetV2_ckpt(model_dir)

In [0]:
ws_vars = get_ws_vars(ws_ckpt_fn)

In [0]:
data_dir_local = untar_data(URLs.DVC, './dvc')
train_dir_local = os.path.join(data_dir_local, 'dogscats/train')
valid_dir_local = os.path.join(data_dir_local, 'dogscats/valid')

In [0]:
train_fn = os.path.join(data_dir, 'train.tfrec')
valid_fn = os.path.join(data_dir, 'valid.tfrec')

In [0]:
_ = data_dir_tfrecord(train_dir_local, train_fn)
_ = data_dir_tfrecord(valid_dir_local, valid_fn)

In [0]:
class TransferLearningNet(tf.keras.Model):
  def __init__(self, c=256):
    super().__init__()
    self.ir2 = get_InceptionResNetV2()
    self.dense = tf.keras.layers.Dense(c, use_bias=False)
    self.bn = tf.keras.layers.BatchNormalization()
    self.linear = tf.keras.layers.Dense(2, use_bias=False)

  def call(self, x):
    return self.linear(tf.nn.relu(self.bn(self.dense(self.ir2(x)))))

In [0]:
total_steps = 23000//BATCH_SIZE*EPOCHS

In [0]:
def model_fn(features, labels, mode, params):
  phase = 1 if mode == tf.estimator.ModeKeys.TRAIN else 0
  tf.keras.backend.set_learning_phase(phase)

  model = TransferLearningNet()
  logits = model(features)
  loss = tf.losses.sparse_softmax_cross_entropy(labels, logits)

  step = tf.train.get_or_create_global_step()
  lr_func = tf.train.cosine_decay_restarts(LEARNING_RATE, step, total_steps)
  opt = tf.train.AdamOptimizer(learning_rate=lr_func)
  opt = tf.contrib.tpu.CrossShardOptimizer(opt)
  with tf.control_dependencies(model.get_updates_for(features)):
    train_op = opt.minimize(loss, global_step=step)

  preds = tf.math.argmax(logits, axis=-1)
  metric_fn = lambda preds, labels: {'accuracy': tf.metrics.accuracy(labels, preds)}
  tpu_metrics = (metric_fn, [preds, labels])
  return tf.contrib.tpu.TPUEstimatorSpec(mode, loss=loss, train_op=train_op, eval_metrics=tpu_metrics)

In [0]:
parser_train = get_tfexample_image_parser(IMG_SIZE, IMG_SIZE, augment=True)
parser_eval = get_tfexample_image_parser(IMG_SIZE, IMG_SIZE)

train_input_func = lambda params: tfrecord_ds(train_fn, parser_train,
  params['batch_size'], training=True)
valid_input_func = lambda params: tfrecord_ds(valid_fn, parser_eval, 
  params['batch_size'], training=False)

In [0]:
estimator = get_tpu_estimator(23000, 2000, model_fn, work_dir, 
                             ws_dir, ws_vars, BATCH_SIZE)

estimator.train(train_input_func, steps=total_steps)
estimator.evaluate(input_fn=valid_input_func, steps=1)

INFO:tensorflow:Using config: {'_model_dir': 'gs://gs_colab/work/dvc', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.7.107.114:8470"
    }
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7ff03d1ebb00>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.7.107.114:8470', '_evaluation_master': 'grpc://10.7.107.114:8470', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=179, num_shards=None, num_cores_per_replica=N

{'accuracy': 0.987, 'global_step': 6265, 'loss': 0.11687436}

In [0]:
create_clean_dir(work_dir)