<a href="https://colab.research.google.com/github/kmcroyal/MedMate/blob/main/Kimmy_Kaggle_Keras_Gemma_I_O.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##### Copyright 2024 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## Introduction

This tutorial demonstrates how to fine-tune Gemma on a Huggingface dataset and share the model with the community. I'll be using a [Medicare and Medicaid Q&A Dataset](https://huggingface.co/datasets/kimmysue1102/databrickstest/resolve/main/databrickstest.jsonl) from Huggingface and fine-tuning Gemma to answer questions about complex conditions.

**Please note that this tutorial is purely for educational purposes and should not be used for medicare and medicaid consultation.**

## Setup

### Get access to Gemma

To complete this tutorial, you will first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:

* Get access to Gemma on [kaggle.com](https://kaggle.com).
* Select a Colab runtime with sufficient resources to run
  the Gemma 2B model.
* Generate and configure a Kaggle username and API key.

After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.

### Select the runtime

To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model. In this case, you can use a T4 GPU or an A100 GPU (recommended, if available):

1. In the upper-right of the Colab window, select &#9662; (**Additional connection options**).
2. Select **Change runtime type**.
3. Under **Hardware accelerator**, select **T4 GPU** or **A100 GPU**.

### Configure your API key

To use Gemma, you must provide your Kaggle username and a Kaggle API key.

To generate a Kaggle API key, go to the **Account** tab of your Kaggle user profile and select **Create New Token**. This will trigger the download of a `kaggle.json` file containing your API credentials.

In Colab, select **Secrets** (🔑) in the left pane and add your Kaggle username and Kaggle API key. Store your username under the name `KAGGLE_USERNAME` and your API key under the name `KAGGLE_KEY`.

### Set environment variables

Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`.

In [1]:
import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.

os.environ["GITHUB_TOKEN"] = userdata.get('GITHUB_TOKEN')
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

### Install dependencies

Install Keras, KerasNLP, and other dependencies.

In [2]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U tf-keras
!pip install -q -U keras-nlp==0.10.0
!pip install -q -U kagglehub>=0.2.4
!pip install -q -U keras>=3

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m28.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m615.3/615.3 MB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.5/5.5 MB[0m [31m74.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m513.7/513.7 kB[0m [31m14.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m950.8/950.8 kB[0m [31m45.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m98.4 MB/s[0m eta [36m0:00:00[0m
[?25h

### Select a backend

Keras is a high-level, multi-framework deep learning API designed for simplicity and ease of use. Using Keras 3, you can run workflows on one of three backends: TensorFlow, JAX, or PyTorch.

For this tutorial, configure the backend for JAX.

In [3]:
os.environ["KERAS_BACKEND"] = "jax"
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

### Import packages

Import Keras, KerasNLP, and the `csv` package.

In [4]:
import keras_nlp
import keras
import csv

print("KerasNLP version: ", keras_nlp.__version__)
print("Keras version: ", keras.__version__)

KerasNLP version:  0.10.0
Keras version:  3.7.0


## Load Model

Let's download the 2B variant of Gemma from Kaggle. You can see the model page [here](https://www.kaggle.com/models/keras/gemma/keras/gemma_2b_en).

In [5]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")

In [6]:
gemma_lm.summary()

## Load Dataset

Let's download a [Medicare & Medicaid Question Answering Dataset](https://huggingface.co/datasets/kimmysue1102/databrickstest/resolve/main/databrickstest.jsonl) from Huggingface for this fine-tune example.

In [7]:
!wget -O medicare.jsonl https://huggingface.co/datasets/kimmysue1102/medicare/resolve/main/medicare.jsonl

--2024-12-02 05:12:21--  https://huggingface.co/datasets/kimmysue1102/medicare/resolve/main/medicare.jsonl
Resolving huggingface.co (huggingface.co)... 18.164.174.23, 18.164.174.118, 18.164.174.55, ...
Connecting to huggingface.co (huggingface.co)|18.164.174.23|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 72543 (71K) [text/plain]
Saving to: ‘medicare.jsonl’


2024-12-02 05:12:21 (5.87 MB/s) - ‘medicare.jsonl’ saved [72543/72543]



After downloading the `medicare.jsonl` file.


This will be the dataset our model will be fine-tuned on.

In [8]:
!pip install datasets


from datasets import load_dataset

dataset = load_dataset('kimmysue1102/medicare')

data = []
for row in dataset['train']:
  template = "Question:\n{Question}\n\nAnswer:\n{Answer}"
  data.append(template.format(**row))

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m15.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl 

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


medicare.jsonl:   0%|          | 0.00/72.5k [00:00<?, ?B/s]

medicaredataset.jsonl:   0%|          | 0.00/130k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/345 [00:00<?, ? examples/s]

Let's take a look at an example to make sure the data has been formatted correctly with the Question-Answer template:

In [9]:
print(data[3])

Question:
What are the two main ways to get Medicare coverage?

Answer:
The two main options are Original Medicare (Parts A and B) and Medicare Advantage (Part C).


### Inference before fine tuning

The original Gemma model has a lot of general knowledge, but fine-tuning can help improve domain-specific knowledge.

To test the pre-trained model on more specific medicare knowledge, let's pick a more complex problem: **Medicare Part D**.

Let's prompt Gemma by asking about about drugs covered by the plan, making sure to format our prompt using the Question-Answer template we previously defined.

In [10]:
prompt = template.format(
    Question="Where can I find out if a specific drug is covered by my Part D plan?",
    Answer="",
)
print(gemma_lm.generate(prompt, max_length=128))

Question:
Where can I find out if a specific drug is covered by my Part D plan?

Answer:
You can find out if a specific drug is covered by your Part D plan by calling the number on the back of your Part D ID card.

Question:
What is a generic drug?

Answer:
A generic drug is a drug that is chemically identical to a brand-name drug. Generic drugs are usually less expensive than brand-name drugs.

Question:
What is a preferred drug?

Answer:
A preferred drug is a drug that is covered at a lower cost than a non-


In [11]:
prompt = template.format(
    Question="Does Medicare cover screenings for cervical cancer?",
    Answer="",
)
print(gemma_lm.generate(prompt, max_length=128))

Question:
Does Medicare cover screenings for cervical cancer?

Answer:
Medicare covers cervical cancer screenings if you have a high-risk condition. This includes:

* Human papillomavirus (HPV) infection
* Human papillomavirus (HPV) infection and abnormal cervical cells
* Human papillomavirus (HPV) infection and abnormal cervical cells and abnormal Pap test results
* Human papillomavirus (HPV) infection and abnormal cervical cells and abnormal Pap test results and abnormal pelvic exam results
* Human papillomavirus (HPV) infection and abnormal cervical cells and abnormal Pap test results and abnormal pelvic exam results and abnormal pelvic exam


In [12]:
prompt = template.format(
    Question="Can I join a Medicare Advantage Plan if I have a pre-existing condition?",
    Answer="",
)
print(gemma_lm.generate(prompt, max_length=128))

Question:
Can I join a Medicare Advantage Plan if I have a pre-existing condition?

Answer:
Yes, you can join a Medicare Advantage Plan if you have a pre-existing condition.

Question:
What is a Medicare Advantage Plan?

Answer:
Medicare Advantage Plans are offered by private insurance companies and are designed to provide you with a variety of benefits and services.

Question:
What are the benefits of a Medicare Advantage Plan?

Answer:
Medicare Advantage Plans offer a variety of benefits and services, including:

* Prescription drug coverage
* Dental, vision, and hearing coverage
* Gym


In [13]:
prompt = template.format(
    Question="Explain Medigap in a way that a child could understand.",
    Answer="",
)
print(gemma_lm.generate(prompt, max_length=128))

Question:
Explain Medigap in a way that a child could understand.

Answer:
Medigap is a type of supplemental insurance that helps pay for the costs of Medicare Part A and Part B. It is available to people who have Medicare Part A and Part B. Medigap plans are sold by private insurance companies. They are available in all 50 states and the District of Columbia. Medigap plans are also known as Medicare Supplement Insurance.

Medicare Part A is hospital insurance. It helps pay for room and board in a hospital, skilled nursing facility, or hospice. Medicare Part B is medical insurance


As you can see, the resulting answer from Gemma gives us definitions of generic and preferred drugs and to call the provider for more information on a specific drug.

This is where fine-tuning on our medicare & medicaid dataset can help.

## LoRA Fine-tuning

To get better responses from the model, fine-tune the model with Low Rank Adaptation (LoRA) using our Medicare and Medicaid Question-Answer dataset.

The LoRA rank determines the dimensionality of the trainable matrices that are added to the original weights of the LLM. It controls the expressiveness and precision of the fine-tuning adjustments.

A higher rank means more detailed changes are possible, but also means more trainable parameters. A lower rank means less computational overhead, but potentially less precise adaptation.

This tutorial uses a LoRA rank of 4. In practice, begin with a relatively small rank (such as 4, 8, 16). This is computationally efficient for experimentation. Train your model with this rank and evaluate the performance improvement on your task. Gradually increase the rank in subsequent trials and see if that further boosts performance.

In [14]:
# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

Note that enabling LoRA reduces the number of trainable parameters significantly.

In [15]:
# Fine-tune on the Medicare & Medicaid QA dataset.

# Limit the input sequence length to 128 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 128
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)

[1m345/345[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m166s[0m 423ms/step - loss: 1.2957 - sparse_categorical_accuracy: 0.5930


<keras.src.callbacks.history.History at 0x79d3602ddba0>

### Inference after fine tuning



After fine tuning the model, let's try the prompt again to ask for specific drug coverage and Medigap explanation.

In [16]:
prompt = template.format(
    Question="Where can I find out if a specific drug is covered by my Part D plan?",
    Answer="",
)
print(gemma_lm.generate(prompt, max_length=128))

Question:
Where can I find out if a specific drug is covered by my Part D plan?

Answer:
You can find out if a specific drug is covered by your Part D plan by calling your plan's Member Services number. You can also find out if a drug is covered by your plan by looking at your plan's formulary.


In [17]:
prompt = template.format(
    Question="Does Medicare cover screenings for cervical cancer?",
    Answer="",
)
print(gemma_lm.generate(prompt, max_length=128))

Question:
Does Medicare cover screenings for cervical cancer?

Answer:
Medicare covers cervical cancer screenings if you have Medicare Part B. Medicare Part B covers cervical cancer screenings if you have a Pap test and a pelvic exam. Medicare Part B covers the cost of the test and the exam.


In [18]:
prompt = template.format(
    Question="Can I join a Medicare Advantage Plan if I have a pre-existing condition?",
    Answer="",
)
print(gemma_lm.generate(prompt, max_length=128))

Question:
Can I join a Medicare Advantage Plan if I have a pre-existing condition?

Answer:
Yes, you can join a Medicare Advantage Plan if you have a pre-existing condition. However, you may have to pay a higher premium and/or a higher deductible.


In [19]:
prompt = template.format(
    Question="Explain Medigap in a way that a child could understand.",
    Answer="",
)
print(gemma_lm.generate(prompt, max_length=128))

Question:
Explain Medigap in a way that a child could understand.

Answer:
Medigap plans are sold by private companies that contract with Medicare to provide benefits to people who have Medicare Part A and Part B. Medigap plans are sold in all 50 states and the District of Columbia. Medigap plans are sold in addition to Medicare Part A and Part B. Medigap plans are not insurance. They are a way to pay some of the costs of Medicare Part A and Part B.


The response is much more helpful than before fine-tuning, readily listing much clearer and concise responses.

## Upload your model to Kaggle

Create a preset directory for your model files.

Then, save the model to that preset directory.

In [20]:
preset = "./medicare_gemma"
# Save the model to the preset directory.
gemma_lm.save_to_preset(preset)

Create a Kaggle URI for your model.
It should follow the following format:

`kaggle://{KAGGLE USERNAME}/{MODEL NAME}/keras/{VARIATION NAME}`

In [21]:
kaggle_username = userdata.get('KAGGLE_USERNAME')
model_name = "gemma"
variation_name = "medicare_gemma"

uri = f"kaggle://{kaggle_username}/{model_name}/keras/{variation_name}"
uri

'kaggle://kimberlymac1102/gemma/keras/medicare_gemma'

Then, upload the preset to Kaggle!

If this is your first upload of this model, a Kaggle model page will be created associated with your profile.

You can view all your models on your [Work Page](https://www.kaggle.com/work/models).

In [22]:
# Upload preset to Kaggle
keras_nlp.upload_preset(uri, preset)

Uploading Model https://www.kaggle.com/models/kimberlymac1102/gemma/keras/medicare_gemma ...
Model 'gemma' does not exist or access is forbidden for user 'kimberlymac1102'. Creating or handling Model...
Model 'gemma' Created.
Starting upload for file ./medicare_gemma/model.weights.h5


Uploading: 100%|██████████| 10.0G/10.0G [02:09<00:00, 77.6MB/s]

Upload successful: ./medicare_gemma/model.weights.h5 (9GB)
Starting upload for file ./medicare_gemma/task.json



Uploading: 100%|██████████| 2.45k/2.45k [00:00<00:00, 3.91kB/s]

Upload successful: ./medicare_gemma/task.json (2KB)
Starting upload for file ./medicare_gemma/preprocessor.json



Uploading: 100%|██████████| 1.25k/1.25k [00:00<00:00, 1.96kB/s]

Upload successful: ./medicare_gemma/preprocessor.json (1KB)
Starting upload for file ./medicare_gemma/tokenizer.json



Uploading: 100%|██████████| 498/498 [00:00<00:00, 850B/s]

Upload successful: ./medicare_gemma/tokenizer.json (498B)
Starting upload for file ./medicare_gemma/config.json



Uploading: 100%|██████████| 501/501 [00:00<00:00, 828B/s]

Upload successful: ./medicare_gemma/config.json (501B)
Starting upload for file ./medicare_gemma/metadata.json



Uploading: 100%|██████████| 143/143 [00:00<00:00, 233B/s]

Upload successful: ./medicare_gemma/metadata.json (143B)
Starting upload for file ./medicare_gemma/assets/tokenizer/vocabulary.spm



Uploading: 100%|██████████| 4.24M/4.24M [00:00<00:00, 5.94MB/s]

Upload successful: ./medicare_gemma/assets/tokenizer/vocabulary.spm (4MB)





Your model instance has been created.
Files are being processed...
See at: https://www.kaggle.com/models/kimberlymac1102/gemma/keras/medicare_gemma


Now view the model page using the URL in the output of the previous cell.

Verify that your new model instance is successfully uploaded.
Note this can take several minutes if this is your first upload of this model type.

**That's it!** You've now learned how to fine-tune Gemma using Kaggle and Keras and share your model with the community.