In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<a href="https://colab.research.google.com/github/moficodes/ai-on-gke/blob/main/tutorials/finetune-gemma-7b-on-tpu/finetune-gemma-on-gke-using-TPU.ipynb" target="_blank"><img height="40" alt="Run your own notebook in Colab" src = "https://colab.research.google.com/assets/colab-badge.svg"></a>

# Finetune Gemma on GKE using TPU

## Overview

This notebook demonstrates downloading and fine tuning Gemma, open models from Google DeepMind using Pytorch and Hugging Face Libraries. In this notebook we will finetune and publish Gemma model on Hugging Face. In this guide we specifically use TPU V4 but this guide should also work for any TPU version with enough memory.


### Objective

Finetune and Publish Gemma with Transformers and Lora on TPUs.

### TPUs

Tensor Processing Units (TPUs) are Google's custom-developed application-specific integrated circuits (ASICs) used to accelerate machine learning workloads.

Before you use TPUs in GKE, we recommend that you complete the following learning path:

Learn about [TPUs in GKE](https://cloud.google.com/tpu/docs/tpus-in-gke)

## Before you begin

### Configure Environment

Set the following variables for the experiment environment.

In [None]:
# The HuggingFace token used to download models.
# Make sure Token has Write Permission
HF_TOKEN = "<YOUR_HF_TOKEN>"  # @param {type:"string"}

# The size of the model to launch
MODEL_SIZE = "7b"  # @param ["2b", "7b"]

# Cloud project id.
PROJECT_ID = "<YOUR_PROJECT_ID>"  # @param {type:"string"}

# Region for launching clusters.
REGION = "us-central2"  # @param {type:"string"}

LOCATION = "us-central2-a"  # @param {type:"string"}

# The cluster name to create
CLUSTER_NAME = "keras"  # @param {type:"string"}

In [None]:
! gcloud auth login
! gcloud config set project "$PROJECT_ID"
! gcloud services enable container.googleapis.com
# If using in public colab need to login with gcloud auth login

# Add kubectl to the set of available tools.
! mkdir -p /tools/google-cloud-sdk/.install
! gcloud components install kubectl --quiet

### Create a GKE cluster and a node pool

GKE creates the following resources for the model based on the MODEL_SIZE environment variable set above.

- Standard cluster

---



If you already have a cluster, you can skip to `Use an existing GKE cluster` instead.

In [None]:
! gcloud container --project {CLUSTER_NAME} clusters create {CLUSTER_NAME} \
    --cluster-version "1.29.2-gke.1217000" \
    --release-channel "rapid" \
    --machine-type "n1-standard-4" \
    --num-nodes "1" \
    --node-locations {LOCATION}

! gcloud container --project {CLUSTER_NAME} node-pools create "tpu" \
    --cluster {CLUSTER_NAME} \
    --node-version "1.29.2-gke.1217000" \
    --machine-type "ct4p-hightpu-4t"  \
    --num-nodes "4" \
    --placement-type=COMPACT \
    --tpu-topology=2x2x4

### Use an existing GKE cluster

In [None]:
! gcloud container clusters \
    get-credentials {CLUSTER_NAME} \
    --location {LOCATION}

### Create Kubernetes secret for Hugging Face credentials

Create a Kubernetes Secret that contains the Hugging Face token.

In [None]:
! kubectl create secret generic hf-secret \
--from-literal=hf_api_token={HF_TOKEN} \
--dry-run=client -o yaml | kubectl apply -f -

## The Dataset
We use Lora to quickly finetune Gemma with `b-mc2/sql-create-context` dataset.

This dataset has the following structure.

| Answer                                              | Question                                                                      | Context                                                           |
|-----------------------------------------------------|-------------------------------------------------------------------------------|-------------------------------------------------------------------|
| SELECT COUNT(*) FROM head WHERE age > 56            | How many heads of the departments are older than 56 ?                         | CREATE TABLE head (age INTEGER)                                   |
| SELECT name, born_state, age FROM head ORDER BY age | List the name, born state and age of the heads of departments ordered by age. | CREATE TABLE head (name VARCHAR, born_state VARCHAR, age VARCHAR) |

We will finetune `google/gemma-7b` model to get SQL queries based on questions and context.

## Finetuning Gemma on GKE using GPU with Pytorch

In this demo we will use Pytorch-XLA and Huggingface libraries to finetune Gemma. Save the following code in a file named `fsdp.py`

```python
# Make sure to run the script with the following envs:
#   PJRT_DEVICE=TPU XLA_USE_SPMD=1
import os
import torch
import torch_xla

import torch_xla.core.xla_model as xm

from datasets import load_dataset
from peft import LoraConfig, PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from trl import SFTTrainer

import transformers

print("TORCH: ", torch.__version__)
print("TRANSFORMERS: ", transformers.__version__)

# Set up TPU device.
device = xm.xla_device()
model_id = os.getenv("MODEL_ID","google/gemma-7b")
new_model_id = os.getenv("NEW_MODEL_ID","gemma-7b-sql-context")

job_index = os.getenv("JOB_COMPLETION_INDEX")

print("### LOAD TOKENIZER ###")
# Load the pretrained model and tokenizer.
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training


print("### LOAD MODEL ###")
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)

print(model)

# Set up PEFT LoRA for fine-tuning.
lora_config = LoraConfig(
    r=8,
    lora_alpha = 16,
    lora_dropout = 0.1,
    bias="none",
    target_modules=["q_proj", "v_proj"],
    task_type="CAUSAL_LM",
)

print("### LOAD DATASET ###")

limit = int(os.getenv("LIMIT", "5000"))

dataset_name = "b-mc2/sql-create-context"
# Load the dataset and format it for training.
dataset = load_dataset(dataset_name, split="train")
dataset = dataset.shuffle(seed=42).select(range(limit))

def transform(data):
    question = data['question']
    context = data['context']
    answer = data['answer']
    template = "Question: {question}\nContext: {context}\nAnswer: {answer}"
    return {'text': template.format(question=question, context=context, answer=answer)}

print("### TRANSFORM DATASET ###")
dataset = dataset.map(transform)


max_seq_length = 512

# Set up the FSDP config. To enable FSDP via SPMD, set xla_fsdp_v2 to True.
fsdp_config = {"fsdp_transformer_layer_cls_to_wrap": [
        "GemmaDecoderLayer"
    ],
    "xla": True,
    "xla_fsdp_v2": True,
    "xla_fsdp_grad_ckpt": True}

print("### CREATE SFTTRAINER###")
# Finally, set up the trainer and train the model.
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    args=TrainingArguments(
        per_device_train_batch_size=64,  # This is actually the global batch size for SPMD.
        num_train_epochs=1,
        max_steps=-1,
        output_dir="./output",
        optim="adafactor",
        logging_steps=1,
        dataloader_drop_last = True,  # Required for SPMD.
        fsdp="full_shard",
        fsdp_config=fsdp_config,
    ),
    peft_config=lora_config,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    packing=True,
)


print("### STARTING TRAINING ###")
trainer.train()
print("### TRAINING ENDED ###")


print("JOB INDEX: ", job_index)

print("### COMBINE AND MODEL WEIGHT ###")
trainer.save_model(new_model_id)
# Reload model in FP16 and merge it with LoRA weights
base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.bfloat16,
)

model = PeftModel.from_pretrained(base_model, new_model_id)
model = model.merge_and_unload()

print("### DONE MERGING ###")

if job_index == "0":
    print("### UPLOAD MODEL TO HUGGING FACE ###")
    # model.config.to_json_file("adapter_config.json")
    print(model)
    os.listdir(new_model_id)
    model.push_to_hub(repo_id=new_model_id)
    tokenizer.push_to_hub(repo_id=new_model_id)
else:
    print("Model will be uploaded by job 0")

```

## Create a Container Manifest with Dockerfile

Use the following `Dockerfile` to create a container image.

```bash
FROM us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20240229

RUN pip install -U git+https://github.com/huggingface/transformers.git
RUN pip install -U git+https://github.com/huggingface/trl.git
RUN pip install -U datasets peft

COPY . .

CMD python fsdp.py
```

In [None]:
DOCKERFILE = """
FROM us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20240229

RUN pip install -U git+https://github.com/huggingface/transformers.git
RUN pip install -U git+https://github.com/huggingface/trl.git
RUN pip install -U datasets peft

COPY . .

CMD python fsdp.py
"""

with open("Dockerfile", "w") as f:
    f.write(DOCKERFILE)

### Containerize the Code with Docker and Cloud Build

Using Cloud Build and the following Dockerfile we build and push the image in Artifact Registry Docker Repository.

In [None]:
# Create a Artifact Registry Repo
! gcloud artifacts repositories create gemma \
    --project={PROJECT_ID} \
    --repository-format=docker \
    --location=us \
    --description="Gemma Repo"

In [None]:
# Build and push the image using Cloud Build
! gcloud builds submit \
    --tag us-docker.pkg.dev/{PROJECT_ID}/gemma/finetune-gemma-tpu:1.0.1 .

## Run Finetune Job on GKE

Use the YAML to run Gemma Finetune on GKE. Notice we have a job with a headless service. This is because we have a TPU V4 2x2x4 slice which is 4 Nodes with 4 TPU devices each connected via high speed interconnect. We create a indexed job and headless service to give each instance of the job to have be able to communicate with each other via network.

In [None]:
K8S_JOB_YAML = f"""
apiVersion: v1
kind: Service
metadata:
  name: headless-svc
spec:
  clusterIP: None
  selector:
    job-name: tpu-job
---
apiVersion: batch/v1
kind: Job
metadata:
  name: tpu-job
spec:
  backoffLimit: 0
  completions: 4
  parallelism: 4
  completionMode: Indexed
  template:
    spec:
      subdomain: headless-svc
      restartPolicy: Never
      nodeSelector:
        cloud.google.com/gke-tpu-accelerator: tpu-v4-podslice
        cloud.google.com/gke-tpu-topology: 2x2x4
      containers:
      - name: tpu-job
        image: us-docker.pkg.dev/{PROJECT_ID}/gemma/finetune-gemma-tpu:1.0.1
        ports:
        - containerPort: 8471 # Default port using which TPU VMs communicate
        - containerPort: 8431 # Port to export TPU runtime metrics, if supported.
        securityContext:
          privileged: true
        resources:
          requests:
            google.com/tpu: 4
          limits:
            google.com/tpu: 4
        env:
        - name: PJRT_DEVICE
          value: "TPU"
        - name: XLA_USE_SPMD
          value: "1"
        - name: XLA_USE_BF16
          value: "1"
        - name: HF_TOKEN
          valueFrom:
            secretKeyRef:
              name: hf-secret
              key: hf_api_token
        - name: NEW_MODEL_ID
          value: gemma-7b-sql-kubecon-eu-2024
        - name: LIMIT
          value: "10000"
        volumeMounts:
        - mountPath: /dev/shm
          name: dshm
      volumes:
      - name: dshm
        emptyDir:
          medium: Memory
"""

with open("finetune.yaml", "w") as f:
    f.write(K8S_JOB_YAML)

In [None]:
!kubectl apply -f finetune.yaml

#### Waiting for the container to create

Use the command below to check on the status of the container.

In [None]:
! kubectl get po -l job-name=tpu-job

### View the logs from the running Job

This will download the needed artifacts and run the finetuning code, this process will take close to 30 minutes.

In [None]:
! kubectl logs -f job/tpu-job

## Find the model on Huggingface

If the Job ran successfully you can now go find the model on your Huggingface profile.

## Clean up resources

In [None]:
! kubectl delete job tpu-job
! kubectl delete svc headless-svc

In [None]:
! kubectl delete secrets hf-secret

In [None]:
! gcloud container clusters delete {CLUSTER_NAME} \
  --region={LOCATION} \
  --quiet