<a href="https://colab.research.google.com/github/gmsahu/Finetuning-LLM/blob/main/fine_tuning_gemma.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-tune Gemma models in Keras using LoRA


# Overview
Gemma is a family of lightweight, state-of-the art open models built from the same research and technology used to create the Gemini models.

Large Language Models (LLMs) like Gemma have been shown to be effective at a variety of NLP tasks. An LLM is first pre-trained on a large corpus of text in a self-supervised fashion. Pre-training helps LLMs learn general-purpose knowledge, such as statistical relationships between words. An LLM can then be fine-tuned with domain-specific data to perform downstream tasks (such as sentiment analysis).

LLMs are extremely large in size (parameters in the order of billions). Full fine-tuning (which updates all the parameters in the model) is not required for most applications because typical fine-tuning datasets are relatively much smaller than the pre-training datasets.

# Low Rank Adaptation (LoRA)
{:.external} is a fine-tuning technique which greatly reduces the number of trainable parameters for downstream tasks by freezing the weights of the model and inserting a smaller number of new weights into the model. This makes training with LoRA much faster and more memory-efficient, and produces smaller model weights (a few hundred MBs), all while maintaining the quality of the model outputs.

This tutorial walks you through using KerasNLP to perform LoRA fine-tuning on a Gemma 2B model using the Databricks Dolly 15k dataset{:.external}. This dataset contains 15,000 high-quality human-generated prompt / response pairs specifically designed for fine-tuning LLMs.

# Setup
Get access to Gemma
To complete this tutorial, you will first need to complete the setup instructions at Gemma setup. The Gemma setup instructions show you how to do the following:

Get access to Gemma on kaggle.com{:.external}.
Select a Colab runtime with sufficient resources to run the Gemma 2B model.
Generate and configure a Kaggle username and API key.
After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.

# Select the runtime
To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model. In this case, you can use a T4 GPU:

# In the upper-right of the Colab window, select ▾ (Additional connection options).
Select Change runtime type.
Under Hardware accelerator, select T4 GPU.
Configure your API key
To use Gemma, you must provide your Kaggle username and a Kaggle API key.

To generate a Kaggle API key, go to the Account tab of your Kaggle user profile and select Create New Token. This will trigger the download of a kaggle.json file containing your API credentials.

In Colab, select Secrets (🔑) in the left pane and add your Kaggle username and Kaggle API key. Store your username under the name KAGGLE_USERNAME and your API key under the name KAGGLE_KEY.



In [1]:
import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.

os.environ["KAGGLE_USERNAME"] = 'gayatrisahu77'
os.environ["KAGGLE_KEY"] = userdata.get('gayatrisahu77')




#Install dependencies
Install Keras, KerasNLP, and other dependencies.



In [2]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras-nlp
!pip install -q -U keras>=3
!pip install -U keras-hub

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m691.2/691.2 kB[0m [31m12.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m33.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m615.3/615.3 MB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.5/5.5 MB[0m [31m78.2 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tf-keras 2.17.0 requires tensorflow<2.18,>=2.17, but you have tensorflow 2.18.0 which is incompatible.[0m[31m
[0m[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tf-keras 2.17.0 requires tensorflow<2.18,>=

# Select a backend
Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. Using Keras 3, you can run workflows on one of three backends: TensorFlow, JAX, or PyTorch.

For this tutorial, configure the backend for JAX.

In [3]:
os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"


# Import packages
Import Keras and KerasNLP.

In [4]:
import keras
import keras_hub
import keras_nlp

# Load dataset

In [5]:
!wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl


--2025-01-03 04:06:47--  https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
Resolving huggingface.co (huggingface.co)... 3.165.160.61, 3.165.160.11, 3.165.160.59, ...
Connecting to huggingface.co (huggingface.co)|3.165.160.61|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.hf.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1736136407&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTczNjEzNjQwN319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwO

# Preprocess the data. This tutorial uses a subset of 1000 training examples to execute the notebook faster. Consider using more training data for higher quality fine-tuning.

In [6]:
import json
data = []
with open("databricks-dolly-15k.jsonl") as file:
    for line in file:
        features = json.loads(line)
        # Filter out examples with context, to keep it simple.
        if features["context"]:
            continue
        # Format the entire example as a single string.
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# Only use 1000 training examples, to keep it fast.
data = data[:1000]


In [7]:
data[1]

'Instruction:\nWhy can camels survive for long without water?\n\nResponse:\nCamels use the fat in their humps to keep them filled with energy and hydration for long periods of time.'

# Load Model
KerasNLP provides implementations of many popular model architectures{:.external}. In this tutorial, you'll create a model using GemmaCausalLM, an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens.

Create the model using the from_preset method:

In [8]:
gemma_llm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
#gemma_llm = keras_hub.models.GemmaCausalLM.from_preset("gemma_2b_en")

gemma_llm.summary()

The from_preset method instantiates the model from a preset architecture and weights. In the code above, the string "gemma_2b_en" specifies the preset architecture — a Gemma model with 2 billion parameters.

NOTE: A Gemma model with 7 billion parameters is also available. To run the larger model in Colab, you need access to the premium GPUs available in paid plans. Alternatively, you can perform distributed tuning on a Gemma 7B model on Kaggle or Google Cloud.

# Inference before fine tuning
In this section, you will query the model with various prompts to see how it responds.

Europe Trip Prompt
### Query the model for health benifit of yoga

In [9]:
prompt = template.format(
    instruction="What is the history of India?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_llm.compile(sampler=sampler)
print(gemma_llm.generate(prompt, max_length=256))

Instruction:
What is the history of India?

Response:
Indian civilization dates back over 5000 years. The Indus River valley is considered to be the birthplace of the Indian civilization. It is believed the first city was established in the Harrappa Valley, which later developed into the Indus Valley Civilization. This civilization is considered to be the first urban civilization of the ancient world. The first city of the Indus Valley Civilization, Harrappa was a planned city.

The civilization is divided into four phases: Harrapa, the first phase was a small town with a population of 5,000 people. The second phase was a large town. The third phase was the urban phase, where the city grew into a major metropolis of 150,000 people. The fourth phase was the decline phase, where the civilization declined and the city was abandoned.

The Harrapa civilization is considered to be one of the world's first urban civilization and the first planned city of the ancient world. It is believed to h

# ELI5 Hurricane Prompt
Prompt the model to explain how Hurricane happenend in terms simple enough for a 15 year old child to understand.



In [10]:
prompt = template.format(
    instruction="Explain the process of Hurricane in a way that a 7 years old could understand.",
    response="",
)
print(gemma_llm.generate(prompt, max_length=256))


Instruction:
Explain the process of Hurricane in a way that a 7 years old could understand.

Response:
Hurricanes are the most powerful and dangerous weather phenomenon on Earth. The wind and waves that accompany hurricanes can kill people and cause billions of dollars worth of damage.

Hurricanes are a natural phenomenon that occur when warm water and wind come together to create the perfect conditions for a hurricane. When the air gets warm, it rises up. This makes the air around the hurricane get colder. As the air cools, it creates a pressure difference that causes the hurricane to spin. The hurricane is then able to suck in air from other places around it and create more wind and rain.

Hurricanes can happen anywhere in the world but they tend to happen more commonly in the Caribbean and the Atlantic Ocean. The most common hurricane season is from August to October but they can occur at any time of the year.

When a hurricane is coming, the first thing that should be done is to ev

# LoRA Fine-tuning
To get better responses from the model, fine-tune the model with Low Rank Adaptation (LoRA) using the Databricks Dolly 15k dataset.

The LoRA rank determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments.

A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation.

This tutorial uses a LoRA rank of 4. In practice, begin with a relatively small rank (such as 4, 8, 16). This is computationally efficient for experimentation. Train your model with this rank and evaluate the performance improvement on your task. Gradually increase the rank in subsequent trials and see if that further boosts performance.

In [11]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_llm.backbone.enable_lora(rank=4)
gemma_llm.summary()

# Note that enabling LoRA reduces the number of trainable parameters significantly (from 2.5 billion to 1.3 million).



In [12]:
# Limit the input sequence length to 512 (to control memory usage).
gemma_llm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_llm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_llm.fit(data, epochs=1, batch_size=1)

[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1343s[0m 1s/step - loss: 0.4589 - sparse_categorical_accuracy: 0.5231


<keras.src.callbacks.history.History at 0x7805861730d0>

# Note on mixed precision fine-tuning on NVIDIA GPUs
Full precision is recommended for fine-tuning. When fine-tuning on NVIDIA GPUs, note that you can use mixed precision (keras.mixed_precision.set_global_policy('mixed_bfloat16')) to speed up training with minimal effect on training quality. Mixed precision fine-tuning does consume more memory so is useful only on larger GPUs.

For inference, half-precision (keras.config.set_floatx("bfloat16")) will work and save memory while mixed precision is not applicable.



In [None]:
# Uncomment the line below if you want to enable mixed precision training on GPUs
#keras.mixed_precision.set_global_policy('mixed_bfloat16')


# Inference after fine-tuning
After fine-tuning, responses follow the instruction provided in the prompt.



In [13]:
#template = "Instruction:\n{instruction}\n\nResponse:\n{response}"

prompt = template.format(
    instruction="What should I do to save a river?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_llm.compile(sampler=sampler)
print(gemma_llm.generate(prompt, max_length=256))

Instruction:
What should I do to save a river?

Response:
To save a river, it is important to understand where the river starts from its headwaters and where it flows to the ocean. Once this is understood then you should work towards protecting the watershed.


In [14]:
prompt = template.format(
    instruction="Explain the process of solar ecclipse in a way that a child could understand.",
    response="",
)
print(gemma_llm.generate(prompt, max_length=512))

Instruction:
Explain the process of solar ecclipse in a way that a child could understand.

Response:
The solar eclipse occurs when the moon passes between the Earth and the sun, and blocks part of the sun's visible rays. The sun, the moon and the Earth all lie in a line in the sky, so the shadow from the moon will fall directly on one of the two points where the sun and the earth meet. This shadow is called the "umbra". The other point where the sun and the Earth meet will not be shadowed, because the shadow from the moon is not large enough to reach that far out. The part of the sky that is shadowed by the moon is called the "annulus".

The shadow of the moon moves around the Earth as it orbits the Sun, so the eclipse lasts for a few minutes at each point on Earth, and the length of time that it lasts varies depending on where you are. The shadow will be at its largest and darkest when it falls on the middle of Earth, and it will become larger as you move away from that point in the 

### Note that for demonstration purposes, this tutorial fine-tunes the model on a small subset of the dataset for just one epoch and with a low LoRA rank value. To get better responses from the fine-tuned model, you can experiment with:

Increasing the size of the fine-tuning dataset
Training for more steps (epochs)
Setting a higher LoRA rank
Modifying the hyperparameter values such as learning_rate and weight_decay.


# Summary and next steps
This tutorial covered LoRA fine-tuning on a Gemma model using KerasNLP. Check out the following docs next:

1.   Learn how to [generate text with a Gemma model](https://)
2.   Learn how to perform [distributed fine-tuning and inference on a Gemma model](https://).

3.   Learn how to use [Gemma open models with Vertex AI](https://)
4.   Learn how to [fine-tune Gemma using KerasNLP and deploy to Vertex AI](https://)