<a href="https://colab.research.google.com/github/mohsenh17/jaxLearning/blob/main/flax/MNIST_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Train and test set (TensorFlow)

In [32]:
import tensorflow_datasets as tfds
import tensorflow as tf

from flax import nnx
from functools import partial

import jax.numpy as jnp

import optax


In [22]:
tf.random.set_seed(42)

train_steps = 1400
eval_every = 200
batch_size = 32

mnist_train: tf.data.Dataset = tfds.load('mnist', split='train')
mnist_test: tf.data.Dataset = tfds.load('mnist', split='test')

mnist_train = mnist_train.map(lambda sample: {'image': tf.cast(sample['image'], tf.float32) / 255., 'label': sample['label']})
mnist_test = mnist_test.map(lambda sample: {'image': tf.cast(sample['image'], tf.float32) / 255., 'label': sample['label']})

# the current setting won't cover all images -> (1875, 32)
mnist_train = mnist_train.repeat().shuffle(1024)
mnist_train = mnist_train.batch(batch_size, drop_remainder=True).take(train_steps).prefetch(1)

mnist_test = mnist_test.batch(batch_size, drop_remainder=True).prefetch(1)






In [28]:
class CNN(nnx.Module):
  def __init__(self, *, num_classes: int, rngs:nnx.Rngs):
    self.num_classes = num_classes
    self.conv1 = nnx.Conv(in_features=1, out_features=32, kernel_size=(3, 3), rngs=rngs)
    self.conv2 = nnx.Conv(in_features=32, out_features=64, kernel_size=(3, 3), rngs=rngs)
    self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
    self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
    self.linear2 = nnx.Linear(256, num_classes, rngs=rngs)
  def __call__(self, x):
    x = self.avg_pool(nnx.gelu(self.conv1(x)))
    x = self.avg_pool(nnx.gelu(self.conv2(x)))
    x = x.reshape(x.shape[0], -1)  # flatten
    x = nnx.gelu(self.linear1(x))
    x = self.linear2(x)
    return x

# Instantiate the model.
model = CNN(num_classes=10, rngs=nnx.Rngs(0))
# Visualize it.
nnx.display(model)

CNN(
  num_classes=10,
  conv1=Conv(
    kernel_shape=(3, 3, 1, 32),
    kernel=Param(
      value=Array(shape=(3, 3, 1, 32), dtype=float32)
    ),
    bias=Param(
      value=Array(shape=(32,), dtype=float32)
    ),
    in_features=1,
    out_features=32,
    kernel_size=(3, 3),
    strides=1,
    padding='SAME',
    input_dilation=1,
    kernel_dilation=1,
    feature_group_count=1,
    use_bias=True,
    mask=None,
    dtype=None,
    param_dtype=<class 'jax.numpy.float32'>,
    precision=None,
    kernel_init=<function variance_scaling.<locals>.init at 0x797f231a68c0>,
    bias_init=<function zeros at 0x797f3e206b00>,
    conv_general_dilated=<function conv_general_dilated at 0x797f3ea2b760>
  ),
  conv2=Conv(
    kernel_shape=(3, 3, 32, 64),
    kernel=Param(
      value=Array(shape=(3, 3, 32, 64), dtype=float32)
    ),
    bias=Param(
      value=Array(shape=(64,), dtype=float32)
    ),
    in_features=32,
    out_features=64,
    kernel_size=(3, 3),
    strides=1,
    padding='S

In [31]:
y = model(jnp.ones((1, 28, 28, 1)))
nnx.display(y)

[[-0.0380987  -0.1663384   0.04239376 -0.1821506   0.05388633  0.03874649
  -0.03934722  0.05054386  0.17043065 -0.06008959]]


In [34]:
lr = 1e-3
momentum = 0.9

optimizer = nnx.Optimizer(model, optax.adamw(lr, momentum))

metrics = nnx.MultiMetric(
  accuracy=nnx.metrics.Accuracy(),
  loss=nnx.metrics.Average('loss'),
)

nnx.display(metrics)
nnx.display(optimizer)

MultiMetric(
  accuracy=Accuracy(
    argname='values',
    total=MetricState(
      value=Array(0., dtype=float32)
    ),
    count=MetricState(
      value=Array(0, dtype=int32)
    )
  ),
  loss=Average(
    argname='loss',
    total=MetricState(
      value=Array(0., dtype=float32)
    ),
    count=MetricState(
      value=Array(0, dtype=int32)
    )
  )
)
Optimizer(
  step=OptState(
    value=Array(0, dtype=uint32)
  ),
  model=CNN(
    num_classes=10,
    conv1=Conv(
      kernel_shape=(3, 3, 1, 32),
      kernel=Param(
        value=Array(shape=(3, 3, 1, 32), dtype=float32)
      ),
      bias=Param(
        value=Array(shape=(32,), dtype=float32)
      ),
      in_features=1,
      out_features=32,
      kernel_size=(3, 3),
      strides=1,
      padding='SAME',
      input_dilation=1,
      kernel_dilation=1,
      feature_group_count=1,
      use_bias=True,
      mask=None,
      dtype=None,
      param_dtype=<class 'jax.numpy.float32'>,
      precision=None,
      kernel_ini

In [35]:
def loss_fn(model: CNN, batch):
  logits = model(batch['image'])
  loss = optax.softmax_cross_entropy_with_integer_labels(
    logits=logits, labels=batch['label']
  ).mean()
  return loss, logits

@nnx.jit
def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
  """Train for a single step."""
  grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(model, batch)
  metrics.update(loss=loss, logits=logits, labels=batch['label'])  # In-place updates.
  optimizer.update(grads)  # In-place updates.

@nnx.jit
def eval_step(model: CNN, metrics: nnx.MultiMetric, batch):
  loss, logits = loss_fn(model, batch)
  metrics.update(loss=loss, logits=logits, labels=batch['label'])  # In-place updates.

In [37]:
metrics_history = {
  'train_loss': [],
  'train_accuracy': [],
  'test_loss': [],
  'test_accuracy': [],
}

for step, batch in enumerate(mnist_train.as_numpy_iterator()):
  # Run the optimization for one step and make a stateful update to the following:
  # - The train state's model parameters
  # - The optimizer state
  # - The training loss and accuracy batch metrics
  train_step(model, optimizer, metrics, batch)

  if step > 0 and (step % eval_every == 0 or step == train_steps - 1):  # One training epoch has passed.
    # Log the training metrics.
    for metric, value in metrics.compute().items():  # Compute the metrics.
      metrics_history[f'train_{metric}'].append(value)  # Record the metrics.
    metrics.reset()  # Reset the metrics for the test set.

    # Compute the metrics on the test set after each training epoch.
    for test_batch in mnist_test.as_numpy_iterator():
      eval_step(model, metrics, test_batch)

    # Log the test metrics.
    for metric, value in metrics.compute().items():
      metrics_history[f'test_{metric}'].append(value)
    metrics.reset()  # Reset the metrics for the next training epoch.

    print(
      f"[train] step: {step}, "
      f"loss: {metrics_history['train_loss'][-1]}, "
      f"accuracy: {metrics_history['train_accuracy'][-1] * 100}"
    )
    print(
      f"[test] step: {step}, "
      f"loss: {metrics_history['test_loss'][-1]}, "
      f"accuracy: {metrics_history['test_accuracy'][-1] * 100}"
    )

[train] step: 200, loss: 0.4062541127204895, accuracy: 87.98196411132812
[test] step: 200, loss: 0.14917251467704773, accuracy: 95.49000549316406
[train] step: 400, loss: 0.13892975449562073, accuracy: 95.828125
[test] step: 400, loss: 0.09353181719779968, accuracy: 96.98999786376953
[train] step: 600, loss: 0.09272447228431702, accuracy: 97.203125
[test] step: 600, loss: 0.06901627033948898, accuracy: 97.83999633789062
[train] step: 800, loss: 0.0753866508603096, accuracy: 97.65625
[test] step: 800, loss: 0.0621148981153965, accuracy: 97.98999786376953
[train] step: 1000, loss: 0.06575710326433182, accuracy: 97.703125
[test] step: 1000, loss: 0.050169169902801514, accuracy: 98.41999816894531
[train] step: 1200, loss: 0.06956712901592255, accuracy: 97.890625
[test] step: 1200, loss: 0.05240246653556824, accuracy: 98.2699966430664
[train] step: 1399, loss: 0.058782048523426056, accuracy: 98.13127899169922
[test] step: 1399, loss: 0.05572550743818283, accuracy: 98.18999481201172
