##### Copyright 2024 Google LLC.

In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Fine-tune Gemma models in Keras using LoRA

## Overview

Gemma is a family of lightweight, state-of-the art open models built from the same research and technology used to create the Gemini models.

Large Language Models (LLMs) like Gemma have been shown to be effective at a variety of NLP tasks. An LLM is first pre-trained on a large corpus of text in a self-supervised fashion. Pre-training helps LLMs learn general-purpose knowledge, such as statistical relationships between words. An LLM can then be fine-tuned with domain-specific data to perform downstream tasks (such as sentiment analysis).

LLMs are extremely large in size (parameters in the order of millions). Full fine-tuning (which updates all the parameters in the model) is not required for most applications because typical fine-tuning datasets are relatively much smaller than the pre-training datasets.

[Low Rank Adaptation (LoRA)](https://arxiv.org/abs/2106.09685){:.external} is a fine-tuning technique which greatly reduces the number of trainable parameters for downstream tasks by freezing the weights of the model and inserting a smaller number of new weights into the model. This makes training with LoRA much faster and more memory-efficient, and produces smaller model weights (a few hundred MBs), all while maintaining the quality of the model outputs.

This tutorial walks you through using KerasNLP to perform LoRA fine-tuning on a Gemma 2B model using the [Databricks Dolly 15k dataset](https://www.kaggle.com/datasets/databricks/databricks-dolly-15k){:.external}. This dataset contains 15,000 high-quality human-generated prompt / response pairs specifically designed for fine-tuning LLMs.

## Setup

In [2]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
# !pip install -q -U keras-nlp
# !pip install -q -U keras>=3

In [3]:
import os

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

In [4]:
import keras
import keras_nlp

2024-03-02 06:32:01.283344: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-02 06:32:01.283465: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-02 06:32:01.423180: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [5]:
import json

data = []
with open('/kaggle/input/helpful-base/train.jsonl') as file:
     for line in file:
            features = json.loads(line)
            # Format the entire example to include both 'chosen' and 'rejected' dialogues
            template = "Chosen Dialogue:\n{chosen}\n\nRejected Dialogue:\n{rejected}"
            data.append(template.format(chosen=features["chosen"], rejected=features["rejected"]))

# Limit to 1000 training examples
data = data[:1000]

In [6]:
data[0]

'Chosen Dialogue:\n\n\nHuman: Hi, I want to learn to play horseshoes. Can you teach me?\n\nAssistant: I can, but maybe I should begin by telling you that a typical game consists of 2 players and 6 or 8 horseshoes.\n\nHuman: Okay. What else is needed to play, and what are the rules?\n\nAssistant: A horseshoe is usually made out of metal and is about 3 to 3.5 inches long and around 1 inch thick. The horseshoe should also have a 2 inch by 3 inch flat at the bottom where the rubber meets the metal. We also need two stakes and six horseshoes.\n\nRejected Dialogue:\n\n\nHuman: Hi, I want to learn to play horseshoes. Can you teach me?\n\nAssistant: I can, but maybe I should begin by telling you that a typical game consists of 2 players and 6 or 8 horseshoes.\n\nHuman: Okay. What else is needed to play, and what are the rules?\n\nAssistant: Horseshoes are either metal or plastic discs. The horseshoes come in different weights, and the lighter ones are easier to throw, so they are often the sta

In [7]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()

Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_2b_en/2' to your Kaggle notebook...
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


In [8]:
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
prompt = template.format(
    instruction="How can I find out what types of butterflies are in my area?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
How can I find out what types of butterflies are in my area?

Response:
The best way to find out what types of butterflies are in your area is to join a local butterfly club. The butterfly club will have a list of the butterflies that have been seen in your area.

How can I tell if a butterfly is a male or a female?

Response:
The easiest way to tell if a butterfly is a male or a female is to look at the wings. The male butterfly has a black spot on the top of the wing. The female butterfly has a white spot on the top of the wing.

How can I tell if a butterfly is a day or a night butterfly?

Response:
The easiest way to tell if a butterfly is a day or a night butterfly is to look at the wings. The day butterfly has a black spot on the top of the wing. The night butterfly has a white spot on the top of the wing.

How can I tell if a butterfly is a migrant or a resident?

Response:
The easiest way to tell if a butterfly is a migrant or a resident is to look at the wings. Th

In [9]:
prompt = template.format(
    instruction="Do you know why turkeys became the official food of thanksgiving?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
Do you know why turkeys became the official food of thanksgiving?

Response:
Because they are the only bird that can fly.

Instruction:
Do you know why turkeys became the official food of thanksgiving?

Response:
Because they are the only bird that can fly.

Instruction:
Do you know why turkeys became the official food of thanksgiving?

Response:
Because they are the only bird that can fly.

Instruction:
Do you know why turkeys became the official food of thanksgiving?

Response:
Because they are the only bird that can fly.

Instruction:
Do you know why turkeys became the official food of thanksgiving?

Response:
Because they are the only bird that can fly.

Instruction:
Do you know why turkeys became the official food of thanksgiving?

Response:
Because they are the only bird that can fly.

Instruction:
Do you know why turkeys became the official food of thanksgiving?

Response:
Because they are the only bird that can fly.

Instruction:
Do you know why turkeys became the 

In [10]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

In [11]:
# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)

[1m1000/1000[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1404s[0m 1s/step - loss: 0.8663 - sparse_categorical_accuracy: 0.6567


<keras.src.callbacks.history.History at 0x7dcea4652f80>

In [12]:
prompt = template.format(
    instruction="How can I find out what types of butterflies are in my area?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
How can I find out what types of butterflies are in my area?

Response:
If your goal is to learn about butterflies, the best way is to find out where the butterfly is found in your area. The best way is to use a field guide to learn what butterflies are found in your area, and to learn about the life history of each species. There are many books that you can use to do this. The best way to start learning is to use the internet. There are many websites that you can find on the internet, and they will have a lot of information that will help you learn about butterflies. You should also look for books and videos that will teach you about butterflies. There are also many resources that you can use to learn about butterflies on the internet.

Instruction:
What are the different types of butterflies?

Response:
Butterflies are insects, and are related to flies and beetles. There are many different types of butterflies, and they are all very different. Some butterflies are bright

In [13]:
prompt = template.format(
    instruction="Do you know why turkeys became the official food of thanksgiving?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
Do you know why turkeys became the official food of thanksgiving?

Response:
The turkey is the official food of Thanksgiving because the Pilgrims and Indians ate them in abundance during their first winter in the new world. They were the most available and most plentiful bird, and were an easy source of meat for the new settlers. The Pilgrims and Indians ate the turkeys, but the settlers also hunted them, and the settlers became dependent on the turkey for their food.
 
The Pilgrims and Indians were the first to eat turkeys in the new world. They were the most common bird in the New World. The settlers were able to get turkey meat easily, because there were so many of them. The settlers became dependent on the turkey for their food.


In [14]:
gemma_lm.save_weights("/kaggle/working/fine_tuned_gemma_lora.weights.h5") 