In [None]:
import jax.numpy as jnp
from jax import grad, jit, random, tree_util, relu
from jax.nn import sparse_softmax_cross_entropy_with_logits
from jax.scipy.special import logsumexp
import tensorflow as tf
import tensorflow_datasets as tfds

# tf.config.set_visible_devices([], device_type='GPU')

In [None]:
# Load MNIST from TensorFlow Datasets
data_dir = '/tmp/tfds' # data_dir = './data/tfds'
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)

E0000 00:00:1764349571.743580    7916 cuda_executor.cc:1309] INTERNAL: CUDA Runtime error: Failed call to cudaGetRuntimeVersion: Error loading CUDA libraries. GPU will not be used.: Error loading CUDA libraries. GPU will not be used.
W0000 00:00:1764349571.751990    7916 gpu_device.cc:2342] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


In [None]:
def normalise(x, x_max=255.0):
    return x / x_max

def convert_to_jax(data_np, data_type):
    if data_type == "image":
        data_jax = normalise(jnp.array(data_np, dtype=jnp.float32))
    elif data_type == "label":
        data_jax = jnp.array(data_np)
    else:
        raise ValueError("not image or label")
    return data_jax

def flatten_image_for_mlp(data_jax):
    n_batch, n_pixels_vertical, n_pixels_horizontal, n_channels = data_jax.shape
    data_flattened = data_jax.reshape(n_batch, -1)
    return data_flattened

def prepare_data(data_dict):
    data_jax = {}
    for data_type, data_tf in data_dict.items():
        data_numpy = data_tf.numpy()
        data_jax[data_type] = convert_to_jax(data_numpy, data_type)
        if data_type == "image":
            data_jax[data_type] = flatten_image_for_mlp(data_jax[data_type])
    return data_jax

In [None]:
dataset_tf = "train"
all_data_tf = mnist_data[dataset_tf]
all_data_jax = prepare_data(all_data_tf)

In [None]:
images = all_data_jax["image"]
labels = all_data_jax["label"]

In [None]:
print("Images shape:", images.shape)
print("Labels shape:", labels.shape)

Images shape: (60000, 784)
Labels shape: (60000,)


In [None]:
def calculate_preactivations(X, W, b):
    """
    X is the input matrix (rows=samples, cols=pixels) - shape: (batch_size, 784)
    W is the weights matrix (rows=input_dim, cols=output_dim) - shape: (784, 128)
    b is the bias vector - shape: (128,)
    Returns: (batch_size, 128)
    """
    return X @ W + b

In [None]:
def forward_pass(X, params_list):
    """Forward pass through all layers; X: (batch,n_in), returns (batch,n_out)"""
    for layer_number, params in enumerate(params_list):
        W, b = params["W"], params["b"]
        X = calculate_preactivations(X, W, b)
        if layer_number != (len(params_list) - 1):
            X = relu(X)
    return X

In [None]:
def calculate_loss(predictions_logits, observed_labels):
    return sparse_softmax_cross_entropy_with_logits(predictions_logits, observed_labels).mean()

### TODO
- vmap instead of having batch as part of the input matrix dimension
- initialisation of parameters
- backprop

### Data details
- n cols x n cells x n colour channels
- logits output (10 classes)