# Fine-tune Gemma models using LORA for Cake Boss Example
Adding additional changes based on feedback

## Setup

In [None]:
import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

In [None]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras-nlp
!pip install -q -U "keras>=3"

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m572.2/572.2 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m15.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0m
[?25h

### Select a backend

Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. Using Keras 3, you can run workflows on one of three backends: TensorFlow, JAX, or PyTorch.

For this tutorial, configure the backend for JAX.

In [None]:
os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

### Import packages

Import Keras and KerasNLP.

In [None]:
import keras
import keras_nlp

## Load Dataset

## Load Model

KerasNLP provides implementations of many popular [model architectures](https://keras.io/api/keras_nlp/models/). In this tutorial, you'll create a model using `GemmaCausalLM`, an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens.

Create the model using the `from_preset` method:

In [None]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_instruct_2b_en")
gemma_lm.summary()

### Cake prompt
This is from the untuned model. The results aren't exactly what we'd like


In [None]:
template = "{instruction}\n{response}"

prompt = template.format(
    instruction="""From the following get the type of inquiry, (order or request for information), filling, flavor, size, and pickup location and put it into a json
Hi,
I'd like to order a red velvet cake with custard filling. Please make it 8 inch round""",
    response="",
)
# sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
# For our use case greedy is best
# gemma_lm.compile(sampler=sampler)
gemma_lm.compile(sampler="greedy")

print(gemma_lm.generate(prompt, max_length=256))

From the following get the type of inquiry, (order or request for information), filling, flavor, size, and pickup location and put it into a json
Hi,
I'd like to order a red velvet cake with custard filling. Please make it 8 inch round
and pick it up from the bakery on 22nd street.

Thanks!
 
```json
{
  "inquiry_type": "order",
  "filling": "custard",
  "flavor": "red velvet",
  "size": "8 inch round",
  "pickup_location": "22nd street bakery"
}
```
```json
{
  "inquiry_type": "request",
  "filling": "custard",
  "flavor": "red velvet",
  "size": "8 inch round",
  "pickup_location": "22nd street bakery"
}
```
```json
{
  "inquiry_type": "order",
  "filling": "custard",
  "flavor": "red velvet",
  "size": "8 inch round",
  "pickup_location": "22nd street bakery"


In [None]:
import json
prompt_1 = dict(prompt = """
Hi Indian Bakery Central,
Do you happen to have 10 pendas, and thirty bundi ladoos on hand? Also do you sell a vanilla frosting and chocolate flavor cakes. I'm looking for a 6 inch size
""",
response = json.loads("""
  {
    "type": "inquiry",
    "items": [
      {
        "name": "pendas",
        "quantity": 10
      },
      {
        "name": "bundi ladoos",
        "quantity": 30
      },
      {
        "name": "cake",
        "filling": null,
        "frosting": "vanilla",
        "flavor": "chocolate",
        "size": "6 in"
      }
    ]
}
""")
)


In [None]:
{
    "training_prompt": """
Hi Indian Bakery Central,
Do you happen to have 10 pendas, and thirty bundi ladoos on hand? Also do you sell a vanilla frosting and chocolate flavor cakes. I'm looking for a 6 inch size
"""
  "response":"""
    [
      {
        "name": "pendas",
        "quantity": 10
      },
      {
        "name": "bundi ladoos",
        "quantity": 30
      },
      {
        "name": "cake",
        "filling": null,
        "frosting": "vanilla",
        "flavor": "chocolate",
        "size": "6 in"
      }
    ]
}
"""
}

In [None]:
prompt_2 = dict(prompt = """
I saw your business on google maps. Do you sell jellabi and gulab jamun?
""",
response = json.loads("""
  {
    "type": "inquiry",
    "items": [
      {
        "name": "jellabi",
        "quantity": null
      },
      {
        "name": "gulab jamun",
        "quantity": null
      }
    ]
}
""")
)

In [None]:
prompt_3 = dict(prompt = """
I'd like to place an order for a 8 inch red velvet cake with lemon frosting and chocolate chips topping.
""",
response = json.loads("""
  {
    "type": "order",
    "items": [
      {
        "name": "cake",
        "filling": "8inch",
        "frosting": "lemon",
        "flavor": "chocolate",
        "size": "8 in"
      }
    ]
}
""")
)

In [None]:
prompt_4 = dict(prompt = """
I'd like four jellabi and three gulab Jamun.
""",
response = json.loads("""
  {
    "type": "order",
    "items": [
      {
        "name": "Jellabi",
        "quantity": 4
      },
      {
        "name": "Gulab Jamun",
        "quantity": 3
      }
    ]
}
""")
)
prompt_4

{'prompt': "\nI'd like four jellabi and three gulab Jamun.\n",
 'response': {'type': 'order',
  'items': [{'name': 'Jellabi', 'quantity': 4},
   {'name': 'Gulab Jamun', 'quantity': 3}]}}

In [None]:
prompt_4_2 = dict(prompt = """
Please pack me a box with 10 halva.
""",
response = json.loads("""
  {
    "type": "order",
    "items": [
      {
        "name": "halva",
        "quantity": 10
      }
    ]
}
""")
)

In [None]:
prompt_5 = dict(prompt = """
Do you sell strawberry cakes with vanilla frosting with custard inside?
""",
response = json.loads("""
  {
    "type": "inquiry",
    "items": [
      {
        "name": "cake",
        "filling": "custard",
        "frosting": "vanilla",
        "flavor": "strawberry",
        "size": "null"
      }
    ]
}
""")
)


In [None]:
prompt_5_2 = dict(prompt = """
Do you sell carrot cakes with cream cheese frosting?
""",
response = json.loads("""
  {
    "type": "inquiry",
    "items": [
      {
        "name": "cake",
        "filling": "null",
        "frosting": "cream cheese",
        "flavor": "carrot",
        "size": "null"
      }
    ]
}
""")
)
prompt_5

{'prompt': '\nDo you sell strawberry cakes with vanilla frosting with custard inside?\n',
 'response': {'type': 'inquiry',
  'items': [{'name': 'cake',
    'filling': 'custard',
    'frosting': 'vanilla',
    'flavor': 'strawberry',
    'size': 'null'}]}}

In [None]:
prompt_6 = dict(prompt = """
I found your website. What kind of items do you sell?
""",
response = json.loads("""
  {
    "type": "inquiry",
    "items": [
    ]
}
""")
)


In [None]:
# Starts overfitting on lemon if you add this

# prompt_7 = dict(prompt = """
# Can I buy 18 halva, as well as a lemon cake with lemon frosting?
# """,
# response = json.loads("""
#   {
#     "type": "inquiry",
#     "items": [
#       {
#         "name": "halva",
#         "quantity": 18
#       },
#       {
#         "filling": null,
#         "frosting": "lemon",
#         "flavor": "lemon",
#         "size": null
#       }
#     ]
# }
# """)
# )

In [None]:
data = []

for prompt in [prompt_1, prompt_2, prompt_3, prompt_4, prompt_4_2, prompt_5, prompt_5_2, prompt_6]:
  data.append(template.format(instruction=prompt["prompt"],response=prompt["response"]))

## LoRA Fine-tuning

The LoRA rank determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments.

A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation.

This tutorial uses a LoRA rank of 4. In practice, begin with a relatively small rank (such as 4, 8, 16). This is computationally efficient for experimentation. Train your model with this rank and evaluate the performance improvement on your task. Gradually increase the rank in subsequent trials and see if that further boosts performance.

Be careful for over or underfit
* Rank
* Learning Rate
*

In [None]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

Note that enabling LoRA reduces the number of trainable parameters significantly (from 2.6 billion to 2.9 million).

In [None]:
# for weight_decay in [.009, .0001, ]:
  # Generate Examples

# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=9e-4,
    weight_decay=0.004,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=3, batch_size=1)

Epoch 1/3
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m83s[0m 6s/step - loss: 0.7486 - sparse_categorical_accuracy: 0.6278
Epoch 2/3
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 2s/step - loss: 0.5113 - sparse_categorical_accuracy: 0.6984
Epoch 3/3
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 770ms/step - loss: 0.3469 - sparse_categorical_accuracy: 0.7796


<keras.src.callbacks.history.History at 0x7c8dfc58b1c0>

## Inference after fine-tuning
After fine-tuning, responses follow the instruction provided in the prompt.

### Order Prompt

In [None]:
prompt = template.format(
    instruction="""Hi, I'd like to order an 8 inch red velvet cake with custard filling""",
    response="",
)

print(gemma_lm.generate(prompt, max_length=256))

Hi, I'd like to order an 8 inch red velvet cake with custard filling
{'type': 'order', 'items': [{'name': 'cake', 'filling': 'custard', 'size': '8 inch', 'flavor': 'red velvet'}]}


In [None]:
# Misspelling
prompt = template.format(
    instruction="""Hi Indian Bakery Central,
                   I'd like to order one lemon cake that has vanilla filing, 2 gulab jamun and 1 penda""",
    response="",
)

print(gemma_lm.generate(prompt, max_length=256))

Hi Indian Bakery Central,
                   I'd like to order one lemon cake that has vanilla filing, 2 gulab jamun and 1 penda
{'type': 'order', 'items': [{'name': 'lemon cake', 'filling': 'vanilla', 'quantity': '1'}, {'name': 'gulab jamun', 'quantity': '2'}, {'name': 'penda', 'quantity': '1'}]}


In [None]:
# Failure case
prompt = template.format(
    instruction="""Hello, do you have 20 pendas and 10 ladoos? Also Can you make a chocolate cake with raspberry filling?""",
    response="",
)

print(gemma_lm.generate(prompt, max_length=256))

Hello, do you have 20 pendas and 10 ladoos? Also Can you make a chocolate cake with raspberry filling?
{'type': 'inquiry', 'items': [{'name': 'pendas', 'quantity': 20}, {'name': 'ladoos', 'quantity': 10}], 'order': {'cake': 'chocolate', 'filling': 'raspberry', 'size': 'null'}}


## Summary and next steps

This tutorial covered LoRA fine-tuning on a Gemma model using KerasNLP. Check out the following docs next:

* Learn how to [generate text with a Gemma model](https://ai.google.dev/gemma/docs/get_started).
* Learn how to perform [distributed fine-tuning and inference on a Gemma model](https://ai.google.dev/gemma/docs/distributed_tuning).
* Learn how to [use Gemma open models with Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/open-models/use-gemma).
* Learn how to [fine-tune Gemma using KerasNLP and deploy to Vertex AI](https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_gemma_kerasnlp_to_vertexai.ipynb).