*Copyright 2024 The Penzai Authors.*

*Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
You may obtain a copy of the License at*

> http://www.apache.org/licenses/LICENSE-2.0

*Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
See the License for the specific language governing permissions and
limitations under the License.*

---

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/penzai/blob/main/notebooks/v2_jitting_and_sharding.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google-deepmind/penzai/blob/main/notebooks/v2_jitting_and_sharding.ipynb)

# Jitting and Sharding Penzai Models (V2 API)

Penzai is designed to be compatible with JAX's standard function transformations, including JIT-compilation and array sharding. If you're already familiar with JIT compilation and distributed arrays in JAX, you shouldn't have to learn anything fundamentally new to apply it to Penzai! But Penzai does provide some utilities to make it easier to construct and manipulate shardings for Penzai models.

This notebook walks through some of the common aspects of JIT-compilation and sharding as they apply to Penzai tools and Penzai models. It assumes some basic familiarity with JAX's [JIT compilation](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html) and [distributed array](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) systems.

```{note}
This tutorial uses the V2 neural network API, defined in `pz.experimental.v2`.
```

## Setup

Before we can get started in earnest, we need to set up the environment.

### Imports

To run this notebook, you need a Python environment with `penzai` and its dependencies installed.

In Colab or Kaggle, you can install it using the following command:

In [1]:
try:
  import penzai
except ImportError:
  !pip install penzai[notebook]

In [2]:
from __future__ import annotations

In [3]:
import dataclasses

import jax
import jax.numpy as jnp
import optax

In [4]:
import penzai
from penzai.experimental.v2 import pz

In [5]:
from penzai.experimental.v2.models import transformer
from penzai.experimental.v2.models import simple_mlp
from penzai.experimental.v2.toolshed import basic_training

### Setting up Penzai

For this tutorial, we'll enable Treescope (Penzai's pretty-printer) as the default IPython pretty-printer. This is recommended when using Penzai in an interactive environment. We'll also enable automatic array visualization, which also makes it easy to visualize array shardings.

In [6]:
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.ts.register_context_manager_magic()

In [7]:
pz.enable_interactive_context()
pz.ts.active_autovisualizer.set_interactive(pz.ts.ArrayAutovisualizer())

We'll assume this notebook is running on a backend with eight devices. If needed, you can force JAX to treat the CPU backend as multiple devices using
```python
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
```

In [8]:
pz.show(jax.local_devices())
assert jax.local_device_count() == 8

## JIT-Compiling Penzai Models

Penzai model objects themselves are always JAX PyTrees. However, in addition to arrays and arraylike leaves, Penzai models can also include two types of "variable" leaves: `pz.Parameter` and `pz.StateVariable`. These are currently not directly supported by `jax.jit`.

For example, consider the following (somewhat contrived) model, which has a learnable parameter and an incrementing counter:

In [9]:
@pz.pytree_dataclass
class CounterLayer(pz.nn.Layer):
  counter: pz.StateVariable[int]

  def __call__(self, x, **_side_inputs):
    self.counter.value += 1
    return (x, self.counter.value)

model = pz.nn.Sequential([
    pz.nn.Linear.from_config(
        name="linear",
        init_base_rng=jax.random.PRNGKey(0),
        input_axes={"features": 8},
        output_axes={"features_out": 8},
    ),
    CounterLayer(counter=pz.StateVariable(value=0, label="counter")),
])

model

In [10]:
model(pz.nx.ones({"features": 8}))

In [11]:
model(pz.nx.ones({"features": 8}))

To JIT-compile a Penzai model, you have three options:

- The "functional API": A set of Penzai tools to help you manipulate variable states using pure functions and JAX PyTrees.
- `pz.variable_jit`: A convenience wrapper around `jax.jit` that also works for PyTrees containing `pz.Parameter` and `pz.StateVariable`.
- `toolshed.jit_wrapper.Jitted`: A model combinator that acts like an ordinary `Layer`, but always runs under `jax.jit` (using `pz.variable_jit` around its `__call__` method).

### The "Functional API"

Each of Penzai's variables comes in three forms:

- Mutable variables (`pz.Parameter` and `pz.StateVariable`), which are Python objects whose `.value` attribute can be modified freely,
- Frozen variable values (`pz.ParameterValue` and `pz.StateVariableValue`), which are immutable JAX PyTree objects that are safe to pass through JAX transforms,
- Variable slots (`pz.ParameterSlot` and `pz.StateVariableSlot`), which are empty placeholders that indicate locations of variables in a larger tree.

For full control over JIT compilation, you can manually convert variables from their mutable form to their immutable form when crossing JAX transform boundaries. The relevant functions:

- `pz.unbind_variables` (and type-specific variants `pz.unbind_params` and `pz.unbind_state_vars`): Extracts and deduplicates variables, returning a tree of variable slots along with the deduplicated variables.
- `pz.bind_variables`: Re-inserts variables into variable slots.
- `Parameter.freeze()` and `StateVariable.freeze()`: Converts a mutable variable into an immutable value.
- `ParameterValue.unfreeze_as_copy()` and `StateVariableValue.unfreeze_as_copy()`: Converts an immutable value back into a (new) mutable variable.

For instance, for our example model above, we can use `pz.unbind_variables` and `.freeze()` to extract the mutable parts:

In [12]:
model_with_slots, all_vars = pz.unbind_variables(model)
pz.show("model_with_slots:", model_with_slots)
pz.show("all_vars:", all_vars)

In [13]:
frozen_vars = [var.freeze() for var in all_vars]
pz.show("frozen_vars:", frozen_vars)

We can then define a pure function that re-binds these variables, and call it under `jax.jit`:

In [14]:
@jax.jit
def rebinding_call(model_with_slots, frozen_vars, arg):
  # Make temporary mutable copies:
  new_vars = [var.unfreeze_as_copy() for var in frozen_vars]
  # Re-attach them to the model:
  model = pz.bind_variables(model_with_slots, new_vars)
  # Run it:
  result = model(arg)
  # Extract and re-freeze the variables:
  refrozen_vars = [var.freeze() for var in new_vars]
  return result, refrozen_vars

In [15]:
result, new_frozen_vars = rebinding_call(
    model_with_slots, frozen_vars, pz.nx.ones({"features": 8})
)
pz.show("result:", result)
pz.show("new_frozen_vars:", new_frozen_vars)

We can then update the old variables with their new values:

In [16]:
for var, new_value in zip(all_vars, new_frozen_vars):
  var.update(new_value)

To make this a bit less verbose, `pz.nn.Layer` has a method `.stateless_call(vars, ...)` that makes temporary mutable copies of its input variables, like `rebinding_call`. So, we could have equivalently written the following:

In [17]:
@jax.jit
def rebinding_call_2(model_with_slots, frozen_vars, arg):
  result, refrozen_vars = model_with_slots.stateless_call(frozen_vars, arg)
  return result, refrozen_vars

In [18]:
result, new_frozen_vars = rebinding_call_2(
    model_with_slots, frozen_vars, pz.nx.ones({"features": 8})
)
pz.show("result:", result)
pz.show("new_frozen_vars:", new_frozen_vars)

If you want to JIT-compile your model initializer, you can do this using the functional API:

In [19]:
@jax.jit
def functional_init(init_base_rng):
  model = pz.nn.Sequential([
      pz.nn.Linear.from_config(
          name="linear",
          init_base_rng=init_base_rng,
          input_axes={"features": 8},
          output_axes={"features_out": 8},
      ),
      CounterLayer(counter=pz.StateVariable(value=0, label="counter")),
  ])
  # Unbind and also freeze all variables:
  return pz.unbind_variables(model, freeze=True)

In [20]:
model_with_slots, init_var_values = functional_init(jax.random.PRNGKey(0))
# Re-bind variables and also make them mutable again:
model = pz.bind_variables(
    model_with_slots, init_var_values, unfreeze_as_copy=True
)

In [21]:
model

### `pz.variable_jit`

If you don't want to use the functional API directly, you can instead use `pz.variable_jit`, which is a wrapper around `jax.jit` that allows the function arguments to contain `pz.Parameter` and `pz.StateVariable` in addition to arrays, and handles updating their values for you. For instance, you could write:

In [22]:
@pz.variable_jit
def jitted_call(model, arg):
  return model(arg)

In [23]:
jitted_call(model, pz.nx.ones({"features": 8}))

In [24]:
jitted_call(model, pz.nx.ones({"features": 8}))

Note that `pz.variable_jit` does not support returning variables from the jitted computation, so it can't be used to JIT-compile model initialization. It also does not support "closing over" global references to variable objects defined outside of the function. Every variable used by the function inside `pz.variable_jit` must have been passed in as an input argument.

### `jit_wrapper.Jitted`

`pz.variable_jit` works for top-level functions, but sometimes you may want to JIT-compile a specific part of a Penzai model, or compile the forward pass without having to use an indirect `jitted_call` function. For this purpose, Penzai provides a layer wrapper `Jitted` in `penzai.experimental.v2.toolshed.jit_wrapper`, which JIT-compiles its forward pass when called.

To use it, you can simply wrap your model in `jit_wrappers.Jitted` and then call it as normal:

In [25]:
from penzai.experimental.v2.toolshed import jit_wrapper

In [26]:
jit_model = jit_wrapper.Jitted(model)
jit_model

In [27]:
jit_model(pz.nx.ones({"features": 8}))

You can also insert `Jitted` around any sublayer of the model, e.g.

In [28]:
jit_model_2 = (
    pz.select(model)
    .at_instances_of(pz.nn.Linear | CounterLayer)
    .apply(jit_wrapper.Jitted)
)
jit_model_2

In [29]:
jit_model_2(pz.nx.ones({"features": 8}))

Note that the `Jitted` wrapper is just an ordinary Penzai layer, and you can still pull back out the original model:

In [30]:
jit_model.body == model

## Sharding Basics, and Visualizing Shardings with Treescope

Penzai's array autovisualizer supports showing shardings and sharded arrays by default. This section explains the basics of JAX's distributed array shardings and how you can visualize the different components in Treescope. (See [this page](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) for the official documentation of JAX's sharding system.)

### Positional shardings

At a high level, you can think of a "sharding" as a multidimensional array of device objects, which will be matched with your multidimensional array of data to determine which part of the array ends up on each device. You generally build a sharding by starting with a NumPy array of devices:

In [31]:
from jax.experimental import mesh_utils
devices = mesh_utils.create_device_mesh((8,))
devices

A simple type of sharding is `PositionalSharding`, which essentially just holds onto these devices and tracks some extra JAX-specific information. If you print out a `PositionalSharding` in Treescope, it color-codes the devices and shows you their arrangement:

In [32]:
pos_sharding = jax.sharding.PositionalSharding(devices)
pos_sharding

In this case, the sharding has a single positional axis, of length 8. We can use this to shard arrays whose (first) positional axis is a multiple of 8. For instance:

In [33]:
jax.device_put(jnp.ones(16), pos_sharding)

You can click the "Sharded across 8 TPU devices" message to show a visualization of the sharding for this array. When automatic array visualization is enabled, sharding visualizations are automatically added to any array that is sharded or replicated.

We can reshape positional shardings to give them multiple axes:

In [34]:
pos_sharding.reshape((4,2))

In [35]:
jax.device_put(jnp.ones([8, 8]), pos_sharding.reshape((4,2)))

If you expand the sharding visualization above, you'll see how the two axes of the array are matched with the two axes of the sharding.

You can also use shardings to indicate that certain parts of the array should be *replicated* on multiple devices, using `replicate`:

In [36]:
pos_sharding.reshape((2, 4)).replicate(axis=0)

In [37]:
jax.device_put(jnp.ones([8, 8]), pos_sharding.reshape((2, 4)).replicate(axis=0))

Each element of an array with a replicated sharding will appear on more than one device. This is visually represented in Treescope using a multicolored pattern.

You can also fully-replicate the array over all of the devices:

In [38]:
pos_sharding.replicate(axis=0)

In [39]:
jax.device_put(jnp.ones([8, 8]), pos_sharding.replicate(axis=0).reshape((1, 1)))

Fully-replicated arrays are also identified as such in the sharding summary before being expanded.

### Meshes and named shardings

It is often convenient to refer to different axes of an array of devices by name instead of by position. JAX represents this using the type `jax.sharding.Mesh`. Conceptually, just as a `PositionalSharding` is essentially a positional array of devices, a `Mesh` is essentially a named array of devices, i.e. an array of devices where each axis has a name.

Penzai annotates the device ID arrays of `Mesh` instances with axis names instead of axis positions:

In [40]:
mesh = jax.sharding.Mesh(devices.reshape((4, 2)), axis_names=('foo', 'bar'))
mesh

To shard a (positionally-indexed) JAX array using a mesh, you can use `jax.sharding.NamedSharding` to assign particular axis indices to mesh axis names, like this:

In [41]:
jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('foo', 'bar'))

In [42]:
jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None, ('bar', 'foo'), None))

In [43]:
jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('foo'))

Note: Each `NamedSharding` specifies how to shard an input array's *positional axes*, since ordinary JAX arrays only have positional axes. The names in the `NamedSharding` are just a way to match the positional axes in the array with the corresponding names in the `Mesh`. For this reason, visualizations of `NamedSharding` instances are annotated with positional axes, not axis names.

(Penzai already has its own mechanism for binding names to an array's positional axes: `pz.nx.NamedArray`. We'll discuss how to shard Penzai's `NamedArray` next.)

## Sharding Penzai's NamedArrays

### Manually sharding NamedArrays

Fundamentally, there are no changes when applying JAX shardings to Penzai's `NamedArray`s. Internally, a `NamedArray` is just a dataclass PyTree node that contains a JAX array and some axis name annotations, which we can see if we disable automatic array visualization temporarily:

In [44]:
arr = pz.nx.arange("foo", 1, 4) + pz.nx.arange("bar", 0, 4)

In [45]:
# With automatic array visualization enabled:
arr

In [46]:
%%autovisualize None
# ^ With automatic array visualization disabled (and expanding it to show detail)
pz.select(arr).at_instances_of(jax.Array).show_value()

JAX's sharding system allows you to specify the sharding for a PyTree of arrays by using a matching PyTree of shardings. So, we can build a sharding for this named array by inserting a positional sharding into it:

In [47]:
data_array_sharding = jax.sharding.PositionalSharding(devices).reshape((2,4)).replicate(axis=0)
sharding_for_arr = pz.nx.NamedArray(
    named_axes=arr.named_axes,
    data_array=data_array_sharding,
)
sharding_for_arr

Applying this sharding to the NamedArray shards the `data_array` attribute (try expanding below):

In [48]:
%%autovisualize lambda a,p: pz.ts.ArrayAutovisualizer()(a, p) if isinstance(a, jax.Array) else None
# (^ this line overrides the autovisualizer to show the sharding of the data array when expanded)

sharded_arr = jax.device_put(arr, sharding_for_arr)
pz.select(sharded_arr).at_instances_of(jax.Array).show_value()

But with normal automatic array visualization, treescope will show you how the *named* axes are sharded, since that's usually what you care about when using Penzai models in practice:

In [49]:
sharded_arr

### Automatically building shardings for NamedArrays

To simplify this process, Penzai provides some optional utilities for constructing shardings for `NamedArray` instances. These utilities take a `Mesh`, and allow you to map from `NamedArray` axis names to `Mesh` axis names across a tree of arrays.

For instance, consider this tree of arrays:

In [50]:
some_array_tree = {
    "one": pz.nx.ones({"a": 4, "b": 8, "c": 6}),
    "two": pz.nx.ones({"a": 8}),
    "three": pz.nx.ones({"b": 4, "d": 12}),
}
some_array_tree

And this mesh:

In [51]:
mesh = jax.sharding.Mesh(devices.reshape((4, 2)), axis_names=('foo', 'bar'))
mesh

We can assign each named axis in `some_array_tree` to an axis in the mesh using the `name_to_name_sharding` utility, which builds a tree of shardings that is compatible with the tree of arrays:

In [52]:
from penzai.experimental.v2.toolshed import sharding_util

In [53]:
shardings = sharding_util.name_to_name_sharding(
    some_array_tree,
    mesh,
    axis_name_to_mesh_name={
        "a": "bar",
        "b": "foo",
    },
)
shardings

We can then apply those shardings to the original array tree to shard the corresponding axes:

In [54]:
jax.device_put(some_array_tree, shardings)

Even simpler, if you just want to call `device_put` you can bundle them into one call:

In [55]:
sharding_util.name_to_name_device_put(
    some_array_tree,
    mesh,
    axis_name_to_mesh_name={
        "a": "bar",
        "b": "foo",
    },
)

If your mesh happens to use the exact same axis names as your arrays, you don't need the `axis_name_to_mesh_name` argument:

In [56]:
already_matching_mesh = jax.sharding.Mesh(devices.reshape((4, 2)), axis_names=('b', 'a'))
sharding_util.name_to_name_device_put(
    some_array_tree,
    already_matching_mesh,
    # axis_name_to_mesh_name inferred as {"a":"a", "b":"b"}
)

## Sharding Penzai Models and Training Loops

Penzai also provides some utilities that are specific to training and using Penzai neural newtork models. These are simple self-contained utilities that can be a good starting point, but you are free to customize them to get lower-level control when needed.

### Sharding Parameter Initializers

If you already know the shardings for your model parameters, you can pass those, you can JIT-compile parameter optimization using something like

```python
def functional_init(init_base_rng):
  model = ...
  return pz.unbind_variables(model, freeze=True)

sharded_init = jax.jit(
  functional_init,
  out_shardings=..., # <- insert your desired sharding specification here
)

model = pz.bind_variables(*sharded_init(rng))
```

If you want to infer `out_shardings` using the axis names of your parameters, you can do that using the helper function `sharding_util.sharded_init`. This function just traces the initializer to figure out the parameter shapes, infers the right sharding to use, and then runs your initializer accordingly.

For instance, here's how you could initialize the parameters of a small transformer in a sharded way:

In [57]:
from penzai.experimental.v2.toolshed import sharding_util
from penzai.experimental.v2.models.transformer.variants import llamalike_common

In [58]:
# Very small transformer config, for demo purposes
config = llamalike_common.LlamalikeTransformerConfig(
    num_kv_heads=2,
    query_head_multiplier=1,
    embedding_dim=64,
    projection_dim=16,
    mlp_hidden_dim=128,
    num_decoder_blocks=2,
    vocab_size=100,
    mlp_variant="geglu_approx",
    rope_wavelength=10_000,
    tie_embedder_and_logits=True,
    use_layer_stack=False,
    parameter_dtype=jnp.float32,
    activation_dtype=jnp.float32,
)

tiny_transformer = sharding_util.sharded_init(
    llamalike_common.build_llamalike_transformer,
    config=config,
    init_base_rng=jax.random.key(42),
    mesh=jax.sharding.Mesh(devices, axis_names=('devices',)),
    axis_name_to_mesh_name={
        # Shard the embedding dimension across devices.
        "embedding": "devices",
    },
)

In [59]:
tiny_transformer

### Sharding Model Training and Inference

Once you've sharded your parameters, you usually don't have to do anything special to enable device-parallel computation when training or running a model. This is because JAX can automatically propagate and infer array sharding information. (See JAX's documentation on [automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)!)

For instance, we can call our sharded model with a sharded input:

In [60]:
tokens = sharding_util.name_to_name_device_put(
    pz.nx.ones({"batch": 16, "seq": 20}, dtype=jnp.int32),
    mesh=jax.sharding.Mesh(devices, axis_names=('devices',)),
    axis_name_to_mesh_name={"batch": "devices"},
)

In [61]:
result = tiny_transformer(tokens)
result

The result in this case will usually be sharded also over the batch axis, which means JAX automatically chose a "fully sharded data parallel" (FSDP) sharding for our computation!

In [62]:
pz.ts.render_array_sharding(result)


If you need more direct control, you can use the `in_shardings` and `out_shardings` arguments of `jax.jit` in combination with the "Functional API" for Penzai's parameters and state variables, described above.

### Adding Sharding Constraints to Models

You may want more control over the way that intermediate values are sharded. JAX allows you to control this using `jax.lax.with_sharding_constraint`, which forces a particular value to have a particular sharding.

In a Penzai model, sharding constraints can be enforced by simply inserting new layers into the model at the points where you want to constrain the shardings. Penzai's `sharding_util` module provides two simple classes `ConstrainSharding` and `ConstrainShardingByName` for this purpose, defined as
```python
@pz.pytree_dataclass
class ConstrainSharding(pz.nn.Layer):
  sharding: PyTreeOfShardings = field(metadata={"pytree_node": False})
  def __call__(self, tree: Any, **_unused_side_inputs) -> Any:
    return jax.lax.with_sharding_constraint(tree, self.sharding)

@pz.pytree_dataclass
class ConstrainShardingByName(pz.nn.Layer):
  mesh: jax.sharding.Mesh = field(metadata={"pytree_node": False})
  axis_name_to_mesh_name: dict[str, str | tuple[str, ...]] | None = (
      field(default=None, metadata={"pytree_node": False})
  )
  def __call__(self, tree: PyTreeOfNamedArrays, **_unused_side_inputs) -> PyTreeOfNamedArrays:
    return jax.lax.with_sharding_constraint(
        tree,
        name_to_name_sharding(tree, self.mesh, self.axis_name_to_mesh_name),
    )
```

You can insert them into the model using logic like this:

In [63]:
mesh = jax.sharding.Mesh(devices, axis_names=('devices',))

In [64]:
# Make sure it's sharded over the batch axis after each block.
tiny_transformer_constrained = (
    pz.select(tiny_transformer)
    .at_instances_of(transformer.model_parts.TransformerBlock)
    .insert_after(sharding_util.ConstrainShardingByName(
        mesh, axis_name_to_mesh_name={"batch": "devices"}
    ))
)

In [65]:
# Visualize the constraints:
pz.select(tiny_transformer_constrained).at_instances_of(sharding_util.ConstrainShardingByName)

This gives you a version of the model whose intermediates will always be sharded in the way you specified.

If you later want to change how your model's intermediates are sharded, you can simply remove these constraints:

In [66]:
tiny_transformer_unconstrained = (
    pz.select(tiny_transformer_constrained)
    .at_instances_of(sharding_util.ConstrainShardingByName)
    .remove_from_parent()
)

# No more constraints:
(
    pz.select(tiny_transformer_unconstrained)
    .at_instances_of(sharding_util.ConstrainShardingByName)
    .assert_count_is(0)
)