In [1]:
import jax.numpy as jnp
from jax import grad, jit, random, tree_util, vmap
from jax.nn import relu
from jax.scipy.special import logsumexp
import tensorflow as tf
import tensorflow_datasets as tfds

In [2]:
# Load MNIST from TensorFlow Datasets
data_dir = '/tmp/tfds' # data_dir = './data/tfds'
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)

E0000 00:00:1764616896.260375   10649 cuda_executor.cc:1309] INTERNAL: CUDA Runtime error: Failed call to cudaGetRuntimeVersion: Error loading CUDA libraries. GPU will not be used.: Error loading CUDA libraries. GPU will not be used.
W0000 00:00:1764616896.277408   10649 gpu_device.cc:2342] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


In [3]:
def normalise(x, x_max=255.0):
    return x / x_max

def convert_to_jax(data_np, data_type):
    if data_type == "image":
        data_jax = normalise(jnp.array(data_np, dtype=jnp.float32))
    elif data_type == "label":
        data_jax = jnp.array(data_np)
    else:
        raise ValueError("not image or label")
    return data_jax

def flatten_image_for_mlp(data_jax):
    """Produces one greyscale vector per sample"""
    n_batch, n_pixels_vertical, n_pixels_horizontal, n_channels = data_jax.shape
    data_flattened = data_jax.reshape(n_batch, -1)
    return data_flattened

def prepare_data(data_dict):
    data_jax = {}
    for data_type, data_tf in data_dict.items():
        data_numpy = data_tf.numpy()
        data_jax[data_type] = convert_to_jax(data_numpy, data_type)
        if data_type == "image":
            data_jax[data_type] = flatten_image_for_mlp(data_jax[data_type])
    return data_jax

In [4]:
dataset_tf = "train"
all_data_tf = mnist_data[dataset_tf]
all_data_jax = prepare_data(all_data_tf)



In [5]:
images = all_data_jax["image"]
labels = all_data_jax["label"]

In [6]:
print("Images shape:", images.shape)
print("Labels shape:", labels.shape)

Images shape: (60000, 784)
Labels shape: (60000,)


In [7]:
def calculate_preactivations(x, W, b):
    """
    x is a row vector (single sample) - shape: (784,)
    W is the weights matrix - shape: (784, 128)
    b is the bias vector - shape: (128,)
    Returns: row vector - shape: (128,)
    """
    return x @ W + b

In [8]:
def forward_pass(x, params_list):
    """Forward pass through all layers: inputs and outputs are vectors"""
    for layer_number, params in enumerate(params_list):
        W, b = params["W"], params["b"]
        x = calculate_preactivations(x, W, b)
        if layer_number != (len(params_list) - 1):
            x = relu(x)
    return x

In [9]:
def calculate_loss(predictions_logits, observed_label):
    log_probs = predictions_logits - logsumexp(predictions_logits)
    return -log_probs[observed_label] 

In [10]:
forward_pass_batch = vmap(forward_pass, in_axes=(0, None))
calculate_loss_batch = vmap(calculate_loss, in_axes=(0, 0))

In [11]:
def initialise_params(dims, key, scale=0.1):
    return scale * random.normal(key, dims)

def initialise_layer(m, n, key):
    """Initialize weights and biases for one layer"""
    w_key, b_key = random.split(key)
    return {
        "W": initialise_params((m, n), w_key),
        "b": initialise_params((n,), b_key)
    }

def initialise_network(sizes, key):
    """Initialize all layers for a fully-connected neural network"""
    params_list = []
    keys = random.split(key, len(sizes) - 1)
    
    for i in range(len(sizes) - 1):
        input_size = sizes[i]
        output_size = sizes[i + 1]
        layer_key = keys[i]
        layer_params = initialise_layer(input_size, output_size, layer_key)
        params_list.append(layer_params)
    
    return params_list

In [12]:
def calculate_mean_loss_batch(params, images, labels):
    logits = forward_pass_batch(images, params)
    loss = calculate_loss_batch(logits, labels)
    mean_loss = jnp.mean(loss)
    return mean_loss
    
calculate_gradients_by_param = grad(calculate_mean_loss_batch)

def update_parameters(params, gradients_by_param, learning_rate):
    return tree_util.tree_map(
        lambda p, g: p - learning_rate * g, 
        params, 
        gradients_by_param
    )

@jit
def take_training_step(params, images, labels, learning_rate=0.1):
    gradients_by_param = calculate_gradients_by_param(params, images, labels)
    params_new = update_parameters(params, gradients_by_param, learning_rate)
    return params_new

In [13]:
def run_training(images, labels, n_steps, layer_sizes, key):
    params = initialise_network(layer_sizes, key)
    for step in range(n_steps):
        params = take_training_step(params, images, labels)
        loss = calculate_mean_loss_batch(params, images, labels)
        print(f"step {step} complete: loss = {loss}")
    print(f"training finished after {n_steps}: final loss = {loss}")
    return params

In [15]:
layer_sizes = [784, 128, 10]
key = random.key(0)

trial_set_size = 20
test_images = images[:trial_set_size]
test_labels = labels[:trial_set_size]

# Train the model
params = run_training(test_images, test_labels, n_steps=5, layer_sizes=layer_sizes, key=key)

# Then use it
logits = forward_pass_batch(test_images, params)
loss = calculate_loss_batch(logits, test_labels)
predictions = jnp.argmax(logits, axis=1)

step 0 complete: loss = 1.7302042245864868
step 1 complete: loss = 1.4029744863510132
step 2 complete: loss = 1.1516979932785034
step 3 complete: loss = 0.9586992263793945
step 4 complete: loss = 0.8078511357307434
training finished after 5: final loss = 0.8078511357307434


In [17]:
print("True labels:    ", test_labels)
print("Predictions:    ", predictions)
print("Match:          ", predictions == test_labels)
print("Loss            ", loss)

True labels:     [4 1 0 7 8 1 2 7 1 6 6 4 7 7 3 3 7 9 9 1]
Predictions:     [7 1 0 7 8 1 2 7 1 6 6 4 7 7 3 3 7 9 7 1]
Match:           [False  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True False  True]
Loss             [1.9643625  0.53660583 1.072345   0.45093012 0.9924766  0.38990068
 1.3392543  0.36016273 0.3386774  0.85717463 1.1086264  1.5149956
 0.34358168 0.6781447  0.74361444 0.48400784 0.31850433 1.1063684
 1.2447559  0.31253457]


### Data details
- n cols x n cells x n colour channels
- logits output (10 classes)

### Tutorial
- https://docs.jax.dev/en/latest/notebooks/neural_network_with_tfds_data.html