In [1]:
import jax
from jax import numpy as jnp
from jax import grad, jit, vmap

# Create PRNGKey (PRNG State)
key = jax.random.PRNGKey(0)

# Model parameters and forward pass

In [2]:
## Create model parameters, excluding input layer, actually there are three Linear layers, each layer contains a set of (w, b), for a total of three sets of parameters

def random_layer_params(m, n, key, scale=1e-2):
    """
    A helper function to randomly initialize weights and biases
    for a dense neural network layer
    """
    w_key, b_key = jax.random.split(key)  # Update explicitly PRNG state
    return scale * jax.random.normal(w_key, (n, m)), scale * jax.random.normal(b_key, (n,))


def init_network_params(sizes, key):
    """Initialize all layers for a fully-connected neural network with sizes "sizes"
    """
    keys = jax.random.split(key, len(sizes))  # split can create multiple keys at the same time
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

In [3]:
layer_sizes = [784, 512, 512, 10]

key, init_key = jax.random.split(key)  # init_key used for initialization
params = init_network_params(layer_sizes, init_key)

print(len(params), len(params[0]), len(params[1]), len(params[2]))

3 2 2 2


In [4]:
print(params[0][0].shape, params[0][1].shape)

(512, 784) (512,)


In [5]:
# Creating a network is actually to write the forward pass
def relu(x):
    return jnp.maximum(0, x)

# Note that the x below is just an image，we don't need to implement batched_x
def model_forward(params, x):
    # per-example predictions
    for w, b in params[:-1]:
        x = jnp.dot(w, x) + b
        x = relu(x)
    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, x) + final_b
    return logits


# forward has been completed, test it
key, test_key = jax.random.split(key)
random_flattened_image = jax.random.normal(test_key, (784, ))
preds = model_forward(params, random_flattened_image)
print(preds.shape)

(10,)


In [6]:
# Create a random batch data, shape=(32, 784)
random_batched_flattened_images = jax.random.normal(jax.random.PRNGKey(1), (32, 784))
# model_forward(params, random_batched_flattened_images)  # error

In [7]:
# create a model_forward which support batch data? just to use vmap, life is so easy
batched_forward = vmap(model_forward, in_axes=(None, 0), out_axes=0)

batched_preds = batched_forward(params, random_batched_flattened_images)
print(batched_preds.shape)

(32, 10)


# Data loading

In [8]:
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data import Sampler, SequentialSampler
from torchvision.datasets import MNIST


class FlattenAndCast(object):
    def __call__(self, pic):
        return np.ravel(np.array(pic, dtype=jnp.float32))

# DataLoader returns numpy array，not torch Tensor
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)

class JAXRandomSampler(Sampler):
    def __init__(self, data_source, rng_key):
        self.data_source = data_source
        self.rng_key = rng_key
        
    def __len__(self):
        return len(self.data_source)
    
    def __iter__(self):
        self.rng_key, current_rng = jax.random.split(self.rng_key)
        return iter(jax.random.permutation(current_rng, jnp.arange(len(self))).tolist())

In [9]:
class NumpyLoader(DataLoader):
    def __init__(self, dataset, rng_key=None, batch_size=1,
                 shuffle=False, **kwargs):
        if shuffle:
            sampler = JAXRandomSampler(dataset, rng_key)
        else:
            sampler = SequentialSampler(dataset)
        
        super().__init__(dataset, batch_size, sampler=sampler, **kwargs)

In [10]:
# With the help of torchvision and NumpyLoader
mnist_dataset_train = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
key, loader_key = jax.random.split(key)
train_loader = NumpyLoader(mnist_dataset_train, loader_key, batch_size=32, shuffle=True,
                           num_workers=0, collate_fn=numpy_collate, drop_last=True)

mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False, transform=FlattenAndCast())
eval_loader = NumpyLoader(mnist_dataset_test, batch_size=128, shuffle=False, num_workers=0,
                          collate_fn=numpy_collate, drop_last=False)

# Training process

In [11]:
from jax.scipy.special import logsumexp


def loss(params, images, targets):
    logits = batched_forward(params, images)
    preds = logits - logsumexp(logits)
    return -jnp.mean(preds * targets)


@jit
def sgd_update(params, x, y, lr):
    grads = grad(loss)(params, x, y)
    return jax.tree_util.tree_map(lambda p, g: p - lr * g, params, grads)

In [12]:
def one_hot(x, k=10, dtype=jnp.float32):
    """Create a one-hot encoding of x of size k."""
    return jnp.array(x[:, None] == jnp.arange(k), dtype)

def accuracy(params, loader):
    total_acc = 0
    total_num = 0
    for x, y in loader:
        predicted_class = jnp.argmax(batched_forward(params, x), axis=1)
        total_num += len(x)
        total_acc += jnp.sum(predicted_class == y)
    return total_acc / total_num


lr = 0.01
n_classes = 10
for epoch in range(5):
    for idx, (x, y) in enumerate(train_loader):
        y = one_hot(y, n_classes)
        params = sgd_update(params, x, y, lr)
        lr = lr * 0.999 if lr > 1e-3 else 1e-3  # very simple lr scheduler
        if idx % 100 == 0:  # evaluation
            train_acc = accuracy(params, train_loader)
            eval_acc = accuracy(params, eval_loader)
            print("Epoch {} - batch_idx {}, Training set acc {}, eval set accuracy {}".format(
                  epoch, idx, train_acc, eval_acc))

Epoch 0 - batch_idx 0, Training set acc 0.09814999997615814, eval set accuracy 0.09950000047683716
Epoch 0 - batch_idx 100, Training set acc 0.8302666544914246, eval set accuracy 0.8362999558448792
Epoch 0 - batch_idx 200, Training set acc 0.8892666697502136, eval set accuracy 0.8940999507904053
Epoch 0 - batch_idx 300, Training set acc 0.8997166752815247, eval set accuracy 0.9006999731063843
Epoch 0 - batch_idx 400, Training set acc 0.9085167050361633, eval set accuracy 0.9128999710083008
Epoch 0 - batch_idx 500, Training set acc 0.9076499938964844, eval set accuracy 0.911300003528595
Epoch 0 - batch_idx 600, Training set acc 0.9230999946594238, eval set accuracy 0.9253000020980835
Epoch 0 - batch_idx 700, Training set acc 0.9269000291824341, eval set accuracy 0.9298999905586243
Epoch 0 - batch_idx 800, Training set acc 0.9295666813850403, eval set accuracy 0.9334999918937683
Epoch 0 - batch_idx 900, Training set acc 0.9290666580200195, eval set accuracy 0.9296999573707581
Epoch 0 - b


KeyboardInterrupt

