# Training a Simple Neural Network with PyTorch Data Loading
Reference doc: https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html

In [3]:
import time

import numpy as np
from torch import utils
from torchvision.datasets import MNIST

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax.scipy.special import logsumexp

### Hyperparameters

In [4]:
def random_layer_params(m, n, key, scale=1e-2):
    """ Generate randomly initialized weights & biases. """
    w_key, b_key = random.split(key)
    w = scale * random.normal(w_key, (n, m))
    b = scale * random.normal(b_key, (n,))
    return w, b

def init_network_params(sizes, key):
    """ Init all layers for a fully-connected nn with sizes. """
    keys = random.split(key, len(sizes))
    layer_args = zip(sizes[:-1], sizes[1:], keys)
    return [random_layer_params(*args) for args in layer_args]

# Hyperparams.
layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 8
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))



### Auto-batching predictions
Let us first define our prediction function.
NB: we are defining this for a *single* image example, then using JAX's `vmap` function to automatically handle mini-batches (with no performance penalty).

In [5]:
def relu(x):
    return jnp.maximum(0, x)

def predict(params, image):
    # per-sample predictions
    activations = image
    for w, b in params[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)
    
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

In [6]:
# This works on single examples.
k = random.PRNGKey(1)
random_flattened_image = random.normal(k, (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)

(10,)
