# Fine-tuning Gemma2 2B model on Roadrunner with JAX, Flax.

We have adopted the Gemma notebook from Google Deepmind to use HuggingFace's libraries, added support for doing **model parallel training** and simplified the setup.

## Setup 

In [1]:
import os
import sys
import importlib
def import_local_module(module_path: str):
    sys.path.append('')
    module = importlib.import_module(module_path)
    return importlib.reload(module)

# Imports felafax trainer_engine
setup = import_local_module("trainer_engine.setup")
setup.setup_environment()

In [2]:
# PyTorch
!pip install torch --index-url https://download.pytorch.org/whl/cpu -q

# JAX ecosystem
!pip install --upgrade jax -q
!pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -q
!pip install jax-lorax -q
!pip install "flax[all]" -q
!pip install --upgrade optax==0.2.2

# Machine learning libraries
!pip install --no-cache-dir transformers==4.43.3
!pip install --no-cache-dir datasets==2.18.0
!pip install qax -q

# Utility libraries
!pip install --upgrade einops
!pip install --upgrade tqdm
!pip install --upgrade requests
!pip install --upgrade typing-extensions
!pip install --upgrade sentencepiece
!pip install --upgrade pydantic
!pip install --upgrade cloudpickle
!pip install gcsfs

# Web development libraries
!pip install --upgrade fastapi
!pip install --upgrade uvicorn
!pip install --upgrade gradio

# Configuration management
!pip install --upgrade ml_collections

[0mCollecting optax==0.2.2
  Using cached optax-0.2.2-py3-none-any.whl.metadata (8.1 kB)
Using cached optax-0.2.2-py3-none-any.whl (223 kB)
Installing collected packages: optax
  Attempting uninstall: optax
    Found existing installation: optax 0.1.9
    Uninstalling optax-0.1.9:
      Successfully uninstalled optax-0.1.9
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
qax 0.3.1 requires optax<0.2.0,>=0.1.5, but you have optax 0.2.2 which is incompatible.[0m[31m
[0mSuccessfully installed optax-0.2.2
Collecting fsspec<=2024.2.0,>=2023.1.0 (from fsspec[http]<=2024.2.0,>=2023.1.0->datasets==2.18.0)
  Downloading fsspec-2024.2.0-py3-none-any.whl.metadata (6.8 kB)
Downloading fsspec-2024.2.0-py3-none-any.whl (170 kB)
Installing collected packages: fsspec
  Attempting uninstall: fsspec
    Found existing installation: fsspec 2024.6.1
    Uninstalling fsspec

In [3]:
globals().update(setup.setup_imports())

utils = import_local_module("trainer_engine.utils")
llama_model = import_local_module("trainer_engine.llama_model")
checkpoint_lib = import_local_module("trainer_engine.checkpoint_lib")
training_pipeline = import_local_module("trainer_engine.training_pipeline")
convert_to_hf = import_local_module("trainer_engine.convert_to_hf")

  from .autonotebook import tqdm as notebook_tqdm


## Step 0: Input your HF username, token and download model weights

### Select the base model you want to fine-tune 👇

In [4]:
# Select a supported model from above list to use!
MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B"
JAX_MODEL_NAME = "felafax/llama-3.1-8B-JAX"
model_ckpt_path = "/mnt/persistent-disk/hf/models--felafax--llama-3.1-8B-JAX/snapshots/ebca17f216e4c02e0f31cc47264a9d65a4f5b9a9/llama3.1_8b_serialized.flax"

### Input your HuggingFace🤗 username and token below

In [5]:
hf_model_name = MODEL_NAME
HUGGINGFACE_USERNAME = input("INPUT: Please provide your HUGGINGFACE_USERNAME: ")
HUGGINGFACE_TOKEN = input("INPUT: Please provide your HUGGINGFACE_TOKEN: ")

INPUT: Please provide your HUGGINGFACE_USERNAME:  felarof01
INPUT: Please provide your HUGGINGFACE_TOKEN:  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


In [6]:
config = AutoConfig.from_pretrained(
    MODEL_NAME, 
    token=HUGGINGFACE_TOKEN)

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME, 
    token=HUGGINGFACE_TOKEN,
)
tokenizer.pad_token = tokenizer.eos_token

In [7]:
from huggingface_hub import snapshot_download
model_path = snapshot_download(repo_id=JAX_MODEL_NAME, token=HUGGINGFACE_TOKEN)

Fetching 3 files: 100%|██████████| 3/3 [00:00<00:00, 21254.92it/s]


## Step 1: prepare the dataset

For this project, we're utilizing the refined **Alpaca dataset**, curated by yahma. This dataset is a carefully filtered selection of 52,000 entries from the original Alpaca collection. Feel free to substitute this section with your own data preparation code if you prefer.

It's crucial to include the EOS_TOKEN (End of Sequence Token) in your tokenized output. Failing to do so may result in endless generation loops.

In [8]:
def get_dataset(*, tokenizer, batch_size=1, seq_length=32, max_examples=None):
    # Define Alpaca prompt template
    alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
    
    ### Instruction: {}
    
    ### Input: {}
    
    ### Response: {}"""
    
    EOS_TOKEN = tokenizer.eos_token
    
    # Defines formatting function.
    def _format_prompts(examples):
        instructions = examples["instruction"]
        inputs = examples["input"]
        outputs = examples["output"]
        texts = []
        for instruction, input, output in zip(instructions, inputs, outputs):
            text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
            texts.append(text)
        return {"text": texts}

    def _tokenize(examples):
        tokenized = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=seq_length+1)
        return {
            'input_tokens': [input_id[:-1] for input_id in tokenized['input_ids']],
            'target_tokens': [input_id[1:] for input_id in tokenized['input_ids']],
            'loss_masks': [input_id[1:] for input_id in tokenized['attention_mask']]
        }

    def _custom_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, jnp.ndarray]:
        """
        Collates batch items and converts PyTorch tensors to JAX arrays.
        Applies default_data_collator, then converts tensors to JAX format.
        """
        collated = default_data_collator(batch)
        jax_batch = {}
        for key, value in collated.items():
            jax_batch[key] = jnp.array(value.numpy()) if isinstance(value, torch.Tensor) else value
        
        return jax_batch

    # Load and preprocess the dataset
    dataset = load_dataset("yahma/alpaca-cleaned", split="train")
    if max_examples:
        dataset = dataset.select(range(max_examples))
    dataset = dataset.map(_format_prompts, batched=True)

    # Create train and test dataset.
    ds = dataset.train_test_split(test_size=0.15)
    for split in ['train', 'test']:
        ds[split] = ds[split].map(_tokenize, batched=True, remove_columns=dataset.column_names)

    # Create DataLoaders
    dataloader_args = dict(shuffle=True, batch_size=batch_size, collate_fn=_custom_collate_fn)
    train_dataloader = torch.utils.data.DataLoader(ds['train'], **dataloader_args)
    test_dataloader = torch.utils.data.DataLoader(ds['test'], **dataloader_args)

    return train_dataloader, test_dataloader

**Uncomment below code ⬇️ if you'd like to run and test 💯 your dataset pipeline.**

In [None]:
def test_dataset_pipeline(tokenizer):
    """Print shapes of first batch to verify dataset pipeline."""
    train_loader, _ = get_dataset(tokenizer=tokenizer, batch_size=1, seq_length=32, max_examples=512)
    batch = next(iter(train_loader))
    print("Input tokens shape:", batch['input_tokens'].shape)
    print("Target mask shape:", batch['target_tokens'].shape)
test_dataset_pipeline(tokenizer)

Map: 100%|██████████| 435/435 [00:00<00:00, 934.31 examples/s]
Map: 100%|██████████| 77/77 [00:00<00:00, 904.35 examples/s]


## Step 2: Train the model by configuring the hyperparameters below.

In [None]:
@chex.dataclass(frozen=True)
class TrainingConfig:
    learning_rate: float = 1e-4
    num_epochs: int = 1
    max_steps: int | None = 5
    batch_size: int = 32
    seq_length: int = 64
    dataset_size_limit: int | None = 512
    print_every_n_steps: int = 1


training_cfg = TrainingConfig()


**NOTE**: The **time-to-first step of training will be slow** because XLA takes time initially to compile the computational graph. However, once the compilation is complete, subsequent steps will run much faster using the compiled and cached graph, leveraging the full power of all TPU cores for accelerated training.

In [None]:
# Configure mesh
devices = jax.devices()
device_count = len(devices)
device_mesh = mesh_utils.create_device_mesh((1, device_count, 1))
mesh = Mesh(devices=device_mesh, axis_names=("dp", "fsdp", "mp"))

In [None]:
config = 

In [None]:
# Initialize model and optimizer
llama_config = llama_model.LlamaConfig("llama3_8b")
hf_pretrained_llama_config = llama_config.get_hf_pretrained_config(dict(llama_config.get_model_config()))

model = llama_model.CausalLlamaModule(
    hf_pretrained_llama_config,
    dtype=jnp.float32,
    param_dtype=jnp.float32,
)
optimizer = optax.sgd(training_cfg.learning_rate)


In [None]:
# Prepare dataset
train_dataloader, val_dataloader = get_dataset(
    tokenizer=tokenizer,
    seq_length=training_cfg.seq_length,
    max_examples=training_cfg.dataset_size_limit,
)

In [None]:
model_ckpt_path

In [None]:
# Initialize the Trainer
trainer = training_pipeline.Trainer(
    model=model,
    model_ckpt_path=model_ckpt_path,
    model_config=llama_config,
    optimizer=optimizer,
    training_config=training_cfg,
    mesh=mesh,
    model_params=state.params
)

In [None]:
state = trainer.train_state

In [None]:
# Train the model
state = trainer.train(mesh, state, train_dataloader)

In [None]:
convert_to_hf.main([])

In [None]:
from huggingface_hub import HfApi


In [None]:
api = HfApi()
api.upload_folder(
    folder_path="/mnt/persistent-disk/easy/e2hf/",
    repo_id="felafax/llama3.1-8b-easylm-to-hf",
    repo_type="model",
    ignore_patterns=[".*"],
    token="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"
)