<a href="https://www.kaggle.com/code/gpreda/fine-tuning-gemma-2-model-with-roleplay-dataset?scriptVersionId=197291626" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

<center><h1>Fine-tuning Gemma 2 model using LoRA and Keras</h1></center>

<center><img src="https://res.infoq.com/news/2024/02/google-gemma-open-model/en/headerimage/generatedHeaderImage-1708977571481.jpg" width="400"></center>


# Introduction

This notebook will demonstrate three things:

1. How to fine-tune Gemma model using LoRA
2. Creation of a specialised class to query about Kaggle features
3. Some results of querying about various topics while instructing the model to adopt a certain persona, from the ones included in the data used for fine tuning.



# What is Gemma 2?

Gemma is a collection of lightweight, advanced open models developed by Google, leveraging the same research and technology behind the Gemini models. These models are text-to-text, decoder-only large language models available in English, with open weights provided for both pre-trained and instruction-tuned versions. Gemma models excel in a range of text generation tasks, such as question answering, summarization, and reasoning. Their compact size allows for deployment in resource-constrained environments like laptops, desktops, or personal cloud infrastructure, making state-of-the-art AI models more accessible and encouraging innovation for all. 

Gemma 2 represent the 2nd generation of Gemma models. These models were trained on a dataset of text data that includes a wide variety of sources. The **27B** model was trained with **13 trillion** tokens, the **9B** model was trained with **8 trillion tokens**, and **2B** model was trained with **2 trillion** tokens. Here is a summary of their key components: 
* **Web Documents**: A diverse collection of web text ensures the model is exposed to a broad range of linguistic styles, topics, and vocabulary. Primarily English-language content.
* **Code**: Exposing the model to code helps it to learn the syntax and patterns of programming languages, which improves its ability to generate code or understand code-related questions.
* **Mathematics**: Training on mathematical text helps the model learn logical reasoning, symbolic representation, and to address mathematical queries.

To learn more about Gemma 2, follow this link: [Gemma 2 Model Card](https://www.kaggle.com/models/google/gemma-2).




# What is LoRA?  

**LoRA** stands for **Low-Rank Adaptation**. It is a method used to fine-tune large language models (LLMs) by freezing the weights of the LLM and injecting trainable rank-decomposition matrices. The number of trainable parameters during fine-tunning will decrease therefore considerably. According to **LoRA** paper, this number decreases **10,000 times**, and the computational resources size decreases 3 times. 

# How we proceed?

For fine-tunning with LoRA, we will follow the steps:

1. Install prerequisites
2. Load and process the data for fine-tuning
3. Initialize the code for Gemma causal language model (Gemma Causal LM)
4. Perform fine-tuning so that the model will learn the various persona and be able to perform in each role.
5. Test the fine-tunned model with questions from the data used for fine-tuning and with aditional questions

# Prerequisites


## Install packages

We start by installing `keras-nlp` and `keras` packages.

In [None]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras-nlp
!pip install -q -U keras>=3
!pip install -q -U kagglehub --upgrade

## Import packages

Now we can import the packages we just installed. We will also install `os`, so that we can set the environment variables needed for keras backend. We will use `jax` as `KERAS_BACKEND`.

Because we want to publish the Model from the Notebook, we also include `kagglehub` and import secrets from `Kaggle App`.

In [None]:
import os
os.environ["KERAS_BACKEND"] = "jax" # you can also use tensorflow or torch
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00" # avoid memory fragmentation on JAX backend.
os.environ["JAX_PLATFORMS"] = ""
import keras
import keras_nlp
import kagglehub





import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
tqdm.pandas() # progress bar for pandas

import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, Markdown

## Initialize user secrets

We initialize user secrets, so that we can publish the model using `kagglehub`.

In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
os.environ["KAGGLE_USERNAME"] = user_secrets.get_secret("kaggle_username")
os.environ["KAGGLE_KEY"] = user_secrets.get_secret("kaggle_key")

## Configurations


We use a `Config` class to group the information needed to control the fine-tuning process:
* random seed 
* dataset path
* preset - name of pretrained Gemma 2
* sequence length - this is the maximum size of input sequence for training
* batch size - size of the input batch in training, x 2 as two GPUs
* lora rank - rank for LoRA, higher means more trainable parameters 
* learning rate used in the train
* epochs - number of epochs for train

In [None]:
class Config:
    seed = 42
    dataset_path = "/kaggle/input/roleplay-snapshot/roleplay.csv"
    preset = "/kaggle/input/gemma2/keras/gemma2_2b_en/1" # name of pretrained Gemma 2
    sequence_length = 512 # max size of input sequence for training
    batch_size = 1 # size of the input batch in training
    lora_rank = 4 # rank for LoRA, higher means more trainable parameters
    learning_rate=8e-5 # learning rate used in train
    epochs = 15 # number of epochs to train

Set a random seed for results reproducibility.

In [None]:
keras.utils.set_random_seed(Config.seed)

# Load the data


We load the data we will use for fine-tunining.

In [None]:
df = pd.read_csv(f"{Config.dataset_path}", sep=";")
df.head()

Let's check the total number of rows in this dataset.

In [None]:
df.shape, df.columns

# Preprocess the data

We will preprocess the data so that, from the sequences in the `text` column, we extract the `<|system|>` prompt and the pairs of {`<|user|>`, `<|assistant|>`} to form triplets of {`<|system|>`, `<|user|>`, `<|assistant|>`}  for each entry in the data for fine-tuning.

In [None]:
import re
def extract_dialogue_components(text):
    # Extract the system prompt
    system_prompt = re.search(r"<\|system\|>.*?</s>", text, re.DOTALL).group(0)
    
    # Extract user and assistant dialogues
    dialogue_pairs = re.findall(r"(<\|user\|>.*?</s>\s*<\|assistant\|>.*?</s>)", text, re.DOTALL)
    
    return system_prompt, dialogue_pairs

We process the data. We will only include in the data for fine-tuning the model the rows that fits in the max length as configured.

In [None]:
data = []
for row in df.iterrows():
    text = row[1]["text"]
    try:
        system_prompt, dialogue_pairs = extract_dialogue_components(text)
        for pair in dialogue_pairs:
            prompt_sample = f"{system_prompt}\n\n{pair}"
            data.append(prompt_sample)
    except Exception as ex:
        print(ex, row[0])

## Template utility function


We use this function to reformat the output of our queries, so that it is more user friendly.

We replace and highlight the initial special tokens with more human-readable text (Instruction, Question, Answer).

In [None]:
def colorize_text(text):
    for word, formatted_word, color in zip(["<|system|>:", "<|user|>:", "<|assistant|>:"], 
                                           ["Instruction:", "Question:", "Answer:"],
                                           ["blue", "red", "green"]):
        text = text.replace(f"\n\n{word}", f"\n\n**<font color='{color}'>{formatted_word}</font>**")
    return text

# Specialized class to query Gemma


We define a specialized class to query Gemma. But first, we need to initialize an object of GemmaCausalLM class.

## Initialize the code for Gemma Causal LM

In [None]:
gemma_causal_lm = keras_nlp.models.GemmaCausalLM.from_preset(Config.preset)
gemma_causal_lm.summary()

## Define the specialized class

Here we define the special class `GemmaQA`. 
in the `__init__` we pass the `GemmaCausalLM` object created before.
The `query` member function uses `GemmaCausalLM` member function `generate` to generate the answer, based on a prompt that includes the category and the question.

In [None]:
template = "\n\n<|system|>:\n{instruct}\n\n<|user|>:\n{question}\n\n<|assistant|>:\n{answer}"
class GemmaQA:
    def __init__(self, max_length=512):
        self.max_length = max_length
        self.prompt = template
        self.gemma_causal_lm = gemma_causal_lm
        
    def query(self, instruct, question):
        response = self.gemma_causal_lm.generate(
            self.prompt.format(
                instruct=instruct,
                question=question,
                answer=""), 
            max_length=self.max_length)
        display(Markdown(colorize_text(response)))
        

## Gemma preprocessor


This preprocessing layer will take in batches of strings, and return outputs in a ```(x, y, sample_weight)``` format, where the y label is the next token id in the x sequence.

From the code below, we can see that, after the preprocessor, the data shape is ```(num_samples, sequence_length)```.

In [None]:
x, y, sample_weight = gemma_causal_lm.preprocessor(data[0:2])

In [None]:
print(x, y)

# Perform fine-tuning with LoRA

## Enable LoRA for the model

LoRA rank is setting the number of trainable parameters. A larger rank will result in a larger number of parameters to train.

In [None]:
# Enable LoRA for the model and set the LoRA rank to the lora_rank as set in Config (4).
gemma_causal_lm.backbone.enable_lora(rank=Config.lora_rank)
gemma_causal_lm.summary()

We see that only a small part of the parameters are trainable. 2.6 billions parameters total, and only 2.9 Millions parameters trainable.

## Run the training sequence

We set the `sequence_length` for the `GemmaCausalLM` (from configuration, will be 512).
We compile the model, with the loss, optimizer and metric.
For the metric, it is used `SparseCategoricalAccuracy`. This metric calculates how often predictions match integer labels.

In [None]:
#set sequence length cf. config (896)
gemma_causal_lm.preprocessor.sequence_length = Config.sequence_length 

# Compile the model with loss, optimizer, and metric
gemma_causal_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.Adam(learning_rate=Config.learning_rate),
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train model
gemma_causal_lm.fit(data, epochs=Config.epochs, batch_size=Config.batch_size)

We obtained a rather good accuracy after the 15 steps of fine-tuning.

# Test the fine-tuned model

We instantiate an object of class GemmaQA. Because `gemma_causal_lm` was fine-tuned using LoRA, `gemma_qa` defined here will use the fine-tuned model.

In [None]:
gemma_qa = GemmaQA()

For start, we are testing the model with some of the data from the training set itself.

## Sample 1

In [None]:
gemma_qa = GemmaQA(max_length=96)
instruct = "Sherlock the renowned detective from Baker Street is known for his astute logical reasoning disguise ability and use of forensic science to solve perplexing crimes"
question = "What's Sherlock secret to solving crimes?"
gemma_qa.query(instruct, question)

## Not seen question(s)

In [None]:
gemma_qa = GemmaQA(max_length=128)
instruct = "Sherlock the renowned detective from Baker Street is known for his astute logical reasoning disguise ability and use of forensic science to solve perplexing crimes"
question = "How is able Sherlock to be so succesfull in solving crimes?"
gemma_qa.query(instruct, question)

In [None]:
instruct = df.description.values[1]
question = "How is Serena Williams play against Simona Halep?"

gemma_qa.query(instruct,question)

In [None]:
gemma_qa = GemmaQA(max_length=256)
instruct = df.description.values[2]
question = "How will Beethoven solve the equation system x + 1 = 2?"

gemma_qa.query(instruct,question)

In [None]:
gemma_qa = GemmaQA(max_length=90)
instruct = df.description.values[10]
question = "How will Michael Jordan knitt a sweater?"

gemma_qa.query(instruct,question)

# Save the model

In [None]:
preset_dir = ".\gemma2_2b_en_roleplay"
gemma_causal_lm.save_to_preset(preset_dir)

# Publish Model on Kaggle as a Kaggle Model

We are publishing now the saved model as a Kaggle Model.

In [None]:
kaggle_username = os.environ["KAGGLE_USERNAME"]

kaggle_uri = f"kaggle://{kaggle_username}/gemma2_2b_en_roleplay/keras/gemma2_2b_en_roleplay"
keras_nlp.upload_preset(kaggle_uri, preset_dir)


# Conclusions



We demonstated how to fine-tune a **Gemma 2** model using LoRA.   
We also created a class to run queries to the **Gemma 2** model and tested it with some examples from the existing training data but also with some new, not seen questions.   