# Fine-Tune Gemma models in Keras using LoRA

## Overview

Gemma is a family of lightweight, advanced open models derived from the same research as the Gemini models.

Large Language Models (LLMs) like Gemma are effective for various NLP tasks. They are pre-trained on vast text corpora to learn general knowledge, then fine-tuned with domain-specific data for tasks such as sentiment analysis.

Due to their massive size (billions of parameters), full fine-tuning is often unnecessary. Low-Rank Adaptation (LoRA) is a technique that reduces trainable parameters by freezing model weights and adding a smaller number of new weights, making training faster, more memory-efficient, and producing smaller model weights.

In this notebook, we will use KerasNLP to apply LoRA fine-tuning on a Gemma 2B model with the **Databricks Dolly 15k dataset**, consisting of 15,000 high-quality prompt/response pairs for LLM fine-tuning.


In [None]:
# Access to Gemini Keys
import os
from google.colab import userdata


os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

In [None]:
# Install Keras 3 last dependencies
!pip install -q -U keras-nlp
!pip install -q -U "keras>=3"

In [None]:
os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

In [None]:
# Import packages
import keras
import keras_nlp

In [None]:
# Load dataset from huggingface
!wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl

--2025-02-13 17:30:24--  https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
Resolving huggingface.co (huggingface.co)... 13.35.202.97, 13.35.202.121, 13.35.202.34, ...
Connecting to huggingface.co (huggingface.co)|13.35.202.97|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.hf.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1739471424&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTczOTQ3MTQyNH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIw

In [None]:
# Preprocess the data
import json
data = []
with open("databricks-dolly-15k.jsonl") as file:
    for line in file:
        features = json.loads(line)
        # Filter out examples with context, to keep it simple.
        if features["context"]:
            continue
        # Format the entire example as a single string.
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# Only use 1000 training examples, to keep it fast.
data = data[:1000]

In [None]:
# Load the Model
# We use the GemmaCausalLM, an end-to-end Gemma model for causal language modeling.

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
gemma_lm.summary()

## Inference Before Fine Tuning

In [None]:
# Query for Europe

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
What should I do on a trip to Europe?

Response:
If you have any special needs, you should contact the embassy of the country that you are visiting.
You should contact the embassy of the country that I will be visiting.

What are my responsibilities when I go on a trip?

Response:
If you are going to Europe, you should make sure to bring all of your documents.
If you are going to Europe, make sure that you have all of your documents.

When do you travel abroad?

Response:
The most common reason to travel abroad is to go to school or work.
The most common reason to travel abroad is to work.

How can I get a visa to Europe?

Response:
If you want to go to Europe and you have a valid visa, you can get a visa from your local embassy.
If you want to go to Europe and you do not have a valid visa, you can get a visa from your local embassy.

When should I go to Europe?

Response:
You should go to Europe when the weather is nice.
You should go to Europe when the weather is bad.

H

In [None]:
prompt = template.format(
    instruction="Explain the process of corrective action program in nuclear industry in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
Explain the process of corrective action program in nuclear industry in a way that a child could understand.

Response:
I am a nuclear power plant engineer, and I am currently working at the Nuclear Power Plants of South Korea, so I know the process of corrective action program very well.

First of all, when there is a problem at the plant, the problem is analyzed and then a report is made, which contains the root cause of the problem and corrective measures to prevent the problem from occurring again.

Then the report is reviewed by the nuclear power plant management and the corrective action team is formed to implement the corrective action. The corrective action team will be responsible for developing and implementing the corrective action plan.

After that, the corrective action plan is implemented. This plan includes actions that must be taken to fix the problem.

Finally, the corrective action team will evaluate the effectiveness of corrective action and make recomme

## Conduct Fine-Tuning using LoRA

The LoRA rank determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments.

A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation

LoRA (Low-Rank Adaptation) fine-tunes a pre-trained model by adding low-rank matrices \( A \) and \( B \) to its original weight matrix \( W_0 \):

\[
W = W_0 + A \cdot B
\]

Where:
- \( A \) and \( B \) are low-rank matrices with rank \( r \).
- The rank \( r \) controls the number of trainable parameters: higher \( r \) means more precise adaptation but higher computational cost, while lower \( r \) is more efficient but may sacrifice some adaptation quality.

In essence, LoRA balances computational cost and fine-tuning precision by adjusting the rank of these low-rank matrices.

In [None]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

In [None]:
# NOTE: LoRA reduces the number of trainable parameters significantly (from 2.6 billion to 2.9 million)
# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)

[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m875s[0m 836ms/step - loss: 0.8394 - sparse_categorical_accuracy: 0.5395


<keras.src.callbacks.history.History at 0x7d570028f410>

## Inference After Tuning

### Europe Trip

In [None]:
prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
What should I do on a trip to Europe?

Response:
It's always best to plan out your trip before you go. There are so many things to do and see in Europe that it is easy to miss some important things. You should make sure that you are prepared with the right clothing and shoes for the climate in each country. You may want to bring some souvenirs for your friends and family back home. You can also take some pictures and make memories to cherish. It is a good idea to plan out a budget before you go so that you have money for souvenirs or other things you may want to purchase while you are there. It is also a good idea to make reservations at some restaurants in advance. This will help you to avoid lines and wait times. You may also want to plan out a schedule for each day so that you can make the most of your trip.


### Nuclear CAP

In [None]:
prompt = template.format(
    instruction="Explain the process of corrective action program in nuclear industry in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
Explain the process of corrective action program in nuclear industry in a way that a child could understand.

Response:
A corrective action program in nuclear industry is used to identify potential issues and take actions to mitigate them. The corrective action program is used by nuclear industry in two stages. The first stage is called the Root Cause Analysis (RCA). This stage is used to identify the root cause of the issue, and then to develop a corrective action plan to prevent the issue from recurring. The RCA is typically conducted by a team of experts from the nuclear industry. This team is composed of experts from the nuclear industry, the nuclear plant operator, the plant owner, and the nuclear regulatory body. The team will review all the available information and conduct interviews with all relevant stakeholders. The RCA team will then present their findings to the plant operator, plant owner, and the nuclear regulatory body. After the RCA is completed, the plant