In [1]:
import stanza.runtime
stanza.runtime.setup()

from stanza.data import PyTreeData
from stanza import train

from typing import Sequence

import optax
import jax
import jax.numpy as jnp
import flax.linen as nn
import stanza.train.ipython
import stanza.train.wandb

In [2]:
X = jnp.arange(1000).reshape((100, 10))
Y = (jnp.arange(100) > 50).reshape((100,))
train_data = PyTreeData((X, Y))

In [3]:
class SimpleMLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, inputs):
        x = inputs
        for i, feat in enumerate(self.features):
            x = nn.Dense(feat, name=f'layers_{i}')(x)
            if i != len(self.features) - 1:
                x = nn.relu(x)
        return x

def loss(vars, rng_key, sample, iteration):
    x, y = sample
    y_hat = model.apply(vars, x)
    loss = jnp.sum(jnp.square(y_hat - y))
    return train.LossOutput(
        loss=loss,
        metrics={"loss": loss}
    )
loss_fn = train.batch_loss(loss)

model = SimpleMLP([20, 10])

In [5]:
import wandb
run = wandb.init(reinit=True)
print(run.url)

optimizer = optax.adamw(1e-4)

vars = model.init(jax.random.key(42), jnp.zeros_like(train_data.structure[0]))
opt_state = optimizer.init(vars["params"])

with train.loop(train_data, 
            batch_size=16, 
            rng_key=jax.random.key(42),
            iterations=1000,
            progress=True
        ) as loop:
    for epoch in loop.epochs():
        for step in epoch.steps():
            # *note*: consumes opt_state, vars
            opt_state, vars, metrics = train.step(
                loss_fn, optimizer, opt_state, vars, 
                step.rng_key, step.batch,
                # extra arguments for the loss function
                iteration=step.iteration
            )
            if step.iteration % 100 == 0:
                train.ipython.log(step.iteration, metrics)
                train.wandb.log(step.iteration, metrics, run=run)
    train.ipython.log(step.iteration, metrics)
    train.wandb.log(step.iteration, metrics, run=run)

https://wandb.ai/dpfrommer-projects/stanza-projects_examples_notebooks_common/runs/wm25rwnw


Iteration       [32m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [35m1000/1000[0m [35m100%[0m [36m0:00:00[0m [33m0:00:01[0m
Epoch           [32m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [35m167/167  [0m [35m100%[0m [36m0:00:00[0m [33m0:00:01[0m
Epoch Iteration [32m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [35m4/4      [0m [35m100%[0m [36m0:00:00[0m [33m0:00:00[0m


[2;36m[15:56:55][0m[2;36m [0m[34mINFO    [0m stanza.train -      [1;36m0[0m | loss: [1;36m1134275.625[0m
[2;36m          [0m[2;36m [0m[34mINFO    [0m stanza.train -    [1;36m100[0m | loss: [1;36m558311.9375[0m
[2;36m          [0m[2;36m [0m[34mINFO    [0m stanza.train -    [1;36m200[0m | loss: [1;36m375918.28125[0m
[2;36m          [0m[2;36m [0m[34mINFO    [0m stanza.train -    [1;36m300[0m | loss: [1;36m123352.2109375[0m
[2;36m[15:56:56][0m[2;36m [0m[34mINFO    [0m stanza.train -    [1;36m400[0m | loss: [1;36m133825.625[0m
[2;36m          [0m[2;36m [0m[34mINFO    [0m stanza.train -    [1;36m500[0m | loss: [1;36m56262.05859375[0m
[2;36m          [0m[2;36m [0m[34mINFO    [0m stanza.train -    [1;36m600[0m | loss: [1;36m37692.49609375[0m
[2;36m          [0m[2;36m [0m[34mINFO    [0m stanza.train -    [1;36m700[0m | loss: [1;36m8036.95751953125[0m
[2;36m          [0m[2;36m [0m[34mINFO    [0m stanza.train - 