<a href="https://colab.research.google.com/github/cyber-prags/Fine-Tuning-Gemma-model-using-LoRA/blob/main/Fine_Tuning_Gemma_models_in_Keras_using_LoRa.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Fine Tuning Gemma Models in Keras using LoRa**

## **Project Description**


In this project, we fine-tune a Google Gemma model in Keras using the Low-Rank Adaptation technique.

Basic definations of concepts and techniques used in the notebook:

- **Gemma**: Gemma is a lightweight, open-weight language models by Google designed for efficient fine-tuning and inference.

- **LoRA (Low-Rank Adaptation)**	: Efficient fine-tuning technique that injects low-rank trainable matrices into transformer layers to reduce memory and computation costs.

- **Top-K Sampling** : A decoding technique used in text generation where the model samples from the top-K probable tokens.

## **Setup**

Let us quickly set up our environment and our API keys to get started.

To execute this project we will need a **Kaggle API key** where we will be using our Gemma model from.

In [None]:
import os
from google.colab import userdata

os.environ['KAGGLE_USERNAME'] = userdata.get('KAGGLE_USERNAME')
os.environ['KAGGLE_KEY'] = userdata.get('KAGGLE_KEY')

Install the necessary Keras libraries.

We make use of the Keras nlp library which offers tokenizers, models and samplers and designed specifically for Natural Language Processing tasks.

In [None]:
# !pip install -q -U keras-nlp
# !pip install -q -U keras>=3

## **Selecting a backend**

We select the **jax** framework as our backend

In [None]:
os.environ['KERAS_BACKEND'] = "jax"

# to avoid memory fragmentation on JAX backend
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = "1.00"

## **Import our packages**

In [None]:
import keras
import keras_nlp

## **Load our Dataset**

We make use of the ***databricks-dolly-15k*** dataset which is an open-source collection of 15,000 high-quality, human-generated prompt-response pairs designed specifically for instruction tuning large language models (LLMs). Developed by over 5,000 Databricks employees during March and April 2023, this dataset aims to enhance the instruction-following capabilities of LLMs, enabling more interactive and accurate responses.




In [None]:
!wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl

--2025-03-22 18:11:18--  https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
Resolving huggingface.co (huggingface.co)... 13.35.202.97, 13.35.202.34, 13.35.202.40, ...
Connecting to huggingface.co (huggingface.co)|13.35.202.97|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.hf.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1742670678&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0MjY3MDY3OH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwO

In [None]:
import json
data = []
with open('databricks-dolly-15k.jsonl') as file:
  for line in file:
    features = json.loads(line)
    #Filter out examples with context, to keep it simple
    if features["context"]:
      continue
    # Format the entire example as a single string
    template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
    data.append(template.format(**features))

# Only use 1000 training examples, to keep it fast
data = data[:1000]

In [None]:
data[0]

'Instruction:\nWhich is a species of fish? Tope or Rope\n\nResponse:\nTope'

## **Load model**

In [None]:
gemma_lm=keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()

Downloading from https://www.kaggle.com/api/v1/models/keras/gemma/keras/gemma_2b_en/2/download/config.json...


100%|██████████| 555/555 [00:00<00:00, 1.16MB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/gemma/keras/gemma_2b_en/2/download/model.weights.h5...


100%|██████████| 4.67G/4.67G [05:10<00:00, 16.2MB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/gemma/keras/gemma_2b_en/2/download/tokenizer.json...


100%|██████████| 401/401 [00:00<00:00, 730kB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/gemma/keras/gemma_2b_en/2/download/assets/tokenizer/vocabulary.spm...


100%|██████████| 4.04M/4.04M [00:01<00:00, 2.27MB/s]


Our Gemma model has a massive **2.5B** trainable parameters.

## **Inference based Fine Tuning**

In our Inference based fine tuning, we consider two examples or two questions that we ask the base model and evaluate its response.

### Question about Travel

In [None]:
prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
What should I do on a trip to Europe?

Response:
The first thing to remember is to be flexible, because the weather in Europe is always changing. If it is sunny one day and raining the next day, don’t be surprised.

The next thing to remember is to bring a jacket. Even if it’s warm, it can get chilly in the mountains or at night.

The third thing is to pack light. You don’t want to lug around a heavy bag all day. A small backpack is perfect for this.

The fourth thing to remember is to bring a good map. You never know when you might get lost.

The fifth thing is to bring sunscreen. The sun can be very strong in Europe, so you’ll want to protect your skin.

The sixth thing to remember is to bring a good camera. You’ll want to capture all the beautiful sights that Europe has to offer.

The seventh thing to remember is to be polite. In Europe, it is important to show respect for other people and their culture.

The eighth thing to remember is to be patient. The lines for trai

### Question about Photosynthesis

In [None]:
prompt = template.format(
    instruction="Explain photosynthesis in a way a child could understand.",
    response="",
)

print(gemma_lm.generate(prompt, max_length=256))

Instruction:
Explain photosynthesis in a way a child could understand.

Response:
Plants need energy from the sun to make food. Plants can make food in two different ways, but they all start with the same process.
Plants use energy from the sun to make food. They take in carbon dioxide from the air. They take in water. They take in sunlight. They take in nutrients. They use energy from the sun. They break down the carbon dioxide. They break down the water. They use energy from the sun. They break down the nutrients. They use the energy they have made from the sunlight for energy. They use the energy they make to make food and oxygen. Plants take in the carbon dioxide, water, and sunlight. Plants make oxygen as a by-product. Plants need to have water, carbon dioxide, and sunlight.

The process of photosynthesis is the process by which green plants use light energy from the sun, carbon dioxide, and water to create glucose and oxygen. The reaction of the process is:
6CO2+6H₂O+light energy

#### **Initial Thoughts**

The base Gemma model gives us good responses but it can be further fine-tuned to give us better quality answers. Let us do that in our next section.

## **LoRA Fine-Tuning**

We enable the LoRa model and set our rank. The rank determines the rank of decomposition.

In [None]:
# Enable LoRA for the model and set the rank to 4
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

Enabling LoRA reduces the number of trainable parameters significantly(from 2.5Billion to 1.3million)

In [None]:
# Limit the input sequence length to 512
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW( a common optimizer for transformer model)
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01)

optimizer.exclude_from_weight_decay(var_names=["bias","scale"])

# Compile the model
gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics = [keras.metrics.SparseCategoricalAccuracy()]
  )

Due to computing limitations, we fine-tune our Gemma Model on a batch size of just one and just a single epoch.

In [None]:
gemma_lm.fit(data, epochs=1, batch_size=1)

[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m302s[0m 278ms/step - loss: 0.4588 - sparse_categorical_accuracy: 0.5231


<keras.src.callbacks.history.History at 0x787444714c50>

## **Post Fine-Tuning Inference**

## Question on Travel

In [None]:
prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
What should I do on a trip to Europe?

Response:
Europe is the best travel destination. You can see so much in such a short period of time. I would recommend going to Italy and France. Italy has amazing art, food, and architecture. You can see the Colosseum in Rome, the Sistine Chapel in Vatican City, the Leaning Tower of Pisa, and the Uffizi Gallery in Florence. You can see the Eiffel Tower, the Notre Dame Cathedral, and the Louvre in Paris.


We notice that the quality of answer is significantly better than the initial response.