## Custom Jaxloop

Jaxloop comprises five key parts: model, datasets, steps, inner loops, and outer loop. To construct an experiment using Jaxloop, follow these steps:

1. **Define the Model:**  Specify the architecture and functionality of your machine learning model.
2. **Define Datasets:**  Prepare your training and evaluation data, ensuring proper formatting and loading procedures.
3. **Define Steps:**  Create individual functions for distinct operations, such as calculating loss, updating parameters, or evaluating metrics. In Keras terms, a step is similar to an "iteration" - a single forward and backward pass on a batch of data.
4. **Define Inner Loops:**  Combine steps into iterative processes, like a training loop that updates model parameters over multiple batches of data or evaluation loops that assesses the performance of your model parameters. In Keras, an inner loop is analogous to an "epoch" - one complete pass through the entire dataset (or one pass of a given number of data samples).
5. **Define and Invoke the Outer Loop:**  Orchestrate the overall experimental workflow, potentially encompassing multiple inner loops for training, evaluation, and hyperparameter tuning. In Keras, this is similar to the overall "training loop" that encompasses all epochs.

Each of these components will be explored in detail below.

<!-- TODO(b/379344058) Descibe how to install Jaxloop and its dependencies using pip or link to installation instructions -->
**Note:** To run the code examples provided in this documentation, please install Jaxloop and its dependencies.

# Model

Jaxloop is compatible with models written using either [Flax Linen](http://shortn/_x49iy2sUHL) or the newer [Flax NNX](http://shortn/_wisT9XxhAv) API.

For the purposes of this colab, we'll utilize a standard CNN example.

In [None]:
from flax import linen as nn

class CNN(nn.Module):

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x

# Datasets

Jaxloop requires the training dataset to be provided as an **iterator**. This is essential because training involves iterating over different subsets of the dataset.

Here's why:

* **Statefulness:** Iterators maintain an internal state, allowing them to keep track of their position within the dataset. This ensures that each training epoch covers distinct data subsets.
* **Lazy Evaluation:** Iterators generate data on demand, rather than loading the entire dataset into memory at once. This is crucial for handling large datasets efficiently.

Jaxloop offers flexibility in how you provide the training data:

* **Generators:** You can use Python generators to create custom data iterators.
* **Data Ingestion Frameworks:**  Integrate seamlessly with popular frameworks like [TF Data](http://shortn/_oX61MDRjYv) or [PyGrian](http://shortn/_aCGKZaAhTI) for advanced data loading and preprocessing.
* **Automatic Conversion:** If you happen to provide a non-iterator iterable (like a list), Jaxloop will automatically convert it into an iterator internally.

In contrast to the training dataset, the evaluation dataset in Jaxloop must be **an iterable object, not an iterator**. This is because evaluation typically uses the same, fixed subset of data.

In [None]:
import os
import tensorflow as tf
import tensorflow_datasets as tfds

def mnist_datasets(batch_size, data_dir):
  def map_fn(x):
    return {
        'image': tf.cast(x['image'], tf.float32) / 255.0,
        'label': tf.cast(x['label'], tf.int32),
    }

  train_dir = os.path.join(data_dir, 'train')
  test_dir = os.path.join(data_dir, 'test')
  train_ds = (
      tfds.load('mnist', data_dir=train_dir, split='train', shuffle_files=True)
      .map(map_fn)
      .batch(batch_size, drop_remainder=True)
      .prefetch(tf.data.AUTOTUNE)
      .repeat()
  )
  eval_ds = (
      tfds.load('mnist', data_dir=test_dir, split='test', shuffle_files=False)
      .map(map_fn)
      .batch(batch_size, drop_remainder=True)
      .prefetch(tf.data.AUTOTUNE)
      .cache()
  )
  return train_ds, eval_ds

# Steps

In Jaxloop, a "step" refers to a modular unit of computation that processes a batch of data from your dataset. Each step is defined as a class inheriting from `jaxloop.Step`, and its core logic resides within the `run` function.

Here's a breakdown of key aspects:

* **`run` function:** This is the heart of a Jaxloop step. You, as the user, are responsible for implementing this function to define how a batch of data is processed. This could involve anything from computing the loss and gradients to updating model parameters or calculating evaluation metrics.
* **Train vs. Eval Steps:** A Jaxloop step can be designated as either a training step or an evaluation step. This distinction is controlled by a boolean input parameter. The implementation of the `run` function will typically differ between train and eval steps, reflecting the different tasks they perform.
* **Outputs:** The `run` function returns two values:
    * `train_state`:  An object containing the current state of your model and optimizer.
    * `output`: A dictionary (with string keys) where you can store any data relevant to the step's execution, such as loss values, accuracy metrics, gradients, or anything else you need to track.
* **JIT Compilation and Sharding:** To maximize performance, Jaxloop automatically compiles your `run` function using Just-In-Time (JIT) compilation and distributes the computation across multiple devices (if available) using a technique called sharding. The specific sharding strategy is determined by a "partitioner" object, which we'll discuss in more detail later.
* **`begin` and `end` functions:**  Jaxloop provides optional `begin` and `end` functions that you can implement within your step class. These functions allow you to perform pre-processing or post-processing operations on the `train_state` and `output` before and after the `run` function is executed.

By encapsulating different stages of your training and evaluation pipeline into these modular steps, Jaxloop promotes code organization, reusability, and flexibility in designing your experiments.

In [None]:
from typing import Optional

import flax.linen as nn
import jax
import jax.numpy as jnp
from jaxloop import step as step_lib
from jaxloop import types
import optax
import tensorflow as tf

Batch = types.Batch
Output = types.Output
State = types.TrainState
Step = step_lib.Step

class MnistStep(Step):

  def begin(self, state: State, batch: Batch) -> tuple[State, Batch]:
    if isinstance(batch['image'], tf.Tensor):
      batch['image'] = batch['image'].numpy()
    if isinstance(batch['label'], tf.Tensor):
      batch['label'] = batch['label'].numpy()
    return state, batch

  def run(self, state: State, batch: Batch) -> tuple[State, Optional[Output]]:
    images, labels = batch['image'], batch['label']

    def loss_fn(params):
      logits = state.apply_fn({'params': params}, images)
      one_hot = jax.nn.one_hot(labels, 10)
      loss = jnp.mean(optax.softmax_cross_entropy(logits, one_hot))
      return loss, logits

    if self.train:
      grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
      (loss, logits), grads = grad_fn(state.params)
      state = state.apply_gradients(grads=grads)
    else:
      loss, logits = loss_fn(state.params)

    accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == labels)
    return state, {'loss': loss, 'accuracy': accuracy}

Jaxloop Step uses partitioners to jit and shard [batch data](http://shortn/_CYzkbhgGP2), [model initialization](http://shortn/_6IvLhMS8nf), and [step run](http://shortn/_B0hgnwfG4G).

<!-- TODO(b/379340967) update link to partitioner explanation -->
<!-- Please refer to this [section]() for more in-depth explanation of parititioners. -->


In [None]:
from jax import sharding
from jaxloop import partition
from jax.experimental import mesh_utils

Mesh = sharding.Mesh

num_devices = len(jax.devices())
mesh = Mesh(mesh_utils.create_device_mesh((num_devices,)), ('data',))
partitioner = partition.DataParallelPartitioner(mesh=mesh, data_axis='data')

# Inner Loops

Jaxloop provides "inner loops" to handle the iterative execution of your experiment steps, whether for training or evaluation.

Here's how they work:

* **Train Loop:** This loop focuses on training your model. It repeatedly executes your defined training step for a specified number of iterations (`train_loop_steps`). If `train_loop_steps` is not provided, the loop will iterate over the entire training dataset.
* **Eval Loop:** This loop is designed for evaluating your model's performance. It runs your defined evaluation step for a set number of iterations (`eval_spec.num_steps`) or until the entire evaluation dataset has been processed.

**Adding Functionality with Actions**

Jaxloop allows you to incorporate custom actions at the beginning and end of each inner loop. These actions are essentially functions that perform specific tasks periodically during training or evaluation.

* **Built-in Actions:** Jaxloop offers pre-built actions like:
    * `SummaryAction`:  Used for logging summaries (e.g., metrics, visualizations) during training.
    * `CheckpointAction`:  Handles saving model checkpoints at regular intervals.
* **Custom Actions:** You can also define your own actions to perform any operations you need, such as learning rate scheduling, early stopping, or custom logging.

We'll delve deeper into actions and how to use them effectively in a dedicated section later in this documentation.

In [None]:
from orbax import checkpoint
from clu import metric_writers
from jaxloop import actions
from etils import epath

import tempfile

work_dir = tempfile.mkdtemp()
work_dir = epath.Path(work_dir)

ckpt_manager = checkpoint.CheckpointManager(
    work_dir / 'checkpoints',
    checkpoint.Checkpointer(checkpoint.PyTreeCheckpointHandler()),
    checkpoint.CheckpointManagerOptions(max_to_keep=3),
)
metrics_writer = metric_writers.create_default_writer(
    work_dir,
    asynchronous=False,
)

ckpt_action = actions.CheckpointAction(ckpt_manager, interval=100)
summary_action = actions.SummaryAction(metrics_writer, interval=100)

# Outer Loop

The outer loop in Jaxloop serves as the orchestrator of your entire training experiment. It combines the train and eval inner loops, running them repeatedly until a specified number of total training steps (`train_total_steps`) is reached.

**Initiating the Training Process**

To kickstart your training experiment using the Jaxloop outer loop, you'll need to provide the following:

* **Initialized Model State:** Start by initializing your model's parameters and optimizer state. This forms the basis for your `train_step`.
* **Training Dataset Iterator:**  Provide your training dataset in the form of an iterator, as explained earlier.
* **`train_total_steps`:** Define the total number of steps you want the training process to run for. This determines the overall duration of the experiment.
* **`train_loop_steps`:** Specify the number of steps to execute within each individual training loop.
* **`eval_specs`:** Configure the evaluation phase by providing:
    * The evaluation dataset.
    * The interval (in terms of steps) at which you want to run the evaluation loop.

By configuring these elements, you provide the outer loop with the necessary instructions to manage the training and evaluation processes effectively, ensuring your experiment runs smoothly and efficiently.

# Putting it all Together: A Custom Jaxloop Experiment

The following code snippet demonstrates how to orchestrate a complete training experiment using Jaxloop, integrating all the components we've discussed – models, datasets, steps, inner loops, and the outer loop. This serves as the main driver for your custom Jaxloop experiments.

In [None]:
from jaxloop import eval_loop as eval_loop_lib
from jaxloop import outer_loop as outer_loop_lib
from jaxloop import train_loop as train_loop_lib
import tempfile
import tensorflow_datasets as tfds

# 1. Define model.
model = CNN()

# 2. Define train and eval datasets.
data_dir = tempfile.mkdtemp()
train_ds, eval_ds = mnist_datasets(
    batch_size=32, data_dir=data_dir
)

# 3. Define training and eval steps.
train_step = MnistStep(
    jax.random.PRNGKey(0),
    model,
    optimizer=optax.sgd(learning_rate=0.005, momentum=0.9),
    partitioner=partitioner,
    train=True,
)
eval_step = MnistStep(
    jax.random.PRNGKey(0),
    model,
    partitioner=partitioner,
    train=False,
)

# 4. Define inner loops.
train_loop = train_loop_lib.TrainLoop(
    train_step, end_actions=[summary_action, ckpt_action]
)
eval_loop = eval_loop_lib.EvalLoop(eval_step, end_actions=[summary_action])

# 5. Define and invoke outer loop.
outer_loop = outer_loop_lib.OuterLoop(
    train_loop=train_loop, eval_loops=[eval_loop]
)
state, outputs = outer_loop(
    train_step.initialize_model([1, 28, 28, 1]),
    train_dataset=train_ds.as_numpy_iterator(),
    train_total_steps=100,
    train_loop_steps=10,
    eval_specs=[
        outer_loop_lib.EvalSpec(
            dataset=tfds.as_numpy(eval_ds), num_steps=100
        )
    ],
)