In [1]:
import jax
import jax.numpy as jnp
import tensorflow as tf
import tensorflow_datasets as tfds
import time
from jax import jit, random, value_and_grad, vmap
from jax.nn import logsumexp, one_hot, swish

jax.devices()

2024-08-18 22:15:29.523618: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-18 22:15:29.532971: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-18 22:15:29.535613: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


[CudaDevice(id=0)]

In [2]:
data, info = tfds.load(name="mnist",
                       data_dir='/tmp/tfds',
                       as_supervised=True,
                       with_info=True)
data_train = data['train']
data_test  = data['test']

I0000 00:00:1724044530.954866   15489 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-08-18 22:15:30.976102: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2343] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


In [3]:
BATCH_SIZE  = 32
HEIGHT = 28
WIDTH  = 28
CHANNELS = 1
NUM_PIXELS = HEIGHT * WIDTH * CHANNELS
NUM_LABELS = info.features['label'].num_classes

In [4]:
def preprocess(img, label):
  return (tf.cast(img, tf.float32)/255.0), label

train_data = tfds.as_numpy(data_train.map(preprocess).batch(BATCH_SIZE).prefetch(1))
test_data  = tfds.as_numpy(data_test.map(preprocess).batch(BATCH_SIZE).prefetch(1))

In [5]:
def init_network_params(sizes, key=random.PRNGKey(0), scale=1e-2):

  def random_layer_params(m, n, key, scale=1e-2):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k, scale) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

In [6]:
def predict(params, image):
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = swish(outputs)

  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits

batched_predict = vmap(predict, in_axes=(None, 0))

def loss(params, images, targets):
  logits = batched_predict(params, images)
  log_preds = logits - jnp.expand_dims(logsumexp(logits, axis=1), 1)
  return -jnp.mean(targets*log_preds)

In [7]:
@jit
def batch_accuracy(params, images, targets):
  images = jnp.reshape(images, (len(images), NUM_PIXELS))
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == targets)

def accuracy(params, data):
  accs = []
  for images, targets in data:
    accs.append(batch_accuracy(params, images, targets))
  return jnp.mean(jnp.array(accs))

In [8]:
INIT_LR = 1.0
DECAY_RATE = 0.95
DECAY_STEPS = 5

@jit
def update(params, x, y, epoch_number):
  loss_value, grads = value_and_grad(loss)(params, x, y)
  lr = INIT_LR * DECAY_RATE ** (epoch_number / DECAY_STEPS)
  return [(w - lr * dw, b - lr * db)
          for (w, b), (dw, db) in zip(params, grads)], loss_value

In [9]:
LAYER_SIZES = [HEIGHT * WIDTH, 512, 10]
PARAM_SCALE = 0.01
params = init_network_params(LAYER_SIZES, random.PRNGKey(0), scale=PARAM_SCALE)

NUM_EPOCHS = 5
for epoch in range(NUM_EPOCHS):
  start_time = time.time()
  losses = []
  for x, y in train_data:
    x = jnp.reshape(x, (len(x), NUM_PIXELS))
    y = one_hot(y, NUM_LABELS)
    params, loss_value = update(params, x, y, epoch)
    losses.append(loss_value)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_data)
  test_acc = accuracy(params, test_data)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set loss {}".format(jnp.mean(jnp.array(losses))))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

Epoch 0 in 2.20 sec
Training set loss 0.0456433929502964
Training set accuracy 0.9271000027656555
Test set accuracy 0.9273162484169006
Epoch 1 in 1.56 sec
Training set loss 0.021921297535300255
Training set accuracy 0.9525666832923889
Test set accuracy 0.9517771601676941
Epoch 2 in 1.55 sec
Training set loss 0.0153444679453969
Training set accuracy 0.9641667008399963
Test set accuracy 0.9604632258415222
Epoch 3 in 1.57 sec
Training set loss 0.011921359226107597
Training set accuracy 0.9709500074386597
Test set accuracy 0.9663538336753845
Epoch 4 in 1.53 sec
Training set loss 0.009811271913349628
Training set accuracy 0.9756333231925964
Test set accuracy 0.9702475666999817
