In [1]:
import os
os.environ['JAX_PLATFORMS'] = 'cpu'
NUM_DEVICES = 8
os.environ['XLA_FLAGS'] = f"--xla_force_host_platform_device_count={NUM_DEVICES}"

In [2]:
import jax
import jax.numpy as jnp
import tensorflow as tf
import tensorflow_datasets as tfds
import time
from functools import partial
from jax import jit, random, value_and_grad, vmap
from jax.nn import logsumexp, one_hot, swish
from jax.tree_util import tree_map

jax.devices()

2024-08-18 22:18:19.629452: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-18 22:18:19.639264: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-18 22:18:19.641992: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7)]

In [3]:
data, info = tfds.load(name="mnist",
                       data_dir='/tmp/tfds',
                       as_supervised=True,
                       with_info=True)
data_train = data['train']
data_test  = data['test']

I0000 00:00:1724044700.956437   18017 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-08-18 22:18:20.978009: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2343] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


In [4]:
BATCH_SIZE  = 32
HEIGHT = 28
WIDTH  = 28
CHANNELS = 1
NUM_PIXELS = HEIGHT * WIDTH * CHANNELS
NUM_LABELS = info.features['label'].num_classes

In [5]:
def preprocess(img, label):
  return (tf.cast(img, tf.float32)/255.0), label

train_data = tfds.as_numpy(data_train.map(preprocess).batch(BATCH_SIZE).prefetch(1))
test_data  = tfds.as_numpy(data_test.map(preprocess).batch(BATCH_SIZE).prefetch(1))

In [6]:
def init_network_params(sizes, key=random.PRNGKey(0), scale=1e-2):

  def random_layer_params(m, n, key, scale=1e-2):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k, scale) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

In [7]:
def predict(params, image):
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = swish(outputs)

  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits

batched_predict = vmap(predict, in_axes=(None, 0))

def loss(params, images, targets):
  logits = batched_predict(params, images)
  log_preds = logits - jnp.expand_dims(logsumexp(logits, axis=1), 1)
  return -jnp.mean(targets*log_preds)

In [8]:
@jit
def batch_accuracy(params, images, targets):
  images = jnp.reshape(images, (len(images), NUM_PIXELS))
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == targets)

def accuracy(params, data):
  accs = []
  for images, targets in data:
    accs.append(batch_accuracy(params, images, targets))
  return jnp.mean(jnp.array(accs))

In [9]:
INIT_LR = 1.0
DECAY_RATE = 0.95
DECAY_STEPS = 5

@partial(jax.pmap, axis_name='devices', in_axes=(0, 0, 0, None))
def update(params, x, y, epoch_number):
  loss_value, grads = value_and_grad(loss)(params, x, y)
  grads = [(jax.lax.psum(dw, 'devices'), jax.lax.psum(db, 'devices'))
    for dw, db in grads]
  lr = INIT_LR * DECAY_RATE ** (epoch_number / DECAY_STEPS)
  return tree_map(lambda p, g: p - lr * g, params, grads), loss_value

In [10]:
LAYER_SIZES = [HEIGHT * WIDTH, 512, 10]
PARAM_SCALE = 0.01
params = init_network_params(LAYER_SIZES, random.PRNGKey(0), scale=PARAM_SCALE)
replicated_params = tree_map(lambda x: jnp.broadcast_to(x, (NUM_DEVICES,) + x.shape), params)

NUM_EPOCHS  = 5
for epoch in range(NUM_EPOCHS):
  start_time = time.time()
  losses = []
  for x, y in train_data:
    num_elements = len(y)
    x = jnp.reshape(x, (NUM_DEVICES, num_elements//NUM_DEVICES, NUM_PIXELS))
    y = jnp.reshape(one_hot(y, NUM_LABELS), (NUM_DEVICES, num_elements//NUM_DEVICES, NUM_LABELS))
    replicated_params, loss_value = update(replicated_params, x, y, epoch)
    losses.append(jnp.sum(loss_value))
  epoch_time = time.time() - start_time

  params = tree_map(lambda x: x[0], replicated_params)
  train_acc = accuracy(params, train_data)
  test_acc = accuracy(params, test_data)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set loss {}".format(jnp.mean(jnp.array(losses))))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

Epoch 0 in 12.21 sec
Training set loss 0.18622834980487823
Training set accuracy 0.9729166626930237
Test set accuracy 0.9697483777999878
Epoch 1 in 13.60 sec
Training set loss 0.07245483249425888
Training set accuracy 0.9823166728019714
Test set accuracy 0.9731429815292358
Epoch 2 in 11.63 sec
Training set loss 0.04596871882677078
Training set accuracy 0.9859833717346191
Test set accuracy 0.9756389856338501
Epoch 3 in 10.66 sec
Training set loss 0.03174363821744919
Training set accuracy 0.9881666898727417
Test set accuracy 0.9758386611938477
Epoch 4 in 10.73 sec
Training set loss 0.024489162489771843
Training set accuracy 0.9887666702270508
Test set accuracy 0.9758386611938477
