# Introduction


This notebook is based on a tutorial from goo.gle/ai-kaggle-keras-gemma about how to fine-tune **Gemma** on a **Kaggle** dataset and share the model with the community.  

## Model used

We are using **Gemma** model with 2B parameters, Keras, English version, v2.

## Dataset

We will fine-tune **Gemma** using a [Medical Q & A](https://www.kaggle.com/datasets/gpreda/medquad/) dataset. This is a subset of the full public dataset [Healthcare NLP: LLMs, Transformers, Datasets](https://www.kaggle.com/datasets/jpmiller/layoutlm).



# Prepare packages


We will install updated version of Keras, KerasNLP, which we need for fine-tuning, and other dependencies.

In [1]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U tf-keras
!pip install -q -U keras-nlp==0.10.0
!pip install -q -U keras>=3

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow-decision-forests 1.8.1 requires wurlitzer, which is not installed.
tensorflow-decision-forests 1.8.1 requires tensorflow~=2.15.0, but you have tensorflow 2.17.0 which is incompatible.
tensorflow-text 2.15.0 requires tensorflow<2.16,>=2.15.0; platform_machine != "arm64" or platform_system != "Darwin", but you have tensorflow 2.17.0 which is incompatible.[0m[31m
[0m[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow-decision-forests 1.8.1 requires wurlitzer, which is not installed.[0m[31m
[0m[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following depen

With Keras 3, we can run workflows on one of three backends: **TensorFlow**, **JAX**, and **PyTorch**.

For this Notebook, we will configure the backend for **JAX**.

In [2]:
import os
os.environ["KERAS_BACKEND"] = "jax" # you can also use tensorflow or torch
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00" # avoid memory fragmentation on JAX backend.
os.environ["JAX_PLATFORMS"] = ""

Once installed the additional packages, we can now include them.

In [3]:
import keras_nlp
import keras
import csv

print("KerasNLP version: ", keras_nlp.__version__)
print("Keras version: ", keras.__version__)

2024-08-22 09:10:12.195480: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-22 09:10:12.195535: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-22 09:10:12.197004: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


KerasNLP version:  0.10.0
Keras version:  3.5.0


# Load the model

We load the model `gemma_2b_en` using `keras_nlp`.

In [4]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")

Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'task.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'metadata.json' f

Let's print the summary of the model loaded.

In [5]:
gemma_lm.summary()

# Prepare the training data

We prepare the **Medical Q & A** data for training. We will load the data using the template where, for each data that will be included in the training set, we provide pairs of questions and answers.

In [6]:
data = []

# The CSV file contains two columns 'question' and 'answer'
with open("//kaggle/input/medquad/medquad.csv", mode='r', encoding='utf-8') as file:
    reader = csv.DictReader(file)
    for row in reader:
        # Use a template to format the questions and answers in the CSV into
        # questions and answers in the data.
        template = "Question:\n{question}\n\nAnswer:\n{answer}"
        data.append(template.format(**row))

In [7]:
len(data)

16412

Let's limit the data to **200 rows** of questions and answers. We do this to limit the training time for this demonstration Notebook. You can get back and fine-tune using entire dataset or use your own data size.

In [8]:
data = data[:300]

# Check model inference before fine tuning

We wil first check the model before proceeding to fine-tuning. We will test it with some questions about medical matters.  

First, we will define an utility function to display the query and answer from LLM.

In [9]:
from IPython.display import display, Markdown
def colorize_text(text):
    for word, color in zip(["Category", "Question", "Answer"], ["blue", "red", "green"]):
        text = text.replace(f"{word}:", f"\n\n**<font color='{color}'>{word}:</font>**")
    return text

Let's check how we can display the content of one data input using the `colorize_text` function.

In [10]:
print(data[3])

Question:
What are the treatments for Glaucoma ?

Answer:
Although open-angle glaucoma cannot be cured, it can usually be controlled. While treatments may save remaining vision, they do not improve sight already lost from glaucoma. The most common treatments for glaucoma are medication and surgery. Medications  Medications for glaucoma may be either in the form of eye drops or pills. Some drugs reduce pressure by slowing the flow of fluid into the eye. Others help to improve fluid drainage. (Watch the video to learn more about coping with glaucoma. To enlarge the video, click the brackets in the lower right-hand corner. To reduce the video, press the Escape (Esc) button on your keyboard.) For most people with glaucoma, regular use of medications will control the increased fluid pressure. But, these drugs may stop working over time. Or, they may cause side effects. If a problem occurs, the eye care professional may select other drugs, change the dose, or suggest other ways to deal with 

In [11]:
display(Markdown(colorize_text(data[3])))



**<font color='red'>Question:</font>**
What are the treatments for Glaucoma ?



**<font color='green'>Answer:</font>**
Although open-angle glaucoma cannot be cured, it can usually be controlled. While treatments may save remaining vision, they do not improve sight already lost from glaucoma. The most common treatments for glaucoma are medication and surgery. Medications  Medications for glaucoma may be either in the form of eye drops or pills. Some drugs reduce pressure by slowing the flow of fluid into the eye. Others help to improve fluid drainage. (Watch the video to learn more about coping with glaucoma. To enlarge the video, click the brackets in the lower right-hand corner. To reduce the video, press the Escape (Esc) button on your keyboard.) For most people with glaucoma, regular use of medications will control the increased fluid pressure. But, these drugs may stop working over time. Or, they may cause side effects. If a problem occurs, the eye care professional may select other drugs, change the dose, or suggest other ways to deal with the problem.  Read or listen to ways some patients are coping with glaucoma. Surgery Laser surgery is another treatment for glaucoma. During laser surgery, a strong beam of light is focused on the part of the anterior chamber where the fluid leaves the eye. This results in a series of small changes that makes it easier for fluid to exit the eye. Over time, the effect of laser surgery may wear off. Patients who have this form of surgery may need to keep taking glaucoma drugs. Researching Causes and Treatments Through studies in the laboratory and with patients, NEI is seeking better ways to detect, treat, and prevent vision loss in people with glaucoma. For example, researchers have discovered genes that could help explain how glaucoma damages the eye. NEI also is supporting studies to learn more about who is likely to get glaucoma, when to treat people who have increased eye pressure, and which treatment to use first.

Now we will ask the model to answer to a question for which we know the expected answer.

In [12]:
prompt = template.format(
    question="What are the complications of Paget's Disease of Bone ?",
    answer="",
)
response = gemma_lm.generate(prompt, max_length=128)
display(Markdown(colorize_text(response)))



**<font color='red'>Question:</font>**
What are the complications of Paget's Disease of Bone ?



**<font color='green'>Answer:</font>**
Paget's disease of bone is a disorder of bone metabolism characterized by excessive bone resorption and osteoclast activity. The disease is characterized by the formation of new bone in the form of osteoid, which is subsequently replaced by mature bone. The disease is characterized by the formation of new bone in the form of osteoid, which is subsequently replaced by mature bone. The disease is characterized by the formation of new bone in the form of osteoid, which is subsequently replaced by mature bone. The disease is characterized by the formation

Another example below:

In [13]:
prompt = template.format(
    question="What are the treatments for Diabetes ?",
    answer="",
)
response = gemma_lm.generate(prompt, max_length=128)
display(Markdown(colorize_text(response)))



**<font color='red'>Question:</font>**
What are the treatments for Diabetes ?



**<font color='green'>Answer:</font>**
Diabetes is a chronic disease that affects the way your body uses blood sugar (glucose). It is caused by a lack of insulin, a hormone that helps glucose enter your cells to be used for energy.

There are two types of diabetes:

Type 1 diabetes: This is an autoimmune disease in which the body’s immune system attacks the cells that produce insulin. It is often called “insulin-dependent diabetes” or “juvenile diabetes.”

Type 2 diabetes: This is the most common type of diabetes. It is caused by a combination of factors

# Fine-tunning with LoRA   


We are using now **LoRA** for fine-tunning. **LoRA** stands for **Low Rank Adaptation** and is a method for modifying a pretrained model (for example, an LLM or vision transformer) to better suit a specific, often smaller, dataset by **adjusting only a small, low-rank subset of the model's parameters**.


The rank used here for LoRA controls the number of parameters that will be recalculated during fine-tuning.

In [14]:
# Enable LoRA for the model and set the LoRA rank to 3.
gemma_lm.backbone.enable_lora(rank=3)
gemma_lm.summary()

In [15]:
# Fine-tune on the Medical QA dataset.

# Limit the input sequence length to 128 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=10, batch_size=1)

Epoch 1/10
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m85s[0m 236ms/step - loss: 1.7991 - sparse_categorical_accuracy: 0.5784
Epoch 2/10
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m71s[0m 208ms/step - loss: 1.5249 - sparse_categorical_accuracy: 0.6040
Epoch 3/10
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m62s[0m 208ms/step - loss: 1.4618 - sparse_categorical_accuracy: 0.6136
Epoch 4/10
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m62s[0m 208ms/step - loss: 1.4250 - sparse_categorical_accuracy: 0.6189
Epoch 5/10
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m62s[0m 208ms/step - loss: 1.3936 - sparse_categorical_accuracy: 0.6265
Epoch 6/10
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m62s[0m 208ms/step - loss: 1.3620 - sparse_categorical_accuracy: 0.6315
Epoch 7/10
[1m300/300[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m63s[0m 208ms/step - loss: 1.3294 - sparse_categorical_accuracy: 0.6371
Epoch 

<keras.src.callbacks.history.History at 0x7cc48c157370>

# Inference after fine tuning


Let's try again now the two examples from above. We will run now the queries through the fine-tuned model.

In [16]:
prompt = template.format(
    question="What are the complications of Paget's Disease of Bone ?",
    answer="",
)
response = gemma_lm.generate(prompt, max_length=128)
display(Markdown(colorize_text(response)))



**<font color='red'>Question:</font>**
What are the complications of Paget's Disease of Bone ?



**<font color='green'>Answer:</font>**
Complications of Paget's Disease of Bone include:  -  Osteoporosis, which is a disease that weakens bones and makes them more likely to break.  -  Fractures, which are breaks in bones.  -  Joint pain and stiffness.  -  Loss of height.  -  Loss of bone density in the spine, which can lead to back pain and loss of height.  -  Loss of bone density in the pelvis, which can lead to back pain and loss of height.  -  Loss

In [17]:
prompt = template.format(
    question="What are the treatments for Diabetes ?",
    answer="",
)
response = gemma_lm.generate(prompt, max_length=128)
display(Markdown(colorize_text(response)))



**<font color='red'>Question:</font>**
What are the treatments for Diabetes ?



**<font color='green'>Answer:</font>**
Treatments for Diabetes There is no cure for diabetes, but there are many ways to manage it. The goal of treatment is to keep blood glucose levels as close to normal as possible. This can help prevent or delay the complications of diabetes. The treatment plan for each person with diabetes is different. It depends on the type of diabetes a person has, how well the person is able to control blood glucose levels, and how severe the complications are. The treatment plan may include: - Lifestyle changes, such as eating a healthy diet, getting regular exercise, and losing weight if

# Save the model

In [18]:
preset = "./medical_gemma"
# Save the model to the preset directory.
gemma_lm.save_to_preset(preset)

# Conclusions

We saw how to fine tune Gemma model using **LoRA** and **KerasNLP**.

We only used a subset of the **Medical Q & A** data to perform the fine tuning. If you have access to more computational resources, please feel free to improve the fine-tuned model by:
* Adding more data from the available dataset.  
* Increasing the rank parameter used in fine-tuning.

We will use the saved model to publish a Kaggle Model.