# Fine-tuning Gemma2 2B model on Roadrunner with JAX, Flax.

We have adopted the Gemma2 notebook from Google Deepmind to use HuggingFace's libraries and and simplified the steps.

## Setup

In [None]:
%%capture
!pip install --upgrade kagglehub -q
!pip install ipywidgets -q
!pip install tensorflow-cpu -q
!pip install tensorflow_datasets -q
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu -q
!pip install git+https://github.com/felafax/gemma.git -q

In [None]:
import os
os.environ['HF_HUB_CACHE'] = '/mnt/persistent-disk/hf/'
os.environ['HF_HOME'] = '/mnt/persistent-disk/hf/'
!export HF_HUB_CACHE="/mnt/persistent-disk/hf/"
!export HF_HOME="/mnt/persistent-disk/hf/"

In [None]:
# @title Python imports

import enum
import re
import string

# We import JAX and some related packages.
import chex
import jax
import jax.numpy as jnp
import optax
from functools import partial

# We will use HuggingFace's dataset, tokenizer, and model classes.
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator
from datasets import Dataset, load_dataset, concatenate_datasets
import torch

# Finally, we import Gemma.
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm


In [None]:
# HuggingFace username and token to use when downloading.
MODEL_NAME="felafax/gemma-2-2b-it-Flax"
HUGGINGFACE_USERNAME = input("INPUT: Please provide your HUGGINGFACE_USERNAME: ")
HUGGINGFACE_TOKEN = input("INPUT: Please provide your HUGGINGFACE_TOKEN: ")

model_name=MODEL_NAME
hugging_face_token=HUGGINGFACE_TOKEN

In [None]:
%%capture
from huggingface_hub import snapshot_download

ckpt_path = snapshot_download(repo_id=MODEL_NAME, token=HUGGINGFACE_TOKEN)
vocab_path = os.path.join(ckpt_path, 'tokenizer.model')

print(ckpt_path)
print()
print(vocab_path)

## Fine tuning the Gemma model

In [None]:
# Load parameters.
params = params_lib.load_and_format_params(os.path.join(ckpt_path, 'gemma2-2b-it'))

In [None]:
# Load model config.
config_2b = transformer_lib.TransformerConfig.gemma2_2b(cache_size=30)

# You can also infer the model config by using the number of layers in the params.
# config_2b = transformer_lib.TransformerConfig.from_params(params, cache_size=30)

In [None]:
model_2b = transformer_lib.Transformer(config=config_2b)

## Step 1: prepare the dataset

For this project, we're utilizing the refined **Alpaca dataset**, curated by yahma. This dataset is a carefully filtered selection of 52,000 entries from the original Alpaca collection. Feel free to substitute this section with your own data preparation code if you prefer.

It's crucial to include the EOS_TOKEN (End of Sequence Token) in your tokenized output. Failing to do so may result in endless generation loops.

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME, 
    token=HUGGINGFACE_TOKEN
)
if not tokenizer.pad_token:
    print("Tokenizer doesn't have a pad token.")
    tokenizer.pad_token = tokenizer.eos_token

In [None]:
def get_dataset(*, tokenizer, batch_size=1, max_length=25, debug_mode=True):
    # Define Alpaca prompt template
    alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
    
    ### Instruction: {}
    
    ### Input: {}
    
    ### Response: {}"""
    
    EOS_TOKEN = tokenizer.eos_token
    
    # Define formatting function.
    def _format_prompts(examples):
        instructions = examples["instruction"]
        inputs = examples["input"]
        outputs = examples["output"]
        texts = []
        for instruction, input, output in zip(instructions, inputs, outputs):
            text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
            texts.append(text)
        return {"text": texts}

    def _tokenize(examples, max_length=None):
        tokenized = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=25+1 if not max_length else max_length+1)
        tokenized['input_ids'] = [input_id[:-1] for input_id in tokenized['input_ids']]
        return {
            'input_tokens': tokenized['input_ids'],
            'target_mask': tokenized['attention_mask']
        }

    def _custom_collate_fn(batch):
        """Applies default_collate_fn from transformers and converts to JAX NumPy arrays."""
        batch = default_data_collator(batch)
        jax_batch = {}
        for key, value in batch.items():
            if isinstance(value, torch.Tensor):
                jax_batch[key] = jnp.array(value.numpy())
            else:
                jax_batch[key] = value
        
        return jax_batch

    # Load and preprocess the dataset.
    dataset = load_dataset("yahma/alpaca-cleaned", split="train")
    if debug_mode:
        dataset = dataset.select(range(32)) # Use just 32 exampfor faster iteration
    dataset = dataset.map(_format_prompts, batched=True)

    # Create train and test dataset.
    ds = dataset.train_test_split(test_size=0.15)
    ds['train'] = ds['train'].map(_tokenize, batched=True, remove_columns=dataset.column_names)
    ds['test'] = ds['test'].map(_tokenize, batched=True, remove_columns=dataset.column_names)

    # Create DataLoader
    train_dataloader = torch.utils.data.DataLoader(
        ds['train'],
        shuffle=True,
        batch_size=1 if not batch_size else batch_size,
        collate_fn=_custom_collate_fn
    )
    
    test_dataloader = torch.utils.data.DataLoader(
        ds['test'],
        shuffle=True,
        batch_size=1 if not batch_size else batch_size,
        collate_fn=_custom_collate_fn
    )

    return train_dataloader, test_dataloader

In [None]:
# # Test Dataset
train_dataloader, _ = get_dataset(tokenizer=tokenizer)
for i, batch in enumerate(train_dataloader):
    if i>10:
        break
    input_ids, attention_mask = (
        batch["input_tokens"],
        batch["target_mask"],
        
    )
    print(input_ids)
    print()
    print(attention_mask)

In [None]:
def forward_and_loss_fn(params,
                        *,
                        model: transformer_lib.Transformer,
                        input_tokens: jax.Array,            # Shape [B, L]
                        input_mask: jax.Array,              # Shape [B, L]
                        positions: jax.Array,               # Shape [B, L]
                        attention_mask: jax.Array,          # [B, L, L]
                        ) -> jax.Array:
  """Forward pass and loss function.

  Args:
    params: model's input parameters.
    model: gemma transformer model to call.
    input_tokens: input tokens sequence, shape [B, L].
    input_mask: tokens to ignore when computing the loss, shape [B, L].
    positions: relative position of each token, shape [B, L].
    attention_mask: input attention mask, shape [B, L].

  Returns:
    Softmax cross-entropy loss for the next-token prediction task.
  """

  # Forward pass on the input data.
  # No attention cache is needed here.
  logits, _ = model.apply(
        params,
        input_tokens,
        positions,
        None,              # Attention cache is None.
        attention_mask,
    )

  # Exclude the last step as it does not appear in the targets.
  logits = logits[:, :-1]


  # Similarly, the first token cannot be predicteds.
  target_tokens = input_tokens[:, 1:]
  target_mask = input_mask[:, 1:]

  # Convert the target labels into one-hot encoded vectors.
  target_mask = target_mask[...,1:] # TODO
  one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])

  # Don't update on unwanted tokens.
  one_hot = one_hot * target_mask.astype(one_hot.dtype)[..., None]

  # Normalisation factor.
  norm_factor = 1 / (jnp.sum(target_mask) + 1e-8)

  # Return the nll loss.
  return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) * norm_factor

The Gemma transformer requires an attention mask and position vector alongside each input. We can conveniently generate these using the following function:

In [None]:
def get_attention_mask_and_positions(example: jax.Array,
                                     pad_id : int,
                                     )-> tuple[jax.Array, jax.Array]:
  """Builds the position and attention mask vectors from the given tokens."""
  pad_mask = example != pad_id
  current_token_position = transformer_lib.build_positions_from_mask(pad_mask)
  attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)
  return current_token_position, attention_mask

We can now build the train_step function which performs the backward pass and updates the model's parameters accordingly.

In [None]:
def train_step(model: transformer_lib.Transformer,
               params,
               optimizer: optax.GradientTransformation,
               opt_state: optax.OptState,
               pad_id: int,
               example):
  """Train step.

  Args:
    model: gemma transformer model.
    params: model's input parameters.
    optimizer: optax optimizer to use.
    opt_state: input optimizer's state.
    pad_id: id of the pad token.
    example: input batch.

  Returns:
    Training loss, updated parameters, updated optimizer state.
  """
  # Build the position and attention mask vectors.
  positions, attention_mask = get_attention_mask_and_positions(example['input_tokens'], pad_id)


  # Forward and backward passes
  train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(params,
                                                             model=model,
                                                             input_tokens=example['input_tokens'],
                                                             input_mask=example['target_mask'],
                                                             positions=positions,
                                                             attention_mask=attention_mask)
  # Update the parameters
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)

  return train_loss, params, opt_state

Similarly, we build a `validation_step` function without backward pass.

And now the training loop itself.

In [None]:
@chex.dataclass(frozen=True)
class TrainingConfig:
  learning_rate: float
  num_epochs: int
  eval_every_n: int
  batch_size: int
  max_steps: int | None = None

from dataclasses import dataclass
import numpy as np


def train_loop(
    model: transformer_lib.Transformer,
    params,
    train_dataloader,
    tokenizer,
    training_cfg: TrainingConfig):


  compiled_train_step = jax.jit(train_step, static_argnames=['model', 'optimizer'])
  optimizer = optax.sgd(training_cfg.learning_rate)
  opt_state = optimizer.init(params)

  n_steps = 0
  avg_loss=0

  for i, train_example in enumerate(train_dataloader):
    train_loss, params, opt_state = train_step(model=model,
                                                        params=params,
                                                        optimizer=optimizer,
                                                        opt_state=opt_state,
                                                        pad_id=tokenizer.pad_token_id,
                                                        example=train_example)
    n_steps += 1
    avg_loss += train_loss
    print(f"train_loss {train_loss}")
    if training_cfg.max_steps is not None and n_steps > training_cfg.max_steps:
      break
  return params

We can fine-tune our model on a limited number of steps.

In [None]:
training_cfg = TrainingConfig(learning_rate=1e-4,
                              num_epochs=1,
                              eval_every_n=20,
                              batch_size=1,
                              max_steps=10)

params = train_loop(model=model_2b,
                    params={'params': params['transformer']},
                    train_dataloader=train_dataloader,
                    tokenizer=tokenizer,
                    training_cfg=training_cfg)