# Training a Simple Neural Network, with tensorflow/datasets Data Loading

https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html


In [None]:
import jax
import jax.numpy as jnp
import numpy as np
from typing import Tuple, List, Dict, Any

In [None]:
# What hardware is this running on?
print(f'jax.device_count() {jax.device_count()}')
print(f'jax.local_device_count() {jax.local_device_count()}')
for i, device in enumerate(jax.devices()):
    print(f' --- found device: {i} ')
    print(f'device_kind {device.device_kind}')
    print(f'platform {device.platform}')
    print(f'host_id {device.host_id}')

In [None]:
# What is the type of a jax random key?
RANDOM_KEY = jax.random.PRNGKey(42)
type(RANDOM_KEY)

In [None]:
def random_layer_params(
        input_size: int,
        output_size: int,
        random_key: jnp.DeviceArray,
        scale: float = 1e-2) -> Tuple[jnp.DeviceArray]:
    """ Creates a single layer of an MLP. """
    w_key, b_key = jax.random.split(random_key)
    return scale * jax.random.normal(
        w_key, (output_size, input_size)
    ), scale * jax.random.normal(
        b_key, (output_size,))


def init_network_params(
        layer_sizes: List[int],
        random_key: jnp.DeviceArray) -> List[Tuple[jnp.DeviceArray]]:
    """ Initialize a N-layer MLP. """
    layer_keys = jax.random.split(random_key, len(layer_sizes))
    network_params: List[Tuple[jnp.DeviceArray]] = []
    for in_size, out_size, key in zip(layer_sizes[:-1], layer_sizes[1:], layer_keys):
        network_params.append(random_layer_params(in_size, out_size, key))
    return network_params

In [None]:
# Hyperparameters
LAYERS = [784, 512, 512, 10]
STEP_SIZE = 0.01
NUM_EPOCHS = 10
BATCH_SIZE = 128
N_TARGETS = 10 

In [None]:
params = init_network_params(LAYERS, RANDOM_KEY)

In [None]:
def relu(x: jnp.DeviceArray) -> jnp.DeviceArray:
    """ Rectified-Linear Unit. """
    return jnp.maximum(0, x)

def predict(
    params: List[Tuple[jnp.DeviceArray]],
    image: jnp.DeviceArray,
) -> jnp.DeviceArray:
    """ Forward prediction with an MLP denoted by params. """
    x: jnp.DeviceArray = image
    for w, b in params[:-1]:
        x = relu(jnp.dot(w, x) + b)
    # last layer has no activation
    last_w, last_b = params[-1]
    x = relu(jnp.dot(last_w, x) + last_b)
    # log of the sum of exponentials of input elements
    return x - jax.scipy.special.logsumexp(x)


In [None]:
# test out with a random image
random_image_f = jax.random.normal(RANDOM_KEY, (28 * 28,))
print(f' image shape {random_image_f.shape}')
random_image = jax.random.normal(RANDOM_KEY, (28, 28,))
random_flat_image = jnp.ravel(random_image)
print(f' flat image shape {random_flat_image.shape}')
# print(f' Arrays are the same {jnp.equal(random_flat_image, random_image_f)}')

# batch of images
random_images_f = jax.random.normal(RANDOM_KEY, (BATCH_SIZE, 28 * 28,))
print(f' images shape {random_images_f.shape}')
random_images = jax.random.normal(RANDOM_KEY, (BATCH_SIZE, 28, 28,))
random_flat_images = jnp.reshape(random_images, (BATCH_SIZE, -1))
print(f' flat images shape {random_flat_images.shape}')
# print(f' Arrays are the same {jnp.equal(random_flat_images, random_images_f)}')


In [None]:
predict(params, random_flat_image).shape

In [None]:
# Use vmap to batch the prediction function
batched_predict = jax.vmap(predict, in_axes=(None, 0))
batched_predict(params, random_flat_images).shape


In [None]:
def one_hot(
    x: jnp.DeviceArray,
    k: int,
    dtype=jnp.float32,
) -> jnp.DeviceArray:
    """ One-hot encoding of size k. """
    return jnp.array(x[:, None] == jnp.arange(k), dtype)


def accuracy(
    params: List[Tuple[jnp.DeviceArray]],
    images: jnp.DeviceArray,
    targets: jnp.DeviceArray,
) -> jnp.DeviceArray:
    """ Accuracy of one-hot image prediction compared to target. """
    target_class = jnp.argmax(targets, axis=1)
    predict_class = jnp.argmax(batched_predict(params, images), axis=1)
    return jnp.mean(predict_class == target_class)


def loss(
    params: List[Tuple[jnp.DeviceArray]],
    images: jnp.DeviceArray,
    targets: jnp.DeviceArray,
) -> jnp.DeviceArray:
    """ Categorical cross entropy? """
    return -jnp.mean(batched_predict(params, images) * targets)


@jax.jit
def update(
    params: List[Tuple[jnp.DeviceArray]],
    images: jnp.DeviceArray,
    targets: jnp.DeviceArray,
) -> List[Tuple[jnp.DeviceArray]]:
    grads = jax.grad(loss)(params, images, targets)
    return [
        (
            # updated weight
            w - STEP_SIZE*dw,
            # updated bias
            b - STEP_SIZE*db
        ) for (w, b), (dw, db) in zip(params, grads)]


In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

In [None]:
DATA_DIR = '/tmp/tfds'

# Load the full MNIST dataset
mnist_data, info = tfds.load(name='mnist', batch_size=-1, data_dir=DATA_DIR, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data = mnist_data['train']
test_data = mnist_data['test']
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c

# Train dataset
train_labels = train_data['label']
train_labels = one_hot(train_labels, num_labels)
train_images = train_data['image']
train_images = jnp.reshape(train_images, (len(train_images), num_pixels))

# Test dataset
test_labels = test_data['label']
test_labels = one_hot(test_labels, num_labels)
test_images = test_data['image']
test_images = jnp.reshape(test_images, (len(test_images), num_pixels))


In [None]:
print(f'Train: {train_images.shape}, {train_labels.shape}')
print(f'Test: {test_images.shape}, {test_labels.shape}')

In [None]:
import time

def get_train_batches():
    """ Dataloader function returns batches of training data. """
    ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=DATA_DIR)
    ds = ds.batch(BATCH_SIZE).prefetch(1)
    return tfds.as_numpy(ds)

# training loop
for epoch in range(NUM_EPOCHS):
    start_time = time.time()
    for images, labels in get_train_batches():
        images = jnp.reshape(images, (len(images), num_pixels))
        labels = one_hot(labels, num_labels)
        params = update(params, images, labels)
    epoch_time = time.time()
    
    # Re-calculating on entire dataset, this is super inneficient
    train_accuracy = accuracy(params, train_images, train_labels)
    test_accuracy = accuracy(params, test_images, test_labels)
    print(f'Epoch {epoch} started at {start_time}, total duration {epoch_time - start_time}')
    print(f'\t train accuracy {train_accuracy}')
    print(f'\t test accuracy {test_accuracy}')