Copyright 2024 DeepMind Technologies Limited.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

   http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
See the License for the specific language governing permissions and
limitations under the License.

---

# Fine-tuning the 2B Gemma model with flax

In this tutorial you will learn how to fine-tune the 2B Gemma model for a simple translation task. To run this colab, you will need to use a TPU v4 runtime.

## Setup

In [1]:
!pip install --upgrade kagglehub -q

[0m

In [2]:
import kagglehub

In [3]:
# @title Installation
!pip install git+https://github.com/google-deepmind/gemma.git -q

[0m

## Downloading the checkpoint

"To use Gemma's checkpoints, you'll need a Kaggle account and API key. Here's how to get them:

1. Visit https://www.kaggle.com/ and create an account.
2. Go to your account settings, then the 'API' section.
3. Click 'Create new token' to download your key.

Then run the cell below.

In [4]:
!pip install ipywidgets -q
!pip install tensorflow-cpu -q
!pip install tensorflow_datasets -q

[0m

In [5]:
import os

# Set your Kaggle credentials as environment variables
os.environ['KAGGLE_USERNAME'] = 'felarof99'
os.environ['KAGGLE_KEY'] = '1bf21523d3337386eac3aa5d61d62b7c'

kagglehub.login()

VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

If everything went well, you should see:
```
Kaggle credentials set.
Kaggle credentials successfully validated.
```

Now select and download the checkpoint you want to try. On a single host, only the 2b model can fit in memory for fine-tuning.

In [6]:
%%capture
VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:"string"}
weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')

In [7]:
ckpt_path = os.path.join(weights_dir, VARIANT)
vocab_path = os.path.join(weights_dir, 'tokenizer.model')

In [8]:
# @title Python imports

import enum
import re
import string

# We import JAX and some related packages.
import chex
import jax
import jax.numpy as jnp
import optax

# We will use tensorflow to handle the dataset
import tensorflow as tf
import tensorflow_datasets as tfds

# Finally, we import Gemma.
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm

## Step 1: prepare the dataset

### The MTNT dataset

In this tutorial, we will use the MTNT dataset, from the paper [MTNT: A Testbed for Machine Translation of Noisy Text](https://arxiv.org/abs/1809.00388). This dataset is directly available in the [TensorFlow dataset catalog](https://www.tensorflow.org/datasets/catalog/mtnt).

More precisely we will focus on the English to French translation.

But let's have a look at the data themselves.

Let's customize `SentencePieceProcessor` for our English-to-French translation task. Since we're fine-tuning the English-only Gemma 2B model, we need a few adjustments:

- **Input Prefix**: Adding a common prefix to each input signals the translation task. For example we could go with a prompt like `Translate this into French: [INPUT_SENTENCE]`.

- **Translation Start suffix**: We add a suffix at the end of each prompt tells the model exactly when to begin the translation process. A new line should do the job.

- **LM Tokens**: Gemma models expect a *beginning of sequence* token at the beginning of each sequence. Similarly, we need to add an *end of sequence* token at the end of each training example.

In [9]:
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator
from datasets import Dataset, load_dataset, concatenate_datasets


In [10]:
MODEL_NAME="google/gemma-2-2b-it"
HUGGINGFACE_TOKEN = "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"

model_name=MODEL_NAME
hugging_face_token=HUGGINGFACE_TOKEN

In [11]:
config = AutoConfig.from_pretrained(
    model_name, 
    token=hugging_face_token)

tokenizer = AutoTokenizer.from_pretrained(
    model_name, 
    token=hugging_face_token
)

if not tokenizer.pad_token:
    tokenizer.pad_token = tokenizer.eos_token
    config.pad_token_id = tokenizer.pad_token_id
    
model = AutoModelForCausalLM.from_pretrained(
    model_name, 
    token=hugging_face_token,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [12]:
# Small seq size so that everything fits in memory
SEQ_SIZE = 25
max_length = SEQ_SIZE

# Define Alpaca prompt template
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction: {}

### Input: {}

### Response: {}"""

EOS_TOKEN = tokenizer.eos_token

# Define formatting function.
def _format_prompts(examples):
    instructions = examples["instruction"]
    inputs = examples["input"]
    outputs = examples["output"]
    texts = []
    for instruction, input, output in zip(instructions, inputs, outputs):
        text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
        texts.append(text)
    return {"text": texts}

# Tokenize the dataset.
def _tokenize(examples):
    # Tokenized is list within list. Compute labels for causalLM by shifting input_id; 
    # consequently truncate input_id to penultimate position.
    tokenized = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512+1 if not max_length else max_length+1)
    labels = tokenized['input_ids'].copy()
    tokenized['labels'] = [label[1:] for label in labels]
    tokenized['input_ids'] = [input_id[:-1] for input_id in tokenized['input_ids']]
    new_tokenized = {}
    new_tokenized['input_tokens'] = tokenized['input_ids']
    new_tokenized['target_mask'] = tokenized['attention_mask']
    return new_tokenized

# Load and preprocess the dataset.
dataset = load_dataset("yahma/alpaca-cleaned", split="train")
dataset = dataset.select(range(32)) # ***DEBUG***
dataset = dataset.map(_format_prompts, batched=True)

In [13]:
ds = dataset.train_test_split(test_size=0.15)
ds['train'] = ds['train'].map(_tokenize, batched=True, remove_columns=dataset.column_names)
ds['test'] = ds['test'].map(_tokenize, batched=True, remove_columns=dataset.column_names)

Map:   0%|          | 0/27 [00:00<?, ? examples/s]

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

In [14]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu -q

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[0m

In [15]:
# Create DataLoader
import torch
import jax
import jax.numpy as jnp
from functools import partial

batch_size=1

def _custom_collate_fn(batch):
    "Applies default_collate_fn from transformers and converts to JAX NumPy arrays."""
    batch = default_data_collator(batch)
    # batch = [{k: v.numpy() for k, v in b.items()} for b in batch]
    # batch = [{k: jnp.array(v) for k, v in b.items()} for b in batch]
    # return batch
    jax_batch = {}
    for key, value in batch.items():
        if isinstance(value, torch.Tensor):
            jax_batch[key] = jnp.array(value.numpy())
        else:
            jax_batch[key] = value
    
    return jax_batch
    

train_dataloader = torch.utils.data.DataLoader(
    ds['train'],
    shuffle=True,
    batch_size=1 if not batch_size else batch_size,
    collate_fn=_custom_collate_fn
)

test_dataloader = torch.utils.data.DataLoader(
    ds['test'],
    shuffle=True,
    batch_size=1 if not batch_size else batch_size,
    collate_fn=_custom_collate_fn
)

In [16]:
# for i, batch in enumerate(train_dataloader):
#     input_ids, attention_mask = (
#         batch["input_tokens"],
#         batch["target_mask"],
        
#     )
#     print(input_ids)
#     print()
#     print(attention_mask)

## Fine tuning the Gemma model

### Getting started

First let's load the model

In [17]:
# Load parameters

# TODO: change once the downloading url is known
params = params_lib.load_and_format_params(ckpt_path)

# We use the `transformer_lib.TransformerConfig.from_params` function to
# automatically load the correct configuration from a checkpoint. Note that the
# vocabulary size is smaller than the number of input embeddings due to unused
# tokens in this release.
config_2b = transformer_lib.TransformerConfig.from_params(
    params,
    cache_size=30  # Number of time steps in the transformer's cache
)
model_2b = transformer_lib.Transformer(config=config_2b)

Can our model translate French ? Well let's try it out !

As expected, it didn't work. Let's see if we can get better results by fine-tuning.

Before moving further, don't forget to clear the memory if necessary.

### Model forward and loss function

Gemma `Transformer` class inherits from [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/flax_basics.html). It offers two essential methods:

- `init`: Initializes the model's parameters.

- `apply`: Executes the model's `__call__` function using a given set of parameters.

Since are working with pre-trained weights, we won't use the `init` function.

We define a `forward_and_loss_fn` as follows:

In [19]:
def forward_and_loss_fn(params,
                        *,
                        model: transformer_lib.Transformer,
                        input_tokens: jax.Array,            # Shape [B, L]
                        input_mask: jax.Array,              # Shape [B, L]
                        positions: jax.Array,               # Shape [B, L]
                        attention_mask: jax.Array,          # [B, L, L]
                        ) -> jax.Array:
  """Forward pass and loss function.

  Args:
    params: model's input parameters.
    model: gemma transformer model to call.
    input_tokens: input tokens sequence, shape [B, L].
    input_mask: tokens to ignore when computing the loss, shape [B, L].
    positions: relative position of each token, shape [B, L].
    attention_mask: input attention mask, shape [B, L].

  Returns:
    Softmax cross-entropy loss for the next-token prediction task.
  """

  # Forward pass on the input data.
  # No attention cache is needed here.
  logits, _ = model.apply(
        params,
        input_tokens,
        positions,
        None,              # Attention cache is None.
        attention_mask,
    )

  # Exclude the last step as it does not appear in the targets.
  logits = logits[0, :-1]

  # Similarly, the first token cannot be predicteds.
  target_tokens = input_tokens[0, 1:]
  target_mask = input_mask[0, 1:]

  # Convert the target labels into one-hot encoded vectors.
  one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])

  target_mask = target_mask[...,1:]
  # Don't update on unwanted tokens.
  one_hot = one_hot * target_mask.astype(one_hot.dtype)[...,None]

  # Normalisation factor.
  norm_factor = 1 / (jnp.sum(target_mask) + 1e-8)

  # Return the nll loss.
  return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) * norm_factor

The Gemma transformer requires an attention mask and position vector alongside each input. We can conveniently generate these using the following function:

In [20]:
def get_attention_mask_and_positions(example: jax.Array,
                                     pad_id : int,
                                     )-> tuple[jax.Array, jax.Array]:
  """Builds the position and attention mask vectors from the given tokens."""
  pad_mask = example != pad_id
  current_token_position = transformer_lib.build_positions_from_mask(pad_mask)
  attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)
  return current_token_position, attention_mask

We can now build the train_step function which performs the backward pass and updates the model's parameters accordingly.

In [21]:
def train_step(model: transformer_lib.Transformer,
               params,
               optimizer: optax.GradientTransformation,
               opt_state: optax.OptState,
               pad_id: int,
               example):
  """Train step.

  Args:
    model: gemma transformer model.
    params: model's input parameters.
    optimizer: optax optimizer to use.
    opt_state: input optimizer's state.
    pad_id: id of the pad token.
    example: input batch.

  Returns:
    Training loss, updated parameters, updated optimizer state.
  """
  # Build the position and attention mask vectors.
  positions, attention_mask = get_attention_mask_and_positions(example['input_tokens'], pad_id)


  # Forward and backward passes
  train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(params,
                                                             model=model,
                                                             input_tokens=example['input_tokens'],
                                                             input_mask=example['target_mask'],
                                                             positions=positions,
                                                             attention_mask=attention_mask)
  # Update the parameters
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)

  return train_loss, params, opt_state

Similarly, we build a `validation_step` function without backward pass.

And now the training loop itself.

In [22]:
@chex.dataclass(frozen=True)
class TrainingConfig:
  learning_rate: float
  num_epochs: int
  eval_every_n: int
  batch_size: int
  max_steps: int | None = None

from dataclasses import dataclass
import numpy as np


def train_loop(
    model: transformer_lib.Transformer,
    params,
    train_dataloader,
    tokenizer,
    training_cfg: TrainingConfig):


  # We jit the train step, making the whole loop much more efficient
  compiled_train_step = jax.jit(train_step , static_argnames=['model', 'optimizer'])


  # To save memory, we use a SGD optimizer instead of the usual Adam. Note that
  # for this specific example SGD is more than enough.
  optimizer = optax.sgd(training_cfg.learning_rate)
  opt_state = optimizer.init(params)

  n_steps = 0
  avg_loss=0

  for train_example in train_dataloader:
    train_loss, params, opt_state = train_step(model=model,
                                                        params=params,
                                                        optimizer=optimizer,
                                                        opt_state=opt_state,
                                                        pad_id=tokenizer.pad_token_id,
                                                        example=train_example)
    n_steps += 1
    avg_loss += train_loss
    if training_cfg.max_steps is not None and n_steps > training_cfg.max_steps:
      break
  return params

We can fine-tune our model on a limited number of steps.

In [23]:
# Small seq size so that everything fits in memory
# SEQ_SIZE = 25
# tokenizer = GemmaTokenizer(vocab)
training_cfg = TrainingConfig(learning_rate=1e-4,
                              num_epochs=1,
                              eval_every_n=20,
                              batch_size=1,
                              max_steps=100)

params = train_loop(model=model_2b,
                    params={'params': params['transformer']},
                    train_dataloader=train_dataloader,
                    tokenizer=tokenizer,
                    training_cfg=training_cfg)

ValueError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 128.00M. That was not possible. There are 72.03M free.; (0x0x0_HBM0)

Both the training loss and the validation's are going down. But is it working ? Let's try again with our previous example:

In [None]:
sampler = sampler_lib.Sampler(
    transformer=model_2b,
    vocab=vocab,
    params=params['params'],
)

To ensure our input matches the training format, remember to use the prefix 'Translate this into French:\n'  and a newline character at the end. This signals the model to begin translation.

In [None]:
sampler(
    ["Translate this into French:\nHello, my name is Morgane.\n"],
    total_generation_steps=30,
    ).text
