<a href="https://colab.research.google.com/github/bakeronit/note-nbviewer/blob/master/gdg_ai_for_science_finetune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-Tuning a Large Language Model on Your Own Data (Notebook 1)

This is part of a series run by the [GDG AI for Science](https://gdg.community.dev/gdg-ai-for-science-australia/)

**Notebook 1:** Fine Tune with KerasHub (TPU accelerator required)

![You are here](https://colab.research.google.com/assets/colab-badge.svg) or [![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://www.kaggle.com/code/astrobutter/gdg-ai-for-science-finetuning)

Notebook 2: Fine Tune with Transformers (GPU accelerator required)

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1sJviybARN36t3d_i-gfmc0zDs51WDgzz?usp=sharing) or [![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://www.kaggle.com/code/astrobutter/gdg-ai-for-science-finetuning-transformers)

Notebook 3: RAG workflow (GPU recommended)

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1vA3f3XdLB3nvczP4gRuTMwWMfSkXT1ex?usp=sharing) or [![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://www.kaggle.com/code/astrobutter/gdg-ai-for-science-finetuning-rag)

## Prerequisites
* Familiarity with Python, including functions and the pandas library.

* Access to a computing environment with a GPU or TPU (like a Kaggle Notebook or Google Colab).

## Learning Objectives
* Understand the concepts of pre-training and fine-tuning for Large Language Models (LLMs).

* Recognize the hardware requirements and limitations (GPU/TPU/CPU) for training.

* Prepare a custom dataset for fine-tuning.

* Fine-tune an LLM using Keras with a TPU backend.

* Fine-tune an LLM using the transformers library with a GPU backend.

* Understand the difference between full fine-tuning and Parameter-Efficient Fine-Tuning (PEFT) with LoRA.

* Evaluate the model's performance by comparing responses before and after fine-tuning.

* Compare a RAG workflow with Fine Tuning


## What is a Large Language Model (LLM)? 🤖
Think of a pre-trained LLM like a brilliant new research assistant who has read nearly the entire public internet. They have a vast general knowledge and can write essays, summarize articles, and answer questions on a huge range of topics. However, they haven't read your lab's specific protocols, your private research data, or the niche publications in your specialized field, and they make a lot of mistakes whilst trying to appear confident that they are correct (even when they are not).

This "general knowledge" comes from a process called pre-training, where the model is shown trillions of words of text and learns the patterns, grammar, and facts of human language.

## What is Fine-Tuning? 🎯
Fine-tuning is the process of taking that pre-trained model and training it for a bit longer on your own, smaller, domain-specific dataset, or teaching it a specific task. It's like giving your new research assistant a curated stack of your lab's most important papers and data. You aren't teaching them language from scratch; you're adapting their existing knowledge to your specific needs.

Through fine-tuning, the model can learn your domain's specific vocabulary, understand relationships between concepts in your field, and can adopt a specific style or format for its responses.

## Hardware Requirements 🚀
Training and fine-tuning LLMs involves billions of calculations. A standard computer processor (CPU) can handle complex, sequential, singular tasks. However, training requires thousands of simple tasks at the exact same time.

This is where Graphics Processing Units (GPUs) and Tensor Processing Units (TPUs) come in. They are specialised chips designed for massive parallel computation, making them great for deep learning. Fine-tuning even a small LLM is practically impossible without one. For this lesson, we'll use Kaggle's free-tier TPUs and GPUs.

## Frameworks and Architectures

### Model Architecture (The "Brains")
This is the design of the neural network itself. Examples include Gemma, GPT-2, Llama, etc. These are the pre-trained models we will adapt.

### Framework
These are the software libraries that provide the tools to load, manipulate, and train the models. We will use two of the most popular frameworks: **Keras** (a user-friendly, high-level API) and **Hugging Face transformers** (the de-facto industry standard, known for its power and flexibility).

You must use a framework that supports the model architecture you want to work with.


# Example 1: Fine-Tuning Gemma with Keras on a TPU
In this first example, we'll perform a full fine-tuning of Google's Gemma model. This means we will be updating all of the model's weights using our custom data. We'll use the Keras framework with a JAX backend, which is highly optimized for running on TPUs.

## Setup and Environment Configuration

### Kaggle API Token Setup

To download [Gemma models](https://www.kaggle.com/models/google/gemma) (or other datasets/models) from Kaggle directly into Colab or your local computer, you need a Kaggle API token.

**Steps to get your `kaggle.json` token:**
   1. Go to your Kaggle account page: [https://www.kaggle.com/](https://www.kaggle.com/)
   2. Log in or create an account if you don't have one.
   3. Click on your profile picture/icon in the top right corner, then select **"Settings"**.
   4. Scroll down to the **"API"** section.
   5. Click **"Create New Token"**. This will download a file named `kaggle.json` to your computer.

   This file just contains some text with your username and an *API Key*. Treat this like a username and password. You can "revoke" access easily and often have fine-grain control of what the Key can access.

   `{"username":"astrobutter","key":"abcdefghijk1234567890"}`

**Accept Gemma terms of service**

   6. Visit https://www.kaggle.com/models/google/gemma and accept the usage license (this is common for many model families)
  
### Log in to Kaggle and basic imports
Now we can login to Kaggle. Download any data and models we need. And then import our basic libraries and configure the environment to ensure Keras uses the JAX backend and can access the TPU's full memory.

In [1]:
# Import basic libraries for file handling and data manipulation
import os
import pandas as pd

# Login to Kaggle Hub - get credentials from https://www.kaggle.com/settings
import kagglehub
kagglehub.login()

VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

Kaggle credentials set.
Kaggle credentials successfully validated.


In [2]:
# Download models and data from Kaggle
path_gemma = kagglehub.model_download("keras/gemma3/keras/gemma3_instruct_270m")
path_gpt = kagglehub.model_download("keras/gpt2/keras/gpt2_base_en")
path_data = kagglehub.dataset_download("gpreda/medquad")

In [3]:
# Update python libraries to use TPU in a kaggle/colab notebook
# jax 0.6.2 and keras-hub 0.21.1 seem to work
!pip install -U pip -q
!pip install -U "jax[tpu]"==0.6.2 -q
!pip install keras-hub==0.21.1 -U -q

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.2/1.8 MB[0m [31m4.6 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━[0m [32m1.4/1.8 MB[0m [31m20.7 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m19.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [4]:
# --- Environment Setup for Keras with JAX on a TPU ---

# Keras is a high-level API that can run on different backends like TensorFlow, PyTorch, or JAX.
# JAX is a high-performance library from Google that is especially efficient on TPUs.
# We explicitly tell Keras to use JAX for all its computations.
os.environ["KERAS_BACKEND"] = "jax"

# This command instructs JAX to pre-allocate all available TPU memory.
# This can prevent memory fragmentation and speed up computations, but it means this notebook
# will have exclusive use of the TPU.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"

In [5]:
# --- Import Deep Learning Libraries ---

# Import JAX and configure it to use the TPU.
import jax
jax.config.update('jax_platform_name', 'tpu')
print(f"JAX is running on {jax.devices()[0].device_kind}")

# Import our main deep learning frameworks: Keras and keras-hub (forerly keras-nlp) for LLM-specific tools.
import keras
import keras_hub

# bfloat16 uses less memory than the standard float32, which helps our model train faster on a TPU without a major loss in accuracy.
# keras.config.set_floatx("bfloat16")

JAX is running on TPU v2


## Data Loading and Preparation
An LLM needs to be trained on structured examples. For a question-answering task, this means clear pairs of "prompts" (questions) and "responses" (answers). We'll load the [medquad](https://www.kaggle.com/datasets/gpreda/medquad) dataset, which contains medical questions and answers, and then create a small, targeted subset for our task.

In [6]:
# Load and subset the data for training
df = pd.read_csv(path_data+"/medquad.csv")
# data = df.sample(n=100, random_state=42)

# For this workshop, we want the fine-tuning process to be fast and the results to be obvious.
# So, we will "cheat" by creating a very small, highly specific dataset focused only on "pernicious anemia".
# In a real-world project, you would use a much larger and more diverse dataset representing your entire domain.
df_subset_mask = df['question'].str.contains('pernicious anemia', case=False, na=False) | \
                         df['answer'].str.contains('pernicious anemia', case=False, na=False) | \
                         df['focus_area'].str.contains('pernicious anemia', case=False, na=False)
df_subset = df[df_subset_mask]

# Preview the first few lines of the data
df_subset.head(2)

Unnamed: 0,question,answer,source,focus_area
1132,Who is at risk for Gastrointestinal Carcinoid ...,Health history can affect the risk of gastroin...,CancerGov,Gastrointestinal Carcinoid Tumors
3017,What is (are) Autoimmune atrophic gastritis ?,Autoimmune atrophic gastritis is an autoimmune...,GARD,Autoimmune atrophic gastritis


### Format the Data for the LLM
We want to train OUR model on a dataset of prompt-response pairs.
We'll write a simple function to convert our DataFrame into the dictionary format **required** by the model we choose to use.
For best results, you should format the prompt and response to match the template the model was originally trained on. This often involves special tokens like `'<start_of_turn>user'` and`'<start_of_turn>model'`. Check the [Gemma model card](https://ai.google.dev/gemma/docs/core/prompt-structure) for details.

In [7]:
# Helper function to transform our dataframe into the required format.
def format_data(df):
    prompts = []
    responses = []
    for index, row in df.iterrows():
        question = row['question']
        response = row['answer']
        if question and response:
             # prompts.append(f"<start_of_turn>user\nInstruction:\nAnswer the following question.\nQuestion:{question}\n<end_of_turn>")
             # responses.append(f"<start_of_turn>model\nResponse:{response}\n<end_of_turn>")
            prompts.append(f"{question}")
            responses.append(f"{response}")

    data_to_preprocess = {"prompts": prompts, "responses": responses}
    return data_to_preprocess

# Apply the formatting to our data.
formatted_data = format_data(df_subset)

## Loading the Pre-Trained Model
Now, we'll load the pre-trained Gemma model. We are using a Gemma3CausalLM, which is a "Causal Language Model." This means it works by predicting the very next word (or "token") in a sequence based on the words that came before it. This is the fundamental mechanism behind text generation.

In [8]:
# Load the Gemma3 model
# `from_preset` is a convenient Keras function to load a model with its standard configuration.
# This includes the model architecture itself, the pre-trained weights, and the tokenizer
# which converts text into numbers the model can understand.
# We are loading a smaller 270 Million parameter version of Gemma 3, which is suitable for quick fine-tuning.
print("Loading model...")
causal_lm = keras_hub.models.Gemma3CausalLM.from_preset(path_gemma)

# The .summary() method gives us a look at the model's architecture.
# Pay attention to the "Total params" and "Trainable params". In this full fine-tuning
# example, they will be the same, meaning we are updating every part of the model.
causal_lm.summary()

Loading model...


## Test Before Fine-Tuning (Establish a Baseline)
It's crucial to see how the model performs before we fine-tune it. This gives us a baseline to measure our improvements against. We will ask it a question about our topic and see what its general knowledge provides.

In [9]:
# Set a prompt
prompt = "What is pernicious anemia?"

In [10]:
print("Sending prompt to model...")

# The .generate() method takes our text prompt and produces a response.
response_raw = causal_lm.generate(prompt)

print(f"{response_raw}")

Sending prompt to model...
What is pernicious anemia?

The answer is that it is a condition where the body's ability to produce enough red blood cells is impaired. This can lead to a variety of symptoms, including fatigue, weakness, shortness of breath, and dizziness.

What is pernicious anemia?

The answer is that it is a condition where the body's ability to produce enough red blood cells is impaired. This can lead to a variety of symptoms, including fatigue, weakness, shortness of breath, and dizziness.

What is the main cause of pernicious anemia?

The answer is that it is a condition where the body's ability to produce enough red blood cells is impaired. This can lead to a variety of symptoms, including fatigue, weakness, shortness of breath, and dizziness.

What is the main symptom of pernicious anemia?

The answer is that it is a condition where the body's ability to produce enough red blood cells is impaired. This can lead to a variety of symptoms, including fatigue, weakness, 

## Compile and Fine-Tune the Model
Now we need to enable our model to be modified. Then we need to "compile" the model with our training options. Then we can calll .fit() to begin fine-tuning on our data.

In [11]:
# Enable Low-Rank Adaptation (LoRA) for parameter efficient fine-tuning.
# LoRA freezes all weights on the backbone except for specific attention layer components
causal_lm.backbone.enable_lora(rank=16)
print(f"Number of trainable weights after LoRA: {len(causal_lm.trainable_weights)}")
print(f"Number of non-trainable weights after LoRA: {len(causal_lm.non_trainable_weights)}")

Number of trainable weights after LoRA: 72
Number of non-trainable weights after LoRA: 236


In [12]:
print("Compiling the model...")
causal_lm.compile(
    # The optimizer is the algorithm that updates the model's weights to minimize the loss.
    # Adam is a very popular and effective general-purpose optimizer.
    # The `learning_rate` is the single most important hyperparameter. It controls the size of the
    # weight updates. Too large, and the training can become unstable; too small, and it will be too slow.
    # A small learning rate like 1e-4 (0.0001) is a good starting point for fine-tuning.
    optimizer=keras.optimizers.Adam(learning_rate=1e-4),
    # The "loss function" calculates a score that measures how wrong the model's predictions are.
    # The goal of training is to minimize this score. SparseCategoricalCrossentropy is the standard
    # loss function for next-token prediction tasks.
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    # Metrics are used to monitor the training process. Here, we'll track accuracy.
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()]
)
print("Done.")

Compiling the model...
Done.


In [13]:
print("Starting fine-tuning...")
causal_lm.fit(formatted_data, epochs=10, batch_size=1) # Adjust batch_size depening on VRAM available. Adjust epoch until loss plateaus
print("Fine-tuning complete!")

Starting fine-tuning...
Epoch 1/10
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m48s[0m 1s/step - loss: 0.8251 - sparse_categorical_accuracy: 0.5134
Epoch 2/10
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 101ms/step - loss: 0.7462 - sparse_categorical_accuracy: 0.5284
Epoch 3/10
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 88ms/step - loss: 0.6913 - sparse_categorical_accuracy: 0.5405
Epoch 4/10
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 101ms/step - loss: 0.6982 - sparse_categorical_accuracy: 0.5483
Epoch 5/10
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 100ms/step - loss: 0.6811 - sparse_categorical_accuracy: 0.5521
Epoch 6/10
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 100ms/step - loss: 0.6649 - sparse_categorical_accuracy: 0.5623
Epoch 7/10
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 100ms/step - loss: 0.6492 - sparse_categorical_accuracy: 0.5718
Epoch

## Test After Fine-Tuning
Now, we ask the exact same prompt to our newly fine-tuned model. The hope is that its answer will be ~~better, more accurate~~ closer to what we have trained the model to do.

In [16]:
print("Testing generation from the fine-tuned model:")
response_ft = causal_lm.generate(prompt)
print(f"{response_ft}")

Testing generation from the fine-tuned model:
What is pernicious anemia?

Pernicious anemia is a condition in which the body does not produce enough of the vitamin B12 in the body. This deficiency can lead to various health problems.

How pernicious anemia is caused.

Pernicious anemia is caused by a deficiency in the vitamin B12 in the body. This vitamin is essential for the proper functioning of the nervous system, the heart, and the eyes. When the body doesn't have enough B12, it can't properly perform these functions.

How pernicious anemia is treated.

Treatment for pernicious anemia is usually done by a doctor. The doctor will prescribe a medication to help the body absorb the B12. The medication may help to prevent or treat other health problems. The medication may also help to prevent or treat pernicious anemia.
The treatment for pernicious anemia is usually done by a doctor. The doctor will prescribe a medication to help the body absorb the B12. The medication may help to prev

In [18]:
# And compare the scope of the fine-tuned model
causal_lm.summary()

In [19]:
#Save the model to disk!
causal_lm.save_to_preset("./my-model-ft")

# Example 2: A Quick Look at GPT-2
Here we want to highlight how the choice of a different model means we have to make different choices in our data and framework. And the absoulte bare minimum amount of code for model tuning.

In [20]:
# Load a GPT2 backbone with pre-trained weights. NOTE the differnet keras_hub.models method!
causal_lm = keras_hub.models.CausalLM.from_preset(path_gpt)

prompt = "What is pernicious anemia?"
causal_lm.generate(prompt)

"What is pernicious anemia?\n\nAnemia is an abnormal form of anemia. It can be caused by a deficiency of iron in the blood, a lack of iron in the body's blood, or by any of the following factors:\n\nA lack of oxygen\n\nA lack of calcium\n\nA lack of iron in the blood\n\nA lack of vitamin A\n\nThe most common cause of anemia is a buildup of plaque in the blood. These buildup are called pernicious anemia.\n\nPernicious anemia is caused by a deficiency of iron in the blood. It can be caused by a deficiency of iron in the blood, a lack of iron in the body's blood, or by any of the following factors:\n\nA lack of oxygen\n\nA lack of calcium\n\nA lack of vitamin A\n\nThe most common cause of anemia is a buildup of plaque in the blood. These buildup are called pernicious anemia.\n\nPernicious anemia is caused by a deficiency of iron in the blood. It can be caused by a deficiency of iron in the blood, a lack of iron in the body's blood, or by any of the following factors:\n\nA lack of oxygen\n

In [21]:
# Format the data into what GPT2 model expects - different to Gemma!
def format_data_gpt2(df):
    prompts = []
    responses = []
    for index, row in df.iterrows():
        question = row['question']
        response = row['answer']
        if question and response:
             responses.append(f"{response}\n")

    return responses

formatted_data_gpt2 = format_data_gpt2(df_subset)

In [22]:
# Just use the defaults to demonstrate how lean our model training can be! (No LORA - so full fine tuning)
causal_lm.compile()

In [23]:
causal_lm.fit(formatted_data_gpt2, epochs=10, batch_size=7)

Epoch 1/10
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m96s[0m 23s/step - loss: 1.0016 - sparse_categorical_accuracy: 0.5545
Epoch 2/10
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m49s[0m 1s/step - loss: 0.9332 - sparse_categorical_accuracy: 0.5802
Epoch 3/10
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 1s/step - loss: 0.9064 - sparse_categorical_accuracy: 0.5795
Epoch 4/10
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 1s/step - loss: 0.8824 - sparse_categorical_accuracy: 0.5924
Epoch 5/10
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 1s/step - loss: 0.8623 - sparse_categorical_accuracy: 0.5959
Epoch 6/10
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 1s/step - loss: 0.8384 - sparse_categorical_accuracy: 0.6072
Epoch 7/10
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 1s/step - loss: 0.8232 - sparse_categorical_accuracy: 0.6092
Epoch 8/10
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[

<keras.src.callbacks.history.History at 0x7b01cfd66900>

In [None]:
# Try again with fine tuned model
causal_lm.generate(prompt)

"What is pernicious anemia?\n\nPernicious anemia is an infection of the central nervous system caused by a type of cancer called a cancer called Mycoplasma gondii. The disease is a type of cancer that can cause a variety of health problems, including heart disease, type 2 diabetes, and cancers of the thyroid and lung.\n\nThe most common types of cancer in people are breast, prostate, and lung cancer. The disease affects the central nervous system, including those areas of the body where nerves and nerves play important roles.\n\nThe cause of pernicious anemia is not well understood. But a number of studies have identified a number of possible causes, including:\n\nHemorrhagic anemia\n\nHepatic anemia\n\nHepatic cell carcinoma\n\nTubercular anemia\n\nChronic wasting syndrome (CWS).\n\nWhat causes pernicious anemia?\n\nA person may develop pernicious anemia because of an abnormal immune response to a protein called the T helper protein. T cells attack the central nervous system and attac

# How You Can Adapt This For Your Research

The examples above use a medical question-answering dataset, but the workflow is highly adaptable.

Understand your task. Pick a model. Pick a framework.  Build your pipeline!

The key is to structure your data into what your framework/model requires, that also teaches the model how to perform the task you want.

## Links:
KerasHub Documentation: https://keras.io/keras_hub/api/models/gemma3/

Good Huggingface training demo: https://www.youtube.com/watch?v=uikZs6y0qgI

Gemma 3 fine tune documentation: https://ai.google.dev/gemma/docs/core/lora_tuning

Gemma 3 on Kaggle: https://www.kaggle.com/models/google/gemma-3

GPT-2 on Kaggle: https://www.kaggle.com/models/keras/gpt2

Fine Tune Gemini: https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-use-supervised-tuning (deprecated AI studio method: https://ai.google.dev/gemini-api/docs/model-tuning)

Huggingface Transformers documentation: https://huggingface.co/docs/transformers/en/index

Huggingface Sentence-Transformers documentation: https://huggingface.co/sentence-transformers