# Fine-tuning Gemma2 2B model on Roadrunner with JAX, Flax.

We have adopted the Gemma notebook from Google Deepmind to use HuggingFace's libraries, added support for doing **model parallel training** and simplified the setup.

## Setup 

In [1]:
import os
import sys
import importlib
def import_local_module(module_path: str):
    sys.path.append('')
    module = importlib.import_module(module_path)
    return importlib.reload(module)

# Import felafax libraries
setup = import_local_module("trainer_engine.setup")
setup.setup_environment()

In [2]:
%%capture
!pip install --upgrade kagglehub -q
!pip install ipywidgets -q
!pip install torch --index-url https://download.pytorch.org/whl/cpu -q
!pip install git+https://github.com/felafax/gemma.git -q
!pip install qax -q
!pip install jax-lorax -q

In [3]:
globals().update(setup.setup_imports())

utils = import_local_module("trainer_engine.utils")
training_pipeline = import_local_module("trainer_engine.training_pipeline")

## Step 0: Input your HF username, token and download model weights

In [4]:
# HuggingFace username and token to use when downloading.
MODEL_NAME="felafax/gemma-2-2b-it-JAX"
HUGGINGFACE_USERNAME = input("INPUT: Please provide your HUGGINGFACE_USERNAME: ")
HUGGINGFACE_TOKEN = input("INPUT: Please provide your HUGGINGFACE_TOKEN: ")

model_name=MODEL_NAME
hugging_face_token=HUGGINGFACE_TOKEN

INPUT: Please provide your HUGGINGFACE_USERNAME:  felarof01
INPUT: Please provide your HUGGINGFACE_TOKEN:  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


In [5]:
%%capture
from huggingface_hub import snapshot_download

ckpt_path = snapshot_download(repo_id=MODEL_NAME, token=HUGGINGFACE_TOKEN)
vocab_path = os.path.join(ckpt_path, 'tokenizer.model')

# Load parameters.
params = {"params": params_lib.load_and_format_params(os.path.join(ckpt_path, 'gemma2-2b-it'))['transformer']}

## Step 1: prepare the dataset

For this project, we're utilizing the refined **Alpaca dataset**, curated by yahma. This dataset is a carefully filtered selection of 52,000 entries from the original Alpaca collection. Feel free to substitute this section with your own data preparation code if you prefer.

It's crucial to include the EOS_TOKEN (End of Sequence Token) in your tokenized output. Failing to do so may result in endless generation loops.

In [6]:
def get_dataset(*, tokenizer, batch_size=1, max_length=32, max_examples=None):
    # Define Alpaca prompt template
    alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
    
    ### Instruction: {}
    
    ### Input: {}
    
    ### Response: {}"""
    
    EOS_TOKEN = tokenizer.eos_token
    
    # Define formatting function.
    def _format_prompts(examples):
        instructions = examples["instruction"]
        inputs = examples["input"]
        outputs = examples["output"]
        texts = []
        for instruction, input, output in zip(instructions, inputs, outputs):
            text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
            texts.append(text)
        return {"text": texts}

    def _tokenize(examples):
        tokenized = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=max_length+1)
        tokenized['input_ids'] = [input_id[:-1] for input_id in tokenized['input_ids']]
        tokenized['target_mask'] = [input_id[:-1] for input_id in tokenized['attention_mask']]
        return {
            'input_tokens': tokenized['input_ids'],
            'target_mask': tokenized['target_mask']
        }

    def _custom_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, jnp.ndarray]:
        """
        Collates batch items and converts PyTorch tensors to JAX arrays.
        Applies default_data_collator, then converts tensors to JAX format.
        """
        collated = default_data_collator(batch)
        jax_batch = {}
        for key, value in collated.items():
            jax_batch[key] = jnp.array(value.numpy()) if isinstance(value, torch.Tensor) else value
        
        return jax_batch

    # Load and preprocess the dataset
    dataset = load_dataset("yahma/alpaca-cleaned", split="train")
    if max_examples:
        dataset = dataset.select(range(max_examples))
    dataset = dataset.map(_format_prompts, batched=True)

    # Create train and test dataset.
    ds = dataset.train_test_split(test_size=0.15)
    for split in ['train', 'test']:
        ds[split] = ds[split].map(_tokenize, batched=True, remove_columns=dataset.column_names)

    # Create DataLoaders
    dataloader_args = dict(shuffle=True, batch_size=batch_size, collate_fn=_custom_collate_fn)
    train_dataloader = torch.utils.data.DataLoader(ds['train'], **dataloader_args)
    test_dataloader = torch.utils.data.DataLoader(ds['test'], **dataloader_args)

    return train_dataloader, test_dataloader

In [7]:
# # Uncomment to test dataset pipeline
# def test_dataset_pipeline(tokenizer):
#     """Print shapes of first batch to verify dataset pipeline."""
#     train_loader, _ = get_dataset(tokenizer=tokenizer, batch_size=2, max_length=64)
#     batch = next(iter(train_loader))
#     print("Input tokens shape:", batch['input_tokens'].shape)
#     print("Target mask shape:", batch['target_mask'].shape)

# tokenizer = AutoTokenizer.from_pretrained(
#     MODEL_NAME, 
#     token=HUGGINGFACE_TOKEN
# )
# test_dataset_pipeline(tokenizer)

## Step 2: Train the model by configuring the hyperparameters below.

In [8]:
@chex.dataclass(frozen=True)
class TrainingConfig:
  learning_rate: float
  num_epochs: int
  eval_every_n: int
  batch_size: int
  max_steps: int | None = None
  max_examples: int | None = 32

In [9]:
training_cfg = TrainingConfig(learning_rate=1e-4,
                              num_epochs=1,
                              eval_every_n=20,
                              batch_size=1,
                              max_steps=10)

In [10]:
# Load model config.
config = transformer_lib.TransformerConfig.gemma2_2b(cache_size=30)
model = transformer_lib.Transformer(config=config)
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME, 
    token=HUGGINGFACE_TOKEN
)
optimizer = optax.sgd(training_cfg.learning_rate)

In [11]:
# Set up the device mesh for distributing the model across TPU cores and do model parallel training.
devices = jax.devices()
device_mesh = mesh_utils.create_device_mesh((1, 4, 1))
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model', 'replica'))

In [12]:
train_dataloader, val_dataloader = get_dataset(tokenizer=tokenizer)

new_state = training_pipeline.train_loop(model=model,
                    params=params,
                    optimizer=optimizer,
                    train_dataloader=train_dataloader,
                    tokenizer=tokenizer,
                    training_cfg=training_cfg, 
                    mesh = mesh)

Map:   0%|          | 0/27 [00:00<?, ? examples/s]

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

Step 1, Train Loss: 3.5580
Step 2, Train Loss: 3.2197
Step 3, Train Loss: 2.9345
Step 4, Train Loss: 2.6793
Step 5, Train Loss: 2.4685
Step 6, Train Loss: 2.2560
Step 7, Train Loss: 2.0905
Step 8, Train Loss: 1.9558
Step 9, Train Loss: 1.8249
Step 10, Train Loss: 1.7283
Training complete. Average loss: 2.4715
