In [1]:
import os
os.environ['JAX_PLATFORMS'] = 'cpu'
NUM_DEVICES = 8
os.environ['XLA_FLAGS'] = f"--xla_force_host_platform_device_count={NUM_DEVICES}"

In [2]:
import jax
import jax.numpy as jnp
import tensorflow as tf
import tensorflow_datasets as tfds
import time
from jax import jit, random, value_and_grad, vmap
from jax.nn import logsumexp, one_hot, swish
from jax.sharding import PositionalSharding

jax.devices()

2024-08-18 22:20:29.836581: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-18 22:20:29.846405: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-18 22:20:29.849142: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7)]

In [3]:
data, info = tfds.load(name="mnist",
                       data_dir='/tmp/tfds',
                       as_supervised=True,
                       with_info=True)
data_train = data['train']
data_test  = data['test']

I0000 00:00:1724044831.170295   19039 cuda_executor.cc:1001] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-08-18 22:20:31.194410: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2343] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


In [4]:
BATCH_SIZE  = 32
HEIGHT = 28
WIDTH  = 28
CHANNELS = 1
NUM_PIXELS = HEIGHT * WIDTH * CHANNELS
NUM_LABELS = info.features['label'].num_classes

In [5]:
def preprocess(img, label):
  return (tf.cast(img, tf.float32)/255.0), label

train_data = tfds.as_numpy(data_train.map(preprocess).batch(BATCH_SIZE).prefetch(1))
test_data  = tfds.as_numpy(data_test.map(preprocess).batch(BATCH_SIZE).prefetch(1))

In [6]:
def init_network_params(sizes, key=random.PRNGKey(0), scale=1e-2):

  def random_layer_params(m, n, key, scale=1e-2):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k, scale) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

In [7]:
def predict(params, image):
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = swish(outputs)

  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits

batched_predict = vmap(predict, in_axes=(None, 0))

def loss(params, images, targets):
  logits = batched_predict(params, images)
  log_preds = logits - jnp.expand_dims(logsumexp(logits, axis=1), 1)
  return -jnp.mean(targets*log_preds)

In [8]:
@jit
def batch_accuracy(params, images, targets):
  images = jnp.reshape(images, (len(images), NUM_PIXELS))
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == targets)

def accuracy(params, data):
  accs = []
  for images, targets in data:
    accs.append(batch_accuracy(params, images, targets))
  return jnp.mean(jnp.array(accs))

In [9]:
INIT_LR = 1.0
DECAY_RATE = 0.95
DECAY_STEPS = 5

@jit
def update(params, x, y, epoch_number):
  loss_value, grads = value_and_grad(loss)(params, x, y)
  lr = INIT_LR * DECAY_RATE ** (epoch_number / DECAY_STEPS)
  return [(w - lr * dw, b - lr * db)
          for (w, b), (dw, db) in zip(params, grads)], loss_value


#### 8-way data parallelism

In [10]:
LAYER_SIZES = [HEIGHT * WIDTH, 512, 10]
PARAM_SCALE = 0.01
params = init_network_params(LAYER_SIZES, random.PRNGKey(0), scale=PARAM_SCALE)

sharding = PositionalSharding(jax.devices()).reshape(8, 1)

NUM_EPOCHS  = 5
for epoch in range(NUM_EPOCHS):
  start_time = time.time()
  losses = []
  for x, y in train_data:
    x = jnp.reshape(x, (len(x), NUM_PIXELS))
    y = one_hot(y, NUM_LABELS)
    x = jax.device_put(x, sharding)
    y = jax.device_put(y, sharding)
    params = jax.device_put(params, sharding.replicate())
    params, loss_value = update(params, x, y, epoch)
    losses.append(loss_value)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_data)
  test_acc = accuracy(params, test_data)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set loss {}".format(jnp.mean(jnp.array(losses))))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

Epoch 0 in 10.08 sec
Training set loss 0.04564366862177849
Training set accuracy 0.9270833730697632
Test set accuracy 0.9273162484169006
Epoch 1 in 9.17 sec
Training set loss 0.021921737119555473
Training set accuracy 0.9524999856948853
Test set accuracy 0.9517771601676941
Epoch 2 in 9.31 sec
Training set loss 0.015344643034040928
Training set accuracy 0.9641333222389221
Test set accuracy 0.9603633880615234
Epoch 3 in 9.35 sec
Training set loss 0.011921508237719536
Training set accuracy 0.9709666967391968
Test set accuracy 0.9664536714553833
Epoch 4 in 9.42 sec
Training set loss 0.009811367839574814
Training set accuracy 0.9756333231925964
Test set accuracy 0.9702475666999817


#### 4-way data parallelism, 2-way tensor parallelism

In [16]:
#LAYER_SIZES = [HEIGHT * WIDTH, 10000, 10000, 10]
params = init_network_params(LAYER_SIZES, random.PRNGKey(0), scale=PARAM_SCALE)

sharding = PositionalSharding(jax.devices()).reshape(4, 2)

sharded_params = []
for i,(w,b) in enumerate(params):
  if i==(len(params)-1):
    w = jax.device_put(w, sharding.replicate())
    b = jax.device_put(b, sharding.replicate())
  else:
    w = jax.device_put(w, sharding.replicate(0))
    b = jax.device_put(b, sharding.replicate())
  sharded_params.append((w,b))
  #jax.debug.visualize_array_sharding(w)
  #jax.debug.visualize_array_sharding(b)

params = sharded_params

NUM_EPOCHS  = 5
for epoch in range(NUM_EPOCHS):
  start_time = time.time()
  losses = []
  for x, y in train_data:
    x = jnp.reshape(x, (len(x), NUM_PIXELS))
    y = one_hot(y, NUM_LABELS)
    x = jax.device_put(x, sharding.replicate(1))
    y = jax.device_put(y, sharding.replicate(1))
    params, loss_value = update(params, x, y, epoch)
    losses.append(jnp.sum(loss_value))
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_data)
  test_acc = accuracy(params, test_data)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set loss {}".format(jnp.mean(jnp.array(losses))))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

Epoch 0 in 5.52 sec
Training set loss 0.04564366862177849
Training set accuracy 0.9270833730697632
Test set accuracy 0.9273162484169006
Epoch 1 in 5.55 sec
Training set loss 0.021921737119555473
Training set accuracy 0.9524999856948853
Test set accuracy 0.9517771601676941
Epoch 2 in 5.41 sec
Training set loss 0.015344643034040928
Training set accuracy 0.9641333222389221
Test set accuracy 0.9603633880615234
Epoch 3 in 5.27 sec
Training set loss 0.01192150916904211
Training set accuracy 0.9709666967391968
Test set accuracy 0.9664536714553833
Epoch 4 in 5.39 sec
Training set loss 0.009811367839574814
Training set accuracy 0.9756333231925964
Test set accuracy 0.9702475666999817
