# Tut 4a -  Fine tuning a custom model ( Kaggle 5-Day Generative AI Course )

Today, we’ll learn how to **fine-tune** a custom model using the **Gemini API**. Fine-tuning means adjusting a pre-existing AI model to make it **work better for a specific task**, like **classifying text**. In this case, we'll teach the model to identify the category of a text from newsgroup posts.

In [24]:
!pip uninstall -qqy jupyterlab  # Remove unused conflicting packages
!pip install -U -q "google-genai==1.7.0"

[0m

### Set up your API key

To run the following cell, your API key must be stored it in a Kaggle secret named GOOGLE_API_KEY.

If you don't already have an API key, you can grab one from AI Studio. You can find detailed instructions in the docs.

To make the key available through Kaggle secrets, choose Secrets from the Add-ons menu and follow the instructions to add your key or enable it for this notebook.

In [25]:
import os
from google import genai
from google.genai import types
from kaggle_secrets import UserSecretsClient

GOOGLE_API_KEY = UserSecretsClient().get_secret("GOOGLE_API_KEY")
os.environ["GOOGLE_API_KEY"] = GOOGLE_API_KEY

client = genai.Client(api_key=GOOGLE_API_KEY)


To start fine-tuning your custom model, use the `models.list` API method to find a suitable model that supports fine-tuning, refer to the model tuning documentation for more details, and then use the `TunedModel.create` API method to initiate the fine-tuning process and create your custom model.
[https://ai.google.dev/gemini-api/docs/model-tuning/tutorial?lang=python](http://)

## Download the dataset
###  In this activity, you will use the fine-tuned Gemini model to train/text same newsgroups dataset. The 20 Newsgroups Text Dataset contains 18,000 newsgroups posts on 20 topics divided into training and test sets.

In [26]:
for model in client.models.list():
    if "createTunedModel" in model.supported_actions:
        print(model.name)

models/gemini-1.5-flash-001-tuning


In [27]:
from sklearn.datasets import fetch_20newsgroups

newsgroups_train = fetch_20newsgroups(subset="train")
newsgroups_test = fetch_20newsgroups(subset="test")

# View list of class names for dataset
newsgroups_train.target_names

print(newsgroups_train.data[0])


From: lerxst@wam.umd.edu (where's my thing)
Subject: WHAT car is this!?
Nntp-Posting-Host: rac3.wam.umd.edu
Organization: University of Maryland, College Park
Lines: 15

 I was wondering if anyone out there could enlighten me on this car I saw
the other day. It was a 2-door sports car, looked to be from the late 60s/
early 70s. It was called a Bricklin. The doors were really small. In addition,
the front bumper was separate from the rest of the body. This is 
all I know. If anyone can tellme a model name, engine specs, years
of production, where this car is made, history, or whatever info you
have on this funky looking car, please e-mail.

Thanks,
- IL
   ---- brought to you by your neighborhood Lerxst ----







### Why this cleaning is important:

This code prepares text data from a newsgroup dataset (think online forums) for use in a machine learning model. It focuses on cleaning and formatting the text to make it more suitable for analysis.

In [28]:
import email
import re

import pandas as pd


def preprocess_newsgroup_row(data):
    # Extract only the subject and body
    msg = email.message_from_string(data)
    text = f"{msg['Subject']}\n\n{msg.get_payload()}"
    # Strip any remaining email addresses
    text = re.sub(r"[\w\.-]+@[\w\.-]+", "", text)
    # Truncate the text to fit within the input limits
    text = text[:40000]

    return text


def preprocess_newsgroup_data(newsgroup_dataset):
    # Put data points into dataframe
    df = pd.DataFrame(
        {"Text": newsgroup_dataset.data, "Label": newsgroup_dataset.target}
    )
    # Clean up the text
    df["Text"] = df["Text"].apply(preprocess_newsgroup_row)
    # Match label to target name index
    df["Class Name"] = df["Label"].map(lambda l: newsgroup_dataset.target_names[l])

    return df

In [29]:
# Apply preprocessing to training and test datasets
df_train = preprocess_newsgroup_data(newsgroups_train)
df_test = preprocess_newsgroup_data(newsgroups_test)

df_train.head()

Unnamed: 0,Text,Label,Class Name
0,WHAT car is this!?\n\n I was wondering if anyo...,7,rec.autos
1,SI Clock Poll - Final Call\n\nA fair number of...,4,comp.sys.mac.hardware
2,"PB questions...\n\nwell folks, my mac plus fin...",4,comp.sys.mac.hardware
3,Re: Weitek P9000 ?\n\nRobert J.C. Kyanko () wr...,1,comp.graphics
4,Re: Shuttle Launch Question\n\nFrom article <>...,14,sci.space


**Sampling the Data for Training**

* Select a small, representative subset of our data.
* Specifically, we'll keep 50 data points (messages) from each category.
* This small sample size is sufficient because we're using a parameter-efficient fine-tuning (PEFT) technique.
* PEFT allows us to train effectively with fewer examples.

In [30]:
def sample_data(df, num_samples, classes_to_keep):
    # Sample rows, selecting num_samples of each Label.
    df = (
        df.groupby("Label")[df.columns]
        .apply(lambda x: x.sample(num_samples))
        .reset_index(drop=True)
    )

    df = df[df["Class Name"].str.contains(classes_to_keep)]
    df["Class Name"] = df["Class Name"].astype("category")

    return df


TRAIN_NUM_SAMPLES = 50
TEST_NUM_SAMPLES = 10
# Keep rec.* and sci.*
CLASSES_TO_KEEP = "^rec|^sci"

df_train = sample_data(df_train, TRAIN_NUM_SAMPLES, CLASSES_TO_KEEP)
df_test = sample_data(df_test, TEST_NUM_SAMPLES, CLASSES_TO_KEEP)

## Baseline Evaluation: Sample Inspection

To gauge tuning effectiveness, we first inspect a sample from the test data.

1.  **Select Sample:** We pick a row (e.g., the first).
2.  **Preprocess & Label:** We process the text and get its category.
3.  **Display:** We show the processed text and its label.

This helps us understand the data and verify labels before model tuning.

In [31]:
sample_idx = 0
sample_row = preprocess_newsgroup_row(newsgroups_test.data[sample_idx])
sample_label = newsgroups_test.target_names[newsgroups_test.target[sample_idx]]

print(sample_row)
print('---')
print('Label:', sample_label)

Need info on 88-89 Bonneville


 I am a little confused on all of the models of the 88-89 bonnevilles.
I have heard of the LE SE LSE SSE SSEI. Could someone tell me the
differences are far as features or performance. I am also curious to
know what the book value is for prefereably the 89 model. And how much
less than book value can you usually get them for. In other words how
much are they in demand this time of year. I have heard that the mid-spring
early summer is the best time to buy.

			Neil Gandler

---
Label: rec.autos


In [32]:
response = client.models.generate_content(
    model="gemini-1.5-flash-001", contents=sample_row)
print(response.text)

##  Bonneville Models Explained (1988-1989)

You're right, the Bonneville lineup was a bit of a jungle in those years. Here's a breakdown:

**Base Bonneville:**  The basic, no-frills model. It still had a powerful engine, but lacked many of the luxuries of the higher trims.

**LE (Luxury Edition):**  Added comfort and convenience features like power windows, locks, and mirrors. Often came with cloth upholstery.

**SE (Special Edition):**  A sporty trim with a more aggressive look. Featured a unique grille, wheels, and sometimes a slightly more powerful engine.

**LSE (Luxury Sport Edition):**  Combined the luxury of the LE with the sporty features of the SE. It was the ultimate "best of both worlds" Bonneville.

**SSE (Sport Sedan Edition):**  A true performance version. The SSE came with a more powerful engine, sporty suspension, and often had a special appearance package.

**SSEi (Sport Sedan Edition with Injection):**  The top-of-the-line Bonneville. The SSEi received the most power

a.  **Prompt Definition:**
    * `prompt = "From what newsgroup does the following message originate?"`

b.  **Model Interaction:**
    * `baseline_response = client.models.generate_content(...)`

c.  **Output Display:**
    * `print(baseline_response.text)`

In [33]:
# Ask the model directly in a zero-shot prompt.

prompt = "From what newsgroup does the following message originate?"
baseline_response = client.models.generate_content(
    model="gemini-1.5-flash-001",
    contents=[prompt, sample_row])
print(baseline_response.text)

While the message doesn't explicitly state the newsgroup, it's highly likely it originated from **rec.autos.pontiac**, a newsgroup dedicated to discussions about Pontiac cars. 

Here's why:

* **Topic:** The message specifically focuses on Pontiac Bonneville models from the 1988-1989 years.
* **Jargon:** The use of model designations like LE, SE, LSE, SSE, and SSEi are specific to Pontiac Bonneville models of that era.
* **Common Interest:**  People interested in buying or learning about a particular Pontiac model would likely seek information in a newsgroup dedicated to the brand.

Therefore, **rec.autos.pontiac** is the most likely origin of this message. 



## **Retry Mechanism:**

In [34]:
from google.api_core import retry

# You can use a system instruction to do more direct prompting, and get a
# more succinct answer.

system_instruct = """
You are a classification service. You will be passed input that represents
a newsgroup post and you must respond with the newsgroup from which the post
originates.
"""

# Define a helper to retry when per-minute quota is reached.
is_retriable = lambda e: (isinstance(e, genai.errors.APIError) and e.code in {429, 503})

# If you want to evaluate your own technique, replace this body of this function
# with your model, prompt and other code and return the predicted answer.
@retry.Retry(predicate=is_retriable)
def predict_label(post: str) -> str:
    response = client.models.generate_content(
        model="gemini-1.5-flash-001",
        config=types.GenerateContentConfig(
            system_instruction=system_instruct),
        contents=post)

    rc = response.candidates[0]

    # Any errors, filters, recitation, etc we can mark as a general error
    if rc.finish_reason.name != "STOP":
        return "(error)"
    else:
        # Clean up the response.
        return response.text.strip()


prediction = predict_label(sample_row)

print(prediction)
print()
print("Correct!" if prediction == sample_label else "Incorrect.")

rec.autos.misc

Incorrect.


## Evaluates the performance of the `predict_label` function (from the previous example) on a subset of test data.


In [35]:
import tqdm
from tqdm.rich import tqdm as tqdmr
import warnings

# Enable tqdm features on Pandas.
tqdmr.pandas()

# But suppress the experimental warning
warnings.filterwarnings("ignore", category=tqdm.TqdmExperimentalWarning)


# Further sample the test data to be mindful of the free-tier quota.
df_baseline_eval = sample_data(df_test, 2, '.*')

# Make predictions using the sampled data.
df_baseline_eval['Prediction'] = df_baseline_eval['Text'].progress_apply(predict_label)

# And calculate the accuracy.
accuracy = (df_baseline_eval["Class Name"] == df_baseline_eval["Prediction"]).sum() / len(df_baseline_eval)
print(f"Accuracy: {accuracy:.2%}")

Output()

Accuracy: 18.75%


### Prepares training data and either reuses an existing tuned model or queues a new tuning job for newsgroup classification using Gemini 1.5 Flash.

**Here's a breakdown:**

1.  **Data Preparation:**
    * `input_data` converts the training data (`df_train`) into a dictionary format suitable for model tuning.

2.  **Model ID Handling:**
    * `model_id` stores the ID of the tuned model to be used.

3.  **Tuning Job Submission:**
    * If no existing `model_id` was found, it submits a new tuning job:

4.  **Output:**
    * `print(model_id)` prints the `model_id`, which is either the ID of an existing tuned model or the ID of the newly queued tuning job.

In [36]:
from collections.abc import Iterable
import random


# Convert the data frame into a dataset suitable for tuning.
input_data = {'examples': 
    df_train[['Text', 'Class Name']]
      .rename(columns={'Text': 'textInput', 'Class Name': 'output'})
      .to_dict(orient='records')
 }

# If you are re-running this lab, add your model_id here.
model_id = None

# Or try and find a recent tuning job.
if not model_id:
  queued_model = None
  # Newest models first.
  for m in reversed(client.tunings.list()):
    # Only look at newsgroup classification models.
    if m.name.startswith('tunedModels/newsgroup-classification-model'):
      # If there is a completed model, use the first (newest) one.
      if m.state.name == 'JOB_STATE_SUCCEEDED':
        model_id = m.name
        print('Found existing tuned model to reuse.')
        break

      elif m.state.name == 'JOB_STATE_RUNNING' and not queued_model:
        # If there's a model still queued, remember the most recent one.
        queued_model = m.name
  else:
    if queued_model:
      model_id = queued_model
      print('Found queued model, still waiting.')


# Upload the training data and queue the tuning job.
if not model_id:
    tuning_op = client.tunings.tune(
        base_model="models/gemini-1.5-flash-001-tuning",
        training_dataset=input_data,
        config=types.CreateTuningJobConfig(
            tuned_model_display_name="Newsgroup classification model",
            batch_size=16,
            epoch_count=2,
        ),
    )

    print(tuning_op.state)
    model_id = tuning_op.name

print(model_id)

  tuning_op = client.tunings.tune(


JobState.JOB_STATE_QUEUED
tunedModels/newsgroup-classification-model-v03myfqwv


### Waits for a model tuning job to complete and handles potential timeouts or errors.

1.  **Initialization:**
    * `MAX_WAIT` defines the maximum time to wait for the tuning job (10 minutes).
2.  **Waiting and Status Updates:**
3.  **Timeout Handling:**
4.  **Completion and Error Handling:**

The code waits for a model tuning job to finish, providing status updates. If the job takes too long, it switches to a pre-trained model to avoid waiting indefinitely. It also handles and reports any errors that occurred during the tuning process.

In [37]:
import datetime
import time


MAX_WAIT = datetime.timedelta(minutes=10)

while not (tuned_model := client.tunings.get(name=model_id)).has_ended:

    print(tuned_model.state)
    time.sleep(60)

    # Don't wait too long. Use a public model if this is going to take a while.
    if datetime.datetime.now(datetime.timezone.utc) - tuned_model.create_time > MAX_WAIT:
        print("Taking a shortcut, using a previously prepared model.")
        model_id = "tunedModels/newsgroup-classification-model-ltenbi1b"
        tuned_model = client.tunings.get(name=model_id)
        break


print(f"Done! The model state is: {tuned_model.state.name}")

if not tuned_model.has_succeeded and tuned_model.error:
    print("Error:", tuned_model.error)

JobState.JOB_STATE_RUNNING
JobState.JOB_STATE_RUNNING
JobState.JOB_STATE_RUNNING


KeyboardInterrupt: 

In [38]:
new_text = """
First-timer looking to get out of here.

Hi, I'm writing about my interest in travelling to the outer limits!

What kind of craft can I buy? What is easiest to access from this 3rd rock?

Let me know how to do that please.
"""

response = client.models.generate_content(
    model=model_id, contents=new_text)

print(response.text)

ClientError: 400 INVALID_ARGUMENT. {'error': {'code': 400, 'message': 'Tuned model tunedModels/newsgroup-classification-model-v03myfqwv is not ready to use.', 'status': 'INVALID_ARGUMENT'}}


1.  **Baseline Model Input Cost:**
2.  **Tuned Model Input Cost:**
3.  **Token Savings Calculation:**

In [39]:
# Calculate the *input* cost of the baseline model with system instructions.
sysint_tokens = client.models.count_tokens(
    model='gemini-1.5-flash-001', contents=[system_instruct, sample_row]
).total_tokens
print(f'System instructed baseline model: {sysint_tokens} (input)')

# Calculate the input cost of the tuned model.
tuned_tokens = client.models.count_tokens(model=tuned_model.base_model, contents=sample_row).total_tokens
print(f'Tuned model: {tuned_tokens} (input)')

savings = (sysint_tokens - tuned_tokens) / tuned_tokens
print(f'Token savings: {savings:.2%}')  # Note that this is only n=1.

System instructed baseline model: 172 (input)
Tuned model: 136 (input)
Token savings: 26.47%


In [40]:
baseline_token_output = baseline_response.usage_metadata.candidates_token_count
print('Baseline (verbose) output tokens:', baseline_token_output)

tuned_model_output = client.models.generate_content(
    model=model_id, contents=sample_row)
tuned_tokens_output = tuned_model_output.usage_metadata.candidates_token_count
print('Tuned output tokens:', tuned_tokens_output)

Baseline (verbose) output tokens: 157


ClientError: 400 INVALID_ARGUMENT. {'error': {'code': 400, 'message': 'Tuned model tunedModels/newsgroup-classification-model-v03myfqwv is not ready to use.', 'status': 'INVALID_ARGUMENT'}}

In [41]:
# Calculate the *input* cost of the baseline model with system instructions.
sysint_tokens = client.models.count_tokens(
    model='gemini-1.5-flash-001', contents=[system_instruct, sample_row]
).total_tokens
print(f'System instructed baseline model: {sysint_tokens} (input)')

# Calculate the input cost of the tuned model.
tuned_tokens = client.models.count_tokens(model=tuned_model.base_model, contents=sample_row).total_tokens
print(f'Tuned model: {tuned_tokens} (input)')

savings = (sysint_tokens - tuned_tokens) / tuned_tokens
print(f'Token savings: {savings:.2%}')  # Note that this is only n=1.

System instructed baseline model: 172 (input)
Tuned model: 136 (input)
Token savings: 26.47%


In [42]:
baseline_token_output = baseline_response.usage_metadata.candidates_token_count
print('Baseline (verbose) output tokens:', baseline_token_output)

tuned_model_output = client.models.generate_content(
    model=model_id, contents=sample_row)
tuned_tokens_output = tuned_model_output.usage_metadata.candidates_token_count
print('Tuned output tokens:', tuned_tokens_output)

Baseline (verbose) output tokens: 157


ClientError: 400 INVALID_ARGUMENT. {'error': {'code': 400, 'message': 'Tuned model tunedModels/newsgroup-classification-model-v03myfqwv is not ready to use.', 'status': 'INVALID_ARGUMENT'}}