##### Copyright 2024 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [2]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras-nlp
!pip install -q -U keras>=3

### Select a backend

Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. Using Keras 3, you can run workflows on one of three backends: TensorFlow, JAX, or PyTorch.

For this tutorial, configure the backend for JAX.

In [5]:
#import libraries
import os
import json
import keras
import keras_nlp

#set some parameters
os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

## Load Dataset

In [6]:
data = []
with open('/kaggle/input/databricks-dolly-15k/databricks-dolly-15k.jsonl') as file:
    for line in file:
        features = json.loads(line)
        if features["context"]:
            continue
        #format to see instruction and response at the same time
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# let's use only first 1000 of them to be able to run it fast
data = data[:1000]

In [8]:
data[:5]

['Instruction:\nWhich is a species of fish? Tope or Rope\n\nResponse:\nTope',
 'Instruction:\nWhy can camels survive for long without water?\n\nResponse:\nCamels use the fat in their humps to keep them filled with energy and hydration for long periods of time.',
 "Instruction:\nAlice's parents have three daughters: Amy, Jessy, and what’s the name of the third daughter?\n\nResponse:\nThe name of the third daughter is Alice",
 'Instruction:\nWho gave the UN the land in NY to build their HQ\n\nResponse:\nJohn D Rockerfeller',
 'Instruction:\nWhy mobile is bad for human\n\nResponse:\nWe are always engaged one phone which is not good.']

## Load Model

In this code, we'll use a model using `GemmaCausalLM`, an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens.

Create the model using the `from_preset` method:

In [9]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.summary()

## Results before fine tuning


In [10]:
prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256)) #very basic response like "take a trip to europe"

Instruction:
What should I do on a trip to Europe?

Response:
1. Take a trip to Europe.
2. Take a trip to Europe.
3. Take a trip to Europe.
4. Take a trip to Europe.
5. Take a trip to Europe.
6. Take a trip to Europe.
7. Take a trip to Europe.
8. Take a trip to Europe.
9. Take a trip to Europe.
10. Take a trip to Europe.
11. Take a trip to Europe.
12. Take a trip to Europe.
13. Take a trip to Europe.
14. Take a trip to Europe.
15. Take a trip to Europe.
16. Take a trip to Europe.
17. Take a trip to Europe.
18. Take a trip to Europe.
19. Take a trip to Europe.
20. Take a trip to Europe.
21. Take a trip to Europe.
22. Take a trip to Europe.
23. Take a trip to Europe.
24. Take a trip to Europe.
25. Take a trip to


In [12]:
prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256)) #it would be hard for a child to undersdant what "carbon dioxide", "glucose" means

Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Photosynthesis is the process by which plants use the energy from the sun to convert water and carbon dioxide into oxygen and glucose. The process begins with the absorption of light energy by chlorophyll molecules in the leaves of plants. The energy from the light is used to split water molecules into hydrogen and oxygen. The oxygen is released into the atmosphere, while the hydrogen is used to make glucose. The glucose is then used by the plant to make energy and grow.

Explanation:
Photosynthesis is the process by which plants use the energy from the sun to convert water and carbon dioxide into oxygen and glucose. The process begins with the absorption of light energy by chlorophyll molecules in the leaves of plants. The energy from the light is used to split water molecules into hydrogen and oxygen. The oxygen is released into the atmosphere, while the hydrogen is used to make gluc

In [13]:
prompt = template.format(
    instruction="What is the main difference between normal football and AMerican football?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))

Instruction:
What is the main difference between normal football and AMerican football?

Response:
The main difference between normal football and American football is that in normal football, the ball is kicked, but in American football, the ball is thrown.

Instruction:
What is the main difference between normal football and American football?

Response:
The main difference between normal football and American football is that in normal football, the ball is kicked, but in American football, the ball is thrown.

Instruction:
What is the main difference between normal football and American football?

Response:
The main difference between normal football and American football is that in normal football, the ball is kicked, but in American football, the ball is thrown.

Instruction:
What is the main difference between normal football and American football?

Response:
The main difference between normal football and American football is that in normal football, the ball is kicked, but in 

## LoRA Fine-tuning

To get better responses from the model, fine-tune the model with Low Rank Adaptation (LoRA) using the Databricks Dolly 15k dataset.

The LoRA rank determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments.

A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation.

This tutorial uses a LoRA rank of 4. In practice, begin with a relatively small rank (such as 4, 8, 16). This is computationally efficient for experimentation. Train your model with this rank and evaluate the performance improvement on your task. Gradually increase the rank in subsequent trials and see if that further boosts performance.

In [14]:
# let's enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

Note that enabling LoRA reduces the number of trainable parameters significantly (from 2.5 billion to 1.3 million).

In [None]:
#limit sequence to the 512
gemma_lm.preprocessor.sequence_length = 512

# use AdamW optimizer
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)

#exclude bias from decay
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], #there are lots of categoires, so let's use "Sparse Categorical Accuracy"
)
gemma_lm.fit(data, epochs=1, batch_size=1) #1 epoch is enough for now to be able to run it fast, more epochs would help more

## Results after fine-tuning

In [None]:
prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256)) #now it explains places to trip in europe

In [None]:
prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256)) #now it explain in simpler terms