# Getting Started With Trax: Resnet50

## Overview

Author: Henry Allen

In this tutorial, we will use a Resnet50 to classify handwritten digits from the MNIST dataset, which should be automatically loaded into Colab.

Objectives:
1. Create an iterator to stream training data
2. Classify handrwitten digits

## Setup

Install and import the necessary packages. Notice that there are two separate versions of numpy. We only use the default numpy (onp) in this tutorial, but the jax.numpy package is essential for creating trax layers and custom loss functions.

In [1]:
! pip install -q -U trax
! pip install -q tensorflow

import os
import numpy as onp
import jax.numpy as np

import trax
import trax.layers as tl

import cv2

[?25l[K     |▉                               | 10kB 24.4MB/s eta 0:00:01[K     |█▊                              | 20kB 4.7MB/s eta 0:00:01[K     |██▋                             | 30kB 6.7MB/s eta 0:00:01[K     |███▌                            | 40kB 8.4MB/s eta 0:00:01[K     |████▍                           | 51kB 5.4MB/s eta 0:00:01[K     |█████▎                          | 61kB 6.3MB/s eta 0:00:01[K     |██████▏                         | 71kB 7.2MB/s eta 0:00:01[K     |███████                         | 81kB 8.0MB/s eta 0:00:01[K     |████████                        | 92kB 8.8MB/s eta 0:00:01[K     |████████▉                       | 102kB 7.0MB/s eta 0:00:01[K     |█████████▊                      | 112kB 7.0MB/s eta 0:00:01[K     |██████████▋                     | 122kB 7.0MB/s eta 0:00:01[K     |███████████▌                    | 133kB 7.0MB/s eta 0:00:01[K     |████████████▍                   | 143kB 7.0MB/s eta 0:00:01[K     |█████████████▎            

The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



## Data Formatting

Import mnist data from colab. These files should already be present.

In [2]:
data = onp.genfromtxt("sample_data/mnist_train_small.csv", delimiter=',')

label = data[:, 0]
print(label.shape)

train = data[:, 1:]
print(train.shape)

test = onp.genfromtxt("sample_data/mnist_test.csv", delimiter=',')
print(test.shape)

test_label = test[:, 0]
test_data = test[:, 1:]

(20000,)
(20000, 784)
(10000, 785)


## Training Data Iterator and Model

Define the model (resnet50_model) and the iterator to stream training data (mnist_iterator2). The iterator yields (x, y) tuples where X is the the training sample and y is the training label. We resize the images to be 256 x 256 pixels.

In [0]:
def resnet50_model(mode):
  return trax.models.Resnet50(d_hidden=64, n_output_classes=10, mode=mode)


def mnist_iterator2():
  """
  Generator to stream data values from the training set
  """
  count = 0
  while True:
    X = []
    y = []
    for i in range(64):
      X_i = cv2.resize(train[count].reshape(28, 28), (256, 256)).reshape((256, 256, 1))
      X += [X_i]
      y += [label[count]]
      count += 1
      count = count % 20000
    X = onp.array(X)
    y = onp.array(y)
    
    yield (X, y)

## Evaluation Stream

This stream of data is used for inference, it is basically the same as mnist_iterator2, but it streams out the test images and test labels. We use this in the inference step.

In [0]:
def mnist_eval_stream():
  count = 0
  while True:
    X = []
    y = []
    for i in range(64):
      X_i = cv2.resize(test_data[count].reshape((28,28)), (256,256)).reshape((256,256,1))
      X += [X_i]
      y += [test_label[count]]
      count = (count + 1) % 10000
    X = onp.array(X)
    y = onp.array(y)
    yield (X, y)

## Setup Trainer Inputs

This initializes the resnet inputs with a call to trax.Supervised.Inputs

In [0]:
resnet_inputs = trax.supervised.Inputs(lambda _: mnist_iterator2())

data_stream = resnet_inputs.train_stream(1)
inputs, labels = next(data_stream)
print(inputs.shape)
print(labels.shape)

(64, 256, 256, 1)
(64,)


## Train Classifier

We need to create an instance of the "Trainer" class, and initialize it with the model, loss, and inputs


In [0]:
output_dir = os.path.expanduser('~/train_mnist_dir/')
!rm -f ~/train_mnist_dir/model.pkl

trainer = trax.supervised.Trainer(
    model=resnet50_model,
    loss_fn=trax.layers.CrossEntropyLoss,
    optimizer=trax.optimizers.Adam,
    lr_schedule=trax.lr.MultifactorSchedule,
    inputs=resnet_inputs,
    output_dir=output_dir,
    has_weights=False) # Make sure to set this to FALSE for Resnet

# Train
n_epochs  = 3
train_steps = 100
eval_steps = 20
for _ in range(n_epochs):
  trainer.train_epoch(train_steps, eval_steps)

## Inference

Initialize the prediction model from the training model we just created and perform inference.

In [0]:
predict_model = resnet50_model(mode='predict')
predict_model.init_from_file("/root/train_mnist_dir/model.pkl")

eval_stream = mnist_eval_stream()
accuracy = 0
for i in range(100):
  X_test = next(eval_stream)
  y_pred = predict_model(X_test[0])
  for i in range(len(y_pred)):
    if np.argmax(y_pred[i]) == X_test[1][i]:
      accuracy += 1

print(accuracy / (100 * 64))