# Fine-tuning Gemma with ChatML and HuggingFace TRL

Google Gemma comes in two sizes:
* 7B, for efficient deployment and development on consumer-size GPU and TPU,
* 2B, for CPU and on-device applications.

Both come in base and instruction-tuned variants.

In this example, we will use HuggingFace TRL, Transformers, and datasets libraies.

In [None]:
# Install Pytorch & other libraries
!pip install "torch==2.1.2" tensorboard

# Install Hugging Face libraries
!pip install  --upgrade \
  "transformers==4.38.2" \
  "datasets==2.16.1" \
  "accelerate==0.26.1" \
  "evaluate==0.4.1" \
  "bitsandbytes==0.42.0" \
  "trl==0.7.11" \
  "peft==0.8.2"

If we have a GPU with Ampere architecture (e.g. NVIDIA A10G or RTX 4090/3090) or newer we can use Flash attention. Flash Attention is a an method that reorders the attention computation and leverages classical techniques (tiling, recomputation) to significantly speed it up and reduce memory usage from quadratic to linear in sequence length.

In [None]:
import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'
# install flash-attn
!pip install ninja packaging
!MAX_JOBS=4 pip install flash-attn --no-build-isolation --upgrade

In [None]:
from huggingface_hub import notebook_login
notebook_login()

## Create and prepare the dataset

We will use the Databricks Dolly dataset, formatted as messages, which means that we can use the `conversational` format to fine-tune our model:
```yaml
{'messages': [{'role': 'system', 'content': 'You are...'}, {'role': 'user', 'content': "..."}, {'role': 'assistant', 'content': "..."}]}
{'messages': [{'role': 'system', 'content': 'You are...'}, {'role': 'user', 'content': "..."}, {'role': 'assistant', 'content': "..."}]}
{'messages': [{'role': 'system', 'content': 'You are...'}, {'role': 'user', 'content': "..."}, {'role': 'assistant', 'content': "..."}]}
```

In [3]:
from datasets import load_dataset

# load dolly dataset
dataset = load_dataset(
    'philschmid/dolly-15k-oai-style',
    split='train'
)
dataset

Downloading readme:   0%|          | 0.00/523 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.24M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/15011 [00:00<?, ? examples/s]

Dataset({
    features: ['messages'],
    num_rows: 15011
})

In [8]:
dataset[0]

{'messages': [{'content': "When did Virgin Australia start operating?\nVirgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.",
   'role': 'user'},
  {'content': 'Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.',
   'role': 'assistant'}]}

## Fine-tune LLM using `trl` and the `SFTTrainer`

We will use the `SFTTrainer` from `trl` to fine-tune the model. The `SFTTrainer` makes it straightforward to *supervise fine-tune* open-source LLMs. The `SFTTrainer` is a subclass of the `Trainer` from the `transformers` library and supports all the same features, including logging, evaluation and checkpointing, but adds additional features, including:
* Dataset formatting, including conversational and instruction format
* Training on completions only, ignoring prompts
* Packing datasets for more efficient training
* PEFT (parameter efficient fine-tuning) support including Q-LoRA
* Preparing the model and tokenizer for conversational fine-tuning (e.g., adding special tokens)

In this example, we will use the dataset formatting, packing and PEFT. As PEFT method, we will use QLoRA, a technique to reduce the memory footprint of large language models during finetuning, without sacrificing performance by using quantization.

Gemma comes with a vocabulary of ~250,000 tokens. Normally if we want to fine-tune LLMs on ChatML format we would need to add special tokens to the tokenizer and teach the model to understand the different roles in a conversation.

*However, Google included ~100 placeholder tokens in the vocabulary, which we can replace with special tokens, like* `<|im_start|>` and `<|im_end|>`. We can apply a [tokenizer for the ChatML format](https://huggingface.co/philschmid/gemma-tokenizer-chatml) which we can use to fine-tune Gemma with ChatML format.

The Chat teemplate used during fine-tuning is not 100% compatible with the ChatML format. Since `google/gemma-7b` requires inputs always to start with a `<bos>` token. This means that our inputs will look like:
```
<bos><|im_start|>system
You are Gemma.<|im_end|>
<|im_start|>user
Hello, how are you?<|im_end|>
<|im_start|>assistant
I'm doing great. How can I help you?<|im_end|>\n<eos>
```

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_id = 'google/gemma-7b'
tokenizer_id = 'philschmid/gemma-tokenizer-chatml'

# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.bfloat16
)

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map='auto',
    attn_implementation='flash_attention_2',
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
tokenizer.padding_side = 'right' # to prevent warnings

`SFTTrainer` supports a native integration with `peft`, which makes it easy to efficiently fine-tune LLMs using, e.g., Q-LoRA. We only need to create our `LoraConfig` and provide it to the trainer.

In [None]:
from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=8,
    lora_dropout=0.05,
    r=6,
    bias='none',
    target_modules='all-linear',
    task_type='CAUSAL_LM'
)

Next, we need to define the hyperparameters in the training arguments

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir='gemma-7b-dolly-chatml',
    num_train_epochs=3,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=2, # number of steps before performing a backward/update pass
    gradient_checkpointing=True,   # use gradient checkpointing to save memory
    optim='adamw_torch_fused',
    logging_steps=10,   # log every 10 steps
    save_strategy='epoch',   # save checkpoint every epoch
    bf16=True,   # use bfloat16 precision
    tf32=True,   # use tf32 precision
    learning_rate=2e-4,  # learning rate based on QLoRA paper
    max_grad_norm=0.3,   # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,   # warmup ratio based on QLoRA paper
    lr_scheduler_type='constant',
    push_to_hub=False,
    report_to='tensorboard'
)

In [None]:
from trl import SFTTrainer

max_seq_length = 1512 # max sequence length for model and packing of the dataset

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    peft_config=peft_config,
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    packing=True,
    dataset_kwargs={
        'add_special_tokens': False, # we template with special tokens
        'append_concat_token': False, # No need to add additional separator token
    }
)

Since we only use a PEFT method, we will only save the adapted model weights and not the full model.

In [None]:
trainer.train()

# save model
trainer.save_model()

## Test model and run inference

Evaluating GenAI models is not a trivial task since 1 input can have multiple correct outputs.

In [None]:
# free up the memory if possible
del model, trainer
torch.cuda.empty_cache()

We will load the adapted model and the tokenizer into the `pipeline` to easily test it and extract the token id of `<|im_end|>` to use it in the `generate` method.

In [None]:
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer, pipeline

peft_model_id = 'gemma-7b-dolly-chatml'

tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
model = AutoPeftModelForCausalLM.from_pretrained(
    peft_model_id,
    device_map='auto',
    torch_dtype=torch.float16
)
pipe = pipeline(
    'text-generation',
    model=model,
    tokenizer=tokenizer,
)

# get token id for end of conversation
eos_token = tokenizer("<|im_end|>", add_special_tokens=False)['input_ids'][0]

Now we can test some prompt samples:

In [None]:
prompts = [
    "What is the capital of Germany? Explain why thats the case and if it was different in the past?",
    "Write a Python function to calculate the factorial of a number.",
    "A rectangular garden has a length of 25 feet and a width of 15 feet. If you want to build a fence around the entire garden, how many feet of fencing will you need?",
    "What is the difference between a fruit and a vegetable? Give examples of each.",
]

def test_inference(prompt):
    prompt = pipe.tokenizer.apply_chat_template(
        [{'role': 'user', 'content': prompt}],
        tokenize=False,
        add_generation_prompt=True
    )
    outputs = pipe(
        prompt,
        max_new_tokens=1024,
        do_sample=True,
        temperature=0.7,
        top_k=50,
        top_p=0.05,
        eos_token_id=eos_token
    )

    return outputs[0]['generated_text'][len(prompt):].strip()


for prompt in prompts:
    print(f"    prompt:\n{prompt}")
    print(f"    response:\n{test_inference(prompt)}")
    print('-'*50)