Training a simple neural network, with PyTorch data loading: https://github.com/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb

In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

I am noticing that this is basically the same as the blog post. The blog post has more words. And more examples. So, we can reference that one later if we need more fundamental examples

Let's build a simple MLP without Flax for now

In [2]:
def random_layer_params(m, n, key, scale=1e-2):
    # split the keys
    w_key, b_key = random.split(key)
    w = scale * random.normal(w_key, (n, m))
    b = scale * random.normal(b_key, (n,))
    return w, b

def init_network_params(sizes, key):
    # obtain all keys
    keys = random.split(key, len(sizes))
    MLP = []
    for m, n, k in zip(sizes[:-1], sizes[1:], keys):
        print(f'm, n, k: {m}, {n}, {k}')
        MLP.append(random_layer_params(m, n, k))
    #MLP = [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
    return MLP

layer_sizes = [28*28, 512, 512, 10]
step_size = 0.01
num_epochs = 8
batch_size = 128
n_targets = 10
# initialize the model's parameters
params = init_network_params(layer_sizes, random.key(0))

m, n, k: 784, 512, Array((), dtype=key<fry>) overlaying:
[1797259609 2579123966]
m, n, k: 512, 512, Array((), dtype=key<fry>) overlaying:
[ 928981903 3453687069]
m, n, k: 512, 10, Array((), dtype=key<fry>) overlaying:
[4146024105 2718843009]


784 inputs gets embedded into 512 dimensions, then passes through a linear layer, then gets further embedded (activated) into a 10 dimensional output, which corresponds to our MNIST digits

Now let's define our prediction function

In [3]:
from jax.scipy.special import logsumexp

def ReLU(x):
    return jnp.maximum(0, x)

def predict(params, image):
    ''' Predict the image based on model params '''
    # let initial activations in the image
    activations = image
    # pass the image through the model
    for w, b in params[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = ReLU(outputs)
    
    # from the final weights and params
    final_w, final_b = params[-1]
    
    # obtain the logits
    logits = jnp.dot(final_w, activations) + final_b
    
    # predictions are made from log softmax
    return logits - logsumexp(logits)

In [4]:
# have the model guess some random flattened image
random_flattened_image = random.normal(random.key(1), (28 * 28,))
predict(params, random_flattened_image)

Array([-2.2909994, -2.2938476, -2.2901456, -2.3175743, -2.3124409,
       -2.3007157, -2.302447 , -2.3278596, -2.2889743, -2.301602 ],      dtype=float32)

So the model doesn't really have a solid prediction, since all of them look the same.

This was just for one batch; let's use vmap so that we can apply this across all batches

In [5]:
batched_predict = vmap(predict, in_axes=(None, 0), out_axes=0)
 # create 10 batches of 28*28 images for vmap
random_flattened_images = random.normal(random.key(1), shape=(10, 28*28))
batched_preds = batched_predict(params, random_flattened_images)

Now let's use grad to take the derivative of the loss in order to train the model

First define utility and loss functions

In [29]:
def one_hot(x, k, dtype=jnp.float32):
    ''' 
    Create a one-hot encording of input x of size k 
    
    --- makes a row of length k, and encodes a 1 if the element matches the index 
        as we're iterating over the length k, and 0 otherwise.
        
        # Example one-hot vector: 
        # inp = jnp.array([2, 0, 1, 6, 2]); one_hot(inp, len(inp))
    '''
    return jnp.array(x[:, None] == jnp.arange(k), dtype)


def accuracy(params, images, targets):
    ''' Evaluates accurary of model predictions '''
    # plucks out the target, which I think should be a one-hot
    target_class = jnp.argmax(targets, axis=1)
    
    # pluck out the predicted number for each batch
    predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
    
    # get the average across all the batches as to whether
    # we made the right prediction or not
    return jnp.mean(predicted_class == target_class)


def loss(params, images, targets):
    # batched predict gets the log-softmax
    # so here we're going to average over them
    preds = batched_predict(params, images) # 
    print(f'batched predictions shape: {preds.shape}') # 128, 10 = (b, nimages)
    print(f'targets shape: {targets.shape}') # (128, 10) = same as preds
    print(f'targets[0]: {targets.shape}')
    
    # This is how you get the cross entropy loss, which happens when the 
    # targets are one-hot encoded (which they are indeed)
    # and my predictions are log-softmax outputs
    # this obtains the probability of that target, which is in preds. 
    # remember that preds contains all predictions about all of the digits 
    # what we are going to optimize / train our network to do is get the highest
    # probability which corresponds to each target element, which is one hot encoded
    # in the targets vector
    return -jnp.mean(preds * targets)


# update the model's parameters
@jit
def update(params, x, y):
    grads = grad(loss)(params, x, y)
    return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

In [30]:
import numpy as np
from jax.tree_util import tree_map
from torch.utils.data import DataLoader, default_collate
from torchvision.datasets import MNIST

def numpy_collate(batch):
    '''
    Collate function specifies how to combine a list of data samples into a batch.
    default_collate creates pytorch tensors, then tree_map converts them into numpy arrays.
    '''
    return tree_map(np.asarray, default_collate(batch))

def flatten_and_cast(pic):
    ''' Convert PIL image to flat (1-dimensional) numpy array.'''
    return np.ravel(np.array(pic, dtype=jnp.float32))

In [31]:
# Define our dataset, using torch datasets
mnist_dataset = MNIST('/tmp/mnist/', 
                      download=True, 
                      transform=flatten_and_cast)

# Create pytorch data loader with custom collate function
training_generator = DataLoader(mnist_dataset, 
                                batch_size=batch_size, 
                                collate_fn=numpy_collate)

Now let's train the model

In [32]:
# Get the full train dataset (for checking accuracy while training)
train_images = np.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)
train_labels = one_hot(np.array(mnist_dataset.train_labels), n_targets)

# Get full test dataset
mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)
test_images = jnp.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=jnp.float32)
test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets)

In [33]:
import time

for epoch in range(num_epochs):
    start_time = time.time()
    
    for x, y in training_generator:
        y = one_hot(y, n_targets)
        params = update(params, x, y)
        
    epoch_time = time.time() - start_time

    train_acc = accuracy(params, train_images, train_labels)
    test_acc = accuracy(params, test_images, test_labels)
    
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}".format(test_acc))

batched predictions shape: (128, 10)
targets shape: (128, 10)
targets[0]: (128, 10)


KeyboardInterrupt: 