##### Copyright 2025 Google LLC.

In [None]:
# @title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [1]:
%pip install -U -q "google-genai==1.7.0"


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
from google import genai
from google.genai import types

genai.__version__

'1.7.0'

In [3]:
import os
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")

client = genai.Client(api_key=GOOGLE_API_KEY)

In [4]:
for model in client.models.list():
    if "createTunedModel" in model.supported_actions:
        print(model.name)

models/gemini-1.5-flash-001-tuning


## Download the dataset



In [5]:
from sklearn.datasets import fetch_20newsgroups

newsgroups_train = fetch_20newsgroups(subset="train")
newsgroups_test = fetch_20newsgroups(subset="test")

newsgroups_train.target_names

['alt.atheism',
 'comp.graphics',
 'comp.os.ms-windows.misc',
 'comp.sys.ibm.pc.hardware',
 'comp.sys.mac.hardware',
 'comp.windows.x',
 'misc.forsale',
 'rec.autos',
 'rec.motorcycles',
 'rec.sport.baseball',
 'rec.sport.hockey',
 'sci.crypt',
 'sci.electronics',
 'sci.med',
 'sci.space',
 'soc.religion.christian',
 'talk.politics.guns',
 'talk.politics.mideast',
 'talk.politics.misc',
 'talk.religion.misc']

In [6]:
print(newsgroups_train.data[0])

From: lerxst@wam.umd.edu (where's my thing)
Subject: WHAT car is this!?
Nntp-Posting-Host: rac3.wam.umd.edu
Organization: University of Maryland, College Park
Lines: 15

 I was wondering if anyone out there could enlighten me on this car I saw
the other day. It was a 2-door sports car, looked to be from the late 60s/
early 70s. It was called a Bricklin. The doors were really small. In addition,
the front bumper was separate from the rest of the body. This is 
all I know. If anyone can tellme a model name, engine specs, years
of production, where this car is made, history, or whatever info you
have on this funky looking car, please e-mail.

Thanks,
- IL
   ---- brought to you by your neighborhood Lerxst ----







## Prepare the dataset

In [7]:
import email
import re

import pandas as pd

def preprocess_newsgroup_row(data):
    msg = email.message_from_string(data)
    text = f"{msg['Subject']}\n\n{msg.get_payload()}"

    text = re.sub(r"[\w\.-]+@[\w\.-]+", "", text)

    text = text[:40000]

    return text

def preprocess_newsgroup_data(newsgroup_dataset):
    df = pd.DataFrame(
        {"Text": newsgroup_dataset.data, "Label": newsgroup_dataset.target}
    )

    # Clean up the text
    df['Text'] =df['Text'].apply(preprocess_newsgroup_row)

    df["Class Name"] = df["Label"].map(lambda l: newsgroup_dataset.target_names[l])

    return df

In [8]:
# Apply preprocessing to training and test datasets
df_train = preprocess_newsgroup_data(newsgroups_train)
df_test = preprocess_newsgroup_data(newsgroups_test)

df_train.head()

Unnamed: 0,Text,Label,Class Name
0,WHAT car is this!?\n\n I was wondering if anyo...,7,rec.autos
1,SI Clock Poll - Final Call\n\nA fair number of...,4,comp.sys.mac.hardware
2,"PB questions...\n\nwell folks, my mac plus fin...",4,comp.sys.mac.hardware
3,Re: Weitek P9000 ?\n\nRobert J.C. Kyanko () wr...,1,comp.graphics
4,Re: Shuttle Launch Question\n\nFrom article <>...,14,sci.space


In [9]:
def sample_data(df, num_samples, classes_to_keep):
    df =(
        df.groupby("Label")[df.columns]
        .apply(lambda x: x.sample(num_samples))
        .reset_index(drop=True)
    )

    df = df[df["Class Name"].str.contains(classes_to_keep)]
    df["Class Name"] = df["Class Name"].astype("category")

    return df

In [10]:
TRAIN_NUM_SAMPLES = 50
TEST_NUM_SAMPLES = 10

CLASSES_TO_KEEP = "^rec|^sci"

df_train = sample_data(df_train, TRAIN_NUM_SAMPLES, CLASSES_TO_KEEP)
df_test = sample_data(df_test, TEST_NUM_SAMPLES, CLASSES_TO_KEEP)

## Evaluate baseline performance

In [11]:
sample_idx = 0
sample_row = preprocess_newsgroup_row(newsgroups_test.data[sample_idx])
sample_label = newsgroups_test.target_names[newsgroups_test.target[sample_idx]]

print(sample_row)
print('---')
print('Label:', sample_label)

Need info on 88-89 Bonneville


 I am a little confused on all of the models of the 88-89 bonnevilles.
I have heard of the LE SE LSE SSE SSEI. Could someone tell me the
differences are far as features or performance. I am also curious to
know what the book value is for prefereably the 89 model. And how much
less than book value can you usually get them for. In other words how
much are they in demand this time of year. I have heard that the mid-spring
early summer is the best time to buy.

			Neil Gandler

---
Label: rec.autos


In [12]:
response = client.models.generate_content(
    model="gemini-1.5-flash-001", contents=sample_row
)

print(response.text)

You are right to be confused, the 1988-1989 Pontiac Bonneville model designations are a bit of a mess.  Here is a breakdown to help clear things up:

**1988-1989 Pontiac Bonneville Trim Levels:**

* **LE:** This was the base model, offering a standard V6 engine and basic features. 
* **SE:**  The "SE" stood for "Special Edition."  It usually had a slightly more luxurious interior, possibly with cloth or vinyl upholstery, and some additional features.  It may have included an optional V8 engine.  The "SE" could be further broken down into:
    * **LSE:**  The "LSE" was a special "Luxury Special Edition" which added more interior luxury and features.  It could also include a V8 engine as an option.
    * **SSE:**  The "SSE"  was the "Sport Special Edition" and was more focused on performance.  It usually came standard with a powerful V8 engine, upgraded suspension, and sporty exterior touches. 
* **SSEi:**  This was the ultimate performance variant and stood for "Sport Special Edition In

In [13]:
prompt = "From what newsgroup does the following message originate?"
baseline_response = client.models.generate_content(
    model="gemini-1.5-flash-001",
    contents=[prompt, sample_row]
)
print(baseline_response.text)

This message likely originates from the **alt.autos.pontiac** newsgroup. 

Here's why:

* **Topic:** The message focuses specifically on Pontiac Bonnevilles from the 1988-1989 model years.
* **Specific Model References:** The message mentions specific trim levels (LE, SE, LSE, SSE, SSEi) which are relevant to Pontiac Bonnevilles.
* **Newsgroup Focus:**  The alt.autos.pontiac newsgroup is dedicated to discussions about Pontiac vehicles, making it the most likely source for this type of inquiry. 



In [14]:
from google.api_core import retry

system_instruct = """
You are a classification service. You will be passed input that represents
a newsgroup post and you must respond with the newsgroup from which the post
originates.
"""

# helper
is_retriable = lambda e: (isinstance(e, genai.errors.APIError) and e.code in {429, 503})

In [15]:
@retry.Retry(predicate=is_retriable)
def predict_label(post: str) -> str:
    response = client.models.generate_content(
        model="gemini-1.5-flash-001",
        config=types.GenerateContentConfig(
            system_instruction=system_instruct
        ),
        contents=post
    )

    rc = response.candidates[0]

    # Any errors, filters, recitation, etc we can mark as a general error
    if rc.finish_reason.name != "STOP":
        return "(error)"
    else:
        return response.text.strip()
    


In [16]:
prediction = predict_label(sample_row)

In [17]:
print(prediction)
print("---")
print("Correct!" if prediction == sample_label else "Incorrect.")

rec.autos.misc
---
Incorrect.


In [18]:
import tqdm
from tqdm.rich import tqdm as tqdmr
import warnings

tqdmr.pandas()

warnings.filterwarnings("ignore", category=tqdm.TqdmExperimentalWarning)


In [19]:
df_baseline_eval = sample_data(df_test, 2, ".*")

df_baseline_eval['Prediction'] = df_baseline_eval['Text'].progress_apply(predict_label)

accuracy = (df_baseline_eval["Class Name"] == df_baseline_eval["Prediction"]).sum() / len(df_baseline_eval)

print(f"Accuracy: {accuracy:.2%}")


Output()

Accuracy: 31.25%


In [20]:
df_baseline_eval

Unnamed: 0,Text,Label,Class Name,Prediction
0,"Re: WHAT car is this!?\n\nIn article , (Marku...",7,rec.autos,rec.autos.misc
1,Re: Choice of gauges\n\n (Dave Gauge) writes:\...,7,rec.autos,rec.autos.bmw
2,Re: Riceburner Respect\n\nIn article <> \n (Ch...,8,rec.motorcycles,rec.motorcycles
3,1st time Biker iso ADVICE\n\n\nI'm just starti...,8,rec.motorcycles,rec.motorcycles
4,John Franco\n\nWhat's with John Franco? The M...,9,rec.sport.baseball,rec.sport.baseball
5,Re: Bosox go down in smoke II (Seattle 7-0) .....,9,rec.sport.baseball,rec.sports.baseball
6,Re: ESPN sucks: OT or Baseball? Guess which.\...,10,rec.sport.hockey,rec.sport.hockey
7,"Re: Goodbye, good riddance, get lost 'Stars\n\...",10,rec.sport.hockey,rec.sport.hockey
8,Re: Why the clipper algorithm is secret\n\nIn ...,11,sci.crypt,(error)
9,Re: Once they get your keys....\n\nIn article ...,11,sci.crypt,(error)


## Tune a custom model

이 예제에서는 튜닝을 사용하여 프롬프트나 시스템 지침 없이 학습 데이터에 제공된 클래스에서 간결한 텍스트를 출력하는 모델을 생성합니다.

데이터에는 입력 텍스트(처리된 게시물)와 출력 텍스트(카테고리 또는 뉴스그룹)가 모두 포함되어 있으며, 이를 사용하여 모델 튜닝을 시작할 수 있습니다.

tune()을 호출할 때 모델 튜닝 하이퍼파라미터도 지정할 수 있습니다.

* epoch_count: 데이터 반복 횟수를 정의합니다.
* batch_size: 한 단계에서 처리할 행 수를 정의합니다.
* learning_rate: 각 단계에서 모델 가중치를 업데이트하기 위한 스케일링 인자를 정의합니다.

이러한 하이퍼파라미터를 생략하고 기본값을 사용할 수도 있습니다. 이러한 파라미터와 작동 방식에 대해 자세히 알아보세요. 이 예제에서는 튜닝 작업을 실행하고 효율적으로 수렴하는 파라미터를 선택하여 이러한 파라미터를 선택했습니다.

이 예제에서는 새 튜닝 작업이 아직 없는 경우에만 시작합니다. 이를 통해 이 코드랩을 종료하고 나중에 다시 돌아올 수 있습니다. 이 단계를 다시 실행하면 마지막 모델을 찾을 수 있습니다.

In [21]:
from collections.abc import Iterable
import random

# Convert the data frame into a dataset suitable for tuning.
input_data = {'examples': 
              df_train[['Text', 'Class Name']]
              .rename(columns={'Text': 'textInput', 'Class Name': 'output'})
              .to_dict(orient='records')
}

In [22]:
# If you are re-running this lab, add your model_id here.
model_id = None
# model_id = "tunedModels/newsgroup-classification-model-fqremn0a7"



In [23]:
# Or try and find a recent tuning job.
if not model_id:
  queued_model = None
  # Newest models first.
  for m in reversed(client.tunings.list()):
    # Only look at newsgroup classification models.
    if m.name.startswith('tunedModels/newsgroup-classification-model'):
      # If there is a completed model, use the first (newest) one.
      if m.state.name == 'JOB_STATE_SUCCEEDED':
        model_id = m.name
        print('Found existing tuned model to reuse.')
        break

      elif m.state.name == 'JOB_STATE_RUNNING' and not queued_model:
        # If there's a model still queued, remember the most recent one.
        queued_model = m.name
  else:
    if queued_model:
      model_id = queued_model
      print('Found queued model, still waiting.')

Found existing tuned model to reuse.


In [24]:
# Upload the training data and queue the tuning job.
if not model_id:
    tuning_op = client.tunings.tune(
        base_model="models/gemini-1.5-flash-001-tunign",
        training_dataset=input_data,
        config=types.CreateTuningJobConfig(
            tuned_model_display_name="Newsgroup classification model",
            batch_size=16,
            epoch_count=2,
        )
    )

    print(tuning_op.state)
    model_id = tuning_op.name

print(model_id)

tunedModels/newsgroup-classification-model-fqremn0a7


In [25]:
import datetime
import time


MAX_WAIT = datetime.timedelta(minutes=10)

while not (tuned_model := client.tunings.get(name=model_id)).has_ended:

    print(tuned_model.state)
    time.sleep(60)

    # Don't wait too long. Use a public model if this is going to take a while.
    if datetime.datetime.now(datetime.timezone.utc) - tuned_model.create_time > MAX_WAIT:
        print("Taking a shortcut, using a previously prepared model.")
        model_id = "tunedModels/newsgroup-classification-model-ltenbi1b"
        tuned_model = client.tunings.get(name=model_id)
        break


print(f"Done! The model state is: {tuned_model.state.name}")

if not tuned_model.has_succeeded and tuned_model.error:
    print("Error:", tuned_model.error)

Done! The model state is: JOB_STATE_SUCCEEDED


## Use the new model


In [26]:
new_text = """
First-timer looking to get out of here.
Hi, I'm writing about my interest in travelling to the outer limits!
What kind of craft can I buy? What is easiest to access from this 3rd rock?
Let me know how to do that please.
"""

In [27]:
response = client.models.generate_content(
    model=model_id,
    contents=new_text
)

print(response.text)

sci.space


## Evaluation

In [28]:
@retry.Retry(predicate=is_retriable)
def classify_text(text: str) -> str:
    """Classify the provided text into a known newsgroup."""
    response = client.models.generate_content(
        model=model_id,
        contents=text
    )

    rc =response.candidates[0]

    if rc.finish_reason.name != "STOP":
        return "(error)"
    else:
        return rc.content.parts[0].text
    

In [29]:
df_model_eval = sample_data(df_test, 4, '.*')

df_model_eval["Prediction"] = df_model_eval["Text"].progress_apply(classify_text)

accuracy = (df_model_eval["Class Name"] == df_model_eval["Prediction"]).sum() / len(df_model_eval)
print(f"Accuracy: {accuracy:.2%}")

Output()

Accuracy: 84.38%


## Compare token usage

In [30]:
sysint_tokens = client.models.count_tokens(
    model='gemini-1.5-flash-001',
    contents=[system_instruct, sample_row]
).total_tokens

print(f'System instructed baseline model: {sysint_tokens} (input)')



System instructed baseline model: 172 (input)


In [31]:
tuned_tokens = client.models.count_tokens(
    model=tuned_model.base_model, 
    contents=sample_row,
).total_tokens

print(f'Tuned model: {tuned_tokens} (input)')

Tuned model: 136 (input)


In [32]:
savings = (sysint_tokens - tuned_tokens) / tuned_tokens
print(f'Token savings: {savings:.2%}')

Token savings: 26.47%


In [33]:
baseline_token_output = baseline_response.usage_metadata.candidates_token_count
print('Baseline (verbose) output tokens:', baseline_token_output)

Baseline (verbose) output tokens: 123


In [35]:
tuned_model_output = client.models.generate_content(
    model=model_id,
    contents=sample_row
)

In [36]:
tuned_token_output= tuned_model_output.usage_metadata.candidates_token_count
print('Tuned output tokens:', tuned_token_output)

Tuned output tokens: 4
