# Training a Simple Neural Network with PyTorch Data Loading
Reference doc: https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html

In [3]:
import time

import numpy as np
from torch import utils
from torchvision.datasets import MNIST

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax.scipy.special import logsumexp

### Hyperparameters

In [None]:
def random_layer_params(m, n, key, scale=1e-2):
    """ Generate randomly initialized weights & biases. """
    w_key, b_key = random.split(key)
    w = scale * random.normal(w_key, (n, m))
    b = scale * random.normal(b_key, (n,))
    return w, b

def init_network_params(sizes, key):
    """ Init all layers for a fully-connected nn with sizes. """
    keys = random.split(key, len(sizes))
    layer_args = zip(sizes[:-1], sizes[1:], keys)
    return [random_layer_params(m, n, k