### Text Classification Using the Google Gemini API

This project riffs on the notebook supplied on day four in Google and Kaggle's generative AI course  (see 'notebook' folder). Having trained my own models to complete text classification tasks (see 'coursework_text_classifier' notebook) I wanted to experiment with an LLM to see how the results compared. I've retained and/or reworked some of the text from the original Google/Kaggle notebook so that I can continue to use this as a resource for my own learning and development.

In this notebook I use the Gemini API to fine-tune a custom, task-specific model. Fine-tuning can be used for a variety of tasks from classic NLP problems like entity extraction or summarisation, to creative tasks like stylised generation. I fine-tune a model to classify the category of a piece of text (a tweet) into the category it belongs to (a specific natural disaster, or not disaster-realted).

This notebook tunes a model with the API. [AI Studio](https://aistudio.google.com/app/tune) also supports creating new tuned models directly in the web UI, allowing you to quickly create and monitor models using data from Google Sheets, Drive or your own files.

In [33]:
!pip uninstall -qqy jupyterlab  # Remove unused conflicting packages
!pip install -U -q "google-genai==1.7.0"

### Imports

In [34]:
# Standard library imports
import os
import re
import random
import datetime
import time
import warnings
from collections import Counter
from collections.abc import Iterable

# Third-party imports
import pandas as pd
import ndjson
from sklearn.model_selection import train_test_split
from tqdm.rich import tqdm as tqdmr

# Application-specific / service-specific imports
from dotenv import load_dotenv
from google import genai
from google.genai import types
from google.api_core import retry

genai.__version__

'1.7.0'

### Set up the API key

In [35]:
load_dotenv()

GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY")

if not GOOGLE_API_KEY:
    raise ValueError("GEMINI_API_KEY not found in environment variables.")

client = genai.Client(api_key=GOOGLE_API_KEY)

### Explore available models

I'll be using the [`TunedModel.create`](https://ai.google.dev/api/tuning#method:-tunedmodels.create) API method to start the fine-tuning job and create my custom model. You can find a model that supports it through the [`models.list`](https://ai.google.dev/api/models#method:-models.list) endpoint. You can also find more information about tuning models in [the model tuning docs](https://ai.google.dev/gemini-api/docs/model-tuning/tutorial?lang=python).

In [36]:
for model in client.models.list():
    if "createTunedModel" in model.supported_actions:
        print(model.name)

models/gemini-1.5-flash-001-tuning


## Download the dataset

For this activity, I once more use the 'Disaster Tweet Corpus 2020' data set. The entire data set consists of 48 newline delimited JSON (ndjson) files covering 10 different disaster types. These include earthquakes, tsunamis, and humanmade industrial disasters. Each data set contains an even split of disaster-related and non-disaster-related tweets. They are human-labelled, with a '1' indicating that the tweet relates to a disaster, and a '0' indicating that the opposite is true. As well as the tweet text and classification, a third variable, the user id, is also included. The data sets are intended to be used for benchmarking for filtering algorithms. Therefore, they meet the needs of the scope of this project. I used a curated selection of these files, choosing one to represent each of the five disasters: earthquake, flood, hurricane, tornado, and wildfire. I have not repeated the formal references in this project, so please look to the original notebook for said information.

In [37]:
# Read in the files but ignore the ID column as it's unecessary for this work.

data_dir = 'data'
all_data = []

for filename in os.listdir(data_dir):
    if filename.endswith('.ndjson'):
        
        label_word = filename.split('-')[0].lower()
        
        filepath = os.path.join(data_dir, filename)
        
        with open(filepath, 'r', encoding='utf-8') as f:
            tweets = ndjson.load(f)
            
            for tweet in tweets:
                text = tweet.get('text', '')
                relevance = tweet.get('relevance', 0)
                
                if relevance == 1:
                    label = label_word
                else:
                    label =  'non-disaster'
                    
                all_data.append({'text': text, 'label': label})

In [38]:
# Check the labels and their respective counts.

label_counts = Counter(d['label'] for d in all_data)
print(label_counts)

Counter({'non-disaster': 10048, 'hurricane': 3837, 'tornado': 1876, 'flood': 1797, 'earthquake': 1673, 'wildfire': 865})


Here's what a single row looks like.

In [39]:
print(all_data[0])



## Prepare the dataset

In the Google/Kaggle example, they state:

    "This pre-processing removes personal information, which can be used to "shortcut" to known users of a forum, and formats the text to appear a bit more like regular text and less like a newsgroup post (e.g. by removing the mail headers). This normalisation allows the model to generalise to regular text and not over-depend on specific fields. If your input data is always going to be newsgroup posts, it may be helpful to leave this structure in place if they provide genuine signals."

The approach to preprocessing the text I've adopted below is much more light touch than when I trained my own models. For example, I've opted not to tokenize and lemmatize the text as this will likely deny the LLM useful information when classifying the text. I think leaving in some of the symbols is possibly not advisable, but given that this is an experiment I want to see how it performs with those included. In any case, I can come back and edit the preprocessing in an attempt to improve the model's accuracy.

In [40]:
# Go back and adjust some of this if results are not satisfactory.

def clean_text_for_llm(text):
    """
    Preprocesses a text string for input into a large language model (LLM).

    This function performs the following cleaning steps:
    - Converts all characters to lowercase
    - Replaces URLs with a <URL> placeholder
    - Removes specific special characters and encoding noise
    - Collapses excess whitespace into single spaces

    Args:
        text (str): The input text to clean.

    Returns:
        str: The cleaned and normalized text.
    """
    # Lowercase
    text = text.lower()
    
    # Replace URLs with placeholder
    text = re.sub(r'https?://\S+','<URL>', text)
    
    # Fix special characters or encoding noise
    text = re.sub(r'[â€—\x01]', '', text)
    
    # Remove excess whitespace
    text = re.sub(r'\s+', ' ', text).strip()
    
    return text

In [41]:
cleaned_data = []

for tweet in all_data:
    cleaned_tweet = {
        'text': clean_text_for_llm(tweet['text']),
        'label': tweet['label']
    }
    cleaned_data.append(cleaned_tweet)

In [42]:
df_cleaned = pd.DataFrame(cleaned_data)

In [43]:
df_cleaned

Unnamed: 0,text,label
0,cruuuuud. tornado warning for denton county. a...,tornado
1,tornado warning for tarrant county. storm spot...,tornado
2,tornado on the ground in azle headed right tow...,tornado
3,us tornado toll rises to <number> (afp) <url>,tornado
4,tornado sirens. this is scary.,tornado
...,...,...
20091,<number> goes out to <user> @paytonlund <user>...,non-disaster
20092,the perception of doctors has undergone a dram...,non-disaster
20093,"birmingham, live on tour. <url>",non-disaster
20094,february <number> h is international polar bea...,non-disaster


When sampling the data, I will keep 50 rows for each category for training. Google says:

    Note that this is even fewer than the Keras example, as this technique (parameter-efficient fine-tuning, or PEFT) updates a relatively small number of parameters and does not require training a new model or updating the large model.

In [44]:
def sample_data(df, num_samples, labels_to_keep):
    """
    Samples a specified number of examples per label from a DataFrame.

    Filters the DataFrame to include only the specified labels, then 
    randomly samples a fixed number of examples for each label category.

    Args:
        df (pd.DataFrame): The input DataFrame, expected to contain a 'label' column.
        num_samples (int): The number of samples to draw for each label.
        labels_to_keep (Iterable): A list-like object of labels to include in the sample.

    Returns:
        pd.DataFrame: A new DataFrame containing the sampled rows, 
                      with the 'label' column cast to categorical type.
    """
    df = df[df["label"].isin(labels_to_keep)]
    
    df = (
        df.groupby("label", observed=False)[df.columns]
        .apply(lambda x: x.sample(num_samples, random_state=42))
        .reset_index(drop=True)
    )

    df["label"] = df["label"].astype("category")
    
    return df

In [45]:
labels_to_keep = ['non-disaster', 'hurricane', 'tornado', 'flood', 'earthquake','wildfire']
train_num_samples = 50 
test_num_samples = 10

In [46]:
df_train_raw, df_test_raw = train_test_split(
    df_cleaned,
    test_size=0.2,
    stratify=df_cleaned['label'],
    random_state=42
)

In [47]:
df_train = sample_data(df_train_raw, train_num_samples, labels_to_keep) 
df_test = sample_data(df_test_raw, test_num_samples, labels_to_keep) 

In [48]:
df_test.iloc[59]['text']

'for all in sydney and most of illawarra <user> sydney water statement: <hashtag> bluemountains <hashtag> bushfires <url> <hashtag> nswfires”'

## Evaluate baseline performance

First, I  perform an evaluation on the available models to ensure I can measure how much the tuning helps.

Below is a single sample row to use for visual inspection.

In [49]:
sample_idx = 0
sample_row = df_train.iloc[sample_idx]['text']
sample_label = df_train.iloc[sample_idx]['label']

print(sample_row)
print('---')
print('Label:', sample_label)

napa valley wineries sustain damage from <number> earthquake la times: <url> <url> via @carolcnn
---
Label: earthquake


Passing the text directly in as a prompt does not yield the desired results. The model will attempt to respond to the message.

In [50]:
response = client.models.generate_content(
    model="gemini-1.5-flash-001", contents=sample_row)
print(response.text)

This looks like a social media post about Napa Valley wineries being damaged by an earthquake. To give you the most helpful information, I need more details.  

Please tell me:

* **What number is missing?**  Is it the magnitude of the earthquake? The number of wineries damaged? 
* **What are the URLs?** I need the full URLs to access the articles.

Once I have this information, I can help you find out more about the earthquake and its impact on Napa Valley wineries. 



Prompt engineering techniques induce the model to perform the desired task.

In [51]:
# Ask the model directly in a zero-shot prompt.

prompt = "From what natural disaster does the following tweet originate?"
baseline_response = client.models.generate_content(
    model="gemini-1.5-flash-001",
    contents=[prompt, sample_row])
print(baseline_response.text)

The tweet originates from an **earthquake**. 



This technique still produces quite a verbose response. I try and parse out the relevant text, and then refine the prompt even further.

In [52]:
from google.api_core import retry

# You can use a system instruction to do more direct prompting, and get a
# more succinct answer.

system_instruct = """
You are a classification service. You will be passed input that represents
a tweet from a natural disaster and you must respond with the natural disaster from which the tweet
originates.
"""

# Define a helper to retry when per-minute quota is reached.
is_retriable = lambda e: (isinstance(e, genai.errors.APIError) and e.code in {429, 503})

# If you want to evaluate your own technique, replace this body of this function
# with your model, prompt and other code and return the predicted answer.
@retry.Retry(predicate=is_retriable)
def predict_label(post: str) -> str:
    response = client.models.generate_content(
        model="gemini-1.5-flash-001",
        config=types.GenerateContentConfig(
            system_instruction=system_instruct),
        contents=post)

    rc = response.candidates[0]

    # Any errors, filters, recitation, etc we can mark as a general error
    if rc.finish_reason.name != "STOP":
        return "(error)"
    else:
        # Clean up the response.
        return response.text.strip()


prediction = predict_label(sample_row)

print(prediction)
print()
print("Correct!" if prediction == sample_label else "Incorrect.")

Earthquake

Incorrect.


By dint of beginning in uppercase, it wrongly states that this classification is incorrect. I refine the prompt to resolve this problem.

In [53]:
from google.api_core import retry

# You can use a system instruction to do more direct prompting, and get a
# more succinct answer.

system_instruct = """
You are a classification service. You will be passed input that represents
a tweet from a natural disaster and you must respond with the natural disaster from which the tweet
originates. Your response must be lowercase.
"""

# Define a helper to retry when per-minute quota is reached.
is_retriable = lambda e: (isinstance(e, genai.errors.APIError) and e.code in {429, 503})

# If you want to evaluate your own technique, replace this body of this function
# with your model, prompt and other code and return the predicted answer.
@retry.Retry(predicate=is_retriable)
def predict_label(post: str) -> str:
    response = client.models.generate_content(
        model="gemini-1.5-flash-001",
        config=types.GenerateContentConfig(
            system_instruction=system_instruct),
        contents=post)

    rc = response.candidates[0]

    # Any errors, filters, recitation, etc we can mark as a general error
    if rc.finish_reason.name != "STOP":
        return "(error)"
    else:
        # Clean up the response.
        return response.text.strip()


prediction = predict_label(sample_row)

print(prediction)
print()
print("Correct!" if prediction == sample_label else "Incorrect.")

earthquake

Correct!


I run a short evaluation using the function defined above. The test set is further sampled to ensure the experiment runs smoothly on the API's free tier. In practice you would evaluate over the whole set.

In [54]:
df_baseline_eval = sample_data(df_test, 2, labels_to_keep)

In [55]:
import tqdm
from tqdm.rich import tqdm as tqdmr
import warnings

# Enable tqdm features on Pandas.
tqdmr.pandas()

# But suppress the experimental warning
warnings.filterwarnings("ignore", category=tqdm.TqdmExperimentalWarning)

# Further sample the test data to be mindful of the free-tier quota.
df_baseline_eval = sample_data(df_test, 2, labels_to_keep)

# Make predictions using the sampled data.
df_baseline_eval['prediction'] = df_baseline_eval['text'].progress_apply(predict_label)

# And calculate the accuracy.
accuracy = (df_baseline_eval["label"] == df_baseline_eval["prediction"]).sum() / len(df_baseline_eval)
print(f"Accuracy: {accuracy:.2%}")

Output()

Accuracy: 50.00%


Have a look at the dataframe to compare the predictions with the labels.

In [56]:
df_baseline_eval

Unnamed: 0,text,label,prediction
0,california usa downey » <url> <hashtag> sfgate...,earthquake,earthquake
1,totally felt the last <hashtag> napaquake afte...,earthquake,earthquake
2,"<user> discuss how to help flood,police author...",flood,(error)
3,monsoon floods in nepal and india cause <numbe...,flood,flood
4,"usgs:m <number> - <number> m wnw of rincon, pu...",hurricane,earthquake
5,<user> cast members raise <number> 000 for pue...,hurricane,hurricane
6,<user> you deserve win guys! we love you! you ...,non-disaster,This is not a tweet about a natural disaster.
7,"""i've been happy for two days. now it's time t...",non-disaster,This tweet does not describe a natural disaster.
8,i live in joplin mo where the <number> tornado...,tornado,tornado
9,my dad's been in <user> <user> today because h...,tornado,tornado


## Tune a custom model

In this example I use tuning to create a model that requires no prompting or system instructions and outputs succinct text from the classes provided in the training data.

The data contains both input text (the processed tweets) and output text (the category, or disaster), that I can use to start tuning a model.

When calling `tune()`, you can specify model tuning hyperparameters too:
 - `epoch_count`: defines how many times to loop through the data,
 - `batch_size`: defines how many rows to process in a single step, and
 - `learning_rate`: defines the scaling factor for updating model weights at each step.

You can also choose to omit them and use the defaults. [Learn more](https://developers.google.com/machine-learning/crash-course/linear-regression/hyperparameters) about these parameters and how they work.

### Go back and look at this to see how I can adjust for the purposes of my experiment.

For Google's example, they chose parameters by running some tuning jobs and selecting parameters that converged efficiently.

This example will start a new tuning job, but only if one does not already exist. This allows you to leave this codelab and come back later - re-running this step will find your last model.

In [57]:
# Convert the data frame into a dataset suitable for tuning.
#input_data = {'examples': 
#    df_train[['text', 'label']]
#      .rename(columns={'text': 'textInput', 'label': 'output'})
#      .to_dict(orient='records')
# }

# If you are re-running this lab, add your model_id here.
#model_id = None

# Or try and find a recent tuning job.
#if not model_id:
#  queued_model = None
  # Newest models first.
#  for m in reversed(client.tunings.list()):
    # Only look at newsgroup classification models.
#    if m.name.startswith('tunedModels/newsgroup-classification-model'):
      # If there is a completed model, use the first (newest) one.
#      if m.state.name == 'JOB_STATE_SUCCEEDED':
#        model_id = m.name
#        print('Found existing tuned model to reuse.')
#        break

#      elif m.state.name == 'JOB_STATE_RUNNING' and not queued_model:
        # If there's a model still queued, remember the most recent one.
#        queued_model = m.name
#  else:
#    if queued_model:
#      model_id = queued_model
#      print('Found queued model, still waiting.')


# Upload the training data and queue the tuning job.
#if not model_id:
#    tuning_op = client.tunings.tune(
#        base_model="models/gemini-1.5-flash-001-tuning",
#        training_dataset=input_data,
#        config=types.CreateTuningJobConfig(
#            tuned_model_display_name="Newsgroup classification model",
#            batch_size=16,
#            epoch_count=2,
#        ),
#    )
#
#    print(tuning_op.state)
#    model_id = tuning_op.name
#
#print(model_id)

In [58]:
# Prepare your new dataset
input_data = {
    'examples': df_train[['text', 'label']]
        .rename(columns={'text': 'textInput', 'label': 'output'})
        .to_dict(orient='records')
}

# Set model_id to None to ensure a new job is created
model_id = None

# Always create a new tuning job
tuning_op = client.tunings.tune(
    base_model="models/gemini-1.5-flash-001-tuning",
    training_dataset=input_data,
    config=types.CreateTuningJobConfig(
        tuned_model_display_name="Tweet classification",  # You can rename this if you'd like
        batch_size=16,
        epoch_count=2,
    ),
)

print(tuning_op.state)
model_id = tuning_op.name

print(f"New tuning job started: {model_id}")

JobState.JOB_STATE_QUEUED
New tuning job started: tunedModels/tweet-classification-8hm6pyjgdtwr


This has created a tuning job that will run in the background. To inspect the progress of the tuning job, run this cell to plot the current status and loss curve. Once the status reaches `ACTIVE`, tuning is complete and the model is ready to use.

Tuning jobs are queued, so it may look like no training steps have been taken initially but it will progress. Tuning can take anywhere from a few minutes to multiple hours, depending on factors like your dataset size and how busy the tuning infrastrature is.

It is safe to stop this cell at any point. It will not stop the tuning job.

Have a look at the [Search grounding](https://www.kaggle.com/code/markishere/day-4-google-search-grounding/) codelab. If you want to try tuning a local LLM, check out [the fine-tuning guides for tuning a Gemma model](https://ai.google.dev/gemma/docs/tune).

In [59]:
MAX_WAIT = datetime.timedelta(minutes=10)

while not (tuned_model := client.tunings.get(name=model_id)).has_ended:

    print(tuned_model.state)
    time.sleep(60)

    # Don't wait too long. Use a public model if this is going to take a while.
    if datetime.datetime.now(datetime.timezone.utc) - tuned_model.create_time > MAX_WAIT:
        print("Taking a shortcut, using a previously prepared model.")
        model_id = "tunedModels/newsgroup-classification-model-ltenbi1b"
        tuned_model = client.tunings.get(name=model_id)
        break


print(f"Done! The model state is: {tuned_model.state.name}")

if not tuned_model.has_succeeded and tuned_model.error:
    print("Error:", tuned_model.error)

JobState.JOB_STATE_RUNNING
JobState.JOB_STATE_RUNNING
JobState.JOB_STATE_RUNNING
JobState.JOB_STATE_RUNNING
Done! The model state is: JOB_STATE_SUCCEEDED


## Use the new model

Now that I have a  tuned model, I can try it out with custom data. I use the same API as a normal Gemini API interaction, but specify the new model as the model name, which starts with `tunedModels/`.

In [60]:
model_id = 'tunedModels/tweet-classification-6oock2yj2wg0'

new_text = """
The water came in the door.
"""

response = client.models.generate_content(
    model=model_id, contents=new_text)

print(response.text)

flood


### Evaluation

You can see that the model outputs labels that correspond to those in the training data, and without any system instructions or prompting, which is already a significant improvement. Now see how well it performs on the test set.

Note that there is no parallelism in this example; classifying the test sub-set will take a few minutes.

In [61]:
@retry.Retry(predicate=is_retriable)
def classify_text(text: str) -> str:
    """Classify the provided text into a known category."""
    response = client.models.generate_content(
        model=model_id, contents=text)
    rc = response.candidates[0]

    # Any errors, filters, recitation, etc we can mark as a general error
    if rc.finish_reason.name != "STOP":
        return "(error)"
    else:
        return rc.content.parts[0].text


# The sampling here is just to minimise your quota usage. If you can, you should
# evaluate the whole test set with `df_model_eval = df_test.copy()`.
df_model_eval = sample_data(df_test, 4, labels_to_keep)

df_model_eval["prediction"] = df_model_eval["text"].progress_apply(classify_text)

accuracy = (df_model_eval["label"] == df_model_eval["prediction"]).sum() / len(df_model_eval)
print(f"Accuracy: {accuracy:.2%}")

Output()

Accuracy: 79.17%


In [62]:
df_model_eval

Unnamed: 0,text,label,prediction
0,california usa downey » <url> <hashtag> sfgate...,earthquake,earthquake
1,totally felt the last <hashtag> napaquake afte...,earthquake,earthquake
2,if you really want to help <hashtag> drinknapa...,earthquake,earthquake
3,<user> there was also an earthquake in califor...,earthquake,earthquake
4,"<user> discuss how to help flood,police author...",flood,(error)
5,monsoon floods in nepal and india cause <numbe...,flood,flood
6,many dead in nepal and india floods <url> @rar...,flood,flood
7,<hashtag> news <hashtag> mostrecent hundreds d...,flood,flood
8,"usgs:m <number> - <number> m wnw of rincon, pu...",hurricane,earthquake
9,<user> cast members raise <number> 000 for pue...,hurricane,hurricane


## Compare token usage

AI Studio and the Gemini API provide model tuning at no cost, however normal limits and charges apply for *use* of a tuned model.

The size of the input prompt and other generation config like system instructions, as well as the number of generated output tokens, all contribute to the overall cost of a request.

In [63]:
# Calculate the *input* cost of the baseline model with system instructions.
sysint_tokens = client.models.count_tokens(
    model='gemini-1.5-flash-001', contents=[system_instruct, sample_row]
).total_tokens
print(f'System instructed baseline model: {sysint_tokens} (input)')

# Calculate the input cost of the tuned model.
tuned_tokens = client.models.count_tokens(model=tuned_model.base_model, contents=sample_row).total_tokens
print(f'Tuned model: {tuned_tokens} (input)')

savings = (sysint_tokens - tuned_tokens) / tuned_tokens
print(f'Token savings: {savings:.2%}')  # Note that this is only n=1.

System instructed baseline model: 69 (input)
Tuned model: 25 (input)
Token savings: 176.00%


The earlier verbose model also produced more output tokens than needed for this task.

In [64]:
baseline_token_output = baseline_response.usage_metadata.candidates_token_count
print('Baseline (verbose) output tokens:', baseline_token_output)

tuned_model_output = client.models.generate_content(
    model=model_id, contents=sample_row)
tuned_tokens_output = tuned_model_output.usage_metadata.candidates_token_count
print('Tuned output tokens:', tuned_tokens_output)

Baseline (verbose) output tokens: 9
Tuned output tokens: 2


## Next steps

While the accuracy here is good given the relatively small amount of time I have put into tuning the model, this is not as accurate as the models that I have trained myself. I'm confident that I could increase the accuracy of this classifier through further feature engineering, hyperparameter tuning, or reworking the system instruction. 