<a href="https://colab.research.google.com/github/cunning-stunts/Collab/blob/master/training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# How to train a ResNet50 on RxRx1 using TPUs 

Colaboratory makes it easy to train models using [Cloud TPUs](https://cloud.google.com/tpu/), and this notebook demonstrates how to use the code in [rxrx1-utils](https://github.com/recursionpharma/rxrx1-utils) to train ResNet50 on the RxRx1 image set using Colab TPU.

Be sure to select the TPU runtime before beginning!

In [0]:
import json
import os
import sys
import tensorflow as tf

In [0]:
if 'google.colab' in sys.modules:
    !git clone https://github.com/recursionpharma/rxrx1-utils
    sys.path.append('/content/rxrx1-utils')

    from google.colab import auth
    auth.authenticate_user()
    
from rxrx.main import main

Cloning into 'rxrx1-utils'...
remote: Enumerating objects: 118, done.[K
remote: Total 118 (delta 0), reused 0 (delta 0), pack-reused 118[K
Receiving objects: 100% (118/118), 1.59 MiB | 9.13 MiB/s, done.
Resolving deltas: 100% (59/59), done.


W0728 19:27:31.382244 140184717584256 lazy_loader.py:50] 
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



## Train

Set `MODEL_DIR` to be a Google Cloud Storage bucket that you can write to.   The code will write your checkpoins to this directory.

In [0]:
MODEL_DIR = 'gs://cunning_stunts_ml'
URL_BASE_PATH = 'gs://rxrx1-us-central1/tfrecords/random-42'

# make sure we're in a TPU runtime
assert 'COLAB_TPU_ADDR' in os.environ

# set TPU-relevant args
tpu_grpc = 'grpc://{}'.format(os.environ['COLAB_TPU_ADDR'])
num_shards = 8  # colab uses Cloud TPU v2-8

# upload credentials to the TPU
with tf.Session(tpu_grpc) as sess:
    data = json.load(open('/content/adc.json'))
    tf.contrib.cloud.configure_gcs(sess, credentials=data)

tf.logging.set_verbosity(tf.logging.INFO)

main(use_tpu=True,
     tpu=tpu_grpc,
     gcp_project=None,
     tpu_zone=None,
     url_base_path=URL_BASE_PATH,
     use_cache=False,
     model_dir=MODEL_DIR,
     train_epochs=1,
     train_batch_size=512,
     num_train_images=73030,
     epochs_per_loop=1,
     log_step_count_epochs=1,
     num_cores=num_shards,
     data_format='channels_last',
     transpose_input=True,
     tf_precision='bfloat16',
     n_classes=1108,
     momentum=0.9,
     weight_decay=1e-4,
     base_learning_rate=0.2,
     warmup_epochs=5)