# SimpleTrainer (Intermediate)

This notebook provides examples of more complex customizations using the Jaxloop SimpleTrainer.

## Setup

In [1]:
from colabtools import adhoc_import
from flax import linen as nn
import jax
import jax.numpy as jnp
from jax.random import PRNGKey
import optax
import tensorflow as tf
import tensorflow_datasets as tfds

with adhoc_import.Google3SubmittedChangelist():
  from jaxloop.trainers import simple_trainer
  from jaxloop.trainers import simple_step

from jaxloop import step
from jaxloop import types
from typing import Tuple, Mapping, Sequence

## Load Dataset

In [2]:
def get_datasets(batch_size: int):
  """Loads the MNIST dataset.

  Returns:
    train_ds: A tf.data.Dataset object containing the training data.
    test_ds: A tf.data.Dataset object containing the test data.
  """
  train_ds = tfds.load('mnist', split='train')
  test_ds = tfds.load('mnist', split='test')

  train_ds = train_ds.map(
      lambda sample: {
          'input_features': tf.cast(sample['image'], tf.float32) / 255.0,
          'label': sample['label'],
      }
  )
  test_ds = test_ds.map(
      lambda sample: {
          'input_features': tf.cast(sample['image'], tf.float32) / 255.0,
          'label': sample['label'],
      }
  )

  train_ds = train_ds.repeat().shuffle(1024)
  train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1)

  test_ds = test_ds.shuffle(1024)
  test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)

  return train_ds, test_ds

## Model

In [3]:
class SimpleCNN(nn.Module):
  """A CNN model with 2 convolutional layers, 2 pooling layers, and a dense layer."""

  num_classes: int = 10

  @nn.compact
  def __call__(self, input_features: jax.Array, train: bool = False):
    x = input_features
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=128, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))

    x = x.reshape((x.shape[0], -1))
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    logits = nn.Dense(features=self.num_classes)(x)

    return logits

## Customizing the Step Class
The default SimpleStep class uses the MSE loss function. We extend the SimpleStep class and override its loss function.

In [4]:
def cross_entropy_loss_fn(logits: jax.Array, labels: jax.Array) -> jax.Array:
  """Computes the softmax cross-entropy loss with integer labels."""
  return jnp.mean(
      optax.softmax_cross_entropy_with_integer_labels(
          logits=logits, labels=labels
      )
  )


class CrossEntropyStep(simple_step.SimpleStep):
  """A step that uses cross-entropy loss."""

  # Recall that the SimpleStep expects the key for output features to be "output_features".
  # We will override _get_output_features to instead extract output_features with the key "label".
  def _get_output_features(self, batch: dict) -> jax.Array:
    """Extracts the labels from the batch to be used as the ground truth."""
    return batch["label"]

  def loss_fn(
      self,
      output_features_pred: jax.Array,
      true_output_features: jax.Array,
  ) -> jax.Array:
    """Computes the cross-entropy loss between the predicted logits and the true labels.

    Args:
      output_features_pred: The predicted logits from the model.
      true_output_features: The true integer labels.

    Returns:
      The cross-entropy loss.
    """
    return cross_entropy_loss_fn(output_features_pred, true_output_features)

## Initialization & Training

In [5]:
# Hyperparameters
batch_size = 32
learning_rate = 0.001

NUM_EPOCHS = 10
STEPS_PER_EPOCH = 125

num_train_steps = STEPS_PER_EPOCH * NUM_EPOCHS

In [6]:
train_ds, test_ds = get_datasets(batch_size)

In [7]:
"""We get the shape of our input_features from the train_ds with .element_spec and

convert the resulting TensorShape to a Python tuple. Then we get the data type
of
our input_features and convert to a numpy data type. The resulting tuple is
the input_features for the BATCH_SPEC. The same process is repeated for the
label.
"""

BATCH_SPEC: Mapping[str, Tuple[Sequence[int], type]] = {
    "input_features": (
        tuple(train_ds.element_spec["input_features"].shape),
        train_ds.element_spec["input_features"].dtype.as_numpy_dtype,
    ),
    "label": (
        tuple(train_ds.element_spec["label"].shape),
        train_ds.element_spec["label"].dtype.as_numpy_dtype,
    ),
}

In [8]:
CNN_MODEL = SimpleCNN()
OPTIMIZER = optax.adam(learning_rate)

prng_seed = 0
prng = PRNGKey(prng_seed)
BASE_PRNG = {"params": prng}

CHECKPOINTING_CONFIG = None

# Create an instance of the CrossEntropyStep class
CROSS_ENTROPY_STEP = CrossEntropyStep(
    base_prng=BASE_PRNG,
    model=CNN_MODEL,
    optimizer=OPTIMIZER,
    train=True,
)

# Create the trainer
trainer = simple_trainer.SimpleTrainer(
    model=CNN_MODEL,
    epochs=NUM_EPOCHS,
    steps_per_epoch=STEPS_PER_EPOCH,
    batch_spec=BATCH_SPEC,
    step_class=CROSS_ENTROPY_STEP.__class__, # Pass a reference to the type of the instance
    optimizer=OPTIMIZER,
    base_prng=BASE_PRNG,
    log_num_params=True,
    checkpointing_config=CHECKPOINTING_CONFIG,
)

In [9]:
train_outputs = trainer.train(train_ds.as_numpy_iterator())
trained_model_state = trainer.model_state

## Testing & Visualization

In [10]:
@jax.jit
def predict_batch(model_state: types.TrainState, batch_input_features: jax.Array) -> jax.Array:
  logits = model_state.apply_fn({'params': model_state.params}, batch_input_features, train=False)
  return logits

all_test_losses = []
all_test_accuracies = []
total_samples = 0
total_correct_predictions = 0

test_batch = None
logits = None

for test_batch in test_ds.as_numpy_iterator():
  batch_input_features = test_batch['input_features']
  batch_labels = test_batch['label']

  logits = predict_batch(trained_model_state, batch_input_features)

  batch_loss = cross_entropy_loss_fn(logits, batch_labels)
  all_test_losses.append(batch_loss.item())

  predicted_classes = jnp.argmax(logits, axis=-1)
  correct_predictions_in_batch = jnp.sum(predicted_classes == batch_labels)
  all_test_accuracies.append(correct_predictions_in_batch.item() / batch_labels.shape[0])

  total_correct_predictions += correct_predictions_in_batch.item()
  total_samples += batch_labels.shape[0]

average_test_loss = sum(all_test_losses) / len(all_test_losses)
average_test_accuracy = total_correct_predictions / total_samples

print(f"Average Test Loss: {average_test_loss}")
print(f"Average Test Accuracy: {average_test_accuracy}")

# Store the last batch for visualization.
example_test_batch = test_batch
example_logits = logits
example_preds = jnp.argmax(example_logits, axis=-1)

In [11]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(5, 5, figsize=(12, 12))
fig.suptitle('MNIST Test Set Predictions', fontsize=16)

for i, ax in enumerate(axs.flatten()):
    if i < len(example_preds):
        ax.imshow(jnp.squeeze(example_test_batch['input_features'][i]), cmap='gray')

        true_label = example_test_batch['label'][i]
        predicted_label = example_preds[i]
        ax.set_title(f"True: {true_label}\nPred: {predicted_label}",
                     color='green' if true_label == predicted_label else 'red')

        ax.axis('off')
    else:
        ax.axis('off')

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()