In [0]:
import tensorflow as tf
import os

In [0]:
if tf.gfile.Exists('./fenwicks'):
  tf.gfile.DeleteRecursively('./fenwicks')
!git clone https://github.com/fenwickslab/fenwicks.git

import fenwicks as fw

Cloning into 'fenwicks'...
remote: Enumerating objects: 87, done.[K
remote: Counting objects:   1% (1/87)   [Kremote: Counting objects:   2% (2/87)   [Kremote: Counting objects:   3% (3/87)   [Kremote: Counting objects:   4% (4/87)   [Kremote: Counting objects:   5% (5/87)   [Kremote: Counting objects:   6% (6/87)   [Kremote: Counting objects:   8% (7/87)   [Kremote: Counting objects:   9% (8/87)   [Kremote: Counting objects:  10% (9/87)   [Kremote: Counting objects:  11% (10/87)   [Kremote: Counting objects:  12% (11/87)   [Kremote: Counting objects:  13% (12/87)   [Kremote: Counting objects:  14% (13/87)   [Kremote: Counting objects:  16% (14/87)   [Kremote: Counting objects:  17% (15/87)   [Kremote: Counting objects:  18% (16/87)   [Kremote: Counting objects:  19% (17/87)   [Kremote: Counting objects:  20% (18/87)   [Kremote: Counting objects:  21% (19/87)   [Kremote: Counting objects:  22% (20/87)   [Kremote: Counting objects:  24% (21/87)  

In [0]:
BATCH_SIZE = 256 #@param {type:"integer"}
BUCKET = 'gs://gs_colab' #@param {type:"string"}
PROJECT = 'mnist' #@param {type:"string"}
EPOCHS = 24 #@param {type:"integer"}



In [0]:
fw.colab_tpu.setup_gcs()


For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.



In [0]:
_, work_dir = fw.io.get_gcs_dirs(BUCKET, PROJECT)

In [0]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
n_train, n_test = len(X_train), len(X_test)
X_train = (X_train.reshape(-1, 28, 28, 1) / 255.0).astype('float32')
X_test = (X_test.reshape(-1, 28, 28, 1) / 255.0).astype('float32')
y_train = y_train.astype('int64')
y_test = y_test.astype('int64')

In [0]:
train_input_func = lambda params: fw.io.numpy_ds(X_train, y_train, batch_size=params['batch_size'], shuffle_buf_sz=n_train, training=True)
eval_input_func = lambda params: fw.io.numpy_ds(X_test, y_test, batch_size=params['batch_size'], training=False)

In [0]:
def build_nn(c=6, c_dense=200):
  model = fw.Sequential()
  model.add(fw.layers.ConvBN(c, kernel_size=3))
  model.add(fw.layers.ConvBN(c*2, kernel_size=6, strides=2))
  model.add(fw.layers.ConvBN(c*4, kernel_size=6, strides=2))
  model.add(tf.layers.Flatten())
  model.add(fw.layers.DenseBlk(c_dense, drop_rate=0.5))
  model.add(tf.keras.layers.Dense(10, use_bias=False))
  return model

In [0]:
steps_per_epoch = n_train // BATCH_SIZE
opt_func = fw.train.adam_exp_decay(base_lr=0.0001, init_lr=0.01, decay_steps=2000)
model_func = fw.tpuest.get_clf_model_func(build_nn, opt_func)

In [0]:
est = fw.tpuest.get_tpu_estimator(n_train, n_test, model_func, work_dir, trn_bs=BATCH_SIZE)
est.train(train_input_func, steps=steps_per_epoch*EPOCHS)

INFO:tensorflow:Using config: {'_model_dir': 'gs://gs_colab/work/mnist/2019-04-11-11:17:27', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.120.130.154:8470"
    }
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f891e8135c0>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.120.130.154:8470', '_evaluation_master': 'grpc://10.120.130.154:8470', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=234, num_shards=N

<tensorflow.contrib.tpu.python.tpu.tpu_estimator.TPUEstimator at 0x7f891e8133c8>

In [0]:
result = est.evaluate(eval_input_func, steps=1)

INFO:tensorflow:Calling model_fn.
Instructions for updating:
Deprecated in favor of operator or tf.math.divide.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2019-04-11T11:19:24Z
INFO:tensorflow:TPU job name worker
INFO:tensorflow:Graph was finalized.
Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from gs://gs_colab/work/mnist/2019-04-11-11:17:27/model.ckpt-5616
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Init TPU system
INFO:tensorflow:Initialized TPU in 10 seconds
INFO:tensorflow:Starting infeed thread controller.
INFO:tensorflow:Starting outfeed thread controller.
INFO:tensorflow:Initialized dataset iterators in 0 seconds
INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
INFO:tensorflow:Evaluation [1/1]
INFO:tensorflow:Stop infeed thread controll

In [0]:
print(f'Test results: accuracy={result["accuracy"] * 100: .2f}%, loss={result["loss"]: .2f}.')

Test results: accuracy= 99.40%, loss= 0.04.


In [0]:
fw.io.create_clean_dir(work_dir)