# SimpleTrainer (Quick Start)


This notebook is intended as an introduction to using the Jaxloop SimpleTrainer.

## Setup

In [22]:
from colabtools import adhoc_import
from flax import linen as nn
import jax
import jax.numpy as jnp
from jax.random import PRNGKey
import optax

with adhoc_import.Google3SubmittedChangelist():
  from jaxloop.trainers import simple_trainer

from jaxloop import types  # For types.TrainState

In [23]:
def datasets():
  """Generates some data examples (truth: y = sin(x)).

  Returns train_x, train_y, test_x, test_y separately.
  """
  train_x_vals = jnp.arange(-4, 4, 0.003, dtype=jnp.float32)
  test_x_vals = jnp.arange(-5, 5, 0.01, dtype=jnp.float32)

  # Desired shape for model input (..., 1, 1)
  train_x = jnp.expand_dims(train_x_vals, axis=-1)
  train_x = jnp.expand_dims(train_x, axis=-1)

  test_x = jnp.expand_dims(test_x_vals, axis=-1)
  test_x = jnp.expand_dims(test_x, axis=-1)

  train_y = jnp.sin(train_x)  # Shape: (num_points, 1, 1)
  test_y = jnp.sin(test_x)  # Shape: (num_points, 1, 1)

  # Return the four separate arrays
  return train_x, train_y, test_x, test_y


# Util to shuffle and batch data as (x, y) tuples
def _batch_generator(
    data_x: jax.Array,
    data_y: jax.Array,
    batch_size: int,
    prng_key: jax.random.PRNGKey,
):
  """Yields batches of (x, y) tuples."""
  num_samples = data_x.shape[0]
  indices = jax.random.permutation(prng_key, jnp.arange(num_samples))

  data_x_shuffled = data_x[indices]
  data_y_shuffled = data_y[indices]

  batch_count = 0
  for i in range(0, num_samples, batch_size):
    x_batch = data_x_shuffled[i : i + batch_size]
    y_batch = data_y_shuffled[i : i + batch_size]

    x_batch = x_batch.reshape(x_batch.shape[0], -1)
    y_batch = y_batch.reshape(y_batch.shape[0], -1)

    if x_batch.shape[0] == batch_size:
      batch_count += 1
      yield {"input_features": x_batch, "output_features": y_batch}

In [24]:
class SimpleNN(nn.Module):
  """A fully-connected neural network model with 6 layers"""

  @nn.compact
  def __call__(self, x, train=False):
    for _ in range(5):
      x = nn.Dense(features=64)(x)
      x = nn.relu(x)
    x = nn.Dense(features=1)(x)
    return x


NN_MODEL = SimpleNN()

In [25]:
# Hyperparameters
batch_size = 16
learning_rate = 0.005

num_total_points = len(jnp.arange(-4, 4, 0.003))

NUM_EPOCHS = 1
STEPS_PER_EPOCH = num_total_points // batch_size

In [26]:
"""A BatchSpec is a dictionary that maps feature names to tuples of (shape, dtype).

The shape is a tuple of integers representing the shape of the feature, and the
dtype is a NumPy dtype object representing the data type of the feature.

In this case, the input features (the x's) are a 1D array of shape (1,) and
dtype float32.
The output features (the y's) are also a 1D array of shape (1,) and dtype
float32.

The SimpleTrainer expects keys to be "input_features" and "output_features", if
this is not the case you must override _get_input_features() and
_get_output_features() respectively.
"""

BATCH_SPEC = {
    "input_features": (jnp.zeros((1,)).shape, jnp.float32),
    "output_features": (jnp.zeros((1,)).shape, jnp.float32),
}

In [27]:
OPTIMIZER = optax.adam(learning_rate)

# Get data splits
prng_seed = 0
prng = PRNGKey(prng_seed)
train_x_arr, train_y_arr, test_x_arr, test_y_arr = datasets()

BASE_PRNG = {"params": prng}

# A CheckpointingConfig can be created via simple_trainer.trainer_utils.CheckpointingConfig(...)
CHECKPOINTING_CONFIG = None

## Initialization & Training

In addition to the below parameters, the partitioner, step_class, train_loop_class, eval_loop_class, and outer_loop_class can also be customized.

In [28]:
# Create the SimpleTrainer instance with the model we created and hyperparameters
trainer = simple_trainer.SimpleTrainer(
    model=NN_MODEL,
    epochs=NUM_EPOCHS,
    steps_per_epoch=STEPS_PER_EPOCH,
    batch_spec=BATCH_SPEC,
    optimizer=OPTIMIZER,
    base_prng=BASE_PRNG,
    log_num_params=True,
    checkpointing_config=CHECKPOINTING_CONFIG,
)

In [29]:
untrained_state = (
    trainer.model_state
)  # Store the untrained state for later comparison

prng_train_batch, prng_plot = jax.random.split(prng)
train_data_generator = _batch_generator(
    train_x_arr, train_y_arr, batch_size, prng_train_batch
)

train_outputs = trainer.train(train_data_generator)  # Call the train method

trained_state = trainer.model_state

In [30]:
@jax.jit
def pred_step(state: types.TrainState, batch_x: jax.Array):
  """Apply the model."""
  result = state.apply_fn({'params': state.params}, batch_x, train=False)
  return result


# Predictions with the untrained model
untrained_preds = pred_step(untrained_state, test_x_arr)

# Predictions with the trained model
trained_preds = pred_step(trained_state, test_x_arr)

## Visualization

In [31]:
import matplotlib.pyplot as plt

xs_plot = test_x_arr
ys_true_plot = test_y_arr
untrained_plot = untrained_preds
trained_plot = trained_preds

plt.figure(figsize=(10, 10))
plt.scatter(
    xs_plot, untrained_plot, c="blue", label="untrained (SimpleTrainer)"
)
plt.scatter(xs_plot, trained_plot, c="purple", label="pred (SimpleTrainer)")
plt.scatter(xs_plot, ys_true_plot, c="red", label="true")
plt.legend(loc="upper left")
plt.title("Jaxloop SimpleTrainer: y = sin(x)")
plt.xlabel("x")
plt.ylabel("y")
plt.show()

final_test_loss = jnp.mean(optax.l2_loss(trained_preds, test_y_arr))
print(f"Final MSE loss on the test set: {final_test_loss}")