# Fine-tuning Gemma2 2B model on Roadrunner with JAX, Flax.

We have adopted the Gemma notebook from Google Deepmind to use HuggingFace's libraries, added support for doing **model parallel training** and simplified the setup.

## Setup 

In [1]:
import os
import sys
import importlib
def import_local_module(module_path: str):
    sys.path.append('')
    module = importlib.import_module(module_path)
    return importlib.reload(module)

# Imports felafax trainer_engine
setup = import_local_module("trainer_engine.setup")
setup.setup_environment()

In [None]:
# PyTorch
!pip install torch --index-url https://download.pytorch.org/whl/cpu -q

# JAX ecosystem
!pip install --upgrade jax -q
!pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -q
!pip install jax-lorax -q
!pip install "flax[all]" -q
!pip install --upgrade optax==0.2.2

# Machine learning libraries
!pip install --no-cache-dir transformers==4.43.3
!pip install --no-cache-dir datasets==2.18.0
!pip install qax -q

# Utility libraries
!pip install --upgrade einops
!pip install --upgrade tqdm
!pip install --upgrade requests
!pip install --upgrade typing-extensions
!pip install --upgrade sentencepiece
!pip install --upgrade pydantic
!pip install --upgrade cloudpickle
!pip install gcsfs

# Web development libraries
!pip install --upgrade fastapi
!pip install --upgrade uvicorn
!pip install --upgrade gradio

# Configuration management
!pip install --upgrade ml_collections

In [75]:
globals().update(setup.setup_imports())

utils = import_local_module("trainer_engine.utils")
llama_model = import_local_module("trainer_engine.llama_model")
checkpoint_lib = import_local_module("trainer_engine.checkpoint_lib")
training_pipeline = import_local_module("trainer_engine.training_pipeline")
convert_to_hf = import_local_module("trainer_engine.convert_to_hf")
config_lib = import_local_module("trainer_engine.config_lib")

## Step 0: Input your HF username, token and download model weights

### Select the base model you want to fine-tune 👇

In [4]:
# Select a supported model from above list to use!
MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B"
JAX_MODEL_NAME = "felafax/llama-3.1-8B-JAX"
model_ckpt_path = "/mnt/persistent-disk/hf/models--felafax--llama-3.1-8B-JAX/snapshots/ebca17f216e4c02e0f31cc47264a9d65a4f5b9a9/llama3.1_8b_serialized.flax"

### Input your HuggingFace🤗 username and token below

In [5]:
hf_model_name = MODEL_NAME
HUGGINGFACE_USERNAME = input("INPUT: Please provide your HUGGINGFACE_USERNAME: ")
HUGGINGFACE_TOKEN = input("INPUT: Please provide your HUGGINGFACE_TOKEN: ")

INPUT: Please provide your HUGGINGFACE_USERNAME:  felarof01
INPUT: Please provide your HUGGINGFACE_TOKEN:  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


In [6]:
config = AutoConfig.from_pretrained(
    MODEL_NAME, 
    token=HUGGINGFACE_TOKEN)

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME, 
    token=HUGGINGFACE_TOKEN,
)
tokenizer.pad_token = tokenizer.eos_token

In [7]:
from huggingface_hub import snapshot_download
model_path = snapshot_download(repo_id=JAX_MODEL_NAME, token=HUGGINGFACE_TOKEN)

Fetching 3 files: 100%|██████████| 3/3 [01:09<00:00, 23.06s/it]


## Step 1: prepare the dataset

For this project, we're utilizing the refined **Alpaca dataset**, curated by yahma. This dataset is a carefully filtered selection of 52,000 entries from the original Alpaca collection. Feel free to substitute this section with your own data preparation code if you prefer.

It's crucial to include the EOS_TOKEN (End of Sequence Token) in your tokenized output. Failing to do so may result in endless generation loops.

In [8]:
def get_dataset(*, tokenizer, batch_size=1, seq_length=32, max_examples=None):
    # Define Alpaca prompt template
    alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
    
    ### Instruction: {}
    
    ### Input: {}
    
    ### Response: {}"""
    
    EOS_TOKEN = tokenizer.eos_token
    
    # Defines formatting function.
    def _format_prompts(examples):
        instructions = examples["instruction"]
        inputs = examples["input"]
        outputs = examples["output"]
        texts = []
        for instruction, input, output in zip(instructions, inputs, outputs):
            text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
            texts.append(text)
        return {"text": texts}

    def _tokenize(examples):
        tokenized = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=seq_length+1)
        return {
            'input_tokens': [input_id[:-1] for input_id in tokenized['input_ids']],
            'target_tokens': [input_id[1:] for input_id in tokenized['input_ids']],
            'loss_masks': [input_id[1:] for input_id in tokenized['attention_mask']]
        }

    def _custom_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, jnp.ndarray]:
        """
        Collates batch items and converts PyTorch tensors to JAX arrays.
        Applies default_data_collator, then converts tensors to JAX format.
        """
        collated = default_data_collator(batch)
        jax_batch = {}
        for key, value in collated.items():
            jax_batch[key] = jnp.array(value.numpy()) if isinstance(value, torch.Tensor) else value
        
        return jax_batch

    # Load and preprocess the dataset
    dataset = load_dataset("yahma/alpaca-cleaned", split="train")
    if max_examples:
        dataset = dataset.select(range(max_examples))
    dataset = dataset.map(_format_prompts, batched=True)

    # Create train and test dataset.
    ds = dataset.train_test_split(test_size=0.15)
    for split in ['train', 'test']:
        ds[split] = ds[split].map(_tokenize, batched=True, remove_columns=dataset.column_names)

    # Create DataLoaders
    dataloader_args = dict(shuffle=True, batch_size=batch_size, collate_fn=_custom_collate_fn)
    train_dataloader = torch.utils.data.DataLoader(ds['train'], **dataloader_args)
    test_dataloader = torch.utils.data.DataLoader(ds['test'], **dataloader_args)

    return train_dataloader, test_dataloader

**Uncomment below code ⬇️ if you'd like to run and test 💯 your dataset pipeline.**

In [9]:
def test_dataset_pipeline(tokenizer):
    """Print shapes of first batch to verify dataset pipeline."""
    train_loader, _ = get_dataset(tokenizer=tokenizer, batch_size=1, seq_length=32, max_examples=512)
    batch = next(iter(train_loader))
    print("Input tokens shape:", batch['input_tokens'].shape)
    print("Target mask shape:", batch['target_tokens'].shape)
test_dataset_pipeline(tokenizer)

Downloading readme: 100%|██████████| 11.6k/11.6k [00:00<00:00, 38.1MB/s]
Downloading data: 100%|██████████| 44.3M/44.3M [00:00<00:00, 86.2MB/s]
Generating train split: 51760 examples [00:00, 88661.10 examples/s]
Map: 100%|██████████| 512/512 [00:00<00:00, 5515.41 examples/s]
Map: 100%|██████████| 435/435 [00:00<00:00, 2064.32 examples/s]
Map: 100%|██████████| 77/77 [00:00<00:00, 713.18 examples/s]


Input tokens shape: (1, 32)
Target mask shape: (1, 32)


## Step 2: Train the model by configuring the hyperparameters below.

In [10]:
@chex.dataclass(frozen=True)
class TrainingConfig:
    learning_rate: float = 1e-4
    num_epochs: int = 1
    max_steps: int | None = 5
    batch_size: int = 32
    seq_length: int = 64
    dataset_size_limit: int | None = 512
    print_every_n_steps: int = 1


training_cfg = TrainingConfig()


**NOTE**: The **time-to-first step of training will be slow** because XLA takes time initially to compile the computational graph. However, once the compilation is complete, subsequent steps will run much faster using the compiled and cached graph, leveraging the full power of all TPU cores for accelerated training.

In [11]:
# Configure mesh
devices = jax.devices()
device_count = len(devices)
device_mesh = mesh_utils.create_device_mesh((1, device_count, 1))
mesh = Mesh(devices=device_mesh, axis_names=("dp", "fsdp", "mp"))

In [13]:
# Initialize model and optimizer
llama_config = llama_model.LlamaConfig("llama3_8b")
hf_pretrained_llama_config = llama_config.get_hf_pretrained_config(dict(llama_config.get_model_config()))

model = llama_model.CausalLlamaModule(
    hf_pretrained_llama_config,
    dtype=jnp.float32,
    param_dtype=jnp.float32,
)
optimizer = optax.sgd(training_cfg.learning_rate)


In [14]:
# Prepare dataset
train_dataloader, val_dataloader = get_dataset(
    tokenizer=tokenizer,
    seq_length=training_cfg.seq_length,
    max_examples=training_cfg.dataset_size_limit,
)

Map: 100%|██████████| 435/435 [00:00<00:00, 2196.97 examples/s]
Map: 100%|██████████| 77/77 [00:00<00:00, 2159.40 examples/s]


In [15]:
model_ckpt_path

'/mnt/persistent-disk/hf/models--felafax--llama-3.1-8B-JAX/snapshots/ebca17f216e4c02e0f31cc47264a9d65a4f5b9a9/llama3.1_8b_serialized.flax'

In [33]:
# Initialize the Trainer
trainer = training_pipeline.Trainer(
    model=model,
    model_ckpt_path=model_ckpt_path,
    model_config=llama_config,
    optimizer=optimizer,
    training_config=training_cfg,
    mesh=mesh,
    model_params=state.params
)

Loading llama JAX model...


In [17]:
state = trainer.train_state

In [18]:
# Train the model
state = trainer.train(mesh, state, train_dataloader)

Starting epoch 0 of training...


See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.


Epoch 0, Step 0, Train Loss: 2.7612, Accuracy: 0.5156
Epoch 0, Step 1, Train Loss: 2.3961, Accuracy: 0.4844
Epoch 0, Step 2, Train Loss: 2.6189, Accuracy: 0.4531
Epoch 0, Step 3, Train Loss: 2.1961, Accuracy: 0.4844
Epoch 0, Step 4, Train Loss: 2.5189, Accuracy: 0.4531
Epoch 0, Step 5, Train Loss: 2.5808, Accuracy: 0.4375


In [34]:
trainer.save_model(mesh, 
                   state, 
                   trainer.gather_fns)

In [40]:
load_model_path = os.path.join(os.path.dirname(model_ckpt_path), "streaming_params")

In [41]:
load_model_path

'/mnt/persistent-disk/hf/models--felafax--llama-3.1-8B-JAX/snapshots/ebca17f216e4c02e0f31cc47264a9d65a4f5b9a9/streaming_params'

In [56]:
with mesh:
    train_state, restored_params = trainer.checkpointer.load_trainstate_checkpoint(
        load_from='params::/mnt/persistent-disk/hf/models--felafax--llama-3.1-8B-JAX/snapshots/ebca17f216e4c02e0f31cc47264a9d65a4f5b9a9/streaming_params',
        trainstate_target=state,
        trainstate_shard_fns=trainer.shard_fns)

In [85]:
config = config_lib.config_dict(
    load_checkpoint='not required',
    output_dir='/mnt/persistent-disk/easy/e2hf/',
    llama=llama_config,
)


In [68]:
restored_params['params'].keys()

dict_keys(['lm_head', 'transformer'])

In [76]:
with config_lib.open_file('/mnt/persistent-disk/easy/e2hf/saved_model.flax', "wb") as fout:
    fout.write(flax.serialization.msgpack_serialize(restored_params['params'], in_place=True))

In [79]:
import gc
import json
import math
import os
import shutil

import flax
import jax
import jax.numpy as jnp
import numpy as np
import torch
from flax.traverse_util import flatten_dict
from transformers import LlamaConfig, LlamaForCausalLM

def match_keywords(string, positives, negatives):
    for positive in positives:
        if positive not in string:
            return False
    for negative in negatives:
        if negative in string:
            return False
    return True


In [80]:
def load_and_convert_checkpoint(path):
    _, flax_params = checkpoint_lib.Checkpointer.load_trainstate_checkpoint(path)
    flax_params = flatten_dict(flax_params['params'], sep='.')
    torch_params = {}
    for key, tensor in flax_params.items():
        if match_keywords(key, ["kernel"], ["norm", 'ln_f']):
            tensor = tensor.T
        torch_params[key] = torch.tensor(utils.float_tensor_to_dtype(tensor, 'fp32'),
                                         dtype=torch.float16)
    return torch_params

In [81]:
def read_json(path):
    with open(path, "r") as f:
        return json.load(f)


def write_json(text, path):
    with open(path, "w") as f:
        json.dump(text, f)


def permute(w, n_heads, input_dim, output_dim):
    # permute for sliced rotary embedding
    return w.view(n_heads, output_dim // n_heads // 2, 2,
                  input_dim).transpose(1, 2).reshape(output_dim, input_dim)

In [87]:
def write_model(loaded, model_path):
    os.makedirs(model_path, exist_ok=True)
    tmp_model_path = os.path.join(model_path, "tmp")
    os.makedirs(tmp_model_path, exist_ok=True)

    llama_config = hf_pretrained_llama_config
    n_layers = llama_config.num_hidden_layers
    n_heads = llama_config.num_attention_heads
    n_kv_heads = llama_config.num_key_value_heads
    dim = llama_config.hidden_size
    dims_per_head = dim // n_heads
    base = llama_config.rope_theta
    inv_freq = 1.0 / (base**(torch.arange(0, dims_per_head, 2).float() /
                             dims_per_head))

    param_count = 0
    index_dict = {"weight_map": {}}
    for layer_i in range(n_layers):
        filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
        state_dict = {
            f"model.layers.{layer_i}.self_attn.q_proj.weight":
            permute(
                loaded[f"transformer.h.{layer_i}.attention.wq.kernel"],
                llama_config.num_attention_heads,
                llama_config.hidden_size,
                llama_config.hidden_size,
            ),
            f"model.layers.{layer_i}.self_attn.k_proj.weight":
            permute(
                loaded[f"transformer.h.{layer_i}.attention.wk.kernel"],
                llama_config.num_key_value_heads,
                llama_config.hidden_size,
                llama_config.hidden_size //
                (llama_config.num_attention_heads //
                 llama_config.num_key_value_heads),
            ),
            f"model.layers.{layer_i}.self_attn.v_proj.weight":
            loaded[f"transformer.h.{layer_i}.attention.wv.kernel"],
            f"model.layers.{layer_i}.self_attn.o_proj.weight":
            loaded[f"transformer.h.{layer_i}.attention.wo.kernel"],
            f"model.layers.{layer_i}.mlp.gate_proj.weight":
            loaded[f"transformer.h.{layer_i}.feed_forward.w1.kernel"],
            f"model.layers.{layer_i}.mlp.down_proj.weight":
            loaded[f"transformer.h.{layer_i}.feed_forward.w2.kernel"],
            f"model.layers.{layer_i}.mlp.up_proj.weight":
            loaded[f"transformer.h.{layer_i}.feed_forward.w3.kernel"],
            f"model.layers.{layer_i}.input_layernorm.weight":
            loaded[f"transformer.h.{layer_i}.attention_norm.kernel"],
            f"model.layers.{layer_i}.post_attention_layernorm.weight":
            loaded[f"transformer.h.{layer_i}.ffn_norm.kernel"],
        }

        state_dict[
            f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
        for k, v in state_dict.items():
            index_dict["weight_map"][k] = filename
            param_count += v.numel()
        torch.save(state_dict, os.path.join(tmp_model_path, filename))

    filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
    # Unsharded
    state_dict = {
        "model.embed_tokens.weight": loaded["transformer.wte.embedding"],
        "model.norm.weight": loaded["transformer.ln_f.kernel"],
        "lm_head.weight": loaded["lm_head.kernel"],
    }

    for k, v in state_dict.items():
        index_dict["weight_map"][k] = filename
        param_count += v.numel()
    torch.save(state_dict, os.path.join(tmp_model_path, filename))

    # Write configs
    index_dict["metadata"] = {"total_size": param_count * 2}
    write_json(index_dict,
               os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))

    config = LlamaConfig(
        vocab_size=llama_config.vocab_size,
        hidden_size=llama_config.hidden_size,
        intermediate_size=llama_config.intermediate_size,
        num_hidden_layers=llama_config.num_hidden_layers,
        num_attention_heads=llama_config.num_attention_heads,
        num_key_value_heads=llama_config.num_key_value_heads,
        initializer_range=llama_config.initializer_range,
        rms_norm_eps=llama_config.rms_norm_eps,
        max_position_embeddings=llama_config.max_position_embeddings,
        rope_theta=llama_config.rope_theta,
    )
    config.save_pretrained(tmp_model_path)

    # Make space so we can load the model properly now.
    del state_dict
    del loaded
    gc.collect()

    print("Loading the checkpoint in a Llama model.")
    model = LlamaForCausalLM.from_pretrained(tmp_model_path,
                                             torch_dtype=torch.float16)
    # Avoid saving this as part of the config.
    del model.config._name_or_path

    print("Saving in the Transformers format.")
    model.save_pretrained(model_path)
    shutil.rmtree(tmp_model_path)

In [86]:
loaded_params = load_and_convert_checkpoint('flax_params::/mnt/persistent-disk/easy/e2hf/saved_model.flax')

In [88]:
write_model(
    loaded_params,
    model_path=config.output_dir,
)

Loading the checkpoint in a Llama model.


Loading checkpoint shards: 100%|██████████| 33/33 [00:25<00:00,  1.31it/s]


Saving in the Transformers format.


In [None]:
# convert_to_hf.main(config, 'flax_params::/mnt/persistent-disk/easy/e2hf/saved_model.flax')

In [None]:
from huggingface_hub import HfApi


In [None]:
api = HfApi()
api.upload_folder(
    folder_path="/mnt/persistent-disk/easy/e2hf/",
    repo_id="felafax/llama3.1-8b-easylm-to-hf",
    repo_type="model",
    ignore_patterns=[".*"],
    token="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"
)