In [1]:
import jax
import jax.numpy as jnp
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
import time
from jax import jit, random, value_and_grad
from flax import linen as nn
from flax.training import train_state

jax.devices()

2024-08-18 22:16:09.286930: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-18 22:16:09.296576: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-18 22:16:09.299193: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


[CudaDevice(id=0)]

In [2]:
data, info = tfds.load(name="mnist",
                       data_dir='/tmp/tfds',
                       as_supervised=True,
                       with_info=True)
data_train = data['train']
data_test  = data['test']

I0000 00:00:1724044570.833302   15361 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-08-18 22:16:10.854902: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2343] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


In [3]:
BATCH_SIZE  = 32
HEIGHT = 28
WIDTH  = 28
CHANNELS = 1
NUM_PIXELS = HEIGHT * WIDTH * CHANNELS
NUM_LABELS = info.features['label'].num_classes

In [4]:
def preprocess(img, label):
  return (tf.cast(img, tf.float32)/255.0), label

train_data = tfds.as_numpy(data_train.map(preprocess).batch(BATCH_SIZE).prefetch(1))
test_data  = tfds.as_numpy(data_test.map(preprocess).batch(BATCH_SIZE).prefetch(1))

In [5]:
class MLP(nn.Module):
  """A simple MLP model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Dense(features=512)(x)
    x = nn.activation.swish(x)
    x = nn.Dense(features=10)(x)
    return x

model = MLP()

In [6]:
@jax.jit
def batch_accuracy(params, images, targets):
  images = jnp.reshape(images, (len(images), NUM_PIXELS))
  predicted_class = jnp.argmax(model.apply(params, images), axis=1)
  return jnp.mean(predicted_class == targets)

def accuracy(params, data):
  accs = []
  for images, targets in data:
    accs.append(batch_accuracy(params, images, targets))
  return jnp.mean(jnp.array(accs))

In [7]:
@jit
def update(train_state, x, y):
  def loss(params, images, targets):
    logits = train_state.apply_fn(params, images)
    loss_ce = optax.softmax_cross_entropy_with_integer_labels(
      logits=logits, labels=targets).mean()
    return loss_ce
  loss_value, grads = value_and_grad(loss)(train_state.params, x, y)
  train_state = train_state.apply_gradients(grads=grads)
  return train_state, loss_value

In [8]:
key1, key2 = random.split(random.PRNGKey(0))
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=model.init(key2, random.normal(key1, (NUM_PIXELS,))),
    tx=optax.sgd(learning_rate=0.01, momentum=0.9))

In [9]:
num_epochs = 5
for epoch in range(num_epochs):
  start_time = time.time()
  losses = []
  for x, y in train_data:
    x = jnp.reshape(x, (len(x), NUM_PIXELS))
    state, loss_value = update(state, x, y)
    losses.append(loss_value)
  epoch_time = time.time() - start_time

  train_acc = accuracy(state.params, train_data)
  test_acc = accuracy(state.params, test_data)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set loss {}".format(jnp.mean(jnp.array(losses))))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

Epoch 0 in 1.84 sec
Training set loss 0.3728047013282776
Training set accuracy 0.9355999827384949
Test set accuracy 0.9355031847953796
Epoch 1 in 1.23 sec
Training set loss 0.19248442351818085
Training set accuracy 0.9587833285331726
Test set accuracy 0.9567691683769226
Epoch 2 in 1.38 sec
Training set loss 0.13508528470993042
Training set accuracy 0.9692333340644836
Test set accuracy 0.9654552340507507
Epoch 3 in 1.31 sec
Training set loss 0.10496184974908829
Training set accuracy 0.9751333594322205
Test set accuracy 0.9698482155799866
Epoch 4 in 1.32 sec
Training set loss 0.08559435606002808
Training set accuracy 0.9791666865348816
Test set accuracy 0.972743570804596
