# Finetune Gemma-7B on Vertex AI


[Gemma-7b](https://huggingface.co/google/gemma-7b) is a state-of-the-art open model from Google, built from the same research and technology used to create the Gemini models. It is a text-to-text, decoder-only large language model, available in English, with open weights, and is really well-suited for a variety of text generation tasks, including question answering, summarization, and reasoning. Learn more about it [Welcome Gemma - Google’s new open LLM](https://huggingface.co/blog/gemma). 

In this tutorial you will learn how to finetune [google/gemma-7b](https://huggingface.co/google/gemma-7b) on Vertex AI. 


What you'll learn in this tutorial:

1. [Setup development environment](#1-setup-development-environment)
2. [Load Dataset](#2-load-dataset)
3. [Fine-tune Gemma-7b using `trl` and `SFTTrainer`](#3-fine-tune-gemma-7b-using-trl-and-sfttrainer)
4. [Inference with Fine-tuned Model](#4-inference-with-fine-tuned-model)

## 1. Setup development environment


In this example, we will use the Vertex AI Workbench instance with A100 and the [Hugging Face Deep Learning Containers](https://cloud.google.com/deep-learning-containers/docs/choosing-container#hugging-face). The Hugging Face PyTorch DLC comes with all important libraries, like Transformers, Datasets, PEFT, TRL and other packages pre-installed this makes it super easy to get started, since there is no need for environment management. You can now find all Hugging Face containers on [Google Cloud](https://cloud.google.com/deep-learning-containers/docs/choosing-container#hugging-face).


**ToDo**: Add info on how to spin-up a workbench instance or small intro about Vertex AI Workbench Instance.

**ToDo**: Update the link for the image once, GPU containers are released. 


Once the instance is up and running, we can access a Jupyter environment, which we can use for preparing our dataset and launching the training.

## 2. Load the dataset 

We use the [timdettmers/openassistant-guanaco](https://huggingface.co/datasets/timdettmers/openassistant-guanaco) dataset, which is a refined part of the [OpenAssistant dataset](https://huggingface.co/datasets/OpenAssistant/oasst1) designed specifically to train versatile chatbots. The dataset contains various questions that require generative outputs.

The data is like a question along with its answer. Further, its multi-lingual, i.e., we have questions in English and in Spanish. The dataset contains about 9.85K training instances along with 518 test instances. An example from the dataset:

```
###Human: Can you write a joke with the following setup? A penguin and a walrus walk into a bar### Assistant: A penguin and a walrus walk into a bar. The bartender looks up and says, "What is this, some kind of Arctic joke?" The penguin and walrus just look at each other, confused. Then the walrus shrugs and says, "I don't know about Arctic jokes, but we sure know how to break the ice!" The penguin rolls his eyes, but can't help but chuckle.
```


To load the `timdettmers/openassistant-guanaco` dataset, we use the load_dataset() method from the 🤗 Datasets library.

In [None]:
# Import the necessary library for loading datasets
from datasets import load_dataset

# Specify the name of the dataset
dataset_name = "timdettmers/openassistant-guanaco"

# Load the dataset from the specified name and select the "train" split
dataset = load_dataset(dataset_name, split="train")

# Select only 2500 examples for faster training
dataset = dataset.shuffle(seed=42).select(range(2500))

## 3. Fine-tune Gemma-7b using `trl` and `SFTTrainer`

We will use the [SFTTrainer](https://huggingface.co/docs/trl/en/sft_trainer) from  🤗 `trl` to fine-tune our model. The `SFTTrainer`  is built on top of the 🤗 Transformers `Trainer` and inherits all the core functionalities like logging, evaluation, and checkpointing, but offers additional enhancements like:

- Packing datasets for more efficient training
- PEFT (parameter-efficient fine-tuning) support including Q-LoRA
- Preparing the model and tokenizer for conversational fine-tuning (e.g. adding special tokens)

You can read about it in the [trl docs](https://huggingface.co/docs/trl/en/sft_trainer)


As, we all know LLMs are known to be large, and running or training them in consumer hardware is a huge challenge for users and accessibility. Therefore, we  are going to use [QLoRA](https://arxiv.org/abs/2106.09685), a technqiue technique to reduce the memory footprint of LLMs during finetuning, without sacrificing performance. How it works: 

- Quantize the pretrained model to 4 bits and freezing it.
- Attach small, trainable adapter layers. (LoRA)
- Finetune only the adapter layers, while using the frozen quantized model for context.

To further enhance training efficiency, we'll incorporate a recently introduced, high-performance attention mechanism `Flash Attention 2` alongside `QLoRA`. It is nicely integrated with Transformers. It is up to 3x faster than the standard attention mechanism

In [None]:
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer, DataCollatorForLanguageModeling
from transformers import TrainingArguments, Trainer
import torch
from trl import SFTTrainer

In [None]:
# Hugging Face model id
model_id = "google/gemma-7b"

In [None]:
from transformers import BitsAndBytesConfig

config = BitsAndBytesConfig(
    load_in_4bit=True,                     #  quantize the model to 4-bits when you load it
    bnb_4bit_quant_type="nf4",             # use a special 4-bit data type for weights initialized from a normal distribution
    bnb_4bit_use_double_quant=True,        # use a nested quantization scheme to quantize the already quantized weights
    bnb_4bit_compute_dtype=torch.bfloat16, # Use float16 when running on a GPU(T4, V100) where bfloat16 is not supported
)                                          # conversion from bfloat16 to float16 may lead to overflow (and opposite may lead to loss of precision)

In order to use `google/gemma-7b` you will need the Hugging Face Hub Token, so make sure to execute the following:

```bash
huggingface-cli login # The easiest way to authenticate and it saves the token on your machine. 
```

There are other ways too which can be found in the [docs](https://huggingface.co/docs/huggingface_hub/en/quick-start).

In [None]:
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_id, 
                                             quantization_config = config,
                                             attn_implementation = "flash_attention_2", # use flash-attention-2 for faster training
                                             device_map = "auto",
                                             torch_dtype=torch.bfloat16
                                            ) 

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.padding_side = "right"  # to prevent warnings

For using QLoRA with SFTTrainer, we need to create our LoraConfig and pass it as an argument to the SFTTrainer.

In [None]:
from peft import LoraConfig

peft_config = LoraConfig(
    task_type="Causal_LM", 
    target_modules="all-linear", 
    inference_mode=False, 
    r=8, 
    lora_alpha=16, 
    lora_dropout=0.05
)

Before we can start our training we need to define the hyperparameters (TrainingArguments) we want to use.

In [None]:
training_args = TrainingArguments(
    output_dir = "output",               # directory to save trained model
    num_train_epochs = 2,                # number of training epochs
    learning_rate = 2e-4,                # learning rate for training
    max_grad_norm = 0.3,                 # max gradient norm based on QLoRA paper
    warmup_ratio = 0.03,                 # warmup ratio based on QLoRA paper
    lr_scheduler_type = "constant",      # use constant learning rate scheduler
    optim = "paged_adamw_8bit",          # optimizer for training
    per_device_train_batch_size = 1,     # batch size per device during training
    gradient_accumulation_steps = 4,     # Number of steps to accumulate gradients before updating the model
    logging_steps = 100,                 # log every 100 steps
    bf16 = True                          # Use float16 when running on a GPU(T4, V100) where bfloat16 is not supported
                                         # conversion from bfloat16 to float16 may lead to overflow (and opposite may lead to loss of precision)                                       
)


In [None]:
## Initialize the trl SFTTrainer
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    dataset_text_field = "text", # field that contains the text in the dataset
    args = training_args,
    peft_config = peft_config,
)

In [None]:
# start training
trainer.train()

# save model
trainer.save_model()

## 4. Inference with Fine-tuned Model

Once the fine-tuning is done, we want to run inference on the fine-tuned model. We utilize some prompts from the original dataset and see how does the text generation using the fine-tuned model looks like. 

In [None]:
# free the memory again
del model
del trainer
torch.cuda.empty_cache()

In [None]:
from peft import PeftModel

## Load the adapted model
device = "cuda"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = PeftModel.from_pretrained(model, "output").to(device)

Select some prompts for text generation and see how the model performs

In [None]:
prompts = [
    "### Human: Explain in layman's terms what does options trading mean?",
    "### Human: Was kannst Du im Vergleich zu anderen Large Language Models?",
]


In [None]:
def test_inference(prompt):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.autocast(device):
        outputs = model.generate(
            **inputs, max_new_tokens=50)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)[len(prompt):].strip()    

In [None]:
for prompt in prompts:
    print(f"    prompt:\n{prompt}")
    print(f"    response:\n{test_inference(prompt)}")
    print("-"*50)