# Finetuning

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/finetuning.ipynb)

This is an example on fine-tuning Gemma. For an example on how to run a pre-trained Gemma model, see the [sampling](https://github.com/google-deepmind/gemma/blob/main/docs/sampling.md) tutorial.

To fine-tune Gemma, we use the [kauldron](https://kauldron.readthedocs.io/en/latest/) library which abstract most of the boilerplate (checkpoint management, training loop, evaluation, metric reporting, sharding,...).


In [None]:
!pip install -q gemma

In [21]:
# Common imports
import optax
import treescope

# Gemma imports
from kauldron import kd
from gemma import gm

## Data pipeline

First create the tokenizer, as it's required in the data pipeline.

In [22]:
tokenizer = gm.text.Gemma2Tokenizer()

tokenizer.encode('This is an example sentence', add_bos=True)

[<_Gemma2SpecialTokens.BOS: 2>, 1596, 603, 671, 3287, 13060]

First we need a data pipeline. Multiple pipelines are supported including:

* [HuggingFace](https://kauldron.readthedocs.io/en/latest/api/kd/data/py/HuggingFace.html)
* [TFDS](https://kauldron.readthedocs.io/en/latest/api/kd/data/py/Tfds.html)
* [Json](https://kauldron.readthedocs.io/en/latest/api/kd/data/py/Json.html)
* ...

It's quite simple to add your own data, or to create mixtures from multiple sources. See the [pipeline documentation](https://kauldron.readthedocs.io/en/latest/data_py.html).

We use `transforms` to customize the data pipeline, this includes:

* Tokenizing the inputs (with `gm.data.Tokenize`)
* Creating the model inputs (with `gm.data.Tokenize`))
* Adding padding (with `gm.data.Pad`) (required to batch inputs with different lengths)

Note that in practice, you can combine multiple transforms into a higher level transform. See the `gm.data.AddContrastiveFields()` transform in the [DPO example](https://github.com/google-deepmind/gemma/tree/main/examples/dpo.py) for an example.

Here, we try [mtnt](https://www.tensorflow.org/datasets/catalog/mtnt), a small translation dataset. The dataset structure is `{'src': ..., 'dst': ...}`.

In [27]:
ds = kd.data.py.Tfds(
    name='mtnt/en-fr',
    split='train',
    shuffle=True,
    batch_size=8,
    transforms=[
        # TFDS returns `bytes` rather than `str`, so need to decode them first
        gm.data.DecodeBytes(key=['src', 'dst']),
        # We format the input to add the special tokens
        # See `<start_of_turn>` section in
        # https://github.com/google-deepmind/gemma/blob/main/docs/tokenizer.md
        gm.data.FormatText(
            key='src',
            template="""\
            <start_of_turn>user
            {text}<end_of_turn>
            <start_of_turn>model
            """,
        ),
        # Tokenize the inputs/outputs
        gm.data.Tokenize(key='src', tokenizer=tokenizer, add_bos=True),
        gm.data.Tokenize(key='dst', tokenizer=tokenizer, add_eos=True),
        # Create the model inputs/targets/loss_mask.
        gm.data.AddNextTokenPredictionFields(
            in_prompt='src',
            in_response='dst',
            out_input='input',
            out_target='target',
            out_target_mask='loss_mask',
        ),
        # Only keep the fields we need.
        kd.data.Elements(keep=["input", "target", "loss_mask"]),
        # Pad the sequences to support batching.
        gm.data.Pad(
            key=["input", "target", "loss_mask"],
            max_length=200,
            # In this dataset, ~1% of examples are longer than 200 tokens.
            # TODO(epot): Compute statistics
            truncate=True,
        ),
        # For shape compatibility with the loss
        kd.data.Rearrange(
            key=["target", "loss_mask"], pattern="... -> ... 1"
        ),
    ],
)

(ex,) = ds.take(1)

treescope.show(ex)

Disabling pygrain multi-processing (unsupported in colab).
{
    'input': i64[8 200],
    'loss_mask': bool_[8 200 1],
    'target': i64[8 200 1],
}


We can decode an example from the batch to inspect the model input and check it is properly formatted:

In [28]:
text = tokenizer.decode(ex['input'][0])

print(text)

<start_of_turn>user
Would love any other tips from anyone, but specially from someone who’s been where I’m at.<end_of_turn>
<start_of_turn>model
J'apprécierais vraiment d'autres astuces, mais particulièrement par quelqu'un qui était était déjà là où je me trouve.


## Trainer

The [kauldron](https://kauldron.readthedocs.io/en/latest/) trainer allow to train Gemma simply by providing a dataset, a model, a loss and an optimizer.

Dataset, model and losses are connected together through a `key` strings system. For more information, see the [key documentation](https://kauldron.readthedocs.io/en/latest/intro.html#keys-and-context).

Each key starts by a registered prefix. Common prefixes includes:

* `batch`: The output of the dataset (after all transformations). Here our batch is `{'input': ..., 'loss_mask': ..., 'target': ...}`
* `preds`: The output of the model. For Gemma models, this is `gm.nn.Output(logits=..., cache=...)`
* `params`: Model parameters (can be used to add a weight decay loss, or monitor the params norm in metrics)






In [None]:
model = gm.nn.Gemma2_2B(
    tokens="batch.input",
)

In [None]:
loss = kd.losses.SoftmaxCrossEntropyWithIntLabels(
    logits="preds.logits",
    labels="batch.target",
    mask="batch.loss_mask",
)

We then create the trainer:

In [None]:
trainer = kd.train.Trainer(
    seed=42,  # The seed of enlightenment
    workdir='/tmp/ckpts',  # TODO(epot): Make the workdir optional by default
    # Dataset
    train_ds=ds,
    # Model
    model=model,
    init_transform=gm.ckpts.LoadCheckpoint(  # Load the weights from the pretrained checkpoint
        path=gm.ckpts.CheckpointPath.GEMMA2_2B_IT,
    ),
    # Training parameters
    num_train_steps=300,
    train_losses={"loss": loss},
    optimizer=optax.adafactor(learning_rate=1e-3),
)

Trainning can be launched with the `.train()` method.

Note that the trainer like the model are immutables, so it does not store the state nor params. Instead the state containing the trained parameters is returned.

In [None]:
state, aux = trainer.train()

Configuring ...
Initializing ...
Starting training loop at step 0


train:   0%|          | 0/301 [00:00<?, ?it/s]

## Checkpointing

To save the model params, you can either:

* Activate checkpointing in the trainer by adding:

  ```python
  trainer = kd.train.Trainer(
      workdir='/tmp/my_experiment/',
      checkpointer=kd.ckpts.Checkpointer(
          save_interval_steps=500,
      ),
      ...
  )
  ```

  This will also save the optimizer, step, dataset state,...


* Manually save the trained params:

  ```python
  gm.ckpts.save_params('/tmp/my_ckpt/', state.params)
  ```

## Evaluation

Here, we only perform a qualitative evaluation by sampling a prompt.

For more info on evals:

* See the [sampling](https://github.com/google-deepmind/gemma/blob/main/docs/sampling.md) tutorial for more info on running inference.
* To add evals during training, see the Kauldron [evaluator](https://kauldron.readthedocs.io/en/latest/eval.html) documentation.


In [None]:
sampler = gm.text.Sampler(
    model=model,
    params=state.params,
    tokenizer=tokenizer,
)

We test a sentence, using the same formatting used during fine-tuning:

In [None]:
prompt = """\
<start_of_turn>user
Hello! My next holidays are in Paris.<end_of_turn>
<start_of_turn>model
"""

sampler.sample(prompt, max_new_tokens=30)

'Salut ! Mes vacances suivantes seront à Paris.'

The model correctly translated our prompt to French!

## Next steps

To fine-tune outside of Colab, you can look at our [examples/](https://github.com/google-deepmind/gemma/tree/main/examples/) folder for more complexes trainer configs, including LoRA, DPO and sharding.