<link rel="stylesheet" href="/site-assets/css/gemma.css">
<link rel="stylesheet" href="https://fonts.googleapis.com/css2?family=Google+Symbols:opsz,wght,FILL,GRAD@20..48,100..700,0..1,-50..200" />

##### Copyright 2025 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Fine-tune Gemma in Keras using LoRA

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://ai.google.dev/gemma/docs/core/lora_tuning"><img src="https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png" height="32" width="32" />View on ai.google.dev</a>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/core/lora_tuning.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/google/generative-ai-docs/main/site/en/gemma/docs/core/lora_tuning.ipynb"><img src="https://ai.google.dev/images/cloud-icon.svg" width="40" />Open in Vertex AI</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/core/lora_tuning.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

Generative artificial intelligent (AI) models like Gemma are effective at a variety of tasks. You can further fine-tune Gemma models with domain-specific data to perform tasks such as sentiment analysis. However, full fine-tuning of generative models by updating billions of parameters is resource intensive, requiring specialized hardware, such as GPUs, processing time, and memory to load the model parameters.

[Low Rank Adaptation](https://arxiv.org/abs/2106.09685) (LoRA) is a fine-tuning technique which greatly reduces the number of trainable parameters for downstream tasks by freezing the weights of the model and inserting a smaller number of new weights into the model. This technique makes training with LoRA much faster and more memory-efficient, and produces smaller model weights (a few hundred MBs), all while maintaining the quality of the model outputs. This tutorial walks you through using Keras to perform LoRA fine-tuning on a Gemma model.

## Setup

To complete this tutorial, you will first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:

* Get access to Gemma on [kaggle.com](https://kaggle.com).
* Select a Colab runtime with sufficient resources to tune
  the Gemma model you want to run. [Learn more](https://ai.google.dev/gemma/docs/core#sizes).
* Generate and configure a Kaggle username and API key.

After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.

### Select a Colab runtime

To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model. In this case, you can use a T4 GPU:

1. In the upper-right of the Colab window, select &#9662; (**Additional connection options**).
2. Select **Change runtime type**.
3. Under **Hardware accelerator**, select **T4 GPU**.

### Configure your API key

To use Gemma, you must provide your Kaggle username and a Kaggle API key.

To generate a Kaggle API key, go to the **Account** tab of your Kaggle user profile and select **Create New Token**. This triggers the download of a `kaggle.json` file containing your API credentials.

In Colab, select **Secrets** (🔑) in the left pane and add your Kaggle username and Kaggle API key. Store your username under the name `KAGGLE_USERNAME` and your API key under the name `KAGGLE_KEY`.

### Set environment variables

Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`.

In [2]:
import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

### Install Keras packages

Install the Keras and KerasHub Python packages.

In [3]:
!pip install -q -U keras-hub
!pip install  -q -U keras

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/947.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━[0m [32m419.8/947.9 kB[0m [31m12.3 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m947.9/947.9 kB[0m [31m17.5 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
keras-nlp 0.21.1 requires keras-hub==0.21.1, but you have keras-hub 0.22.2 which is incompatible.[0m[31m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m17.2 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
keras-nlp 0.21.1 requires keras

### Select a backend

Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. Using Keras 3, you can run workflows on one of three backends: TensorFlow, JAX, or PyTorch. For this tutorial, configure the backend for JAX as it typically provides the better performance.

In [4]:
os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

### Import packages

Import the Python packages needed for this tutorial, including Keras and KerasHub.

In [5]:
import keras
import keras_hub

## Load model

Keras provides implementations of Gemma and many other popular [model architectures](https://keras.io/keras_hub/api/models/). Use the `Gemma3CausalLM.from_preset()` method to configure an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens.

In [9]:
gemma_lm = keras_hub.models.Gemma3CausalLM.from_preset("gemma3_instruct_1b")
gemma_lm.summary()

The `Gemma3CausalLM.from_preset()` method instantiates the model from a preset architecture and weights. In the code above, the string `"gemma#_xxxxxxx"` specifies a preset version and parameter size for Gemma. You can find the code strings for Gemma models in their **Model Variation** listings on [Kaggle](https://www.kaggle.com/models/keras/gemma3).

## Inference before fine tuning

Once you have downloaded and configured a Gemma model, you can query it with various prompts to see how it responds.

### Europe trip prompt

Query the model for suggestions on what to do on a trip to Europe.

In [10]:
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_hub.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
What should I do on a trip to Europe?

Response:
Europe offers a huge and diverse range of experiences! To help you plan the perfect trip, here's a breakdown of suggestions based on different interests and budgets:

**1. Popular Destinations (Good for First-Timers):**

*   **Paris:** Iconic landmarks like the Eiffel Tower, Louvre Museum, and Notre Dame Cathedral. Romantic atmosphere, world-class cuisine.
*   **Rome:** Ancient history, stunning architecture (Colosseum, Roman Forum), delicious pasta and pizza.
*   **London:** Royal history, vibrant theater scene, museums galore (British Museum, National Gallery).
*   **Barcelona:** Unique architecture (Gaudí's Sagrada Familia), beaches, lively nightlife.
*   **Amsterdam:** Canals, museums (Rijksmuseum, Van Gogh Museum), charming architecture, cycling culture.

**2. Budget-Friendly Options:**

*   **Portugal:** Affordable cost of living, beautiful coastline, delicious seafood, charming towns.
*   **Czech Republic:** Historic 

The model responds with generic tips on how to plan a trip.

### Photosynthesis prompt

Prompt the model to explain photosynthesis in terms simple enough for a 5 year old child to understand.

In [17]:
prompt = template.format(
    instruction="Analyze this SMS and classify it as transactional or non-transactional. Extract relevant financial information and return as JSON with fields: is_transaction, transaction_type, amount, account, card_number, transaction_id, date, time, bank, recipient, sender. Use null for missing fields. input_sms: 13-Nov-23 17:16 நேரப்படி 50% டேட்டா ஒதுக்கீட்டை பயன்படுத்திவிட்டீர்கள் . ஜியோ எண்: 9500502953 திட்டத்தின் படி தினசரி டேட்டா ஒதுக்கீடு: 1.50 GB MyJio ஆப் மூலம் உங்கள் கணக்கை எவ்வாறு நிர்வகிப்பது என்பது பற்றி மேலும் அறிய, https://youtu.be/FZlrHSEv_Po என்பதைக் கிளிக் செய்க உங்கள் தற்போதைய பேலன்ஸ், வேலிடிட்டி, பிளான் விவரங்கள் மற்றும் ரீசார்ஜ் திட்டங்களுக்கு 1991 ஐ டயல் செய்யுங்கள்.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
Analyze this SMS and classify it as transactional or non-transactional. Extract relevant financial information and return as JSON with fields: is_transaction, transaction_type, amount, account, card_number, transaction_id, date, time, bank, recipient, sender. Use null for missing fields. input_sms: 13-Nov-23 17:16 நேரப்படி 50% டேட்டா ஒதுக்கீட்டை பயன்படுத்திவிட்டீர்கள் . ஜியோ எண்: 9500502953 திட்டத்தின் படி தினசரி டேட்டா ஒதுக்கீடு: 1.50 GB MyJio ஆப் மூலம் உங்கள் கணக்கை எவ்வாறு நிர்வகிப்பது என்பது பற்றி மேலும் அறிய, https://youtu.be/FZlrHSEv_Po என்பதைக் கிளிக் செய்க உங்கள் தற்போதைய பேலன்ஸ், வேலிடிட்டி, பிளான் விவரங்கள் மற்றும் ரீசார்ஜ் திட்டங்களுக்கு 1991 ஐ டயல் செய்யுங்கள்.

Response:
```json
{
  "is_transaction": true,
  "transaction_type": "Data Upgrade",
  "amount": 1500,
  "account": "MyJio",
  "card_


The model response contains words that might not be easy to understand for a child such as chlorophyll.

## LoRA fine-tuning

This section shows you how to do fine-tuning using the Low Rank Adaptation (LoRA) tuning technique. This approach allows you to change the behavior of Gemma models using fewer compute resources.

### Load dataset

Prepare a dataset for tuning by downloading an existing data set and formatting if for use with the the Keras `fit()` fine-tuning method. This tutorial uses the [Databricks Dolly 15k dataset](https://huggingface.co/datasets/databricks/databricks-dolly-15k) for fine-tuning. The dataset contains 15,000 high-quality human-generated prompt and response pairs specifically designed for tuning generative models.

In [21]:
!wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl

--2025-09-17 17:02:20--  https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
Resolving huggingface.co (huggingface.co)... 3.163.189.74, 3.163.189.90, 3.163.189.114, ...
Connecting to huggingface.co (huggingface.co)|3.163.189.74|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cas-bridge.xethub.hf.co/xet-bridge-us/64358e2179c45fcf1ada09f4/63c4dabe683d7254493568d2d3995c0e51abc8528ef3b4936497c538cb501e93?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Content-Sha256=UNSIGNED-PAYLOAD&X-Amz-Credential=cas%2F20250917%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250917T170220Z&X-Amz-Expires=3600&X-Amz-Signature=290d390cdcd3bc6f8ab0252eed91bb7e72bbbd8f943c068f911d2a631f3a27f1&X-Amz-SignedHeaders=host&X-Xet-Cas-Uid=public&response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&x-id=GetObject&Expires=1758132140&Policy=eyJTdGF0ZW1lbnQi

### Format tuning data

Format the downloaded data for use with the Keras `fit()` method. The following code extracts a subset of the training examples to execute the notebook faster. Consider using more training data for higher quality fine-tuning.

In [38]:
import json

prompts = []
responses = []
context = []

line_count = 0

filePath = "sample_data/converted_dataset.jsonl"


with open(filePath) as file:

    for line in file:
        if line_count >= 5000:
            break  # Limit the training examples, to reduce execution time.
        # print(line)
        examples = json.loads(line)
        # Filter out examples with context, to keep it simple.
        # if examples["context"]:
            # continue
        # Format data into prompts and response lists.
        prompts.append(examples["instruction"])
        context.append(examples["context"])
        responses.append(examples["response"])
        line_count += 1

data = {
    "prompts": prompts,
    "context": context,
    "responses": responses
}

print(data)

{'prompts': ['Analyze this SMS and classify it as transactional or non-transactional. Extract relevant financial information and return as JSON with fields: is_transaction, transaction_type, amount, account, card_number, transaction_id, date, time, bank, recipient, sender. Use null for missing fields.', 'Analyze this SMS and classify it as transactional or non-transactional. Extract relevant financial information and return as JSON with fields: is_transaction, transaction_type, amount, account, card_number, transaction_id, date, time, bank, recipient, sender. Use null for missing fields.', 'Analyze this SMS and classify it as transactional or non-transactional. Extract relevant financial information and return as JSON with fields: is_transaction, transaction_type, amount, account, card_number, transaction_id, date, time, bank, recipient, sender. Use null for missing fields.', 'Analyze this SMS and classify it as transactional or non-transactional. Extract relevant financial information

### Configure LoRA tuning

Activate LoRA tuning using the Keras `model.backbone.enable_lora()` method, including a LoRA rank value. The *LoRA rank* determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments. A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation.

This example uses a LoRA rank of 4. In practice, begin with a relatively small rank (such as 4, 8, 16). This setting is computationally efficient for experimentation. Train your model with this rank and evaluate the performance improvement on your task. Gradually increase the rank in subsequent trials and see if that further boosts performance.

In [33]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)

ValueError: lora is already enabled. This can only be done once per layer.

Check the model summary after setting the LoRA rank. Notice that enabling LoRA reduces the number of trainable parameters significantly compared to the total number of parameters in the model:

In [39]:
gemma_lm.summary()

Configure the rest of the fine-tuning settings, including the preprocessor settings, optimizer, number of tuning epochs, and batch size:

In [40]:
# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

### Run the fine-tune process

Run the fine-tuning process using the `fit()` method. This process can take several minutes depending on your compute resources, data size, and number of epochs:

In [1]:
gemma_lm.fit(data, epochs=1, batch_size=1)

NameError: name 'gemma_lm' is not defined

#### Mixed precision fine-tuning on NVIDIA GPUs

Full precision is recommended for fine-tuning. When fine-tuning on NVIDIA GPUs, you can use mixed precision (`keras.mixed_precision.set_global_policy('mixed_bfloat16')`) to speed up training with minimal effect on training quality.

In [None]:
# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

## Inference after fine-tuning

After fine-tuning, you should see changes in the responses when the tuned model is given the same prompt.

### Europe trip prompt

Try the Europe trip prompt from earlier and note the differences in the response.

In [37]:
prompt = template.format(
    instruction="Analyze this SMS and classify it as transactional or non-transactional. Extract relevant financial information and return as JSON with fields: is_transaction, transaction_type, amount, account, card_number, transaction_id, date, time, bank, recipient, sender. Use null for missing fields. input_sms: 13-Nov-23 17:16 நேரப்படி 50% டேட்டா ஒதுக்கீட்டை பயன்படுத்திவிட்டீர்கள் . ஜியோ எண்: 9500502953 திட்டத்தின் படி தினசரி டேட்டா ஒதுக்கீடு: 1.50 GB MyJio ஆப் மூலம் உங்கள் கணக்கை எவ்வாறு நிர்வகிப்பது என்பது பற்றி மேலும் அறிய, https://youtu.be/FZlrHSEv_Po என்பதைக் கிளிக் செய்க உங்கள் தற்போதைய பேலன்ஸ், வேலிடிட்டி, பிளான் விவரங்கள் மற்றும் ரீசார்ஜ் திட்டங்களுக்கு 1991 ஐ டயல் செய்யுங்கள்.",
    response="",
)
sampler = keras_hub.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
Analyze this SMS and classify it as transactional or non-transactional. Extract relevant financial information and return as JSON with fields: is_transaction, transaction_type, amount, account, card_number, transaction_id, date, time, bank, recipient, sender. Use null for missing fields. input_sms: 13-Nov-23 17:16 நேரப்படி 50% டேட்டா ஒதுக்கீட்டை பயன்படுத்திவிட்டீர்கள் . ஜியோ எண்: 9500502953 திட்டத்தின் படி தினசரி டேட்டா ஒதுக்கீடு: 1.50 GB MyJio ஆப் மூலம் உங்கள் கணக்கை எவ்வாறு நிர்வகிப்பது என்பது பற்றி மேலும் அறிய, https://youtu.be/FZlrHSEv_Po என்பதைக் கிளிக் செய்க உங்கள் தற்போதைய பேலன்ஸ், வேலிடிட்டி, பிளான் விவரங்கள் மற்றும் ரீசார்ஜ் திட்டங்களுக்கு 1991 ஐ டயல் செய்யுங்கள்.

Response:
```json
{
  "is_transaction": true,
  "transaction_type": "Reload",
  "amount": 1991,
  "account": null,
  "card_number": null,


The model now provides a shorter response to a question about visiting Europe.

### Photosynthesis prompt

Try the photosynthesis explanation prompt from earlier and note the differences in the response.

In [None]:
prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
The process of photosynthesis is a chemical reaction in plants that converts the energy of sunlight into chemical energy, which the plants can then use to grow and develop. During photosynthesis, a plant will absorb carbon dioxide (CO2) from the air and water from the soil and use the energy from the sun to produce oxygen (O2) and sugars (glucose) as a by-product.


The model now explains photosynthesis in simpler terms.

## Improving fine-tune results

For demonstration purposes, this tutorial fine-tunes the model on a small subset of the dataset for just one epoch and with a low LoRA rank value. To get better responses from the fine-tuned model, you can experiment with:

1. Increasing the size of the fine-tuning dataset
2. Training for more steps (epochs)
3. Setting a higher LoRA rank
4. Modifying the hyperparameter values such as `learning_rate` and `weight_decay`.

## Summary and next steps

This tutorial covered LoRA fine-tuning on a Gemma model using Keras. Check out the following docs next:

* Learn how to [generate text with a Gemma model](https://ai.google.dev/gemma/docs/get_started).
* Learn how to perform [distributed fine-tuning and inference on a Gemma model](https://ai.google.dev/gemma/docs/core/distributed_tuning).
* Learn how to [use Gemma open models with Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).
* Learn how to [fine-tune Gemma using Keras and deploy to Vertex AI](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma_kerasnlp_to_vertexai.ipynb).