# LoRA example

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/lora.ipynb)

Example on using LoRA with Gemma (for both training and inference).

In [None]:
from etils import ecolab
import jax
import jax.numpy as jnp

# TODO(epot): Add open-source imports
with ecolab.adhoc():
  from gemma import gm
  from gemma import peft  # Parameter fine-tuning module

## Initializing the model

To use Gemma with LoRA, simply wrap any Gemma model in `gm.nn.LoRAWrapper`:

In [None]:
model = gm.nn.LoRAWrapper(
    rank=4,
    model=gm.nn.Gemma2_2B(),
)

Initialize the weights:

In [None]:
token_ids = jnp.zeros((1, 256,), dtype=jnp.int32)  # Create the (batch_size, seq_length)

params = model.init(
    jax.random.key(0),
    token_ids,
)

params = params['params']

Inspect the params shape/structure. We can see LoRA weights have been added.

In [None]:
with ecolab.collapse('Params'):
  # p: Pretty-print, s: Array specs, h: Syntax highlighting
  ecolab.disp(params, mode='psh')

Restore the pre-trained params. We use `peft.split_params` and `peft.merge_params` to replace the randomly initialized params with the pre-trained ones.

In [None]:
# Splits the params into non-LoRA and LoRA weights
original, lora = peft.split_params(params)

# Load the params from the checkpoint
# Providing the `params=original` ensure that:
# * The memory from the old params is released (so only a single copy of the
#   weights stays in memory)
# * The restored params reuse the same sharding as the input (here there's no
#   sharding, so isn't required)
original = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA2_2B_IT, params=original)

# Merge the pretrained params back with LoRA
params = peft.merge_params(original, lora)

## Fine-tuning

In [None]:
# TODO(epot)

## Inference

Here's an example of running a single model call:

In [None]:
tokenizer = gm.text.Gemma2Tokenizer()

prompt = tokenizer.encode('The capital of France is')
prompt = jnp.asarray([tokenizer.special_tokens.BOS] + prompt)


# Run the model
out = model.apply(
    {'params': params},
    tokens=prompt,
    return_last_only=True,  # Only predict the last token
)


# Show the token distribution
tokenizer.plot_logits(out.logits)

To sample an entire sentence:

In [None]:
sampler = gm.text.Sampler(
    model=model,
    params=params,
    tokenizer=tokenizer,
)

sampler.sample('The capital of France is', max_new_tokens=30)