# Fine-tuning Gemma2 2B model on Roadrunner with JAX, Flax.

We have adopted the Gemma notebook from Google Deepmind to use HuggingFace's libraries, added support for doing **model parallel training** and simplified the setup.

## Setup 

In [41]:
import os
import sys
import importlib
def import_local_module(module_path: str):
    sys.path.append('')
    module = importlib.import_module(module_path)
    return importlib.reload(module)

# Imports felafax trainer_engine
setup = import_local_module("trainer_engine.setup")
setup.setup_environment()

In [42]:
%%capture
!pip install --upgrade kagglehub -q
!pip install ipywidgets -q
!pip install torch --index-url https://download.pytorch.org/whl/cpu -q
!pip install git+https://github.com/felafax/gemma.git -q
!pip install qax -q
!pip install jax-lorax -q

In [43]:
globals().update(setup.setup_imports())

utils = import_local_module("trainer_engine.utils")
training_pipeline = import_local_module("trainer_engine.training_pipeline")

## Step 0: Input your HF username, token and download model weights

### Select the base model you want to fine-tune 👇

In [None]:
supported_models = [
    "gemma-2-2b-it",  # 2b
    "gemma-2-9b-it",  # 9b
]

MODEL_NAME="gemma-2-9b-it"

### Input your HuggingFace🤗 username and token below

In [44]:
hf_model_name = f"felafax/{MODEL_NAME}-JAX"
HUGGINGFACE_USERNAME = input("INPUT: Please provide your HUGGINGFACE_USERNAME: ")
HUGGINGFACE_TOKEN = input("INPUT: Please provide your HUGGINGFACE_TOKEN: ")

INPUT: Please provide your HUGGINGFACE_USERNAME:  
INPUT: Please provide your HUGGINGFACE_TOKEN:  


In [45]:
%%capture
# Downloads the model to disk.
from huggingface_hub import snapshot_download
ckpt_path = snapshot_download(repo_id=hf_model_name, token=HUGGINGFACE_TOKEN)
vocab_path = os.path.join(ckpt_path, 'tokenizer.model')
model_path = os.path.join(ckpt_path, re.sub(r'gemma-(\d+)-', r'gemma\1-', MODEL_NAME))

In [46]:
# Loads the downloaded model.
params = {
    "params": params_lib.load_and_format_params(model_path)['transformer']
}
model_config = transformer_lib.TransformerConfig.from_params(params={"transformer": params["params"]}, cache_size=30)
model = transformer_lib.Transformer(config=model_config)
tokenizer = AutoTokenizer.from_pretrained(
    hf_model_name, 
    token=HUGGINGFACE_TOKEN
)

## Step 1: prepare the dataset

For this project, we're utilizing the refined **Alpaca dataset**, curated by yahma. This dataset is a carefully filtered selection of 52,000 entries from the original Alpaca collection. Feel free to substitute this section with your own data preparation code if you prefer.

It's crucial to include the EOS_TOKEN (End of Sequence Token) in your tokenized output. Failing to do so may result in endless generation loops.

In [47]:
def get_dataset(*, tokenizer, batch_size=1, max_length=32, max_examples=None):
    # Define Alpaca prompt template
    alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
    
    ### Instruction: {}
    
    ### Input: {}
    
    ### Response: {}"""
    
    EOS_TOKEN = tokenizer.eos_token
    
    # Defines formatting function.
    def _format_prompts(examples):
        instructions = examples["instruction"]
        inputs = examples["input"]
        outputs = examples["output"]
        texts = []
        for instruction, input, output in zip(instructions, inputs, outputs):
            text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
            texts.append(text)
        return {"text": texts}

    def _tokenize(examples):
        tokenized = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=max_length+1)
        tokenized['input_ids'] = [input_id[:-1] for input_id in tokenized['input_ids']]
        tokenized['target_mask'] = [input_id[:-1] for input_id in tokenized['attention_mask']]
        return {
            'input_tokens': tokenized['input_ids'],
            'target_mask': tokenized['target_mask']
        }

    def _custom_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, jnp.ndarray]:
        """
        Collates batch items and converts PyTorch tensors to JAX arrays.
        Applies default_data_collator, then converts tensors to JAX format.
        """
        collated = default_data_collator(batch)
        jax_batch = {}
        for key, value in collated.items():
            jax_batch[key] = jnp.array(value.numpy()) if isinstance(value, torch.Tensor) else value
        
        return jax_batch

    # Load and preprocess the dataset
    dataset = load_dataset("yahma/alpaca-cleaned", split="train")
    if max_examples:
        dataset = dataset.select(range(max_examples))
    dataset = dataset.map(_format_prompts, batched=True)

    # Create train and test dataset.
    ds = dataset.train_test_split(test_size=0.15)
    for split in ['train', 'test']:
        ds[split] = ds[split].map(_tokenize, batched=True, remove_columns=dataset.column_names)

    # Create DataLoaders
    dataloader_args = dict(shuffle=True, batch_size=batch_size, collate_fn=_custom_collate_fn)
    train_dataloader = torch.utils.data.DataLoader(ds['train'], **dataloader_args)
    test_dataloader = torch.utils.data.DataLoader(ds['test'], **dataloader_args)

    return train_dataloader, test_dataloader

**Uncomment below code ⬇️ if you'd like to run and test 💯 your dataset pipeline.**

In [48]:
# def test_dataset_pipeline(tokenizer):
#     """Print shapes of first batch to verify dataset pipeline."""
#     train_loader, _ = get_dataset(tokenizer=tokenizer, batch_size=1, max_length=512)
#     batch = next(iter(train_loader))
#     print("Input tokens shape:", batch['input_tokens'].shape)
#     print("Target mask shape:", batch['target_mask'].shape)
# test_dataset_pipeline(tokenizer)

## Step 2: Train the model by configuring the hyperparameters below.

In [55]:
@chex.dataclass(frozen=True)
class TrainingConfig:
  learning_rate: float = 1e-4
  num_epochs: int = 1
  max_steps: int | None = 40  # max number of training steps (**set to None** to train for full num_epochs)

  # Dataset config
  batch_size: int = 32
  max_length: int = 64  # max seq lenght of tokens in input batch
  dataset_size_limit: int | None = None    # limit on number of dataset examples for testing (**set to None** to use full dataset)

  # Misc config
  print_every_n_steps: int = 1

training_cfg = TrainingConfig()

In [51]:
train_dataloader, val_dataloader = get_dataset(tokenizer=tokenizer, max_length=training_cfg.max_length, max_examples=training_cfg.dataset_size_limit)
optimizer = optax.sgd(training_cfg.learning_rate)

Map:   0%|          | 0/43996 [00:00<?, ? examples/s]

Map:   0%|          | 0/7764 [00:00<?, ? examples/s]

In [52]:
# Sets up the device mesh for sharding the model across TPU cores and to do model parallel training.
devices = jax.devices()
device_count = len(devices)
device_mesh = mesh_utils.create_device_mesh((1, device_count, 1))
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model', 'replica'))
print("Sharding model acorss", device_count, " devices.")

Sharding model acorss 4  devices.


**NOTE**: The **time-to-first step of training will be slow** because XLA takes time initially to compile the computational graph. However, once the compilation is complete, subsequent steps will run much faster using the compiled and cached graph, leveraging the full power of all TPU cores for accelerated training.

In [53]:
state = training_pipeline.train_loop(model=model,
                    tokenizer=tokenizer,
                    params=params,
                    optimizer=optimizer,
                    train_dataloader=train_dataloader,
                    training_cfg=training_cfg, 
                    mesh = mesh)

Step 0, Train Loss: 3.6503
Step 1, Train Loss: 2.4262
Step 2, Train Loss: 3.2592
Step 3, Train Loss: 2.2910
Step 4, Train Loss: 2.3543
Step 5, Train Loss: 2.2851
Step 6, Train Loss: 2.1437
Step 7, Train Loss: 2.2376
Step 8, Train Loss: 1.8792
Step 9, Train Loss: 1.9982
Step 10, Train Loss: 2.2012
Step 11, Train Loss: 2.0664
Step 12, Train Loss: 1.9715
Step 13, Train Loss: 2.1281
Step 14, Train Loss: 2.7430
Step 15, Train Loss: 2.1233
Step 16, Train Loss: 1.8954
Step 17, Train Loss: 2.1002
Step 18, Train Loss: 1.9138
Step 19, Train Loss: 1.9351
Step 20, Train Loss: 2.3138
Step 21, Train Loss: 1.7246
Step 22, Train Loss: 1.4195
Step 23, Train Loss: 1.4904
Step 24, Train Loss: 1.5521
Step 25, Train Loss: 1.7244
Step 26, Train Loss: 1.5133
Step 27, Train Loss: 1.5379
Step 28, Train Loss: 1.3721
Step 29, Train Loss: 1.4148
Step 30, Train Loss: 1.3466
Step 31, Train Loss: 1.5437
Step 32, Train Loss: 1.3475
Step 33, Train Loss: 1.4616
Step 34, Train Loss: 1.4873
Step 35, Train Loss: 1.6208
St