*Copyright 2024 The Penzai Authors.*

*Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
You may obtain a copy of the License at*

> http://www.apache.org/licenses/LICENSE-2.0

*Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
See the License for the specific language governing permissions and
limitations under the License.*

---

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/penzai/blob/main/notebooks/lora_from_scratch.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google-deepmind/penzai/blob/main/notebooks/lora_from_scratch.ipynb)

# LoRA From Scratch - Patching Pretrained Models in Penzai

Penzai is designed to make it easy to make targeted modifications to neural networks after they have been trained. In this notebook, we'll show how to take Penzai's reference implementation of [Gemma 7B](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) open-weights transformer model, patch it to support Low-Rank Adaptation (LoRA [Hu et al. 2021](https://arxiv.org/abs/2106.09685)), and train the new parameters on a toy problem with a hand-written loss function.

The goal of this notebook is to show how *you* could implement something like LoRA from scratch in less than a hundred lines of code, starting from a Penzai implementation of a model that doesn't support it already, and without having to fork the existing implementation source code or even modify the pretrained model's configuration. We'll define everything we need as we go and make changes to models interactively. In fact, our implementation will end up being completely modular; we'll start by applying LoRA to a small MLP and then immediately be able to transfer our implementation to Gemma 7B.

Let's get started!

```{note}
This tutorial uses the V2 neural network API, defined in `pz.experimental.v2`.
```

## Setup

Before we can get started in earnest, we need to set up the environment.

### Imports

To run this notebook, you need a Python environment with `penzai` and its dependencies installed.

In Colab or Kaggle, you can install it using the following command:

In [None]:
try:
  import penzai
except ImportError:
  !pip install penzai[notebook]

In [None]:
from __future__ import annotations

In [None]:
import os
import gc

import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint
import optax
from jax.experimental import mesh_utils

In [None]:
import treescope
import penzai
from penzai import pz

In [None]:
import sentencepiece as spm

In [None]:
from penzai.models import transformer
from penzai.models import simple_mlp
from penzai.toolshed import token_visualization
from penzai.toolshed import basic_training
from penzai.toolshed import jit_wrapper

### Setting up Penzai

For this tutorial, we'll enable [Treescope](https://treescope.readthedocs.io/en/stable/) (Penzai's companion pretty-printer) as the default IPython pretty-printer. This is recommended when using Penzai in an interactive environment.

In [None]:
treescope.basic_interactive_setup(autovisualize_arrays=False)

## Intro to Penzai's declarative combinator design

We'll start by giving a brief introduction to Penzai's design conventions, and how they make it easy to insert adapters into pretrained models. Let's begin by initializing a small MLP:

In [None]:
mlp = simple_mlp.MLP.from_config(
    name="mlp",
    init_base_rng=jax.random.key(0),
    feature_sizes=[8, 32, 32, 8],
)

Like most Penzai models and layers, this MLP takes named arrays as input and returns them as output. A named array is just a wrapped JAX array where a subset of its positional axes have been tagged with names. (See the [named axes tutorial](named_axes.ipynb) for more info on how to use Penzai's named axis system.)

We can call the MLP directly on an array of inputs to run it:

In [None]:
%%autovisualize
mlp(pz.nx.NamedArray.wrap(jnp.arange(8, dtype=jnp.float32)).tag("features"))

Penzai models are written in a *declarative*, *combinator*-based style. This means that the structure of the model directly matches the sequence of high-level operations that the model will run in its forward pass. Composite models, like our MLP, just hold onto their sublayers in a list and run these sublayers in order. Primitive layers, like `Linear`, hold on to their parameters as attributes instead of reading them from an external parameter dictionary.

We can see the sublayers by pretty-printing the model:

In [None]:
%%autovisualize
mlp

By convention, most of the "complicated" logic in Penzai model classes happens when we initialize them, using the `.from_config` method we called earlier. Once the model is built, the pretty-printed representation provides a full specification of everything the model does, and the parameters are stored as direct attributes on the layers that need them. A general design principle of Penzai is "*what you see is what you get*"; you should be able to learn everything you need to know about a model by printing it out.

In fact, you can click on a pretty-printed output and press `r` to add qualified names to the pretty-printed visualization (try it above!), which will tell you exactly what type each layer has. (If you remove the parameters first using `pz.unbind_params`, you can even copy and paste the pretty-printed output to rebuild the model structure!)

Note that many classes are annotated with "Sequential", which means they are just an informatively-named sequence of other layers that run one after another. You can also "flatten" a model into a list of sublayers that run in sequence, discarding this extra information:

In [None]:
pz.nn.inline_groups(pz.nn.Sequential([mlp]), lambda _: True, lambda _: True)

And you can freely add new logic as well, even if it wasn't configured in the initial model. For instance, here's how you could insert a new layer that prints out its intermediate activation:

In [None]:
@pz.pytree_dataclass  # <- This tags our class as being a Python dataclass and a JAX pytree node.
class DisplayIntermediateValue(pz.nn.Layer):  # <- pz.nn.Layer is the base class of Penzai layers.

  def __call__(self, intermediate_value, **unused_side_inputs):
    # Show the value:
    pz.show("Showing an intermediate value:", intermediate_value)
    # And return it unchanged.
    return intermediate_value

In [None]:
patched = (
    pz.select(mlp)
    .at(lambda model: model.sublayers[2])
    .insert_after(DisplayIntermediateValue())
)
pz.select(patched).at_instances_of(DisplayIntermediateValue).show_selection()

`patched` is a *copy* of our model that includes our new layer, and it will run our new logic when the model is called:

In [None]:
%%autovisualize
patched(pz.nx.NamedArray.wrap(jnp.arange(8, dtype=jnp.float32)).tag("features"))

This ability makes it remarkably easy to implement adapters like LoRA!

## Building a simple LoRA Layer in Penzai

Low-Rank Adaptation (LoRA) is a parameter-efficient fine-tuning strategy that augments each linear operation in the model with a decomposed low-rank adapter. The original weight matrix is frozen, and two smaller learnable parameter matrices are used to perturb its output. These parameters are kept separate from the original matrix, so gradients of these new parameters can be easily updated in a compute- and memory-efficient way.

The effective weight matrix can be decomposed like this:

```
 ┌────────────────┐       ┌─────┐                      
 │                │       │     │                      
 │                │       │  A: │   ┌────────────────┐
 │    W: d*d      │   +   │ d*r │ * │     B: r*d     │
 │                │       │     │   └────────────────┘
 │                │       │     │                      
 └────────────────┘       └─────┘                      
```

Here `W` is the original frozen weight matrix, `A` is a randomly-initialized matrix, and `B` is initialized to zero to ensure that the adapted model is equivalent to the original one at initialization.

To enable LoRA, we'll do three things for each linear layer in our model:
- Freeze the original weight,
- Initialize our low-rank matrices A and B,
- And replace the original linear layer with the composition of W, A, and B.

Let's try it out with a simple MLP like the one we built in the last section. We'll just randomly initialize one for demonstration purposes; in a real LoRA adaptation setting we would generally load this from a pre-trained model checkpoint.

In [None]:
mlp = simple_mlp.MLP.from_config(
    name="mlp",
    init_base_rng=jax.random.key(0),
    feature_sizes=[2, 32, 32, 2],
)

### Step 1: Freeze parameters

We'll start by freezing all the parameters. Learnable parameters are identifiable because they are instances of `pz.Parameter`:

In [None]:
pz.select(mlp).at_instances_of(pz.Parameter).show_selection()

In this case, the parameters are also the JAX PyTree leaves of the model. This is because they are mutable objects, and are designed to be updated by optimizers.

In [None]:
jax.tree_util.tree_leaves(mlp)

If needed, we can extract these parameters while safely handling repeated parameters using the function `pz.unbind_params`:

In [None]:
mlp_with_slots, params = pz.unbind_params(mlp)
pz.show("mlp_with_slots:", mlp_with_slots)
pz.show("params:", params)

In this case, however, we just need to "freeze" the parameters, which makes them immutable. We can do this using `pz.freeze_params`:

In [None]:
frozen_mlp = pz.freeze_params(mlp)
frozen_mlp

In [None]:
# No more parameters:
pz.select(frozen_mlp).at_instances_of(pz.Parameter).get_sequence()

In [None]:
# Leaves are now ordinary JAX arrays:
jax.tree_util.tree_leaves(frozen_mlp)

### Step 2: Replace `Linear` layers with low-rank adapted versions

Next, we'll replace the Linear layers with implementations of LoRA.

In essence, a LoRA block is a sum of two computation paths: one that uses the original linear layer, and one that uses a sequence of two linear operations. This pattern can be directly mapped to one of Penzai's simple built-in combinators, `BranchAndAddTogether`. We can take each linear layer, like this one:

In [None]:
frozen_mlp.sublayers[0].sublayers[0]

And replace it with a block like this:

In [None]:
pz.nn.BranchAndAddTogether([
    # The original layer with frozen parameters:
    pz.nn.NamedGroup("Pretrained", [
        frozen_mlp.sublayers[0].sublayers[0],
    ]),
    # And a low-rank adapter:
    pz.nn.NamedGroup("Update", [
        pz.nn.Linear.from_config(
            name="LoRA-A",
            init_base_rng=jax.random.key(1),
            input_axes={"features": 8},
            output_axes={"lowrank": 2},
        ),
        pz.nn.Linear.from_config(
            name="LoRA-B",
            init_base_rng=jax.random.key(1),
            input_axes={"lowrank": 2},
            output_axes={"features_out": 8},
            initializer=pz.nn.zero_initializer,
        ),
    ]),
])

Note that the above code is a direct translation of a LoRA block into the structure of our model. The matrices A and B are represented as separate Linear blocks inside the overall combinator, and the order of execution is determined by the positions in the `NamedGroup`.

To simplify the process of making this transformation at every Linear block, we can encapsulate it into a new Layer subclass. Since the computation can already be written as a combination of existing pieces, the idiomatic Penzai approach is to define our new Layer as a subclass of `pz.nn.Sequential`, so that it can be easily flattened (like we did with the MLP) id needed. `Sequential` already defines the necessary attributes and `__call__` method, so we just need to provide a named initializer:

In [None]:
@pz.pytree_dataclass(has_implicitly_inherited_fields=True)
class LowRankAdapter(pz.nn.Sequential):

  @classmethod
  def from_linear(
      cls,
      linear: pz.nn.Linear,
      name: str,
      init_base_rng: jax.Array | None,
      rank: int,
      lowrank_axis: str = "lowrank",
  ) -> 'LowRankAdapter':
    """Builds a LoRA layer from a Linear layer.

    Args:
      linear: The linear layer to adapt.
      name: Name for this layer's parameters. Must be globally unique across all
        LoRA blocks; we recommend using `jax.tree_util.keystr` or
        `pz.pretty_keystr` and setting the name based on the path to the
        original Linear layer being replaced.
      init_base_rng: The base RNG to use for initializing model parameters.
      rank: The rank of the low-rank adapter.
      lowrank_axis: The axis name for low-rank adaptation.

    Returns:
      A LoRA block with uninitialized parameters and the same initial
      behavior as `linear`.
    """
    return cls([
        pz.nn.BranchAndAddTogether([
            pz.nn.NamedGroup("Pretrained", [linear]),
            pz.nn.NamedGroup(
                "Update",
                [
                    pz.nn.Linear.from_config(
                        name=f"{name}/LoRA_A",
                        init_base_rng=init_base_rng,
                        input_axes=linear.input_axes,
                        output_axes={lowrank_axis: rank},
                        parallel_axes=linear.parallel_axes,
                    ),
                    pz.nn.Linear.from_config(
                        name=f"{name}/LoRA_B",
                        init_base_rng=init_base_rng,
                        input_axes={lowrank_axis: rank},
                        output_axes=linear.output_axes,
                        parallel_axes=linear.parallel_axes,
                        initializer=pz.nn.zero_initializer,
                    ),
                ],
            ),
        ])
    ])

Note: Idiomatic Penzai layers generally avoid overriding `__init__`, since dataclasses take their attributes as parameters to `__init__` and we want to ensure the output of the pretty-printer directly corresponds to code we could use to rebuild the model even if we've modified its attributes. When we have nontrivial construction logic, we'll usually define it in a class method like `from_linear` or `from_config` instead.

Layer constructors are generally responsible for ensuring their parameter names are unique within a model, and for initializing their parameters when constructed. For this reason, most layer constructors take arguments `name` and `init_base_rng`. (Note that the name is combined with the RNG when initializing each parameter, so we don't need to manually split the RNGs.)



The next step is to write a helper function for inserting LoRA blocks into a model. We'll use Penzai's `pretty_keystr` function (a fancier version of `jax.tree_util.keystr`) to ensure each block has a unique name:

In [None]:
def loraify_all_linears(model, rank: int, init_base_rng):
  return (
      pz.select(model)
      .at_instances_of(pz.nn.Linear)
      .apply(
          lambda keypath, lin: LowRankAdapter.from_linear(
              lin,
              name="LoRA:" + pz.pretty_keystr(keypath, model),
              init_base_rng=init_base_rng,
              rank=rank,
          ),
          with_keypath=True,
      )
  )

Now we can run it on our MLP:

In [None]:
loraified_mlp = loraify_all_linears(
    frozen_mlp, rank=2, init_base_rng=jax.random.key(42)
)
loraified_mlp

You can directly check that this transformation is doing the right thing by expanding each `Affine` layer and making sure the `LowRankAdapter` looks right.

Note that `loraified_mlp_uninit` is a *copy* of `frozen_mlp` with the requested modifications. In Penzai, transformations of models always return new copies of the model, so you don't have to worry about accidentally making an irreversible change.

Only the model *structure* is copied; the JAX arrays still share memory between the models, and any mutable parameters in the original model will also be shared with the new one. In this case, though, we froze the parameters of `frozen_mlp` first, so only the new parameters are mutable:

In [None]:
pz.select(loraified_mlp).at_instances_of(pz.Parameter).get_sequence()

### Step 3: Training the LoRA weights

We can now train these adapter parameters using Penzai's basic training loop helpers, or use a custom training loop for them. As a demonstration, we'll train this model to implement XOR by only fitting the low-rank adapter parameters.

In [None]:

def loss_fn(model, rng, state, example_inputs, example_labels):
  assert state is None
  model_out = model(example_inputs)
  log_probs = jax.nn.log_softmax(
      model_out.unwrap("batch", "features"), axis=-1
  )
  losses = -log_probs * example_labels
  loss = jnp.sum(losses) / 4
  return loss, None, {"loss": loss}

In [None]:
trainer = basic_training.StatefulTrainer.build(
    model=loraified_mlp,
    optimizer_def=optax.adam(0.1),
    root_rng=jax.random.key(42),
    loss_fn=loss_fn
)

In [None]:
trainer

In [None]:
xor_inputs = pz.nx.wrap(
    jnp.array([[-1, -1], [-1, 1], [1, -1], [1, 1]], dtype=jnp.float32),
    "batch",
    "features",
)
xor_labels = jnp.array([[0, 1], [1, 0], [1, 0], [0, 1]], dtype=jnp.float32)

for i in range(20):
  out = trainer.step(example_inputs=xor_inputs, example_labels=xor_labels)
  print(i, out)

The parameters in the model will be updated in place, allowing us to use the trained model:

In [None]:
%%autovisualize
loraified_mlp(xor_inputs)

In [None]:
%%autovisualize
pz.nx.nmap(jnp.argmax)(loraified_mlp(xor_inputs).untag("features"))

Looks like it worked!

(Note: If you prefer a "functional" training loop, you can extract an immutable version of your parameters by calling `pz.unbind_params(loraified_mlp, frozen=True)`, update them yourself, then substitute the immutable parameters back in using `pz.bind_variables`.)

## Adding LoRA to Gemma

Let's now try adding LoRA to the Gemma 7B pretrained model. Because of Penzai's compositional design, the implementation in the previous section will just work out of the box!

### Loading Gemma

We'll start by loading the weights from the Gemma checkpoint. We'll use the 7B checkpoint for this tutorial, and shard it over our local devices using JAX's automatic partitioning. (You can read more about JAX's automatic distributed arrays [on this JAX documentation page](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html).)

If you prefer, you can also run this tutorial with the 2B checkpoint.

You can download the Gemma checkpoints using a Kaggle account and an API key. If you don't have an API key already, you can:

1. Visit https://www.kaggle.com/ and create an account if needed.
2. Go to your account settings, then the 'API' section.
3. Click 'Create new token' to download your key.

Next, if you are running this notebook in Google Colab:

1. Click the "key" symbol on the left toolbar to open the "Secrets" tab.
2. Add two new secrets, named "KAGGLE_USERNAME" and "KAGGLE_KEY", and set their values based on the API key you downloaded.
3. Run the cell below and grant this notebook access to the secrets you just made.

If you are not running this notebook in Google Colab, you can instead run the cell below, input your username and API key in the textboxes, and click the login button.

In [None]:
import kagglehub
try:
  from google.colab import userdata
  kagglehub.config.set_kaggle_credentials(
      userdata.get("KAGGLE_USERNAME"), userdata.get("KAGGLE_KEY")
  )
except ImportError:
  kagglehub.login()

If everything went well, you should see:

```
Kaggle credentials set.
```

Before downloading Gemma, you will also need to consent to the Gemma Terms of Use. If you haven't done that yet, you can do so here:

> https://www.kaggle.com/models/google/gemma/license/consent

(Make sure you choose to "Verify via Kaggle Account" with the same account you used to log in above!)

Once you've agreed to the terms, you can run the next cell to download the Gemma weights:

In [None]:
weights_dir = kagglehub.model_download('google/gemma/Flax/7b')
ckpt_path = os.path.join(weights_dir, '7b')
vocab_path = os.path.join(weights_dir, 'tokenizer.model')

We can then load the SentencePiece vocabulary and restore the checkpointed parameters into JAX using `orbax`:

In [None]:
vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)

In [None]:
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
metadata = checkpointer.metadata(ckpt_path)

In [None]:
n_devices = jax.local_device_count()
sharding_devices = mesh_utils.create_device_mesh((n_devices,))
sharding = jax.sharding.PositionalSharding(sharding_devices)
restore_args = jax.tree_util.tree_map(
    lambda m: orbax.checkpoint.ArrayRestoreArgs(
        restore_type=jax.Array,
        sharding=sharding.reshape((1,) * (len(m.shape) - 1) + (n_devices,))
    ),
    metadata,
)
flat_params = checkpointer.restore(ckpt_path, restore_args=restore_args)

In [None]:
gemma_model = transformer.variants.gemma.gemma_from_pretrained_checkpoint(
    flat_params,
    upcast_activations_to_float32=True,
)

In [None]:
del flat_params
gc.collect()

Here's what the Gemma model looks like:

In [None]:
%%autovisualize
gemma_model

Try clicking the triangle markers to explore the structure of Gemma and look at some of the parameters!

### Converting Gemma

Now we can freeze its parameters and LoRA-ify its linear blocks in the same way that we did for the simple MLP.

The Penzai implementation of Gemma uses the same `Linear` layer to implement all of the learnable operations, in both the MLP blocks and the attention blocks. So we'll use a slightly-modified helper function that lets us be more specific about which `Linear` layers we want to replace.

In [None]:
def loraify_linears_in_selection(
    selection, rank: int, init_base_rng: jax.Array | None,
):
  model = selection.deselect()
  return selection.at_instances_of(pz.nn.Linear).apply(
      lambda keypath, lin: LowRankAdapter.from_linear(
          lin,
          name="LoRA:" + pz.pretty_keystr(keypath, model),
          init_base_rng=init_base_rng,
          rank=rank,
      ),
      with_keypath=True,
  )

Now we go through and apply each of the transformation steps:

In [None]:
# Step 1: Freeze the pretrained parameters.
frozen_gemma_model = pz.freeze_params(gemma_model)

In [None]:
# Step 2: LoRA-ify the Linear blocks. Following Hu et al. (2021), we'll only
# LoRA-ify the attention parameters.
loraified_gemma_model = loraify_linears_in_selection(
    pz.select(frozen_gemma_model).at_instances_of(pz.nn.Attention),
    rank=16,
    init_base_rng=jax.random.key(123),
)

In [None]:
# Step 3 (optional): Look at it to make sure the transformation looks right.
pz.select(loraified_gemma_model).at_instances_of(LowRankAdapter).show_selection()

If we wanted, we could have just as easily adapted the MLP layers, by changing
```
.at_instances_of(gemma.model_core.GemmaAttention)
```
to
```
.at_instances_of(gemma.model_core.GemmaFeedForward)
```
We could have also customized which blocks have LoRA parameters by using other features of Penzai's selector system (see the separate [selectors tutorial](selectors.ipynb) for more details).

### Fine-tuning Gemma with LoRA

We can now fine-tune our LoRA-ified Gemma model! For this tutorial, we'll just generate some synthetic data. Specifically, we'll show it some examples of evaluating a mysterious function, and train it to figure out what the function does. We won't worry too much about efficiency of the data pipeline, since our goal is just to show how LoRA fine-tuning could work.

In [None]:
def mystery_function(a, b):
  return a + b

In [None]:
def generate_example(np_rng):
  a, b = np_rng.choice(1000, size=(2,))
  c = mystery_function(a, b)
  return f">>> mystery_function({a}, {b})\n{c}"

In [None]:
def tokenize_batch(examples, pad_length=32, include_eos=True):
  padded_tokens = []
  for example in examples:
    example_tokens = [vocab.bos_id()] + vocab.EncodeAsIds(example)
    if include_eos:
      example_tokens = example_tokens + [vocab.eos_id()]
    assert len(example_tokens) <= pad_length
    # Pad from the right (simplifies input positional embeddings)
    example_tokens = (
        example_tokens
        + [vocab.pad_id()] * (pad_length - len(example_tokens))
    )
    padded_tokens.append(example_tokens)
  return pz.nx.wrap(jnp.array(padded_tokens)).tag("batch", "seq")

Penzai has some useful utilities for visualizing token arrays:

In [None]:
%%autovisualize treescope.ArrayAutovisualizer.for_tokenizer(vocab)
np_rng = np.random.default_rng(123)
input_examples = tokenize_batch([generate_example(np_rng) for _ in range(20)])
input_examples

In [None]:
token_visualization.show_token_array(input_examples, vocab)

Let's train our new parameters on this data:

In [None]:
def xent_loss_fn(model, rng, state, input_examples):
  del rng, state  # Unused.
  # Run the model on shifted examples.
  tokens_without_last = input_examples[{"seq": pz.slice[:-1]}]
  outputs = model(tokens_without_last)
  # Compute log-probabilities along the "vocabulary" axis.
  all_log_probs = pz.nx.nmap(jax.nn.log_softmax)(
      outputs.untag("vocabulary")
  ).tag("vocabulary")
  # Index by the correct tokens.
  correct_next_tokens = input_examples[{"seq": pz.slice[1:]}]
  correct_log_probs = all_log_probs[{"vocabulary": correct_next_tokens}]
  # Mask padding tokens.
  correct_log_probs = pz.nx.nmap(jnp.where)(
      correct_next_tokens == vocab.pad_id(),
      0.0,
      correct_log_probs,
  )
  # Take averages.
  loss = -correct_log_probs.untag("batch", "seq").unwrap().mean()
  return loss, None, {"loss": loss}

In [None]:
trainer = basic_training.StatefulTrainer.build(
    model=loraified_gemma_model,
    optimizer_def=optax.adamw(5e-5, weight_decay=0.01),
    root_rng=jax.random.key(42),
    loss_fn=xent_loss_fn,
    donate_states=True,
)

In [None]:
# Train on 200 batches of 16 examples -> 3,200 examples
# (For reference, there are 1000 * 1000 = 1,000,000 possible examples in the
# synthetic distribution we are using.)
print_steps = {*range(10), *range(10, 200, 10)}
while trainer.state.value.step < 200:
  input_examples = tokenize_batch([
      generate_example(np_rng) for _ in range(16)
  ])
  out = trainer.step(input_examples=input_examples)
  if int(trainer.state.value.step) in print_steps:
    print(trainer.state.value.step, out)

To see what the model learned, we can pull out the model from the train state and look at its parameters. In this case, all of the learnable parameters were added by our LoRA adapter.

We'll turn on the autovisualizer so that we can see the distribution of values in the arrays at a glance; try clicking on a few to expand their visualizations.

In [None]:
%%autovisualize
_, params = pz.unbind_params(loraified_gemma_model)
params

Recall that we initialized all of the "B" matrices to zero. So the fact that they are no longer zero indicates that the model has definitely learned something!

But has it learned what we wanted? Let's try running it on a randomly sampled batch of examples.

In [None]:
%%autovisualize treescope.ArrayAutovisualizer.for_tokenizer(vocab)
np_rng = np.random.default_rng(98765)
validation_examples = tokenize_batch([generate_example(np_rng) for _ in range(32)])
validation_examples

In [None]:
token_visualization.show_token_array(validation_examples, vocab)

In [None]:
tokens_without_last = validation_examples[{"seq": pz.slice[:-1]}]
outputs = loraified_gemma_model(tokens_without_last)

In [None]:
# Compute log-probabilities along the "vocabulary" axis.
all_log_probs = pz.nx.nmap(jax.nn.log_softmax)(
    outputs.untag("vocabulary")
).tag("vocabulary")

# Index by the correct tokens.
correct_next_tokens = validation_examples[{"seq": pz.slice[1:]}]
correct_log_probs = all_log_probs[{"vocabulary": correct_next_tokens}]

# Plot the probability of the correct digit.
# This uses the same renderer as %%autovisualize, but doesn't truncate the array
# and lets us mask out elements.
treescope.render_array(
    pz.nx.nmap(jnp.exp)(correct_log_probs),
    valid_mask=(correct_next_tokens != vocab.pad_id()),
)

We can see that the model is predicting the arguments to `mystery_function` with about 10% accuracy, which is reasonable because those digits are random. It also seems to be almost perfectly accurate on the answers, indicating that it has successfully fit the distribution.

## Running inference on our LoRA-ified model

Now that we've fine-tuned the model, we can convert it into decoding mode and sample from it.

In Penzai, autoregressive decoding is performed by a separate class `KVCachingTransformerLM`, instead of being an alternative mode of `Transformer`. This is an instance of a more general pattern in Penzai models: each model and layer does a single thing at runtime, instead of doing different things depending on what arguments you pass. In fact, idiomatic Penzai layers always define a single function `__call__`, and that function always takes a single positional argument (although that argument can be a dictionary or tuple if needed) along with keyword "side inputs". This makes it easy to compose many layers together in a uniform way without having to worry about how to handle function arguments.

The decoding mode transformation is actually very similar to the LoRA adaptation transformation we defined above. Instead of replacing `Linear` blocks with new `LowRankAdapter` blocks (which have new parameters), this transformation replaces `Attention` blocks with `KVCachingAttention` blocks (which have new state variables).

Since the key-value caching for Gemma is itself implemented as a patching transformation, this means that key-value caching can be immediately applied to our final `train_state.model` even though we've already edited the model structure to add new adapted parameters. Our modifications don't conflict with the attention block structure, so the modifications can be easily composed.

Here's how we can enable decoding mode:

In [None]:
finetuned_inference_model = (
    transformer.sampling_mode.KVCachingTransformerLM.from_uncached(
        loraified_gemma_model, cache_len=64, batch_axes={"batch": 4},
    )
)

Let's look inside to see the changes:

In [None]:
# You can use a function to pick out an initial node to expand in the
# visualization. (You can also copy such a function by clicking the grey copy
# icon at the end of each line.)
pz.select(finetuned_inference_model).at(
    (lambda root: root.body.sublayers[5].sublayers[0].delta.sublayers[1].kv_cache)
).show_value()

The `LowRankAdapter` classes we inserted are still there in the model, but there have been a few other changes to the model structure:
- The outermost class is of a different type `KVCachingTransformer`.
- This outer class also has new attributes, which track metadata and state necessary for sampling.
- Inside each of the transformer blocks, the `Attention` layers have been replaced with new `KVCachingAttention` layers that also have internal state variables.

Like `Parameter` objects, `StateVariable` objects are mutable. The difference is that `StateVariable`s are intended to be modified while the model runs. You can unbind them using `pz.unbind_state_vars` and rebind them using `pz.bind_variables` (which works for both parameters and state variables).

Now that we've converted the model, we can use some existing helper functions to sample from it. We'll also wrap our model in `Jitted` so that it JIT-compiles itself whenever it is called.

In [None]:
prompts = [
    ">>> mystery_function(123, 123)",
    ">>> mystery_function(101, 15)",
    ">>> mystery_function(999, 876)",
    ">>>", # Let the model write and solve its own problem
]

In [None]:
%%autovisualize treescope.ArrayAutovisualizer.for_tokenizer(vocab)
tokenized_prompts = tokenize_batch(prompts, 16, include_eos=False)
tokenized_prompts

In [None]:
samples = transformer.simple_decoding_loop.temperature_sample_pyloop(
    (
        pz.select(finetuned_inference_model)
        .at(lambda root: root.body)
        .apply(jit_wrapper.Jitted)
    ),
    prompt=tokenized_prompts,
    rng=jax.random.key(3),
    max_sampling_steps=20,
)

In [None]:
%%autovisualize treescope.ArrayAutovisualizer.for_tokenizer(vocab)
pz.show(samples)
token_visualization.show_token_array(samples, vocab)

As desired, our fine-tuned model seems to have learned the behavior of `mystery_function` using the low-rank updates to its parameters.

## Recap

This notebook demonstrates how Penzai makes it easy to edit the structure of a pretrained model without requiring any changes to the original model implementation. Our `LowRankAdapter` class and associated utilities took less than a hundred lines of code, and were immediately compatible with the pretrained Gemma 7B model, including both training and sampling modes.

The definitions in this notebook are also available in `penzai.toolshed.lora`, and can be imported from there if you are interested in using Penzai to perform parameter-efficient fine-tuning.

However, LoRA is just one example of what you can do with Penzai's powerful patching and model rewriting utilities. The key-value caching transformation is another, which we've seen above. And these tools can also be used to study intermediate activations and perform targeted counterfactual interventions to specific layers in the model, which we discuss in the ["Induction Heads" tutorial](induction_heads.ipynb). Penzai is designed to simplify the general process of editing, visualizing, and analyzing pretrained models; the goal is not to implement every possible type of fine-tuning or patching, but instead to give you powerful general-purpose tools and then get out of your way.
