In [None]:
import ciclo

In [None]:
from ciclo import managed

def create_managed_state(strategy: str = "jit"):
    model = Linear()
    variables = model.init(jax.random.PRNGKey(0), jnp.empty((1, 28, 28, 1)))
    return managed.ManagedState.create(
        apply_fn=model.apply,
        params=variables["params"],
        tx=optax.adamw(1e-3),
        strategy=strategy,
    )

In [None]:
@managed.train_step
def managed_train_step(state: managed.ManagedState, batch):
    inputs, labels = batch["image"], batch["label"]
    logits = state.apply_fn({"params": state.params}, inputs)
    loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=labels
    ).mean()
    logs = ciclo.logs()
    # add at least one loss, these are used to compute gradients
    logs.add_loss("loss", loss)
    # add any other metrics
    logs.add_metric("accuracy", jnp.mean(jnp.argmax(logits, -1) == labels))
    logs.add_metric("loss", loss)
    return logs, state

In [None]:
total_steps = 5_000
state = create_managed_state(strategy="jit") # try "data_parallel" ðŸ¤¯

state, history, elapsed = ciclo.loop(
    state,
    ds_train.as_numpy_iterator(),
    {
        ciclo.every(1): [managed_train_step],
        ciclo.every(steps=1000): [
            ciclo.checkpoint(f"logdir/getting_started/{int(time())}")
        ],
        ciclo.every(1): [ciclo.keras_bar(total=total_steps, interval=0.4)],
    },
    stop=total_steps,
)

In [None]:
import matplotlib.pyplot as plt

# collect metric values
steps, loss, accuracy = history.collect("steps", "loss", "accuracy")

def plot_metrics(steps, loss, accuracy):
    fig, axs = plt.subplots(1, 2)
    axs[0].plot(steps, loss)
    axs[0].set_title("Loss")
    axs[1].plot(steps, accuracy)
    axs[1].set_title("Accuracy")
    plt.show()

plot_metrics(steps, loss, accuracy)