*Copyright 2024 The Penzai Authors.*

*Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
You may obtain a copy of the License at*

> http://www.apache.org/licenses/LICENSE-2.0

*Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
See the License for the specific language governing permissions and
limitations under the License.*

---

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/penzai/blob/main/notebooks/v2_how_to_think_in_penzai.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google-deepmind/penzai/blob/main/notebooks/v2_how_to_think_in_penzai.ipynb)

# How to Think in Penzai (v2 NN API)

Penzai prioritizes legibility, visualization, and easy editing of neural network models. It strives to follow a simple mental model, avoid magic wherever possible, and decompose into modular tools that can be combined without getting in your way. This means that Penzai models are often structured somewhat differently than models written with other libraries.

This document explains the key principles of Penzai's "v2" neural network system, which is available in `penzai.experimental.v2`. The v2 design attempts to remove some of the boilerplate in the initial design and simplify common operations on models.

This is intended to be a self-contained overview and does not require familiarity with the original design. (For a summary of the differences, see [this page](/guides/v2_differences).)

In [1]:
try:
  import penzai
except ImportError:
  !pip install penzai[notebook]

In [2]:
import collections
import dataclasses
import jax
import jax.numpy as jnp
from typing import Any, Callable, Sequence

In [3]:
import penzai
from penzai.experimental.v2 import pz
from penzai.experimental.v2.models import simple_mlp

## Principles

### 1. What You See is What You Get

The first central principle of Penzai is that models are designed to be visualizable by default.

Penzai includes a powerful interactive IPython pretty-printer with automatic embedded array visualizations (called Treescope), which can be used to look inside any JAX-compatible data structure. You can enable Treescope like this:

In [4]:
pz.ts.basic_interactive_setup()

Penzai goes out of its way to make sure that the pretty-printer representation of a model tells you everything you need to know about it:

- Every sublayer of the model is directly contained in its parent, and can be viewed by expanding it.
- Every parameter is an attribute of the layer that owns it, and all attributes are statically known and type-annotated.
- Most model objects are immutable, and all stateful modifications are constrained to explicit "Variable" objects.

For instance, here's what a simple MLP looks like in Penzai:


In [5]:
mlp = simple_mlp.MLP.from_config(
    name="mlp",
    init_base_rng=jax.random.key(0),
    feature_sizes=[8, 32, 32, 8]
)
mlp

Try clicking to expand or collapse different sublayers! We've turned on automatic array visualization, so if you expand one of the parameters, you can immediately visualize its shape and array data.

Importantly, this isn't just a pretty visualization of the model, it's actually a **full specification of the model structure**. Every attribute of the model object appears in the pretty printed output, so if it doesn't show up in the pretty-printed output, it's not part of the model.

(Note: you can click on the pretty-printed output and press `r` to see the fully-qualified type of any object!)

### 2. Models Are Callable, Patchable Data Structures

To make it easier to inspect and modify models, Penzai prioritizes treating models as user-modifiable data structures, rather than as opaque objects.
Every Penzai model object is a frozen [Python dataclass](https://docs.python.org/3/library/dataclasses.html), which means that all of the instance variables of Penzai models are explicitly type-annotated and tracked.

Models can be called with an input argument in order to run the model forward pass:

In [6]:
mlp(pz.nx.ones({"features": 8}))

You can also just as easily call one of their sublayers:

In [7]:
mlp.sublayers[2].sublayers[0]

In [8]:
mlp.sublayers[2].sublayers[0](pz.nx.ones({"features": 32}))

However, you can also easily modify the model forward pass by modifying the model data structure.

Penzai models are designed to be freely modified after they are built, including isolating small parts of larger models, combining models together, or inserting arbitrary logic at arbitrary points in a model's forward pass. And Penzai includes a structure-rewriting utility, `pz.select`, which lets you make arbitrary modifications to Penzai models using `.at(...).set(...)`-style syntax. For instance, you can find and remove particular layers:

In [9]:
# Find bias layers
pz.select(mlp).at_instances_of(pz.nn.AddBias)

In [10]:
# Remove them:
pz.select(mlp).at_instances_of(pz.nn.AddBias).remove_from_parent()

Or insert new layers to run new logic:

In [11]:
@pz.pytree_dataclass
class HelloWorld(pz.nn.Layer):
  def __call__(self, arg, **side_inputs):
    pz.show("Hello world! My value:", arg)
    return arg

In [12]:
# Insert a new layer after each nonlinearity:
patched = (
    pz.select(mlp).at_instances_of(pz.nn.Elementwise).insert_after(HelloWorld())
)
pz.select(patched).at_instances_of(HelloWorld)

In [13]:
# Run it:
patched(pz.nx.ones({"features": 8}))

Penzai models are registered as [JAX Pytree nodes](https://jax.readthedocs.io/en/latest/pytrees.html) (similar to Equinox) so that any Penzai model can be traversed using `jax.tree_util`. In fact, the `pz.select` utility is a general-purpose utility for modifying any JAX Pytree! Modifications to Penzai models always occur by making a *modified copy* of the model, instead of being stored as global state. For instsance, the model `patched` above is a modified copy of `mlp`, which behaves differently when it is run.

Penzai models are also designed to be *as permissive as possible* about their contents after construction. For instance, the MLP class doesn't specifically require it's children to be Affine layers (a.k.a. Dense layers), and doesn't run the activation functions directly. Instead, it is a subclass of `Sequential`, and it just runs its sublayers in sequence without caring about their types. This means we are free to insert new logic into an `MLP` at runtume to customize its behavior, without having to change its original code.

### 3. Parameters And State Are Tracked With Explicit Variable Nodes

As Pytree nodes, Penzai model objects are immutable, which simplifies working with JAX and allows you to safely make copies of your model that behave in different ways. However, models often require keeping track of mutable state:
- Parameters are often updated by gradient descent, and shared parameters need to stay in sync.
- Some model configurations, like key-value caching in Transformers, require keeping track of per-layer model states while the model runs.
- It can be useful to save intermediate model activations so that you can inspect or modify them later.

Penzai supports this using "variable" nodes, which are explicit "pockets of mutable state" inside Penzai models. Each Penzai model tree has two types of leaves:
- JAX arrays and scalars, which are immutable and often represent hyperparameters, and
- Variable objects, which can be modified, and come in two variants:
  - `Parameter`s, which are usually modified by optimizers (not by the model),
  - `StateVariable`s, which are usually updated as the model runs.

For instance, the leaves of the `mlp` above are its parameters, each of which is an instance of `Parameter`:

In [14]:
jax.tree_util.tree_leaves(mlp)

The same parameter can appear multiple times in a single model. As an example, here's a model that repeats the same layer multiple times, along with a scaling factor:

In [15]:
layer = pz.nn.Affine.from_config(
    name="shared_layer",
    init_base_rng=jax.random.key(0),
    input_axes={"features": 8},
    output_axes={"features": 8},
)
my_model_with_repeats = pz.nn.Sequential([
    layer,
    pz.nn.Elementwise(jax.nn.relu),
    pz.nn.ConstantRescale(0.5),
    layer,
])
my_model_with_repeats

In this case, the PyTree leaves of this model will repeat the parameters twice, and also include the rescaling hyperparameter:

In [16]:
jax.tree_util.tree_leaves(my_model_with_repeats)

To extract and deduplicate the parameters, you can use the helper function `pz.unbind_params`. This produces:
- A copy of the model with each `Parameter` replaced with a `ParameterSlot` placeholder,
- A tuple of all unique parameters in the model.

In [17]:
unbound_model, params = pz.unbind_variables(my_model_with_repeats)

pz.show("unbound_model:", unbound_model)
pz.show("params:", params)

These parameters can then be substituted back into the model using `pz.bind_variables`.

You can implement stateful layers using a similar mechanism, but with `StateVariable` instead of `Parameter`. Here's a layer that stores its intermediate activation into a list:

In [18]:
@pz.pytree_dataclass
class SaveIntermediate(pz.nn.Layer):
  saved: pz.StateVariable[list[Any]]
  def __call__(self, x: Any, **unused_side_inputs) -> Any:
    self.saved.value = self.saved.value + [x]
    return x

We can insert two copies of it into our MLP, and then call it to retrieve the values:

In [19]:
var = pz.StateVariable(value=[], label="my_intermediate")

saving_model = (
    pz.select(mlp)
    .at_instances_of(pz.nn.Elementwise)
    .insert_after(SaveIntermediate(var))
)

saving_model

In [20]:
saving_model(pz.nx.ones({"features": 8}))

In [21]:
var

You can similarly unbind state variables using `pz.unbind_state_vars`:

In [22]:
unbound_saving_mlp, all_vars = pz.unbind_state_vars(saving_model)

pz.show("unbound_saving_mlp:", unbound_saving_mlp)
pz.show("all_vars:", all_vars)

Or unbind both parameters and states using `pz.unbind_variables`:

In [23]:
unbound_saving_mlp, all_vars = pz.unbind_variables(saving_model)

pz.show("unbound_saving_mlp:", unbound_saving_mlp)
pz.show("all_vars:", all_vars)

To make it easier to manipulate variables with JAX, any variable can be "frozen" using the `.freeze` method or the `pz.freeze_variables` function:

In [24]:
pz.freeze_variables(all_vars)

Frozen variables are JAX PyTrees, and can be safely passed through JAX transformations. Penzai models also support a "pure" interface that lets you pass frozen variables in and get new frozen variables out:

In [25]:
# Freeze parameters:
frozen_param_model = pz.freeze_params(saving_model)

# Unbind and freeze state vars:
unbound_frozen_model, state_vars = pz.unbind_state_vars(
    frozen_param_model, freeze=True
)
state_var_values = pz.freeze_state_vars(state_vars)

# Call it in "pure" style, tracking modifications to the intermediates variable.
# The input and output variables are frozen, but the variable can be locally
# modified while the model runs:
output, updated_var_states = unbound_frozen_model.stateless_call(
    [pz.StateVariableValue(label='my_intermediate', value=[])],
    pz.nx.ones({"features": 8})
)
pz.show("output:", output)
pz.show("updated_var_states:", updated_var_states)

You may need to freeze variables in order to pass them through JAX's function transformations. (For `jit`, Penzai includes a wrapped version called `pz.variable_jit` that handles this for you.)

### 4. Each Layer Has The Same Signature And Does A Single Thing

Penzai models are built by composing layers, where each layer implements the following interface:

```python
class Layer(pz.Struct, abc.ABC):
  @abc.abstractmethod
  def __call__(self, argument: Any, /, **side_inputs) -> Any:
    ...

```

In short:

- Each layer defines a method `__call__`, which enables it to be called directly like a function, and which contains all of the layer's runtime logic.
- `__call__` always takes exactly one positional argument, which is its input from the previous layer. (If necessary, this argument can be a tuple, dictionary, `pz.Struct`, or other JAX Pytree.)
- `__call__` also takes an arbitrary number of keyword arguments, which are *side inputs*. Side inputs can be used for information like attention masks or random number generators, and are usually shared across every layer in the model. Layers should ignore side inputs that they do not recognize.
- Whenever possible, idiomatic Penzai models should not contain Python conditional branches in their `__call__`. You should be able to JIT-compile the `__call__` of any model, and there should generally be only a single control flow path through it.

Penzai uses this convention because it makes it straightforward to compose layers with each other. For instance, there's an unambiguous way to run two layers in order: pass the output of the first layer as the positional input of the second, and pass the same side inputs to both layers.

This means that, instead of passing configuration data as arguments to the forward pass of each layer, most configuration is directly attached to the layer itself:
- Configuration metadata, such as the input or output axis names for `pz.nn.Linear`, the activation function for `pz.nn.Elementwise`, or the name of a dynamic side input, are stored as attributes on the layer, and set when the layer is initially built.
- Different "modes" of computation, such as "whether or not we should enable dropout" or "whether we are doing scoring or autoregressive decoding", are usually represented as *different classes*. This makes sure that the number of configuration attributes is small, and that the implementation of each layer is simple. You can then swap out model components using `pz.select` to switch between different model behaviors.

The emphasis on "doing one thing" also extends to composite layers. In Penzai, composite layers are usually defined as direct compositions of simpler layers, by subclassing the `pz.nn.Sequential` combinator. Then, their responsibility at runtime is just to call their children in sequence, which means it's easy to insert new logic without interfering with the model's computation. We've already seen an example of this: the `MLP` model and `Affine` blocks in our `mlp` are both subclasses of `pz.nn.Sequential`.

More complex combinators also tend to adhere to this pattern. For instance, the core `Attention` block in Penzai is purely a dataflow combinator, defined as

```python
@struct.pytree_dataclass
class Attention(Layer):
  input_to_query: Layer
  input_to_key: Layer
  input_to_value: Layer
  query_key_to_attn: Layer
  attn_value_to_output: Layer

  def __call__(self, x: NamedArray, **side_inputs) -> NamedArray:
    query = self.input_to_query(x, **side_inputs)
    key = self.input_to_key(x, **side_inputs)
    value = self.input_to_value(x, **side_inputs)
    attn = self.query_key_to_attn((query, key), **side_inputs)
    output = self.attn_value_to_output((attn, value), **side_inputs)
    return output
```

All of the specific logic of computing positional embeddings, applying attention masks, and computing the softmax weights are left to the child layers, which makes it easy to go in and capture intermediates or intervene on their behaviors at any point, without needing to change the attention implementation. `Attention` itself just does a single thing: manage the routing of data between the different components, during training or scoring mode.

If you want to do autoregressive decoding, you can swap out `Attention` blocks for `KVCachingAttention` blocks using something like
```python
(
  pz.select(model)
  .at_instances_of(pz.nn.Attention)
  .apply(lambda attn: pz.nn.KVCachingAttention.from_uncached(attn, **kwargs))
)
```
This produces a copy of your model that additionally manages and updates KV caches, while still supporting arbitrary child layer logic and without changing any of the rest of your model.

### 5. Configuration Happens During Construction (Not `__call__`)

As discussed above, Penzai layers avoid passing configuration arguments at runtime, and avoid making assumptions about their child layers and parameters as much as possible. However, it's still important for layers and models to be able to configure themselves and initialize their parameters. In Penzai, all of this happens when the layers are initially constructed.

By convention, Penzai layers configure themselves using a class method, often called `from_config(cls, name: str, init_base_rng, ...)`. `from_config`, in turn, takes all of the configuration arguments that are necessary to initialize the model, and uses them to set up their sublayers and parameters. Penzai layers usually do NOT override `__init__`, so that it's easy to bypass the initialization logic and rebuild models with different attributes.

We can see this by calling the `from_config` method of `simple_mlp.MLP`:

In [26]:
mlp = simple_mlp.MLP.from_config(
    name="mlp",
    init_base_rng=jax.random.PRNGKey(1),
    feature_sizes=[8, 32, 32, 8],
    activation_fn=jax.nn.gelu,
)
mlp

Notice that the arguments to `from_config` aren't actually stored on the `MLP` itself. Instead, they are simply used to configure and set up the list of sublayers. In general, the configuration arguments of complex models will often "vanish" in this way after the model is initially built.

In fact, all of the custom logic of `MLP` and `Affine` is defined in the `from_config` methods, not `__call__`. Once initialized, you are free to remove them entirely without affecting the behavior of the model. For instance, we can replace the MLP class with a basic `Sequential` and get the same behavior:

In [27]:
pz.nn.inline_groups(
    pz.nn.Sequential([mlp]),
    parent_filter=lambda _: True,
    child_filter=lambda _: True,
)

This pattern also applies to layers that are designed for hot-swapping. For instance, the `KVCachingAttention` block defines a classmethod `.from_uncached` that converts an `Attention` block into a `KVCachingAttention`, which takes ownership of the children of that `Attention` block and then discards the original block.

In general, it may be useful to think of a Penzai model as a "declarative" *list of steps in the model's forward pass*. If different configurations run different steps, they are usually represented as models with different structures.

By convention, layer builders like `from_config` follow the signature

```python
def from_config(cls, name: str, init_base_rng: jax.Array | None, ...):
  ...
```

The `name` argument is used to ensure that all parameters have unique names, and the `init_base_rng` determines how to initialize the parameters:
- If `init_base_rng` is a JAX PRNGKey, it is combined with the `name` argument to initialize the parameter randomly. The resulting model will contain
a `Variable` for each parameter.
- If `init_base_rng` is `None`, parameter initialization is skipped, and the resulting model will instead contain a `VariableSlot` for each parameter. This can be useful for loading pretrained models from checkpoints instead of initializing them from scratch.

To make this work:
- Layers that contain other sublayers should give them unique names by adding a suffix to their own name, e.g. passing `name=f"{name}/Linear_0"` to their child. The `init_base_rng` should be forwarded to sublayers unchanged.
- Layers that directly initialize parameters should use the helper function `pz.nn.make_parameter`, which implements the above logic and ensures parameters with different names are initialized differently, even with the same `init_base_rng`.

### 6. Layers Use Named Axes (Via Lifted Positional Operations)

Axis ordering can make it harder to reason about what complex models are doing, especially when trying to visualize or intervene on internal activations, or when using models from an unfamiliar codebase. It's often easier to refer to axes by name. But you shouldn't have to learn a whole new array API just to use named axes; the existing Numpy and JAX APIs are pretty good!

Penzai strikes a middle ground using a lightweight *locally-positional* named-axis system, defined in a single file and with a minimal API surface. In short:

- The `pz.nx.NamedArray` class wraps an ordinary array, and assigns each axis to either a position *or* a name (but not both).
- You can convert positional axes to named ones using `.tag(...)`, or convert named axes back to positional axes using `.untag(...)`.
- Any JAX function can be *lifted* using `pz.nx.nmap`. The lifted function will act normally over the positional axes but will be automatically vectorized over all of the named axes (using `jax.vmap` internally). Only `NamedArray` arguments are processed in this way; other arguments are just passed through.
- Standard array methods and operators (e.g. `.sum()`, `+`, or slicing) are also lifted so that they operate over positional axes and vectorize over named axes.
- By convention, Penzai layers use axis names to define their interface, but then use `.untag`, `nmap`, and `.tag` to implement their internal logic.

For instance, here's how you might take a softmax over a vocabulary axis:

In [28]:
# Start with a JAX array:
array = jax.random.normal(jax.random.key(0), [8, 32])
# Wrap it as a named array:
wrapped = pz.nx.wrap(array)
# Assign names:
named = wrapped.tag("batch", "vocabulary")
# Visualize it:
named

In [29]:
# Un-tag the vocabulary axis:
untagged = named.untag("vocabulary")
# Map the ordinary JAX softmax function over the temporary positional axis:
softmaxed = pz.nx.nmap(jax.nn.softmax)(untagged, axis=0)
# Tag the positional axis with a name again:
softmaxed.tag("vocabulary")

And here's how you might wrap that in an idiomatic layer, which has a named-axis interface:

In [30]:
@pz.pytree_dataclass
class Softmax(pz.nn.Layer):
  axis_name: str = dataclasses.field(metadata={"pytree_node": False})
  def __call__(self, arg, **unused_side_inputs):
    # Write the logic as if the argument is one dimensional:
    arr = arg.untag(self.axis_name)
    assert len(arr.positional_shape) == 1
    result = pz.nx.nmap(jax.nn.softmax)(arr, axis=0)
    # Then re-bind names at the end:
    return result.tag(self.axis_name)

In [31]:
layer = Softmax("vocabulary")
layer

In [32]:
layer(named)

Because everything vectorizes over names by default, Penzai models can usually be used with arbitrary numbers of batch axes at runtime as long as you give them unique names. You can even insert new layers that manipulate specific batch axes by name (e.g. copying activations from one input to another), without interfering with any of the shapes in the rest of your model.

## Putting It All Together: A Basic Penzai Neural Network

To show how these principles interact, here's how we might implement a neural network from scratch in Penzai. We'll focus on re-implementing a basic MLP (like the running example above), and omit a few advanced features to keep things simple.

An MLP is composed of a sequence of steps, including linear operations, biases, and elementwise activations. We can implement each of these using a separate layer so that we can manipulate them after the model is built, and define each using named axes:

In [33]:
@pz.pytree_dataclass
class SimpleLinear(pz.nn.Layer):
  """A simple linear layer with a single input/output axis."""

  # Parameters are annotated as `ParameterLike` to allow swapping them out after
  # initialization.
  kernel: pz.nn.ParameterLike[pz.nx.NamedArray]

  # Non-Pytree fields (which are not arraylike) should be annotated as such to
  # tell JAX not to try to convert them:
  features_axis: str = dataclasses.field(metadata={"pytree_node": False})

  def __call__(
      self, x: pz.nx.NamedArray, /, **unused_side_inputs
  ) -> pz.nx.NamedArray:
    """Multiplies the input by the learned kernel."""
    # pos_x has one positional axis
    pos_x = x.untag(self.features_axis)
    # pos_kernel has two positional axes
    pos_kernel = self.kernel.value.untag("out_features", "in_features")
    # We can combine them using ordinary positional semantics:
    pos_y = pz.nx.nmap(jnp.dot)(pos_kernel, pos_x)
    return pos_y.tag(self.features_axis)

  @classmethod
  def from_config(
      cls,
      name: str,
      init_base_rng: jax.Array | None,
      in_features: int,
      out_features: int,
      features_axis: str = "features",
  ) -> "SimpleLinear":
    """Constructs a linear layer from configuration arguments."""
    def _initializer(key):
      arr = jax.nn.initializers.xavier_normal()(
          key, (out_features, in_features)
      )
      return pz.nx.wrap(arr).tag("out_features", "in_features")

    return cls(
        kernel=pz.nn.make_parameter(
            name=f"{name}.kernel",
            init_base_rng=init_base_rng,
            initializer=_initializer,
        ),
        features_axis=features_axis,
    )

In [34]:
@pz.pytree_dataclass
class SimpleBias(pz.nn.Layer):
  """A simple bias layer."""
  # The SimpleBias layer doesn't need to store its output axis name at all!
  bias: pz.nn.ParameterLike[pz.nx.NamedArray]

  def __call__(self, x: pz.nx.NamedArray, /, **unused_side_inputs) -> pz.nx.NamedArray:
    """Adds a bias to the input."""
    return x + self.bias.value  # Automatically vectorized!

  @classmethod
  def from_config(
      cls,
      name: str,
      init_base_rng: jax.Array | None,
      features: int,
      features_axis: str = "features",
  ) -> "SimpleBias":
    """Constructs a bias layer from configuration arguments."""
    return cls(
        bias=pz.nn.make_parameter(
            name=f"{name}.bias",
            init_base_rng=init_base_rng,
            initializer=lambda _: pz.nx.zeros({features_axis: features}),
        ),
    )

In [35]:
@pz.pytree_dataclass
class SimpleElementwise(pz.nn.Layer):
  """A simple elementwise layer."""
  fn: Callable[[jax.Array], jax.Array] = dataclasses.field(
      metadata={"pytree_node": False}
  )

  def __call__(self, x: pz.nx.NamedArray, /, **unused_side_inputs) -> pz.nx.NamedArray:
    """Runs the activation function."""
    return pz.nx.nmap(self.fn)(x)

  # No need for `from_config`, since it would be the same as `__init__`.

We can then define a top-level MLP layer as a subclass of `Sequential`:

In [36]:
@pz.pytree_dataclass
class SimpleMLP(pz.nn.Sequential):
  # sublayers is inherited from Sequential, but we restate it here for clarity.
  sublayers: list[pz.nn.Layer]

  # __call__ is inherited from Sequential, so no need to reimplement it! In
  # fact, Sequential.__call__ is marked with @typing.final so you don't
  # accidentally override it.

  @classmethod
  def from_config(
      cls,
      name: str,
      init_base_rng: jax.Array | None,
      feature_sizes: Sequence[int],
      activation: Callable[[jax.Array], jax.Array] = jax.nn.relu,
      features_axis: str = "features",
  ) -> "SimpleMLP":
    """Constructs a MLP with uninitialized parameters."""
    # We build the steps of the forward pass in from_config, and push all
    # configuration arguments down to the sublayers:
    sublayers = []
    for i in range(len(feature_sizes) - 1):
      # We need to ensure parameter name uniqueness ourselves:
      sublayers.append(SimpleLinear.from_config(
          name=f"{name}/block_{i}/linear",
          init_base_rng=init_base_rng,
          in_features=feature_sizes[i],
          out_features=feature_sizes[i + 1],
          features_axis=features_axis,
      ))
      sublayers.append(SimpleBias.from_config(
          name=f"{name}/block_{i}/bias",
          init_base_rng=init_base_rng,
          features=feature_sizes[i + 1],
          features_axis=features_axis,
      ))
      if i < len(feature_sizes) - 2:
        sublayers.append(SimpleElementwise(activation))
    return cls(sublayers)

Building our model without an initialization PRNGKey just builds the structure:

In [37]:
SimpleMLP.from_config(
    name="mlp",
    init_base_rng=None,
    feature_sizes=[8, 32, 32, 8],
    activation=jax.nn.relu,
    features_axis="features",
)

If we pass `init_base_rng`, it will also initialize the parameters as mutable Variable objects:

In [38]:
model = SimpleMLP.from_config(
    name="mlp",
    init_base_rng=jax.random.key(42),
    feature_sizes=[8, 32, 32, 8],
    activation=jax.nn.relu,
    features_axis="features",
)
model

We can call it with some example inputs to check that it works:

In [39]:
model(pz.nx.ones({"features": 8}))

Or set up a simple training loop:

In [40]:
from penzai.experimental.v2.toolshed import basic_training
import optax

In [41]:
example_inputs = pz.nx.wrap(
    jax.random.normal(jax.random.key(100), (100, 8))
).tag("batch", "features")
example_targets = pz.nx.wrap(
    jax.random.normal(jax.random.key(101), (100, 8))
).tag("batch", "features")

def loss_fn(model, rng, state, current_input, current_target):
  del rng, state  # More complex training loops could use these if needed
  model_out = model(current_input)
  losses = pz.nx.nmap(jnp.square)(model_out - current_target)
  loss = losses.untag("batch", "features").unwrap().sum()
  return (loss, None, {"my_loss": loss})

trainer = basic_training.StatefulTrainer.build(
    root_rng=jax.random.key(42),
    model=model,
    optimizer_def=optax.adam(0.01),
    loss_fn=loss_fn,
)

outputs = []
while trainer.state.value.step < 1000:
  out = trainer.step(
      current_input=example_inputs,
      current_target=example_targets,
  )
  if trainer.state.value.step % 20 == 0:
    print(f"At {trainer.state.value.step}: {out}")

At 20: {'my_loss': Array(562.4356, dtype=float32)}
At 40: {'my_loss': Array(354.94263, dtype=float32)}
At 60: {'my_loss': Array(208.88644, dtype=float32)}


At 80: {'my_loss': Array(116.63308, dtype=float32)}
At 100: {'my_loss': Array(69.82226, dtype=float32)}


At 120: {'my_loss': Array(48.01413, dtype=float32)}
At 140: {'my_loss': Array(36.271587, dtype=float32)}


At 160: {'my_loss': Array(27.9783, dtype=float32)}
At 180: {'my_loss': Array(22.618397, dtype=float32)}
At 200: {'my_loss': Array(19.75925, dtype=float32)}


At 220: {'my_loss': Array(15.713261, dtype=float32)}
At 240: {'my_loss': Array(14.278907, dtype=float32)}
At 260: {'my_loss': Array(12.619974, dtype=float32)}


At 280: {'my_loss': Array(11.935994, dtype=float32)}
At 300: {'my_loss': Array(8.933923, dtype=float32)}
At 320: {'my_loss': Array(7.9672227, dtype=float32)}


At 340: {'my_loss': Array(6.886015, dtype=float32)}
At 360: {'my_loss': Array(7.233006, dtype=float32)}


At 380: {'my_loss': Array(5.6690035, dtype=float32)}
At 400: {'my_loss': Array(6.6670713, dtype=float32)}
At 420: {'my_loss': Array(5.707902, dtype=float32)}


At 440: {'my_loss': Array(5.3238034, dtype=float32)}
At 460: {'my_loss': Array(5.2755117, dtype=float32)}
At 480: {'my_loss': Array(3.6372209, dtype=float32)}


At 500: {'my_loss': Array(4.060416, dtype=float32)}
At 520: {'my_loss': Array(3.7824037, dtype=float32)}
At 540: {'my_loss': Array(3.8496475, dtype=float32)}


At 560: {'my_loss': Array(3.2768314, dtype=float32)}
At 580: {'my_loss': Array(3.402127, dtype=float32)}
At 600: {'my_loss': Array(1.9185027, dtype=float32)}


At 620: {'my_loss': Array(2.0824428, dtype=float32)}
At 640: {'my_loss': Array(2.2999728, dtype=float32)}
At 660: {'my_loss': Array(3.2461543, dtype=float32)}


At 680: {'my_loss': Array(2.2447917, dtype=float32)}
At 700: {'my_loss': Array(2.2864742, dtype=float32)}
At 720: {'my_loss': Array(1.6309336, dtype=float32)}


At 740: {'my_loss': Array(3.654787, dtype=float32)}
At 760: {'my_loss': Array(1.7311825, dtype=float32)}
At 780: {'my_loss': Array(2.9659948, dtype=float32)}


At 800: {'my_loss': Array(4.086928, dtype=float32)}
At 820: {'my_loss': Array(2.963339, dtype=float32)}


At 840: {'my_loss': Array(1.3671762, dtype=float32)}
At 860: {'my_loss': Array(1.7845328, dtype=float32)}
At 880: {'my_loss': Array(1.1663316, dtype=float32)}


At 900: {'my_loss': Array(2.9988499, dtype=float32)}
At 920: {'my_loss': Array(2.2414417, dtype=float32)}
At 940: {'my_loss': Array(2.2406294, dtype=float32)}


At 960: {'my_loss': Array(0.8716761, dtype=float32)}
At 980: {'my_loss': Array(1.9427958, dtype=float32)}
At 1000: {'my_loss': Array(1.9259623, dtype=float32)}


## Summary

You now know everything you need to get started with neural networks in Penzai!

Penzai strives to enable complex modifications and interventions on models either before or after training them, without getting in your way. Following the principles described here is a recommended starting point and a great way to take advantage of all of Penzai's tooling, but it's not strictly enforced! You're free to use Penzai's visualization and patching tools with non-Penzai models, or define your own callable PyTree components without conforming to the `pz.nn.Layer` interface, if that makes more sense for your use case.