<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/Gemma_Finetuning_TPU_COLAB.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-Tune GEMMA on Colab TPU
GEMMA models on single host TPUs. For information on TPUs architecture, you can consult the [documentation](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm).


In [None]:
!git clone https://github.com/huggingface/optimum-tpu.git

# Install Optimum TPU
%cd optimum-tpu
!pip install -e . -f https://storage.googleapis.com/libtpu-releases/index.html -q

# Install TRL and PEFT for training (see later how they are used)
!pip install trl peft -q

# This will be necessary for the language modeling example
!pip install datasets evaluate accelerate -q

Then, the tokenizer and model need to be loaded. We will choose [`meta-llama/Llama-3.2-1B`](https://huggingface.co/meta-llama/Llama-3.2-1B) for this example.

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "google/gemma-2b"
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Add custom token for padding Llama
tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)

To tune the model with the [Abirate/english_quotes](https://huggingface.co/datasets/Abirate/english_quotes) dataset, you can load it and obtain the `quote` column:


In [None]:
from datasets import load_dataset


data = load_dataset("Abirate/english_quotes")

def preprocess_function(samples):
    # Create prompt, completion, and combined text columns
    prompts = [f"Generate a quote:\n\n" for _ in samples["quote"]]
    completions = [f"{quote}{tokenizer.eos_token}" for quote in samples["quote"]]
    texts = [p + c for p, c in zip(prompts, completions)]
    return {"prompt": prompts, "completion": completions, "text": texts}

# data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
data = data.map(
    preprocess_function,
    batched=True,
    remove_columns=data["train"].column_names
)

You then need to specify the FSDP training arguments to enable the sharding feature, the function will deduce the classes that should be sharded:


In [4]:
from optimum.tpu import fsdp_v2

fsdp_training_args = fsdp_v2.get_fsdp_training_args(model)

The `fsdp_training_args` will specify the Pytorch module that needs to be sharded:

In [5]:
fsdp_training_args

{'fsdp': 'full_shard',
 'fsdp_config': {'transformer_layer_cls_to_wrap': ['GemmaDecoderLayer'],
  'xla': True,
  'xla_fsdp_v2': True,
  'xla_fsdp_grad_ckpt': True}}

Now training can be done as simply as using the standard `Trainer` class:

In [6]:
from peft import LoraConfig

# Set up PEFT LoRA for fine-tuning.
lora_config = LoraConfig(
    r=8,
    target_modules=["k_proj", "v_proj"],
    task_type="CAUSAL_LM",
)

In [8]:
import jax

print(f"Number of available JAX devices (TPU cores): {jax.device_count()}")

Number of available JAX devices (TPU cores): 8


In [9]:
import torch_xla as xm
num_devices = xm.runtime.world_size()
print(f"Number of available devices (TPU): {num_devices}")

Number of available devices (TPU): 1


In [12]:
from transformers import TrainingArguments, AutoModelForCausalLM, AutoTokenizer, AutoConfig
from trl import SFTTrainer, SFTConfig
from peft import PeftModel
import torch_xla.runtime as xr
import torch

import warnings
warnings.filterwarnings("ignore")


def formatting_func(examples):
    # The 'prompt' column already contains the combined prompt and completion
    return examples["prompt"]

# Create a copy to avoid modifying the original dictionary
fsdp_training_args_copy = fsdp_training_args.copy()

# Extract fsdp_config from the copy
fsdp_config_value = fsdp_training_args_copy.get('fsdp_config', None)

# Remove keys that are explicitly passed to SFTConfig from the copy
if 'fsdp' in fsdp_training_args_copy:
    del fsdp_training_args_copy['fsdp']
if 'fsdp_config' in fsdp_training_args_copy:
    del fsdp_training_args_copy['fsdp_config']

# Reload the model and tokenizer to ensure a clean instance
print(f"Reloading model and tokenizer: {model_id}")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})

# Load model configuration and set use_cache=False to avoid warning
config = AutoConfig.from_pretrained(model_id)
config.use_cache = False

model = AutoModelForCausalLM.from_pretrained(model_id, config=config, torch_dtype=torch.bfloat16)


# Check if the model is already a PeftModel and unload if necessary (should not be needed after reload, but kept as a safeguard)
if isinstance(model, PeftModel):
    print("Unloading existing PEFT adapters...")
    model = model.unload()

# Enable SPMD for torch_xla
xr.use_spmd()


sft_config = SFTConfig(
    per_device_train_batch_size=8,
    num_train_epochs=1,
    max_steps=-50,
    output_dir="/tmp/output",
    optim="adafactor",
    logging_steps=50,
    dataloader_drop_last=True,  # Required by FSDP v2
    completion_only_loss=False, # Disable completion_only_loss
    fsdp="full_shard", # Explicitly set fsdp to "full_shard"
    fsdp_config=fsdp_config_value, # Pass fsdp_config dictionary directly
    **fsdp_training_args_copy, # Unpack remaining args from the copy
)


trainer = SFTTrainer(
    model=model,
    train_dataset=data["train"],
    args=sft_config, # Pass the SFTConfig object here
    peft_config=lora_config,
    #formatting_func=formatting_func, # Remove formatting_func when using SFTConfig
    #max_seq_length=512,
    #packing=True,
)

trainer.train()

Reloading model and tokenizer: google/gemma-2b


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Step,Training Loss
50,2.72
100,2.58
150,2.47
200,2.62
250,2.6
300,2.64


TrainOutput(global_step=313, training_loss=2.603035143769968, metrics={'train_runtime': 4681.4581, 'train_samples_per_second': 0.535, 'train_steps_per_second': 0.067, 'total_flos': 4987841498185728.0, 'train_loss': 2.603035143769968})