# Fine-tuning Gemma2 2B model on Roadrunner with JAX, Flax.

We have adopted the Gemma2 notebook from Google Deepmind to use HuggingFace's libraries and and simplified the steps.

## Setup

In [None]:
%%capture
!pip install --upgrade kagglehub -q
!pip install ipywidgets -q
!pip install torch --index-url https://download.pytorch.org/whl/cpu -q
!pip install git+https://github.com/felafax/gemma.git -q
!pip install qax -q
!pip install jax-lorax -q

In [24]:
import os
os.environ['HF_HUB_CACHE'] = '/mnt/persistent-disk/hf/'
os.environ['HF_HOME'] = '/mnt/persistent-disk/hf/'
!export HF_HUB_CACHE="/mnt/persistent-disk/hf/"
!export HF_HOME="/mnt/persistent-disk/hf/"

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [25]:
# Standard library imports
import os
import pdb
import enum
import re
import string
from dataclasses import dataclass
import functools
from functools import partial
from typing import (
    Any, List, Dict, Tuple, Optional, Union, Sequence, Mapping
)

# JAX and related libraries (including Flax and Optax)
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
from flax.training import train_state
from flax.core.meta import unbox
import optax
import chex
import lorax

# JAX model partitioning and sharding
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as PS
from jax.lax import with_sharding_constraint
from jax.experimental import mesh_utils

# Hugging Face Transformers and Datasets
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator
from datasets import Dataset, load_dataset, concatenate_datasets
import torch

# Gemma-specific imports
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib

In [None]:
import sys
import importlib
def import_local_module(module_path: str):
    sys.path.append('')
    module = importlib.import_module(module_path)
    return importlib.reload(module)

In [26]:
# HuggingFace username and token to use when downloading.
MODEL_NAME="felafax/gemma-2-2b-it-JAX"
HUGGINGFACE_USERNAME = input("INPUT: Please provide your HUGGINGFACE_USERNAME: ")
HUGGINGFACE_TOKEN = input("INPUT: Please provide your HUGGINGFACE_TOKEN: ")

model_name=MODEL_NAME
hugging_face_token=HUGGINGFACE_TOKEN

INPUT: Please provide your HUGGINGFACE_USERNAME:  felarof01
INPUT: Please provide your HUGGINGFACE_TOKEN:  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


In [27]:
%%capture
from huggingface_hub import snapshot_download

ckpt_path = snapshot_download(repo_id=MODEL_NAME, token=HUGGINGFACE_TOKEN)
vocab_path = os.path.join(ckpt_path, 'tokenizer.model')

## Fine tuning the Gemma model

## Step 1: prepare the dataset

For this project, we're utilizing the refined **Alpaca dataset**, curated by yahma. This dataset is a carefully filtered selection of 52,000 entries from the original Alpaca collection. Feel free to substitute this section with your own data preparation code if you prefer.

It's crucial to include the EOS_TOKEN (End of Sequence Token) in your tokenized output. Failing to do so may result in endless generation loops.

In [29]:
def get_dataset(*, tokenizer, batch_size=1, max_length=32, max_examples=32):
    # Define Alpaca prompt template
    alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
    
    ### Instruction: {}
    
    ### Input: {}
    
    ### Response: {}"""
    
    EOS_TOKEN = tokenizer.eos_token
    
    # Define formatting function.
    def _format_prompts(examples):
        instructions = examples["instruction"]
        inputs = examples["input"]
        outputs = examples["output"]
        texts = []
        for instruction, input, output in zip(instructions, inputs, outputs):
            text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
            texts.append(text)
        return {"text": texts}

    def _tokenize(examples):
        tokenized = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=max_length+1)
        tokenized['input_ids'] = [input_id[:-1] for input_id in tokenized['input_ids']]
        tokenized['target_mask'] = [input_id[:-1] for input_id in tokenized['attention_mask']]
        return {
            'input_tokens': tokenized['input_ids'],
            'target_mask': tokenized['target_mask']
        }

    def _custom_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, jnp.ndarray]:
        """
        Collates batch items and converts PyTorch tensors to JAX arrays.
        Applies default_data_collator, then converts tensors to JAX format.
        """
        collated = default_data_collator(batch)
        jax_batch = {}
        for key, value in collated.items():
            jax_batch[key] = jnp.array(value.numpy()) if isinstance(value, torch.Tensor) else value
        
        return jax_batch

    # Load and preprocess the dataset
    dataset = load_dataset("yahma/alpaca-cleaned", split="train")
    if max_examples:
        dataset = dataset.select(range(max_examples))
    dataset = dataset.map(_format_prompts, batched=True)

    # Create train and test dataset.
    ds = dataset.train_test_split(test_size=0.15)
    for split in ['train', 'test']:
        ds[split] = ds[split].map(_tokenize, batched=True, remove_columns=dataset.column_names)

    # Create DataLoaders
    dataloader_args = dict(shuffle=True, batch_size=batch_size, collate_fn=_custom_collate_fn)
    train_dataloader = torch.utils.data.DataLoader(ds['train'], **dataloader_args)
    test_dataloader = torch.utils.data.DataLoader(ds['test'], **dataloader_args)

    return train_dataloader, test_dataloader

In [30]:
# # Uncomment to test dataset pipeline
# def test_dataset_pipeline(tokenizer):
#     """Print shapes of first batch to verify dataset pipeline."""
#     train_loader, _ = get_dataset(tokenizer=tokenizer, batch_size=2, max_length=64)
#     batch = next(iter(train_loader))
#     print("Input tokens shape:", batch['input_tokens'].shape)
#     print("Target mask shape:", batch['target_mask'].shape)

# tokenizer = AutoTokenizer.from_pretrained(
#     MODEL_NAME, 
#     token=HUGGINGFACE_TOKEN
# )
# test_dataset_pipeline(tokenizer)

In [31]:
def forward_and_loss_fn(params,
                        *,
                        state,
                        input_tokens: jax.Array,            # Shape [B, L]
                        input_mask: jax.Array,              # Shape [B, L]
                        positions: jax.Array,               # Shape [B, L]
                        attention_mask: jax.Array,          # [B, L, L]
                        ) -> jax.Array:
  """Forward pass and loss function.

  Args:
    params: model's input parameters.
    model: gemma transformer model to call.
    input_tokens: input tokens sequence, shape [B, L].
    input_mask: tokens to ignore when computing the loss, shape [B, L].
    positions: relative position of each token, shape [B, L].
    attention_mask: input attention mask, shape [B, L].

  Returns:
    Softmax cross-entropy loss for the next-token prediction task.
  """

  # Forward pass on the input data.
  # No attention cache is needed here.
  logits, _ = state.apply_fn(
        {"params": params},
        input_tokens,
        positions,
        None,              # Attention cache is None.
        attention_mask,
    )

  # Exclude the last step as it does not appear in the targets.
  logits = logits[:, :-1]

  # Similarly, the first token cannot be predicteds.
  target_tokens = input_tokens[:, 1:]
  target_mask = input_mask[:, 1:]

  # Convert the target labels into one-hot encoded vectors.
  one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])

  # Don't update on unwanted tokens.
  one_hot = one_hot * target_mask.astype(one_hot.dtype)[..., None]

  # Normalisation factor.
  norm_factor = 1 / (jnp.sum(target_mask) + 1e-8)

  # Return the nll loss.
  loss =  -jnp.sum(jax.nn.log_softmax(logits) * one_hot) * norm_factor
  # pdb.set_trace()
  return loss

The Gemma transformer requires an attention mask and position vector alongside each input. We can conveniently generate these using the following function:

In [32]:
def get_attention_mask_and_positions(example: jax.Array,
                                     pad_id : int,
                                     )-> tuple[jax.Array, jax.Array]:
  """Builds the position and attention mask vectors from the given tokens."""
  pad_mask = example != pad_id
  current_token_position = transformer_lib.build_positions_from_mask(pad_mask)
  attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)
  return current_token_position, attention_mask

We can now build the train_step function which performs the backward pass and updates the model's parameters accordingly.

In [48]:
def train_step(state,
               batch,
              pad_id: int,
              ):
  """Train step.

  Args:
    model: gemma transformer model.
    params: model's input parameters.
    pad_id: id of the pad token.
    batch: input batch.

  Returns:
    Training loss, updated parameters, updated optimizer state.
  """
  # Build the position and attention mask vectors.
  positions, attention_mask = get_attention_mask_and_positions(batch['input_tokens'], pad_id)

  # pdb.set_trace()
  # Forward and backward passes
  train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(state.params,
                                                             state=state,
                                                             input_tokens=batch['input_tokens'],
                                                             input_mask=batch['target_mask'],
                                                             positions=positions,
                                                             attention_mask=attention_mask)

  # pdb.set_trace()
  # Update the parameters
  state = state.apply_gradients(grads=grads)

  return state, train_loss

Similarly, we build a `validation_step` function without backward pass.

In [49]:
def shard_params_pytree(params, mesh):
    def shard_param(param):
        if len(param.shape) == 0:
            return NamedSharding(mesh, PS())
        elif len(param.shape) == 1:
            return NamedSharding(mesh, PS('model'))
        elif len(param.shape) == 2:
            return NamedSharding(mesh, PS('data', 'model'))
        elif len(param.shape) == 3:
            return NamedSharding(mesh, PS('data', 'model', 'replica'))
        else:
            # For higher-dimensional tensors, might need a more complex strategy. But reeplicate by default fornow.
            return NamedSharding(mesh, PS())

    return jax.tree_util.tree_map(shard_param, params)

In [50]:
@chex.dataclass(frozen=True)
class TrainingConfig:
  learning_rate: float
  num_epochs: int
  eval_every_n: int
  batch_size: int
  max_steps: int | None = None

In [51]:
def create_trainstate_from_params(params, model_apply_fn, optimizer):
    state = train_state.TrainState.create(
        apply_fn=model_apply_fn,
        params=params['params'],
        tx=optimizer)
    return state

In [58]:
def train_loop(
    model: transformer_lib.Transformer,
    params,
    train_dataloader,
    tokenizer,
    training_cfg: TrainingConfig,
    mesh):

    # To create sharded train step, you need to figure out how params would look when sharded, for that
    # first, trace (eval) create_trainstate_from_params to get param shapes.
    # second, then shard the params with shapes.
    # third, create sharded_train_step passing sharded param shapes as input to compiler/pjit via in_shardings.
    #   out_shardings can be skipped in jax.jit.
    
    state_shapes = jax.eval_shape(
        functools.partial(
            create_trainstate_from_params,
            params=params,
            model_apply_fn=model.apply,
            optimizer=optimizer,
        ),
    )
    state_shapes_partitioned = shard_params_pytree(
        state_shapes, mesh
    )
    sharded_train_step = jax.jit(
        train_step,
        in_shardings=(state_shapes_partitioned, NamedSharding(mesh, PS())),
        out_shardings=(state_shapes_partitioned, NamedSharding(mesh, PS())),
        static_argnums=(2,)
        # donate_argnums=(0, 1),
    )
    
    n_steps = 0
    avg_loss=0

    # here I'm first creating params which are unsharded. Then, during train_step, it should get sharded.
    # if you visualize the params now, it shouldn't be sharded -- YEP, verified it, everything on TPU0.
    state = create_trainstate_from_params(params, model.apply, optimizer)
    
    for i, train_batch in enumerate(train_dataloader):
        train_batch = jax.device_put(train_batch, NamedSharding(mesh, PS()))
        state, train_loss = sharded_train_step(state, train_batch, tokenizer.pad_token_id, )
        n_steps += 1
        avg_loss += train_loss
        print(f"train_loss {train_loss}")
        if training_cfg.max_steps is not None and n_steps > training_cfg.max_steps:
          break
    return state

In [59]:
# Set up the device mesh
devices = jax.devices()
device_mesh = mesh_utils.create_device_mesh((1, 4, 1))
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model', 'replica'))

In [60]:
# Load parameters.
params = {"params": params_lib.load_and_format_params(os.path.join(ckpt_path, 'gemma2-2b-it'))['transformer']}

In [61]:
training_cfg = TrainingConfig(learning_rate=1e-4,
                              num_epochs=1,
                              eval_every_n=20,
                              batch_size=1,
                              max_steps=10)

In [62]:
# Load model config.
config = transformer_lib.TransformerConfig.gemma2_2b(cache_size=30)
model = transformer_lib.Transformer(config=config)
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME, 
    token=HUGGINGFACE_TOKEN
)
optimizer = optax.sgd(training_cfg.learning_rate)

In [63]:
train_dataloader, val_dataloader = get_dataset(tokenizer=tokenizer)

# with chex.fake_jit():
new_state = train_loop(model=model,
                    params=params,
                    train_dataloader=train_dataloader,
                    tokenizer=tokenizer,
                    training_cfg=training_cfg, 
                   mesh = mesh)

Map:   0%|          | 0/27 [00:00<?, ? examples/s]

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

train_loss 3.5580010414123535
train_loss 3.219733476638794
train_loss 2.9344639778137207
train_loss 2.6792898178100586
train_loss 2.4685018062591553
train_loss 2.2560482025146484
train_loss 2.0904698371887207
train_loss 1.9557868242263794
train_loss 1.824852705001831
train_loss 1.7282675504684448
train_loss 1.6210218667984009
