In [130]:
# in this notebook we will create a very simple convolutional network to train
# against MNIST on multiple cpus in order to explore parallel training,
# sharding, etc... 
import os
flags = os.environ.get('XLA_FLAGS', '')

# Let's keep it simple and simulate 2 CPU devices
os.environ['XLA_FLAGS'] = flags + " --xla_force_host_platform_device_count=2"

In [131]:
import jax
jax.config.update('jax_platform_name', 'cpu')
print(f'We have 2 fake JAX devices now: {jax.devices()}')

We have 2 fake JAX devices now: [CpuDevice(id=0), CpuDevice(id=1)]


In [132]:
# important imports for this notebook
from absl import logging
import flax
import jax.numpy as jnp
from matplotlib import pyplot as plt
import numpy as np
import tensorflow_datasets as tfds
import tensorflow as tf

import optax
from flax.training import train_state
from flax.metrics import tensorboard

tf.config.experimental.set_visible_devices([], "GPU")

logging.set_verbosity(logging.INFO)

In [133]:
# Define some constants
NUM_CLASSES = 10

batch_size = 64
shuffle_buffer_size = 1000
prefetch = 10
image_size = 28
num_channels = 1

In [149]:
# Let's get the data first - start with mnist

# Decoding functions for normalization and augmentation
def normalize_image(example):
    image = example['image']
    image = tf.cast(image, tf.float32)
    image = image/ 255.

    return {'image': image, 'label': example['label']}

# These commands break up the load command into it's parts for more controls
ds_builder = tfds.builder('mnist')
ds_builder.download_and_prepare()

num_train_steps = ds_builder.info.splits['train'].num_examples

train_ds = ds_builder.as_dataset('train')

train_ds = train_ds.cache().repeat().shuffle(shuffle_buffer_size).map(normalize_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(prefetch)

INFO:absl:Load dataset info from /home/don/tensorflow_datasets/mnist/3.0.1
INFO:absl:For 'mnist/3.0.1': fields info.[citation, splits, supervised_keys, module_name] differ on disk and in the code. Keeping the one from code.
INFO:absl:Reusing dataset mnist (/home/don/tensorflow_datasets/mnist/3.0.1)
INFO:absl:Creating a tf.data.Dataset reading 1 files located in folders: /home/don/tensorflow_datasets/mnist/3.0.1.
INFO:absl:Constructing tf.data.Dataset mnist for split train, from /home/don/tensorflow_datasets/mnist/3.0.1


In [135]:
num_train_steps

60000

In [150]:
from flax import jax_utils

# Because we are simulating multiple devices, we will need to shard 
# the data. In this case we create a new dimension before batch 
# equal to the number of devices. pmap takes care of the details
# after that for the most part.
def prepare_tf_data(xs):
  """Convert a input batch from tf Tensors to numpy arrays."""
  local_device_count = jax.local_device_count()

  def _prepare(x):
    # Use _numpy() for zero-copy conversion between TF and NumPy.
    x = x._numpy()  # pylint: disable=protected-access

    # reshape (host_batch_size, height, width, 3) to
    # (local_devices, device_batch_size, height, width, 3)
    return x.reshape((local_device_count, -1) + x.shape[1:])
  
  # This will apply it across the dictionary
  return jax.tree_util.tree_map(_prepare, xs)

# Retuns a separate iterator that applys the sharding function
it = map(prepare_tf_data, train_ds)
# prefetches into memory - apparently speeds things up quite a bit on GPUs.
# Unclear why, but just do it. Something about parallizing data fetching
# and computing.
it = jax_utils.prefetch_to_device(it, 2)

2024-06-09 18:39:26.622367: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2024-06-09 18:39:26.622840: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


In [137]:
# We need a model, so let's make the simplest one possible
from flax import linen as nn

class CNN(nn.Module):
  # will use this eventually - half precision will be faster on A100 and greater
  dtype = jnp.float32
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x

In [138]:
learning_rate = 0.01
momentum = 0.9

# Model and state initialization functions
def initialized(key, image_size, num_channels, model):
  input_shape = (1, image_size, image_size, num_channels)

  @jax.jit
  def init(*args):
    return model.init(*args)

  variables = init({'params': key}, jnp.ones(input_shape, model.dtype))
  # Could add batch stats here
  return variables['params']

def create_train_state(
    rng, model, image_size, num_channels, learning_rate, momentum
):
  """Create initial training state."""
  # Get the params (and batch stats if applying batch normalization)
  params = initialized(rng, image_size, num_channels, model)
  tx = optax.sgd(
      learning_rate=learning_rate,
      momentum=momentum,
      nesterov=True,
  )
  # We will need to define a custom TrainState if we want to do use 
  # more advanced networks
  state = train_state.TrainState.create(
      apply_fn=model.apply,
      params=params,
      tx=tx,
  )
  return state

In [141]:
# The main training step
from jax import lax

def cross_entropy_loss(logits, labels):
  one_hot_labels = jax.nn.one_hot(labels, num_classes=NUM_CLASSES)
  xentropy = optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels)
  return jnp.mean(xentropy)

def compute_metrics(logits, labels):
  loss = cross_entropy_loss(logits, labels)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy,
  }

  # Combine the results across the devices. So, all metrics end up being
  # means of means
  metrics = lax.pmean(metrics, axis_name='batch')
  return metrics

def train_step(state, batch, learning_rate):
  """Perform a single training step."""

  def loss_fn(params):
    """loss function used for training."""
    logits = state.apply_fn(
        {'params': params},
        batch['image']
    )
    loss = cross_entropy_loss(logits, batch['label'])
    weight_penalty_params = jax.tree_util.tree_leaves(params)
    weight_decay = 0.0001
    weight_l2 = sum(
        jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1
    )
    weight_penalty = weight_decay * 0.5 * weight_l2
    loss = loss + weight_penalty
    return loss, logits

  lr = learning_rate

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  aux, grads = grad_fn(state.params)
  # Re-use same axis_name as in the call to `pmap(...train_step...)` below.
  grads = lax.pmean(grads, axis_name='batch')

  # Not currently using, but you can get the the variables out of the 
  # new_model_state here
  logits = aux[1]
  metrics = compute_metrics(logits, batch['label'])
  metrics['learning_rate'] = lr

  # You will also need to update any other custom state variables like batch
  # stats here.
  new_state = state.apply_gradients(
      grads=grads
  )

  return new_state, metrics


In [None]:
# We need a separate function for evaluation. No weight decay or loss needs
# to be calculated. We might want to monitor the loss though for overtraining.
def eval_step(state, batch):
  variables = {'params': state.params}
  logits = state.apply_fn(variables, batch['image'], train=False)
  return compute_metrics(logits, batch['label'])


In [155]:
# Here we will build up the train and evaluate loop
import functools
from flax.training import common_utils
from clu import metric_writers

num_epochs = 10
workdir = './pmap_tests'

writer = metric_writers.create_default_writer(
  logdir=workdir, just_logging=jax.process_index() != 0
)

rng = jax.random.key(0)
model = CNN()

steps_per_epoch = (
      num_train_steps // batch_size
  )

state = create_train_state(rng, model, image_size, num_channels, learning_rate, momentum)

# State is replicated across devices - we will run the data on both devices, but
# each device will have to have state (assume both are updated after each 
# training loop)
state = jax_utils.replicate(state)

# We pmap the training step for automatic sharding of the data array into 
# number of devices. Everything produced by the model then needs to be 
# averaged
p_train_step = jax.pmap(
      functools.partial(train_step, learning_rate=learning_rate),
      axis_name='batch',
  )

train_metrics = []

for step, batch in zip(range(steps_per_epoch*num_epochs), it):
    state, metrics = p_train_step(state, batch)

    train_metrics.append(metrics)
    # Calculate train metrics every so often (metrics from 
    # previous steps is discarded then)
    if (step + 1) % 200 == 0:
      train_metrics = common_utils.get_metrics(train_metrics)
        
      summary = {
        f'train_{k}': v
        for k, v in jax.tree_util.tree_map(
          lambda x: x.mean(), train_metrics
          ).items()
        }
      logging.info(summary)
      writer.write_scalars(step + 1, summary)
      train_metrics = []
    
    if (step + 1) % steps_per_epoch == 0:
        epoch = step // steps_per_epoch

        logging.info(
          'test epoch: %d, loss: %.4f, accuracy: %.2f',
          epoch,
          0,
          0 * 100,
      )
        


INFO:absl:{'train_accuracy': 0.7897656, 'train_learning_rate': 0.009999999, 'train_loss': 0.71739966}
INFO:absl:[200] train_accuracy=0.7897655963897705, train_learning_rate=0.009999998845160007, train_loss=0.7173996567726135
INFO:absl:{'train_accuracy': 0.9203906, 'train_learning_rate': 0.009999999, 'train_loss': 0.2663974}
INFO:absl:[400] train_accuracy=0.9203906059265137, train_learning_rate=0.009999998845160007, train_loss=0.2663973867893219
INFO:absl:{'train_accuracy': 0.9429687, 'train_learning_rate': 0.009999999, 'train_loss': 0.19493891}
INFO:absl:[600] train_accuracy=0.9429687261581421, train_learning_rate=0.009999998845160007, train_loss=0.19493891298770905
INFO:absl:{'train_accuracy': 0.95109373, 'train_learning_rate': 0.009999999, 'train_loss': 0.16796201}
INFO:absl:[800] train_accuracy=0.9510937333106995, train_learning_rate=0.009999998845160007, train_loss=0.16796201467514038
INFO:absl:test epoch: 0, loss: 0.0000, accuracy: 0.00
INFO:absl:{'train_accuracy': 0.9627344, 'tra

KeyboardInterrupt: 

In [144]:
metrics

{'accuracy': Array([0.03125, 0.03125], dtype=float32),
 'learning_rate': Array([0.01, 0.01], dtype=float32, weak_type=True),
 'loss': Array([2.3066719, 2.3066719], dtype=float32)}