*Copyright 2024 The Penzai Authors.*

*Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
You may obtain a copy of the License at*

> http://www.apache.org/licenses/LICENSE-2.0

*Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
See the License for the specific language governing permissions and
limitations under the License.*

---

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/penzai/blob/main/notebooks/lora_from_scratch.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google-deepmind/penzai/blob/main/notebooks/lora_from_scratch.ipynb)

# LoRA From Scratch - Patching Pretrained Models in Penzai

Penzai is designed to make it easy to make targeted modifications to neural networks after they have been trained. In this notebook, we'll show how to take Penzai's reference implementation of [Gemma 7B](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) open-weights transformer model, patch it to support Low-Rank Adaptation (LoRA [Hu et al. 2021](https://arxiv.org/abs/2106.09685)), and train the new parameters on a toy problem with a hand-written loss function.

The goal of this notebook is to show how *you* could implement something like LoRA from scratch in less than a hundred lines of code, starting from a Penzai implementation of a model that doesn't support it already, and without having to fork the existing implementation source code or even modify the pretrained model's configuration. We'll define everything we need as we go and make changes to models interactively. In fact, our implementation will end up being completely modular; we'll start by applying LoRA to a small MLP and then immediately be able to transfer our implementation to Gemma 7B.

Let's get started!

## Setup

Before we can get started in earnest, we need to set up the environment.

### Imports

To run this notebook, you need a Python environment with `penzai` and its dependencies installed.

In Colab or Kaggle, you can install it using the following command:

In [None]:
try:
  import penzai
except ImportError:
  !pip install penzai[notebook]

In [None]:
from __future__ import annotations

In [None]:
import os
import gc
import collections

import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint
import optax
from jax.experimental import mesh_utils

In [None]:
import penzai
from penzai.deprecated.v1 import pz

In [None]:
import sentencepiece as spm

In [None]:
from penzai.deprecated.v1.example_models import gemma
from penzai.deprecated.v1.example_models import simple_mlp
from penzai.deprecated.v1.toolshed import basic_training
from penzai.toolshed import token_visualization
from penzai.deprecated.v1.toolshed import jit_wrapper

### Setting up Penzai

For this tutorial, we'll enable Treescope (Penzai's pretty-printer) as the default IPython pretty-printer. This is recommended when using Penzai in an interactive environment.

In [None]:
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.ts.register_context_manager_magic()

## Intro to Penzai's declarative combinator design

We'll start by giving a brief introduction to Penzai's design conventions, and how they make it easy to insert adapters into pretrained models. Let's begin by initializing a small MLP:

In [None]:
mlp = pz.nn.initialize_parameters(
    simple_mlp.MLP.from_config([8, 32, 32, 8]),
    jax.random.key(0),
)

Like most Penzai models and layers, this MLP takes named arrays as input and returns them as output. A named array is just a wrapped JAX array where a subset of its positional axes have been tagged with names. (See the [named axes tutorial](named_axes.ipynb) for more info on how to use Penzai's named axis system.)

We can call the MLP directly on an array of inputs to run it:

In [None]:
%%autovisualize
mlp(pz.nx.NamedArray.wrap(jnp.arange(8, dtype=jnp.float32)).tag("features"))

Penzai models are written in a *declarative*, *combinator*-based style. This means that the structure of the model directly matches the sequence of high-level operations that the model will run in its forward pass. Composite models, like our MLP, just hold onto their sublayers in a list and run these sublayers in order. Primitive layers, like `Linear`, hold on to their parameters as attributes instead of reading them from an external parameter dictionary.

We can see the sublayers by pretty-printing the model:

In [None]:
%%autovisualize
mlp

By convention, most of the "complicated" logic in Penzai model classes happens when we initialize them, using the `.from_config` method we called earlier. Once the model is built, the pretty-printed representation provides a full specification of everything the model does, and the parameters are stored as direct attributes on the layers that need them. A general design principle of Penzai is "*what you see is what you get*"; you should be able to learn everything you need to know about a model by printing it out.

In fact, every Penzai model class can be rebuilt by executing it's pretty-printed representation. You can click on a pretty-printed output and press `r` to add qualified names to the pretty-printed visualization, and if you remove the arrays first, you can then copy and paste the entire pretty-printed code and execute it to make a copy of the model structure:

In [None]:
# We can use `eval_shape` to remove array data and just keep the shapes.
# Try pressing `r` and copying the below output:
jax.eval_shape(lambda: mlp)

In [None]:
penzai.deprecated.v1.example_models.simple_mlp.MLP( # Sequential
  sublayers=[
    penzai.deprecated.v1.nn.linear_and_affine.Affine( # Sequential
      sublayers=[
        penzai.deprecated.v1.nn.linear_and_affine.Linear(weights=penzai.deprecated.v1.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 8, 'features_out': 32}), data_array=jax.ShapeDtypeStruct(shape=(8, 32), dtype=np.dtype('float32'))), name='Affine_0.Linear.weights'), in_axis_names=('features',), out_axis_names=('features_out',)),
        penzai.deprecated.v1.nn.linear_and_affine.RenameAxes(old=('features_out',), new=('features',)),
        penzai.deprecated.v1.nn.linear_and_affine.AddBias(bias=penzai.deprecated.v1.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32}), data_array=jax.ShapeDtypeStruct(shape=(32,), dtype=np.dtype('float32'))), name='Affine_0.AddBias.bias'), new_axis_names=()),
      ],
    ),
    penzai.deprecated.v1.nn.basic_ops.Elementwise(fn=jax.nn.relu),
    penzai.deprecated.v1.nn.linear_and_affine.Affine( # Sequential
      sublayers=[penzai.deprecated.v1.nn.linear_and_affine.Linear(weights=penzai.deprecated.v1.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32, 'features_out': 32}), data_array=jax.ShapeDtypeStruct(shape=(32, 32), dtype=np.dtype('float32'))), name='Affine_1.Linear.weights'), in_axis_names=('features',), out_axis_names=('features_out',)), penzai.deprecated.v1.nn.linear_and_affine.RenameAxes(old=('features_out',), new=('features',)), penzai.deprecated.v1.nn.linear_and_affine.AddBias(bias=penzai.deprecated.v1.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32}), data_array=jax.ShapeDtypeStruct(shape=(32,), dtype=np.dtype('float32'))), name='Affine_1.AddBias.bias'), new_axis_names=())],
    ),
    penzai.deprecated.v1.nn.basic_ops.Elementwise(fn=jax.nn.relu),
    penzai.deprecated.v1.nn.linear_and_affine.Affine( # Sequential
      sublayers=[penzai.deprecated.v1.nn.linear_and_affine.Linear(weights=penzai.deprecated.v1.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 32, 'features_out': 8}), data_array=jax.ShapeDtypeStruct(shape=(32, 8), dtype=np.dtype('float32'))), name='Affine_2.Linear.weights'), in_axis_names=('features',), out_axis_names=('features_out',)), penzai.deprecated.v1.nn.linear_and_affine.RenameAxes(old=('features_out',), new=('features',)), penzai.deprecated.v1.nn.linear_and_affine.AddBias(bias=penzai.deprecated.v1.nn.parameters.Parameter(value=penzai.core.named_axes.NamedArray(named_axes=collections.OrderedDict({'features': 8}), data_array=jax.ShapeDtypeStruct(shape=(8,), dtype=np.dtype('float32'))), name='Affine_2.AddBias.bias'), new_axis_names=())],
    ),
  ],
)

We won't usually do this in practice, because device arrays can't be copy-pasted this way; the parameters will be replaced with placeholder objects. Instead, Penzai provides a sophisticated *selector* system (`pz.select`) that allow us to make targeted modifications to (copies of) models. The point here is that Penzai model objects aren't "hiding" anything; they directly expose the structure of their computation as a data structure that can be manipulated.

The specific types of Penzai models and composite layers are provided primarily for ease of manipulation and as a way to identify how each part of your model was built. But you can also "flatten" a model into a list of sublayers that run in sequence, discarding this extra information:

In [None]:
pz.nn.inline_groups(mlp, lambda _: True, lambda _: True)

And you can freely add new logic as well, even if it wasn't configured in the initial model. For instance, here's how you could insert a new layer that prints out its intermediate activation:

In [None]:
@pz.pytree_dataclass  # <- This tags our class as being a Python dataclass and a JAX pytree node.
class DisplayIntermediateValue(pz.Layer):  # <- pz.Layer is the base class of Penzai layers.

  def __call__(self, intermediate_value):
    # Show the value:
    pz.show("Showing an intermediate value:", intermediate_value)
    # And return it unchanged.
    return intermediate_value

In [None]:
patched = (
    pz.select(mlp)
    .at(lambda model: model.sublayers[2])
    .insert_after(DisplayIntermediateValue())
)
pz.select(patched).at_instances_of(DisplayIntermediateValue).show_selection()

`patched` is a *copy* of our model that includes our new layer, and it will run our new logic when the model is called:

In [None]:
%%autovisualize
patched(pz.nx.NamedArray.wrap(jnp.arange(8, dtype=jnp.float32)).tag("features"))

This ability makes it remarkably easy to implement adapters like LoRA!

## Building a simple LoRA Layer in Penzai

Low-Rank Adaptation (LoRA) is a parameter-efficient fine-tuning strategy that augments each linear operation in the model with a decomposed low-rank adapter. The original weight matrix is frozen, and two smaller learnable parameter matrices are used to perturb its output. These parameters are kept separate from the original matrix, so gradients of these new parameters can be easily updated in a compute- and memory-efficient way.

The effective weight matrix can be decomposed like this:

```
 ┌────────────────┐       ┌─────┐                      
 │                │       │     │                      
 │                │       │  A: │   ┌────────────────┐
 │    W: d*d      │   +   │ d*r │ * │     B: r*d     │
 │                │       │     │   └────────────────┘
 │                │       │     │                      
 └────────────────┘       └─────┘                      
```

Here `W` is the original frozen weight matrix, `A` is a randomly-initialized matrix, and `B` is initialized to zero to ensure that the adapted model is equivalent to the original one at initialization.

To enable LoRA, we'll do three things for each linear layer in our model:
- Freeze the original weight,
- Initialize our low-rank matrices A and B,
- And replace the original linear layer with the composition of W, A, and B.

Let's try it out with a simple MLP like the one we built in the last section. We'll just randomly initialize one for demonstration purposes; in a real LoRA adaptation setting we would generally load this from a pre-trained model checkpoint.

In [None]:
mlp = pz.nn.initialize_parameters(
    simple_mlp.MLP.from_config([2, 32, 32, 2]),
    jax.random.key(0),
)

### Step 1: Freeze parameters

We'll start by freezing all the parameters. Learnable parameters are identifiable because they are instances of `pz.nn.Parameter`:

In [None]:
pz.select(mlp).at_instances_of(pz.nn.Parameter).show_selection()

In [None]:
pz.select(mlp).at_instances_of(pz.nn.Parameter).get_sequence()

We can freeze these parameters by replacing each `Parameter` with an equivalent `FrozenParameter`. This is directly tracked inside the structure of the model.

In [None]:
frozen_mlp = pz.select(mlp).at_instances_of(pz.nn.Parameter).apply(
    lambda param: pz.nn.FrozenParameter(param.value, param.name)
)
frozen_mlp

In [None]:
pz.select(frozen_mlp).at_instances_of(pz.nn.Parameter).get_sequence()

### Step 2: Replace `Linear` layers with low-rank adapted versions

Next, we'll replace the Linear layers with implementations of LoRA.

In essence, a LoRA block is a sum of two computation paths: one that uses the original linear layer, and one that uses a sequence of two linear operations. This pattern can be directly mapped to one of Penzai's simple built-in combinators, `BranchAndAddTogether`. We can take each linear layer, like this one:

In [None]:
frozen_mlp.sublayers[0].sublayers[0]

And replace it with a block like this:

In [None]:
pz.nn.BranchAndAddTogether([
    # The original layer with frozen parameters:
    pz.nn.NamedGroup("Pretrained", [
        frozen_mlp.sublayers[0].sublayers[0],
    ]),
    # And a low-rank adapter:
    pz.nn.NamedGroup("Update", [
        pz.nn.add_parameter_prefix(
            "LoRA-A",
            pz.nn.Linear.from_config(
                input_axes={"features": 8},
                output_axes={"lowrank": 2},
            ),
        ),
        pz.nn.add_parameter_prefix(
            "LoRA-B",
            pz.nn.Linear.from_config(
                input_axes={"lowrank": 2},
                output_axes={"features_out": 8},
                initializer=pz.nn.zero_initializer,
            ),
        ),
    ]),
])

Note that the above code is a direct translation of a LoRA block into the structure of our model. The matrices A and B are represented as separate Linear blocks inside the overall combinator, and the order of execution is determined by the positions in the `NamedGroup`.

To simplify the process of making this transformation at every Linear block, we can encapsulate it into a new Layer subclass. Since the computation can already be written as a combination of existing pieces, the idiomatic Penzai approach is to define our new Layer as a subclass of `pz.nn.Sequential`, so that it can be easily flattened (like we did with the MLP) id needed. `Sequential` already defines the necessary attributes and `__call__` method, so we just need to provide a named initializer:

In [None]:
@pz.pytree_dataclass(has_implicitly_inherited_fields=True)
class LowRankAdapter(pz.nn.Sequential):

  @classmethod
  def from_linear(
      cls,
      linear: pz.nn.Linear,
      rank: int,
      name: str,
      lowrank_axis: str = "lowrank",
  ) -> 'LowRankAdapter':
    """Builds a LoRA layer from a Linear layer.

    Args:
      linear: The linear layer to adapt.
      rank: The rank of the low-rank adapter.
      name: Prefix for this block's parameters.
      lowrank_axis: The axis name for low-rank adaptation.

    Returns:
      A LoRA block with uninitialized parameters and the same initial
      behavior as `linear`.
    """
    return cls([
        pz.nn.BranchAndAddTogether([
            pz.nn.NamedGroup("Pretrained", [linear]),
            pz.nn.NamedGroup("Update", [
                pz.nn.add_parameter_prefix(
                    name + "/LoRA_A",
                    pz.nn.Linear.from_config(
                        input_axes=linear.input_axes,
                        output_axes={lowrank_axis: rank},
                        parallel_axes=linear.parallel_axes,
                    ),
                ),
                pz.nn.add_parameter_prefix(
                    name + "/LoRA_B",
                    pz.nn.Linear.from_config(
                        input_axes={lowrank_axis: rank},
                        output_axes=linear.output_axes,
                        parallel_axes=linear.parallel_axes,
                        initializer=pz.nn.zero_initializer,
                    ),
                ),
            ]),
        ])
    ])

Note: Idiomatic Penzai layers generally avoid overriding `__init__`, since dataclasses take their attributes as parameters to `__init__` and we want to ensure the output of the pretty-printer directly corresponds to code we could use to rebuild the model even if we've modified its attributes. When we have nontrivial construction logic, we'll usually define it in a class method like `from_linear` or `from_config` instead.

Also, in most Penzai layers, each layer is only responsible for ensuring it's parameter names are *locally* unique, and parent layers add parameter prefixes using `pz.nn.add_parameter_prefix` at each level. In this case, however, we're planning on inserting the LoRA blocks into an existing model, so the names must be *globally* unique. This is why `from_linear` takes a name as an argument but `Linear.from_config` does not.



The next step is to write a helper function for inserting LoRA blocks into a model. We'll use Penzai's `pretty_keystr` function (a fancier version of `jax.tree_util.keystr`) to ensure each block has a unique name:

In [None]:
def loraify_all_linears(model, rank: int):
  return (
      pz.select(model)
      .at_instances_of(pz.nn.Linear)
      .apply(
          lambda keypath, lin: LowRankAdapter.from_linear(
              lin,
              rank=rank,
              name=pz.pretty_keystr(keypath, model),
          ),
          with_keypath=True,
      )
  )

Now we can run it on our MLP:

In [None]:
loraified_mlp_uninit = loraify_all_linears(frozen_mlp, rank=2)
loraified_mlp_uninit

You can directly check that this transformation is doing the right thing by expanding each `Affine` layer and making sure the `LowRankAdapter` looks right.

Note that `loraified_mlp_uninit` is a *copy* of `frozen_mlp` with the requested modifications. In Penzai, transformations of models always return new copies of the model, so you don't have to worry about accidentally making an irreversible change.

(Only the model *structure* is copied; the JAX arrays still share memory between the models. But JAX arrays are immutable, so you don't have to worry about those changing either;  unless you explicitly delete or donate them, training loops always return new copies of your model with updated parameters.)

### Step 3: Initializing and training the LoRA weights

Finally, we can initialize and train the new weights we inserted into the model. To initialize them, we can use the standard Penzai parameter initialization helper function, which finds all `UninitializedParameter` instances and initializes them. In this case, the `UninitializedParameter`s are the LoRA weights, and the `FrozenParameter`s from the pretrained model are ignored.

In [None]:
%%autovisualize
loraified_mlp = pz.nn.initialize_parameters(loraified_mlp_uninit, jax.random.key(42))
loraified_mlp

Since we froze the "pretrained" parameters before we applied `loraify_all_linears`, the learnable parameters of our new model are just the new LoRA weights:

In [None]:
pz.select(loraified_mlp).at_instances_of(pz.nn.Parameter).get_sequence()

This means you can easily train them using Penzai's basic training loop helpers, or write your own custom training loop for them. As a demonstration, we'll train this model to implement XOR by only fitting the low-rank adapter parameters.

In [None]:
xor_inputs = pz.nx.wrap(
    jnp.array([[-1, -1], [-1, 1], [1, -1], [1, 1]], dtype=jnp.float32),
    "batch",
    "features",
)
xor_labels = jnp.array([[0, 1], [1, 0], [1, 0], [0, 1]], dtype=jnp.float32)

def loss_fn(model, rng, state):
  assert state is None
  model_out = model(xor_inputs)
  log_probs = jax.nn.log_softmax(
      model_out.unwrap("batch", "features"), axis=-1
  )
  losses = -log_probs * xor_labels
  loss = jnp.sum(losses) / 4
  return loss, None, {"loss": loss}

In [None]:
train_step = basic_training.build_train_step_fn(loss_fn)
train_state = basic_training.TrainState.initial_state(
    model=loraified_mlp,
    optimizer_def=optax.adam(0.1),
    root_rng=jax.random.key(42),
)

In [None]:
train_state

In [None]:
for i in range(20):
  train_state, out = train_step(train_state)
  print(i, out)

In [None]:
train_state

`TrainState` is an optional utility that manages the optimizer states for us. It also partitions the model into learnable and nonlearnable parts, but we can combine them again by reading the computed property `train_state.model`:

In [None]:
%%autovisualize
train_state.model(xor_inputs)

In [None]:
%%autovisualize
pz.nx.nmap(jnp.argmax)(train_state.model(xor_inputs).untag("features"))

Looks like it worked!

## Adding LoRA to Gemma

Let's now try adding LoRA to the Gemma 7B pretrained model. Because of Penzai's compositional design, the implementation in the previous section will just work out of the box!

### Loading Gemma

We'll start by loading the weights from the Gemma checkpoint. We'll use the 7B checkpoint for this tutorial, and shard it over our local devices using JAX's automatic partitioning. (You can read more about JAX's automatic distributed arrays [on this JAX documentation page](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html).)

If you prefer, you can also run this tutorial with the 2B checkpoint.

You can download the Gemma checkpoints using a Kaggle account and an API key. If you don't have an API key already, you can:

1. Visit https://www.kaggle.com/ and create an account if needed.
2. Go to your account settings, then the 'API' section.
3. Click 'Create new token' to download your key.

Next, if you are running this notebook in Google Colab:

1. Click the "key" symbol on the left toolbar to open the "Secrets" tab.
2. Add two new secrets, named "KAGGLE_USERNAME" and "KAGGLE_KEY", and set their values based on the API key you downloaded.
3. Run the cell below and grant this notebook access to the secrets you just made.

If you are not running this notebook in Google Colab, you can instead run the cell below, input your username and API key in the textboxes, and click the login button.

In [None]:
import kagglehub
try:
  from google.colab import userdata
  kagglehub.config.set_kaggle_credentials(
      userdata.get("KAGGLE_USERNAME"), userdata.get("KAGGLE_KEY")
  )
except ImportError:
  kagglehub.login()

If everything went well, you should see:

```
Kaggle credentials set.
```

Before downloading Gemma, you will also need to consent to the Gemma Terms of Use. If you haven't done that yet, you can do so here:

> https://www.kaggle.com/models/google/gemma/license/consent

(Make sure you choose to "Verify via Kaggle Account" with the same account you used to log in above!)

Once you've agreed to the terms, you can run the next cell to download the Gemma weights:

In [None]:
weights_dir = kagglehub.model_download('google/gemma/Flax/7b')
ckpt_path = os.path.join(weights_dir, '7b')
vocab_path = os.path.join(weights_dir, 'tokenizer.model')

We can then load the SentencePiece vocabulary and restore the checkpointed parameters into JAX using `orbax`:

In [None]:
vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)

In [None]:
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
metadata = checkpointer.metadata(ckpt_path)

In [None]:
n_devices = jax.local_device_count()
sharding_devices = mesh_utils.create_device_mesh((n_devices,))
sharding = jax.sharding.PositionalSharding(sharding_devices)
restore_args = jax.tree_util.tree_map(
    lambda m: orbax.checkpoint.ArrayRestoreArgs(
        restore_type=jax.Array,
        sharding=sharding.reshape((1,) * (len(m.shape) - 1) + (n_devices,))
    ),
    metadata,
)
flat_params = checkpointer.restore(ckpt_path, restore_args=restore_args)

In [None]:
gemma_model = gemma.model_core.GemmaTransformer.from_pretrained(
    flat_params,
    upcast_activations_to_float32=True,
)

In [None]:
del flat_params
gc.collect()

Here's what the Gemma model looks like:

In [None]:
%%autovisualize
gemma_model

Try clicking the triangle markers to explore the structure of Gemma and look at some of the parameters!

### Converting Gemma

Now we can freeze its parameters and LoRA-ify its linear blocks in the same way that we did for the simple MLP.

The Penzai implementation of Gemma uses the same `Linear` layer to implement all of the learnable operations, in both the MLP blocks and the attention blocks. So we'll use a slightly-modified helper function that lets us be more specific about which `Linear` layers we want to replace.

In [None]:
def loraify_linears_in_selection(selection, rank: int):
  model = selection.deselect()
  return (
      selection
      .at_instances_of(pz.nn.Linear)
      .apply(
          lambda keypath, lin: LowRankAdapter.from_linear(
              lin,
              rank=rank,
              name=pz.pretty_keystr(keypath, model),
          ),
          with_keypath=True,
      )
  )

Now we go through and apply each of the transformation steps:

In [None]:
# Step 1: Freeze the pretrained parameters.
frozen_gemma_model = (
    pz.select(gemma_model)
    .at_instances_of(pz.nn.Parameter)
    .apply(
        lambda param: pz.nn.FrozenParameter(param.value, param.name)
    )
)

In [None]:
# Step 2: LoRA-ify the Linear blocks. Following Hu et al. (2021), we'll only
# LoRA-ify the attention parameters.
loraified_gemma_model_uninit = loraify_linears_in_selection(
    pz.select(frozen_gemma_model).at_instances_of(gemma.model_core.GemmaAttention),
    rank=16,
)

In [None]:
# Step 3: Initialize the new LoRA parameters.
loraified_gemma_model = pz.nn.initialize_parameters(
    loraified_gemma_model_uninit, jax.random.key(123)
)

In [None]:
# Step 4 (optional): Look at it to make sure the transformation looks right.
pz.select(loraified_gemma_model).at_instances_of(LowRankAdapter).show_selection()

If we wanted, we could have just as easily adapted the MLP layers, by changing
```
.at_instances_of(gemma.model_core.GemmaAttention)
```
to
```
.at_instances_of(gemma.model_core.GemmaFeedForward)
```
We could have also customized which blocks have LoRA parameters by using Penzai's selector system (see the separate [selectors tutorial](selectors.ipynb) for more details).

### Fine-tuning Gemma with LoRA

We can now fine-tune our LoRA-ified Gemma model, with full control over the training loop.

For this tutorial, we'll just generate some synthetic data. Specifically, we'll show it some examples of evaluating a mysterious function, and train it to figure out what the function does. We won't worry too much about efficiency of the data pipeline, since our goal is just to show how LoRA fine-tuning could work.

In [None]:
def mystery_function(a, b):
  return a + b

In [None]:
def generate_example(np_rng):
  a, b = np_rng.choice(1000, size=(2,))
  c = mystery_function(a, b)
  return f">>> mystery_function({a}, {b})\n{c}"

In [None]:
def tokenize_batch(examples, pad_length=32, include_eos=True):
  padded_tokens = []
  for example in examples:
    example_tokens = [vocab.bos_id()] + vocab.EncodeAsIds(example)
    if include_eos:
      example_tokens = example_tokens + [vocab.eos_id()]
    assert len(example_tokens) <= pad_length
    example_tokens = example_tokens + [vocab.pad_id()] * (pad_length - len(example_tokens))
    padded_tokens.append(example_tokens)
  return pz.nx.wrap(jnp.array(padded_tokens)).tag("batch", "seq")

Penzai has some useful utilities for visualizing token arrays:

In [None]:
%%autovisualize pz.ts.ArrayAutovisualizer.for_tokenizer(vocab)
np_rng = np.random.default_rng(123)
input_examples = tokenize_batch([generate_example(np_rng) for _ in range(20)])
input_examples

In [None]:
token_visualization.show_token_array(input_examples, vocab)

Let's train our new parameters on this data:

In [None]:
def xent_loss_fn(model, rng, state, input_examples):
  del rng, state  # Unused.
  # Run the model on shifted examples.
  # `GemmaInputs.from_basic_segments` is responsible for building the causal
  # attention mask and setting up positional embeddings.
  outputs = model(gemma.model_core.GemmaInputs.from_basic_segments(
      input_examples[{"seq": pz.slice[:-1]}]
  ))
  # Compute log-probabilities along the "vocabulary" axis.
  all_log_probs = pz.nx.nmap(jax.nn.log_softmax)(
      outputs.untag("vocabulary")
  ).tag("vocabulary")
  # Index by the correct tokens.
  correct_next_tokens = input_examples[{"seq": pz.slice[1:]}]
  correct_log_probs = all_log_probs[{"vocabulary": correct_next_tokens}]
  # Mask padding tokens.
  correct_log_probs = pz.nx.nmap(jnp.where)(
      correct_next_tokens == vocab.pad_id(),
      0.0,
      correct_log_probs,
  )
  # Take averages.
  loss = -correct_log_probs.untag("batch", "seq").unwrap().mean()
  return loss, None, {"loss": loss}

In [None]:
train_step = basic_training.build_train_step_fn(xent_loss_fn, donate_params_and_state=True)
train_state = basic_training.TrainState.initial_state(
    model=loraified_gemma_model,
    optimizer_def=optax.adamw(5e-5, weight_decay=0.01),
    root_rng=jax.random.key(42),
)
np_rng = np.random.default_rng(123)

In [None]:
# Train on 200 batches of 16 examples -> 3,200 examples
# (For reference, there are 1000 * 1000 = 1,000,000 possible examples in the
# synthetic distribution we are using.)
print_steps = {*range(10), *range(10, 200, 10)}
while train_state.step < 200:
  input_examples = tokenize_batch([
      generate_example(np_rng) for _ in range(16)
  ])
  train_state, out = train_step(train_state, input_examples=input_examples)
  if train_state.step in print_steps:
    print(train_state.step, out)

To see what the model learned, we can pull out the model from the train state and look at its parameters. In this case, all of the learnable parameters were added by our LoRA adapter.

We'll turn on the autovisualizer so that we can see the distribution of values in the arrays at a glance; try clicking on a few to expand their visualizations.

In [None]:
%%autovisualize
pz.select(train_state.model).at_instances_of(pz.nn.Parameter).get_sequence()

Recall that we initialized all of the "B" matrices to zero. So the fact that they are no longer zero indicates that the model has definitely learned something!

But has it learned what we wanted? Let's try running it on a randomly sampled batch of examples.

In [None]:
%%autovisualize pz.ts.ArrayAutovisualizer.for_tokenizer(vocab)
np_rng = np.random.default_rng(98765)
validation_examples = tokenize_batch([generate_example(np_rng) for _ in range(32)])
validation_examples

In [None]:
token_visualization.show_token_array(validation_examples, vocab)

In [None]:
outputs = train_state.model(gemma.model_core.GemmaInputs.from_basic_segments(
    validation_examples[{"seq": pz.slice[:-1]}]
))

In [None]:
# Compute log-probabilities along the "vocabulary" axis.
all_log_probs = pz.nx.nmap(jax.nn.log_softmax)(
    outputs.untag("vocabulary")
).tag("vocabulary")

# Index by the correct tokens.
correct_next_tokens = validation_examples[{"seq": pz.slice[1:]}]
correct_log_probs = all_log_probs[{"vocabulary": correct_next_tokens}]

# Plot the probability of the correct digit.
# This uses the same renderer as %%autovisualize, but doesn't truncate the array
# and lets us mask out elements.
pz.ts.render_array(
    pz.nx.nmap(jnp.exp)(correct_log_probs),
    valid_mask=(correct_next_tokens != vocab.pad_id()),
)

We can see that the model is predicting the arguments to `mystery_function` with about 10% accuracy, which is reasonable because those digits are random. It also seems to be almost perfectly accurate on the answers, indicating that it has successfully fit the distribution.

## Running inference on our LoRA-ified model

Now that we've fine-tuned the model, we can convert it into decoding mode and sample from it.

In Penzai, autoregressive decoding is performed by a separate class `GemmaKVCachingTransformer`, instead of being an alternative mode of `GemmaTransformer`. This is an instance of a more general pattern in Penzai models: each model and layer does a single thing at runtime, instead of doing different things depending on what arguments you pass. In fact, idiomatic Penzai layers always define a single function `__call__`, and that function always takes a single argument (although that argument can be a dictionary or tuple if needed). This makes it easy to compose many layers together in a uniform way without having to worry about how to handle function arguments.

The decoding mode transformation is actually very similar to the LoRA adaptation transformation we defined above. Instead of replacing `Linear` blocks with new `LowRankAdapter` blocks (which have new parameters), this transformation replaces `Attention` blocks with `KVCachingAttention` blocks (which have new state variables).

Since the key-value caching for Gemma is itself implemented as a patching transformation, this means that key-value caching can be immediately applied to our final `train_state.model` even though we've already edited the model structure to add new adapted parameters. Our modifications don't conflict with the attention block structure, so the modifications can be easily composed.

Here's how we can enable decoding mode:

In [None]:
finetuned_inference_model, initial_inference_state = (
  gemma.sampling_mode.GemmaKVCachingTransformer.from_uncached(
      train_state.model,
      cache_len=64,
      batch_axes={"batch": 4},
  )
)

Let's look inside to see the changes:

In [None]:
# You can use a function to pick out an initial node to expand in the
# visualization. (You can also copy such a function by clicking the grey copy
# icon at the end of each line.)
pz.select(finetuned_inference_model).at(
    (lambda root: root.body.body.body.body.body.sublayers[5].sublayers[0].delta.sublayers[1].kv_cache)
).show_value()

The `LowRankAdapter` classes we inserted are still there in the model, but there have been a few other changes to the model structure:
- The outermost class is of a different type `GemmaKVCachingTransformer`.
- Inside it, there's a new `WithFunctionalLocalState` wrapper, which is responsible for managing the key-value caches, and a new `WithSideInputsFromInputTuple` wrapper that manages the current decoding position.
- Inside each of the transformer blocks, the `GemmaAttention` layers have been replaced with new `GemmaKVCachingAttention` layers that point back to these two new wrappers.


Now that we've converted the model, we can use some existing helper functions to sample from it. (We discuss the decoding mode and helper functions more in the separate ["Gemma From Scratch" tutorial](gemma_from_scratch.ipynb). We'll wrap our model in `Jitted` so that it JIT-compiles itself whenever it is called.

In [None]:
prompts = [
    ">>> mystery_function(123, 123)",
    ">>> mystery_function(101, 15)",
    ">>> mystery_function(999, 876)",
    ">>>", # Let the model write and solve its own problem
]

In [None]:
%%autovisualize pz.ts.ArrayAutovisualizer.for_tokenizer(vocab)
tokenized_prompts = tokenize_batch(prompts, 16, include_eos=False)
tokenized_prompts

In [None]:
samples = gemma.simple_decoding_loop.temperature_sample_pyloop(
    jit_wrapper.Jitted(finetuned_inference_model),
    initial_inference_state,
    prompt=tokenized_prompts,
    rng=jax.random.key(3),
    pad_id=vocab.pad_id(),
    max_sampling_steps=20,
)

In [None]:
%%autovisualize pz.ts.ArrayAutovisualizer.for_tokenizer(vocab)
pz.show(samples)
token_visualization.show_token_array(samples, vocab)

As desired, our fine-tuned model seems to have learned the behavior of `mystery_function` using the low-rank updates to its parameters.

## Recap

This notebook demonstrates how Penzai makes it easy to edit the structure of a pretrained model without requiring any changes to the original model implementation. Our `LowRankAdapter` class and associated utilities took less than a hundred lines of code, and were immediately compatible with the pretrained Gemma 7B model, including both training and sampling modes.

The definitions in this notebook are also available in `penzai.toolshed.lora`, and can be imported from there if you are interested in using Penzai to perform parameter-efficient fine-tuning.

However, LoRA is just one example of what you can do with Penzai's powerful patching and model rewriting utilities. The key-value caching transformation is another, which we discuss in the ["Gemma From Scratch" tutorial](gemma_from_scratch.ipynb). And these tools can also be used to study intermediate activations and perform targeted counterfactual interventions to specific layers in the model, which we discuss in the ["Induction Heads" tutorial](induction_heads.ipynb). Penzai is designed to simplify the general process of editing, visualizing, and analyzing pretrained models; the goal is not to implement every possible type of fine-tuning or patching, but instead to give you powerful general-purpose tools and then get out of your way.
