##### Copyright 2018 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Training a Simple Neural Network, with tensorflow/datasets Data Loading

_Forked from_ `neural_network_and_data_loading.ipynb`

_Dougal Maclaurin, Peter Hawkins, Matthew Johnson, Roy Frostig, Alex Wiltschko, Chris Leary_

![JAX](https://raw.githubusercontent.com/google/jax/master/images/jax_logo_250px.png)

Let's combine everything we showed in the [quickstart notebook](https://colab.research.google.com/github/google/jax/blob/master/notebooks/quickstart.ipynb) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).

Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for builidng our model.

In [1]:
!pip install --upgrade https://storage.googleapis.com/jax-wheels/cuda92/jaxlib-0.1.6-cp36-none-linux_x86_64.whl
!pip install --upgrade jax

[33mDEPRECATION: Python 2.7 will reach the end of its life on January 1st, 2020. Please upgrade your Python as Python 2.7 won't be maintained after that date. A future version of pip will drop support for Python 2.7.[0m


In [2]:
from __future__ import print_function, division, absolute_import
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random

### Hyperparameters
Let's get a few bookkeeping items out of the way.

In [3]:
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]
param_scale = 0.1
step_size = 0.001
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))



### Auto-batching predictions

Let us first define our prediction function. Note that we're defining this for a _single_ image example. We're going to use JAX's `vmap` function to automatically handle mini-batches, with no performance penalty.

In [4]:
from jax.scipy.misc import logsumexp

def relu(x):
  return np.maximum(0, x)

def predict(params, image):
  # per-example predictions
  activations = image
  for w, b in params[:-1]:
    outputs = np.dot(w, activations) + b
    activations = relu(outputs)
  
  final_w, final_b = params[-1]
  logits = np.dot(final_w, activations) + final_b
  return logits - logsumexp(logits)

Let's check that our prediction function only works on single images.

In [5]:
# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)

(10,)


In [6]:
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))
try:
  preds = predict(params, random_flattened_images)
except TypeError:
  print('Invalid shapes!')

Invalid shapes!


In [7]:
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)

(10, 10)


At this point, we have all the ingredients we need to define our neural network and train it. We've built an auto-batched version of `predict`, which we should be able to use in a loss function. We should be able to use `grad` to take the derivative of the loss with respect to the neural network parameters. Last, we should be able to use `jit` to speed up everything.

### Utility and loss functions

In [8]:
def one_hot(x, k, dtype=np.float32):
  """Create a one-hot encoding of x of size k."""
  return np.array(x[:, None] == np.arange(k), dtype)
  
def accuracy(params, images, targets):
  target_class = np.argmax(targets, axis=1)
  predicted_class = np.argmax(batched_predict(params, images), axis=1)
  return np.mean(predicted_class == target_class)

def loss(params, images, targets):
  preds = batched_predict(params, images)
  return -np.sum(preds * targets)

@jit
def update(params, x, y):
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

### Data Loading with `tensorflow/datasets`

JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll use the `tensorflow/datasets` data loader.

In [9]:
# Install tensorflow-datasets
# TODO(rsepassi): Switch to stable version on release
!pip install -q --upgrade tfds-nightly tf-nightly

[33mDEPRECATION: Python 2.7 will reach the end of its life on January 1st, 2020. Please upgrade your Python as Python 2.7 won't be maintained after that date. A future version of pip will drop support for Python 2.7.[0m


In [10]:
import tensorflow_datasets as tfds

data_dir = '/tmp/tfds'

# Fetch full datasets for evaluation
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.dataset_as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c

# Full train set
train_images, train_labels = train_data['image'], train_data['label']
train_images = np.reshape(train_images, (len(train_images), num_pixels))
train_labels = one_hot(train_labels, num_labels)

# Full test set
test_images, test_labels = test_data['image'], test_data['label']
test_images = np.reshape(test_images, (len(test_images), num_pixels))
test_labels = one_hot(test_labels, num_labels)

Dl Completed...: 0 url [00:00, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s][A

Dl Completed...:   0%|          | 0/1 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s][A

Dl Completed...:   0%|          | 0/2 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s][A

Dl Completed...:   0%|          | 0/3 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s][A

Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s][A

Extraction completed...: 0 file [00:00, ? file/s][A[A

[1mDownloading / extracting dataset mnist (11.06 MiB) to /tmp/tfds/mnist/1.0.0...[0m


Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s][A

Dl Completed...:   0%|          | 0/4 [00:00<?, ? url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s][A

Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  3.28 url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s][A

Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  3.28 url/s]
Dl Size...: 0 MiB [00:00, ? MiB/s][A

Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  3.28 url/s]][A[A
Dl Size...:   0%|          | 0/1 [00:00<?, ? MiB/s][A

Extraction completed...:   0%|          | 0/1 [00:00<?, ? file/s][A[A

Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  3.28 url/s]2 file/s][A[A
Dl Size...:   0%|          | 0/1 [00:00<?, ? MiB/s][A

Dl Completed...:  25%|██▌       | 1/4 [00:00<00:00,  3.28 url/s]2 file/s][A[A
Dl Size...:   0%|          | 0/10 [00:00<?, ? MiB/s][A

Dl Completed...:  50%|█████     | 2/4 [00:00<00:00,  3.28 url/s]2 file/s][A[A
Dl Size...:   0%|          | 0/10 [00:00<?, ? 






60000 examples [00:42, 1422.72 examples/s]
W0129 23:00:21.376614 139949807531776 deprecation.py:323] From /usr/local/google/home/rsepassi/python/jax/local/lib/python2.7/site-packages/tensorflow_datasets/core/file_format_adapter.py:249: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`

Reading...: 0 examples [00:00, ? examples/s][A
Reading...: 6000 examples [00:00, 292112.96 examples/s][A
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s][A
Writing...: 100%|██████████| 6000/6000 [00:00<00:00, 268699.14 examples/s][A
Reading...: 0 examples [00:00, ? examples/s][A
Reading...: 6000 examples [00:00, 353919.84 examples/s][A
Writing...:   0%|          | 0/6000 [00:00<?, ? examples/s][A
Shuffling...:  20%|██        | 2/10 [00:00<00:00, 17.18 shard/s]xamples/s][A
Reading...: 0 examples [00:00, ? examples/s][A
Reading...: 6000 e


For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.




301 examples [00:00, 751.16 examples/s][A
563 examples [00:00, 955.27 examples/s][A
821 examples [00:00, 1177.28 examples/s][A
1058 examples [00:00, 1386.31 examples/s][A
1324 examples [00:00, 1618.32 examples/s][A
1566 examples [00:00, 1796.02 examples/s][A
1853 examples [00:00, 2021.03 examples/s][A
2147 examples [00:00, 2228.51 examples/s][A
2404 examples [00:01, 2306.40 examples/s][A
2660 examples [00:01, 2345.34 examples/s][A
2929 examples [00:01, 2437.74 examples/s][A
3200 examples [00:01, 2512.12 examples/s][A
3461 examples [00:01, 2517.43 examples/s][A
3728 examples [00:01, 2560.58 examples/s][A
4022 examples [00:01, 2662.09 examples/s][A
4293 examples [00:01, 2671.85 examples/s][A
4564 examples [00:01, 2577.09 examples/s][A
4825 examples [00:01, 2550.97 examples/s][A
5098 examples [00:02, 2600.71 examples/s][A
5360 examples [00:02, 2573.46 examples/s][A
5648 examples [00:02, 2658.01 examples/s][A
5916 examples [00:02, 2657.79 examples/s][A
6183 examples 

38306 examples [00:14, 2822.44 examples/s][A
38613 examples [00:14, 2891.35 examples/s][A
38909 examples [00:14, 2908.45 examples/s][A
39201 examples [00:14, 2911.31 examples/s][A
39517 examples [00:14, 2978.84 examples/s][A
39834 examples [00:15, 3032.03 examples/s][A
40138 examples [00:15, 2957.30 examples/s][A
40435 examples [00:15, 2913.94 examples/s][A
40728 examples [00:15, 2821.67 examples/s][A
41012 examples [00:15, 2683.67 examples/s][A
41283 examples [00:15, 2632.12 examples/s][A
41548 examples [00:15, 2572.40 examples/s][A
41836 examples [00:15, 2655.74 examples/s][A
42116 examples [00:15, 2695.35 examples/s][A
42394 examples [00:16, 2717.93 examples/s][A
42679 examples [00:16, 2752.96 examples/s][A
42956 examples [00:16, 2737.21 examples/s][A
43255 examples [00:16, 2808.35 examples/s][A
43544 examples [00:16, 2832.24 examples/s][A
43840 examples [00:16, 2868.14 examples/s][A
44128 examples [00:16, 2815.69 examples/s][A
44411 examples [00:16, 2730.56 exa

In [11]:
print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape)

Train: (60000, 784) (60000, 10)
Test: (10000, 784) (10000, 10)


### Training Loop

In [12]:
import time

def get_train_batches():
  # as_supervised=True gives us the (image, label) as a tuple instead of a dict
  ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)
  # You can build up an arbitrary tf.data input pipeline
  ds = ds.batch(128).prefetch(1)
  # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
  return tfds.dataset_as_numpy(ds)

for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in get_train_batches():
    x = np.reshape(x, (len(x), num_pixels))
    y = one_hot(y, num_labels)
    params = update(params, x, y)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

Epoch 0 in 5.25 sec
Training set accuracy 0.0987166687846
Test set accuracy 0.097999997437
Epoch 1 in 3.95 sec
Training set accuracy 0.0987166687846
Test set accuracy 0.097999997437
Epoch 2 in 3.96 sec
Training set accuracy 0.0987166687846
Test set accuracy 0.097999997437
Epoch 3 in 4.14 sec
Training set accuracy 0.0987166687846
Test set accuracy 0.097999997437
Epoch 4 in 3.97 sec
Training set accuracy 0.0987166687846
Test set accuracy 0.097999997437
Epoch 5 in 3.96 sec
Training set accuracy 0.0987166687846
Test set accuracy 0.097999997437
Epoch 6 in 4.06 sec
Training set accuracy 0.0987166687846
Test set accuracy 0.097999997437
Epoch 7 in 3.95 sec
Training set accuracy 0.0987166687846
Test set accuracy 0.097999997437
Epoch 8 in 4.08 sec
Training set accuracy 0.0987166687846
Test set accuracy 0.097999997437
Epoch 9 in 3.95 sec
Training set accuracy 0.0987166687846
Test set accuracy 0.097999997437


We've now used the whole of the JAX API: `grad` for derivatives, `jit` for speedups and `vmap` for auto-vectorization.
We used NumPy to specify all of our computation, and borrowed the great data loaders from `tensorflow/datasets`, and ran the whole thing on the GPU.