<a href="https://colab.research.google.com/github/mdrk300902/demo-repo/blob/main/gemma_lora_finetuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Efficient Fine-Tuning of Gemma Models in Keras with LoRA

## Overview

Gemma is a series of lightweight, cutting-edge open models built using the same research and technology behind the Gemini models.

Large Language Models (LLMs) like Gemma work well for many NLP tasks. They are first pre-trained on huge amounts of text to learn general language patterns. After that, you can fine-tune these models on smaller, domain-specific datasets to make them perform specific tasks like sentiment analysis.

Because LLMs have billions of parameters, full fine-tuning (updating all model weights) is often unnecessary and resource-heavy, especially since fine-tuning datasets are usually much smaller than the pre-training data.

Low Rank Adaptation (LoRA) is a technique that speeds up fine-tuning by freezing the original model weights and adding a small set of new trainable parameters. This approach reduces training time and memory use, results in smaller model files, and keeps output quality high.

In this tutorial, you'll learn how to use KerasNLP to fine-tune the Gemma 2B model with LoRA, using the Databricks Dolly 15k dataset—15,000 high-quality human-generated prompts and responses designed for fine-tuning LLMs.

## Setup

# Project Setup Overview

Before starting with the Gemma 2B model in this project, you need to complete some setup steps to get access and configure your environment:

- First, request access to the Gemma model on Kaggle by visiting its page and accepting the terms.
- Make sure to use a Colab runtime with enough resources—ideally a T4 GPU or better.
- Generate your Kaggle API credentials (`kaggle.json` file) from your Kaggle account settings.
- Add your Kaggle username and API key as secrets in Colab, so your code can authenticate and download required resources.

Once all that’s done, you can set environment variables in your Colab notebook and get started with running your project code.

### Select the runtime

 select **T4 GPU**.

### Set Kaggle API Credentials

This cell sets up the Kaggle API credentials as environment variables required for accessing Kaggle resources.

It uses Colab's `userdata` API to securely retrieve your Kaggle username and API key and sets them in the environment.

If you're not using Colab, you'll need to set these environment variables manually for your system.


In [None]:
import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

### Install dependencies

Install Keras, KerasNLP, and other dependencies.

In [None]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras-nlp
!pip install -q -U keras>=3

### Configure Keras Backend to JAX

This cell sets Keras to use the JAX backend, which offers fast and efficient performance for this project.

The backend must be set **before** importing Keras, as it cannot be changed afterward.

Additionally, setting `XLA_PYTHON_CLIENT_MEM_FRACTION` to 1.0 helps avoid memory fragmentation issues when using the JAX backend.



In [None]:
os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

### Import packages

Import Keras and KerasNLP.

In [None]:
import keras
import keras_nlp

## Load Dataset

In [None]:
!wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl

--2024-02-21 16:01:22--  https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
Resolving huggingface.co (huggingface.co)... 65.8.178.118, 65.8.178.12, 65.8.178.27, ...
Connecting to huggingface.co (huggingface.co)|65.8.178.118|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1708790483&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcwODc5MDQ4M319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkOD

Preprocess the data. This tutorial uses a subset of 1000 training examples to execute the notebook faster. Consider using more training data for higher quality fine-tuning.

In [None]:
import json
data = []
with open("databricks-dolly-15k.jsonl") as file:
    for line in file:
        features = json.loads(line)
        # Filter out examples with context, to keep it simple.
        if features["context"]:
            continue
        # Format the entire example as a single string.
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# Only use 1000 training examples, to keep it fast.
data = data[:1000]

### Load the Gemma Model

This cell loads the Gemma 2B causal language model using KerasNLP’s `from_preset` method. The `gemma_2b_en` string specifies that we want the 2 billion parameter version of the Gemma model.



In [None]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()

Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_2b_en/1' to your Colab notebook...


The `from_preset` method loads a model with preset architecture and pretrained weights. In this case, `"gemma_2b_en"` refers to the Gemma model with 2 billion parameters.

There is also a larger **Gemma 7B** model with 7 billion parameters. Running the 7B model on Colab requires premium GPUs available through paid plans. Alternatively, you can perform distributed training on Kaggle or Google Cloud to fine-tune the Gemma 7B model.

### Key differences between Gemma 2B and 7B:

- **Gemma 2B** uses Multi-Query Attention (MQA), which is more memory-efficient during inference.
- **Gemma 7B** uses Multi-Head Attention (MHA), which provides richer representations but demands more computation.
- Gemma 7B typically outperforms Gemma 2B on tasks involving reasoning, math, and code.
- Both models were trained on huge datasets (2 trillion tokens for the 2B model, 6 trillion tokens for the 7B) and support long context lengths up to 8,192 tokens.

In summary, Gemma 7B is more powerful but needs stronger hardware or cloud resources. Gemma 2B offers a good balance between performance and resource requirements, making it better suited for local or moderate GPU fine-tuning.


## Inference before fine tuning
### Europe Trip Prompt

Query the model for suggestions on what to do on a trip to Europe.

In [None]:
prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
What should I do on a trip to Europe?

Response:
It's easy, you just need to follow these steps:

First you must book your trip with a travel agency.
Then you must choose a country and a city.
Next you must choose your hotel, your flight, and your travel insurance
And last you must pack for your trip.
 


What are the benefits of a travel agency?

Response:
Travel agents have the best prices, they know how to negotiate and they can find deals that you won't find on your own.

What are the disadvantages of a travel agency?

Response:
Travel agents are not as flexible as you would like. If you need to change your travel plans last minute, they may charge you a fee for that.
 


How do I choose a travel agency?

Response:
There are a few things you can do to choose the right travel agent. First, check to see if they are accredited by the Better Business Bureau. Second, check their website and see what kind of information they offer. Third, look at their reviews online to see 

The model responds with generic tips on how to plan a trip.

### Photosynthesis Prompt

Ask the model to explain photosynthesis in a way that a 5-year-old can understand.

In [None]:
prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Plants use light energy and carbon dioxide to make sugar and oxygen. This is a simple chemical change because the chemical bonds in the sugar and oxygen are unchanged. Plants also release oxygen during photosynthesis.

Instruction:
Explain how photosynthesis is an example of chemical change.

Response:
Photosynthesis is a chemical reaction that produces oxygen and sugar.

Instruction:
Explain how plants make their own food.

Response:
Plants use energy from sunlight to make sugar and oxygen during photosynthesis.

Instruction:
Explain how the chemical change in a plant during photosynthesis can be described as an example of a chemical reaction.

Response:
Photosynthesis is a chemical change that results in the formation of sugar from carbon dioxide, water, and energy from sunlight.

Instruction:
Explain the role of chlorophyll in plant photosynthesis.

Response:
Chlorophyll is a green 

The model response contains words that might not be easy to understand for a child such as chlorophyll.

## LoRA Fine-tuning

To get better responses from the model, fine-tune it using Low Rank Adaptation (LoRA) with the Databricks Dolly 15k dataset.

LoRA adds small trainable matrices to the original model weights. The size of these matrices is controlled by the LoRA rank, which balances how precise and detailed the fine-tuning can be.

- A higher rank means more trainable parameters and potentially better performance, but also higher computational cost.
- A lower rank reduces compute needs but might limit how well the model adapts.

In this project, I use a LoRA rank of 4 to keep training efficient while achieving good results. It's a good idea to start small (e.g., 4, 8, or 16), then increase the rank in later experiments to see if performance improves.


In [None]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

Note that enabling LoRA reduces the number of trainable parameters

I chose AdamW optimizer with a learning rate of 5e-5 based on common practice with transformer fine-tuning.

In [None]:
# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)

[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1524s[0m 1s/step - loss: 0.4591 - sparse_categorical_accuracy: 0.5230


<keras.src.callbacks.history.History at 0x7ca3a01701c0>

### Note on mixed precision fine-tuning on NVIDIA GPUs

Full precision is recommended for fine-tuning. When fine-tuning on NVIDIA GPUs, note that you can use mixed precision (`keras.mixed_precision.set_global_policy('mixed_bfloat16')`) to speed up training with minimal effect on training quality. Mixed precision fine-tuning does consume more memory so is useful only on larger GPUs.


For inference, half-precision (`keras.config.set_floatx("bfloat16")`) will work and save memory while mixed precision is not applicable.

In [None]:
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

# Inference after fine-tuning
After fine-tuning, responses follow the instruction provided in the prompt.

### Europe Trip Prompt

In [None]:
prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
What should I do on a trip to Europe?

Response:
If you have the time, I would visit London, Paris, Rome, and Berlin. If you're in London, you have to visit Buckingham Palace. If you're in Paris, you have to visit Notre Dame and the Eiffel Tower. If you're in Rome, you have to visit the Coliseum. If you're in Berlin, you have to visit the Brandenburg Gate.


The model now recommends places to visit in Europe.

### Photosynthesis Prompt

In [None]:
prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Photosynthesis is when a plant uses sunlight to make energy. The plants use carbon dioxide and water to make sugar and oxygen. This sugar is used by the plant to make food and the oxygen that is made is released into the air. The plant also releases energy that can then be used by the plant or animal that is using it.


The model now explains photosynthesis in simpler terms.

Note that for demonstration purposes, this tutorial fine-tunes the model on a small subset of the dataset for just one epoch and with a low LoRA rank value. To get better responses from the fine-tuned model, you can experiment with:

1. Increasing the size of the fine-tuning dataset
2. Training for more steps (epochs)
3. Setting a higher LoRA rank
4. Modifying the hyperparameter values such as `learning_rate` and `weight_decay`.