### Flaxの[Getting Started](https://flax.readthedocs.io/en/latest/getting_started.html)を参照

## Getting Started

This tutorial demonstrates how to construct a simple convolutional neural network (CNN) using the [Flax](https://flax.readthedocs.io/en/latest/) Linen API and train the network for image classification on the MNIST dataset.

Note: This notebook is based on Flax’s official [MNIST Example](https://github.com/google/flax/tree/main/examples/mnist). If you see any changes between the two feel free to create a [pull request](https://github.com/google/flax/compare) to synchronize this Colab with the actual example.

# 1. Imports

Import JAX, [JAX NumPy](https://jax.readthedocs.io/en/latest/jax.numpy.html), Flax, ordinary NumPy, and TensorFlow Datasets (TFDS). Flax can use any data-loading pipeline and this example demonstrates how to utilize TFDS.

In [2]:
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_strict_conv_algorithm_picker=false"
os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".5"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

In [1]:
import jax
import jax.numpy as jnp                # JAX NumPy

from flax import linen as nn           # The Linen API
from flax.training import train_state  # Useful dataclass to keep train state

import numpy as np                     # Ordinary NumPy
import optax                           # Optimizers
import tensorflow_datasets as tfds     # TFDS for MNIST

In [3]:
import tensorflow as tf

gpus = tf.config.list_physical_devices('GPU')
if gpus:
  try:
    # Currently, memory growth needs to be the same across GPUs
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized
    print(e)

In [4]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

cpu


## 2. Define network

Create a convolutional neural network with the Linen API by subclassing [Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#core-module-abstraction). Because the architecture in this example is relatively simple—you’re just stacking layers—you can define the inlined submodules directly within the `__call__` method and wrap it with the [@compact](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#compact-methods) decorator. To learn more about the Flax Linen @compact decorator, refer to the [setup vs compact](https://flax.readthedocs.io/en/latest/guides/setup_or_nncompact.html) guide.

In [5]:
class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x

# 3. Define loss

We simply use `optax.softmax_cross_entropy()`. Note that this function expects both `logits` and `labels` to have shape `[batch, num_classes]`. Since the labels will be read from TFDS as integer values, we first need to convert them to a onehot encoding.

Our function returns a simple scalar value ready for optimization, so we first take the mean of the vector shaped `[batch]` returned by Optax’s loss function.

In [6]:
def cross_entropy_loss(*, logits, labels):
  labels_onehot = jax.nn.one_hot(labels, num_classes=10)
  return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()

## 4. Metric computation

For loss and accuracy metrics, create a separate function:

In [7]:
def compute_metrics(*, logits, labels):
  loss = cross_entropy_loss(logits=logits, labels=labels)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  metrics = {
      'loss': loss,
      'accuracy': accuracy,
  }
  return metrics

## 5. Loading data

Define a function that loads and prepares the MNIST dataset and converts the samples to floating-point numbers.

In [8]:
def get_datasets():
  """Load MNIST train and test datasets into memory."""
  ds_builder = tfds.builder('mnist')
  ds_builder.download_and_prepare()
  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
  train_ds['image'] = jnp.float32(train_ds['image']) / 255.
  test_ds['image'] = jnp.float32(test_ds['image']) / 255.
  return train_ds, test_ds

## 6. Create train state

A common pattern in Flax is to create a single dataclass that represents the entire training state, including step number, parameters, and optimizer state.

Also adding optimizer & model to this state has the advantage that we only need to pass around a single argument to functions like `train_step()` (see below).

Because this is such a common pattern, Flax provides the class [flax.training.train_state.TrainState](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#compact-methods) that serves most basic usecases. Usually one would subclass it to add more data to be tracked, but in this example we can use it without any modifications.

In [9]:
def create_train_state(rng, learning_rate, momentum):
  """Creates initial `TrainState`."""
  cnn = CNN()
  params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params'] # initialize parameters by passing a template image
  tx = optax.sgd(learning_rate, momentum)
  return train_state.TrainState.create(
      apply_fn=cnn.apply, params=params, tx=tx)

## 7. Training step

A function that:

- Evaluates the neural network given the parameters and a batch of input images with the [Module.apply](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.apply) method (forward pass).
- Computes the `cross_entropy_loss` loss function.
- Evaluates the gradient of the loss function using [jax.grad](https://jax.readthedocs.io/en/latest/jax.html#jax.grad).
- Applies a [pytree](https://jax.readthedocs.io/en/latest/pytrees.html#pytrees-and-jax-functions) of gradients to the optimizer to update the model’s parameters.
- Computes the metrics using `compute_metrics` (defined earlier).

Use JAX’s [@jit](https://jax.readthedocs.io/en/latest/jax.html#jax.jit) decorator to trace the entire `train_step` function and just-in-time compile it with [XLA](https://www.tensorflow.org/xla) into fused device operations that run faster and more efficiently on hardware accelerators.

In [10]:
@jax.jit
def train_step(state, batch):
  """Train for a single step."""
  def loss_fn(params):
    logits = CNN().apply({'params': params}, batch['image'])
    loss = cross_entropy_loss(logits=logits, labels=batch['label'])
    return loss, logits
  grad_fn = jax.grad(loss_fn, has_aux=True)
  grads, logits = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  metrics = compute_metrics(logits=logits, labels=batch['label'])
  return state, metrics

## 8. Evaluation step
Create a function that evaluates your model on the test set with [Module.apply](https://flax.readthedocs.io/en/latest/api_reference/flax.linen.html#flax.linen.Module.apply)

In [11]:
@jax.jit
def eval_step(params, batch):
  logits = CNN().apply({'params': params}, batch['image'])
  return compute_metrics(logits=logits, labels=batch['label'])

## 9. Train function

Define a training function that:

- Shuffles the training data before each epoch using [jax.random.permutation](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.permutation.html) that takes a PRNGKey as a parameter (check the [JAX - the sharp bits](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#JAX-PRNG)).
- Runs an optimization step for each batch.
- Asynchronously retrieves the training metrics from the device with `jax.device_get` and computes their mean across each batch in an epoch.
- Returns the optimizer with updated parameters and the training loss and accuracy metrics.



In [12]:
def train_epoch(state, train_ds, batch_size, epoch, rng):
  """Train for a single epoch."""
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, train_ds_size) # get a randomized index array
  perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
  perms = perms.reshape((steps_per_epoch, batch_size)) # index array, where each row is a batch
  batch_metrics = []
  for perm in perms:
    batch = {k: v[perm, ...] for k, v in train_ds.items()} # dict{'image': array, 'label': array}
    state, metrics = train_step(state, batch)
    batch_metrics.append(metrics)

  # compute mean of metrics across each batch in epoch.
  batch_metrics_np = jax.device_get(batch_metrics)
  epoch_metrics_np = {
      k: np.mean([metrics[k] for metrics in batch_metrics_np])
      for k in batch_metrics_np[0]} # jnp.mean does not work on lists

  print('train epoch: %d, loss: %.4f, accuracy: %.2f' % (
      epoch, epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100))

  return state

## 10. Eval function

Create a model evaluation function that:

- Retrieves the evaluation metrics from the device with `jax.device_get`.
- Copies the metrics [data stored](https://flax.readthedocs.io/en/latest/advanced_topics/linen_design_principles.html#how-are-parameters-represented-and-how-do-we-handle-general-differentiable-algorithms-that-update-stateful-variables) in a JAX [pytree](https://jax.readthedocs.io/en/latest/pytrees.html#pytrees-and-jax-functions).

In [13]:
def eval_model(params, test_ds):
  metrics = eval_step(params, test_ds)
  metrics = jax.device_get(metrics)
  summary = jax.tree_util.tree_map(lambda x: x.item(), metrics) # map the function over all leaves in metrics
  return summary['loss'], summary['accuracy']

## 11. Download data

In [14]:
train_ds, test_ds = get_datasets()

2022-12-21 13:41:08.697161: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


## 12. Seed randomness

- Get one [PRNGKey](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.PRNGKey.html#jax.random.PRNGKey) and [split](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.split.html#jax.random.split) it to get a second key that you’ll use for parameter initialization. (Learn more about [PRNG chains](https://flax.readthedocs.io/en/latest/advanced_topics/linen_design_principles.html#how-are-parameters-represented-and-how-do-we-handle-general-differentiable-algorithms-that-update-stateful-variables) and [JAX PRNG design](https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html).)

In [15]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

## 13. Initialize train state

Remember that function initializes both the model parameters and the optimizer and puts both into the training state dataclass that is returned.

In [16]:
learning_rate = 0.1
momentum = 0.9

In [17]:
state = create_train_state(init_rng, learning_rate, momentum)
del init_rng  # Must not be used anymore.

## 14. Train and evaluate

Once the training and testing is done after 10 epochs, the output should show that your model was able to achieve approximately 99% accuracy.

In [18]:
num_epochs = 10
batch_size = 32

In [19]:
import time

start_time = time.time()
for epoch in range(1, num_epochs + 1):
  start_epoch_time = time.time()
  # Use a separate PRNG key to permute image data during shuffling
  rng, input_rng = jax.random.split(rng)
  # Run an optimization step over a training batch
  state = train_epoch(state, train_ds, batch_size, epoch, input_rng)
  epoch_time = time.time() - start_epoch_time
  # Evaluate on the test set after each training epoch
  test_loss, test_accuracy = eval_model(state.params, test_ds)
  print(' test epoch: %d, loss: %.2f, accuracy: %.2f' % (
      epoch, test_loss, test_accuracy * 100))
  print('Epoch %d in %.2f sec' % (epoch, epoch_time))

total_time = time.time() - start_time
print('Total in %.2f sec' % (total_time))

train epoch: 1, loss: 0.1405, accuracy: 95.81
 test epoch: 1, loss: 0.05, accuracy: 98.22
Epoch 1 in 21.68 sec
train epoch: 2, loss: 0.0502, accuracy: 98.44
 test epoch: 2, loss: 0.04, accuracy: 98.66
Epoch 2 in 20.89 sec
train epoch: 3, loss: 0.0343, accuracy: 98.94
 test epoch: 3, loss: 0.03, accuracy: 98.97
Epoch 3 in 21.48 sec
train epoch: 4, loss: 0.0259, accuracy: 99.16
 test epoch: 4, loss: 0.04, accuracy: 99.00
Epoch 4 in 21.08 sec
train epoch: 5, loss: 0.0198, accuracy: 99.42
 test epoch: 5, loss: 0.04, accuracy: 98.93
Epoch 5 in 20.91 sec
train epoch: 6, loss: 0.0180, accuracy: 99.42
 test epoch: 6, loss: 0.05, accuracy: 98.48
Epoch 6 in 21.08 sec
train epoch: 7, loss: 0.0153, accuracy: 99.53
 test epoch: 7, loss: 0.05, accuracy: 98.60
Epoch 7 in 20.78 sec
train epoch: 8, loss: 0.0133, accuracy: 99.59
 test epoch: 8, loss: 0.05, accuracy: 98.93
Epoch 8 in 21.16 sec
train epoch: 9, loss: 0.0108, accuracy: 99.65
 test epoch: 9, loss: 0.05, accuracy: 99.00
Epoch 9 in 20.90 sec
t

Congrats! You made it to the end of the annotated MNIST example. You can revisit the same example, but structured differently as a couple of Python modules, test modules, config files, another Colab, and documentation in Flax’s Git repo: [https://github.com/google/flax/tree/main/examples/mnist](https://github.com/google/flax/tree/main/examples/mnist)