In [4]:
import treescope

In [5]:
from __future__ import annotations

import typing
from typing import Any

import dataclasses

import jax
import jax.numpy as jnp
import numpy as np

import IPython

import numpy as np
np.arange(10)


array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [4]:
some_arrays = {
    f"array_{i}": jax.random.normal(jax.random.key(i), (20, 50))
    for i in range(10)
}
some_arrays

{'array_0': Array([[ 1.62264216e+00,  2.02526474e+00, -4.33594435e-01,
         -7.86173493e-02,  1.76090896e-01, -9.72089231e-01,
         -4.95298743e-01,  4.94378597e-01,  6.64349318e-01,
         -9.50163484e-01,  2.17953038e+00, -1.95515060e+00,
          3.58570725e-01,  1.57795131e-01,  1.27708471e+00,
          1.51046479e+00,  9.70655978e-01,  5.99608064e-01,
          2.47007050e-02, -1.91647720e+00, -1.85934913e+00,
          1.72814405e+00,  4.71903495e-02,  8.14127982e-01,
          1.31327674e-01,  2.82847047e-01,  1.24359429e+00,
          6.90280080e-01, -8.00737441e-01, -7.40989983e-01,
         -1.53882873e+00,  3.02691847e-01, -2.07160451e-02,
          1.13287210e-01, -2.20654696e-01,  7.05225617e-02,
          8.53295803e-01, -8.21773827e-01, -1.46142114e-02,
         -1.50462165e-01, -9.00135219e-01, -7.59072721e-01,
          3.33095133e-01,  8.09249043e-01,  4.26925533e-02,
         -5.77671230e-01, -4.14398938e-01, -1.94125330e+00,
          1.31611836e+00,  7.

In [5]:
with treescope.active_autovisualizer.set_scoped(treescope.ArrayAutovisualizer()):
  treescope.display(some_arrays)


In [8]:
treescope.basic_interactive_setup()


In [10]:
import dataclasses

@dataclasses.dataclass
class MyDataclass:
  a: Any
  b: Any
  c: Any

class TheZenOfPython:
  def __repr__(self):
    return "<The Zen of Python:\nBeautiful is better than ugly.\nExplicit is better than implicit.\nSimple is better than complex.\nComplex is better than complicated.\nFlat is better than nested.\nSparse is better than dense.\nReadability counts.\nSpecial cases aren't special enough to break the rules.\nAlthough practicality beats purity.\nErrors should never pass silently.\nUnless explicitly silenced.\nIn the face of ambiguity, refuse the temptation to guess.\nThere should be one-- and preferably only one --obvious way to do it.\nAlthough that way may not be obvious at first unless you're Dutch.\nNow is better than never.\nAlthough never is often better than *right* now.\nIf the implementation is hard to explain, it's a bad idea.\nIf the implementation is easy to explain, it may be a good idea.\nNamespaces are one honking great idea -- let's do more of those!>"

[
    MyDataclass('a' * i, 'b' * i, ('cccc\n') * i)
    for i in range(10)
] + [
    MyDataclass(TheZenOfPython(), TheZenOfPython(), TheZenOfPython())
]


In [11]:
[
    [1, 2, 3, 4],
    ["a", "b", "c", "d"],
    [True, False, None, NotImplemented, Ellipsis],
    ["a\n  multiline\n    string"]
]

In [12]:
@dataclasses.dataclass(frozen=True)
class Bar:
  c: str
  d: int
  some_list: list = dataclasses.field(default_factory=list)

IPython.display.display(Bar(c="bar", d=2))

In [13]:
Bar(c="bar", d=2)

In [1]:
from functools import partial
import jax
import jax.numpy as jnp
from jax import jit, value_and_grad
from jax.nn.initializers import Initializer, truncated_normal

import optax

import tensorflow_datasets as tfds

from jaxtyping import PRNGKeyArray, PyTree, Array, Num
from pydantic import BaseModel, ConfigDict
from typing import Callable


class ModelBase(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True)

    def params(self, rng: PRNGKeyArray) -> PyTree:
        raise NotImplementedError

    def states(self, rng: PRNGKeyArray) -> PyTree:
        return {"is_training": True}

    def init(self, rng: PRNGKeyArray) -> tuple[PyTree, PyTree]:
        rng_ps, rng_st = jax.random.split(rng)
        return self.params(rng_ps), self.states(rng_st)

    def forward(self, ps: PyTree, x: PyTree, st: PyTree) -> tuple[PyTree, PyTree]:
        raise NotImplementedError

    def __call__(self, ps: PyTree, x: PyTree, st: PyTree) -> tuple[PyTree, PyTree]:
        return self.forward(ps, x, st)


class Dense(ModelBase):
    in_dim: int
    out_dim: int
    w_init: Initializer = truncated_normal()
    b_init: Initializer = truncated_normal()
    activation: None | Callable = None

    def params(self, rng: PRNGKeyArray) -> PyTree:
        rng_w, rng_b = jax.random.split(rng)
        return {
            "w": self.w_init(rng_w, (self.in_dim, self.out_dim)),
            "b": self.b_init(rng_b, (self.out_dim,)),
        }

    def forward(
        self, ps: PyTree, x: Num[Array, "... d"], st: None
    ) -> tuple[Num[Array, "... h"], None]:
        h = jnp.einsum("...d,dh->...h", x, ps["w"])
        o = h + ps["b"]
        if self.activation:
            o = self.activation(o)
        return o, st


class Chain(ModelBase):
    layers: tuple[ModelBase, ...]

    def params(self, rng: PRNGKeyArray) -> PyTree:
        rngs = jax.random.split(rng, len(self.layers))
        return [layer.params(rng) for layer, rng in zip(self.layers, rngs)]

    def states(self, rng: PRNGKeyArray) -> list[PyTree]:
        return [layer.states(rng) for layer in self.layers]

    def forward(
        self, ps: list[PyTree], x: PyTree, st: tuple[PyTree, ...]
    ) -> tuple[PyTree, tuple[PyTree, ...]]:
        h = x
        _st = ()
        for l, p, s in zip(self.layers, ps, st):
            h, _s = l(p, h, s)
            _st = (*_st, _s)
        return h, _st


step_size = 0.01
batch_size = 32

train_ds = (
    tfds.load("mnist", split="train")
    .repeat()
    .shuffle(1024, seed=123)
    .batch(batch_size, drop_remainder=True)
    .take(1000)
    .as_numpy_iterator()
)
test_ds = (
    tfds.load("mnist", split="test").batch(batch_size, drop_remainder=True).take(1000)
)

model = Chain(
    layers=(
        Dense(in_dim=784, out_dim=512, activation=jax.nn.relu),
        Dense(in_dim=512, out_dim=512, activation=jax.nn.relu),
        Dense(in_dim=512, out_dim=10),
    )
)

rng = jax.random.key(0)
params, states = model.init(rng)

optimizer = optax.sgd(0.01)
opt_state = optimizer.init(params)


def loss_fn(model, params, states, x, y):
    logits, states = model(params, x, states)
    losses = optax.softmax_cross_entropy_with_integer_labels(logits, y)
    return jnp.mean(losses), states


def accuracy(model, params, states):
    n_correct, n_total = 0, 0
    for batch in test_ds.as_numpy_iterator():
        x = jnp.reshape(batch["image"], (batch_size, -1))
        y = batch["label"]
        logits, _ = model(params, x, states)
        ŷ = jnp.argmax(logits, axis=1)
        n_correct += (ŷ == y).sum().item()
        n_total += batch_size
    return n_correct / n_total


@partial(jit, static_argnames=["model"])
def step(model, params, states, opt_state, x, y):
    (loss, states), grads = value_and_grad(loss_fn, has_aux=True, argnums=1)(
        model, params, states, x, y
    )
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    return loss, params, states, opt_state


for i, batch in enumerate(train_ds):
    if i % 100 == 0:
        acc = accuracy(model, params, states)
        print(f"Step {i}, Accuracy: {acc:.4f}")
    x, y = batch["image"], batch["label"]
    loss, params, states, opt_state = step(
        model, params, states, opt_state, jnp.reshape(x, (32, -1)), y
    )


Step 0, Accuracy: 0.0852
Step 100, Accuracy: 0.7899
Step 200, Accuracy: 0.9189
Step 300, Accuracy: 0.9266
Step 400, Accuracy: 0.9467
Step 500, Accuracy: 0.8985
Step 600, Accuracy: 0.9448
Step 700, Accuracy: 0.9182
Step 800, Accuracy: 0.9491
Step 900, Accuracy: 0.9504


2025-08-03 12:30:10.553748: W tensorflow/core/kernels/data/cache_dataset_ops.cc:916] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


In [7]:
with treescope.active_autovisualizer.set_scoped(treescope.ArrayAutovisualizer()):
    treescope.display(params)