# Day 2 - Classifying embeddings with Keras and the Gemini API

## Overview

Welcome back to the Kaggle 5-day Generative AI course. In this notebook, you'll learn to use the embeddings produced by the Gemini API to train a model that can classify newsgroup posts into the categories (the newsgroup itself) from the post contents.

This technique uses the Gemini API's embeddings as input, avoiding the need to train on text input directly, and as a result it is able to perform quite well using relatively few examples compared to training a text model from scratch.


In [4]:
import google.generativeai as genai
from openai import OpenAI
from dotenv import load_dotenv
import os

In [15]:
# Load the variables from the .env file
load_dotenv()

# Access the variables
gemini_api_key = os.getenv("GOOGLE_API_KEY")

client = OpenAI(
    api_key=gemini_api_key,
    base_url="https://generativelanguage.googleapis.com/v1beta/"
)


# 1. Dataset

The 20 Newsgroups Text Dataset contains 18,000 newsgroups posts on 20 topics divided into training and test sets. The split between the training and test datasets are based on messages posted before and after a specific date. For this tutorial, you will use sampled subsets of the training and test sets, and perform some processing using Pandas.

In [8]:
from sklearn.datasets import fetch_20newsgroups


newsgroups_train = fetch_20newsgroups(subset="train")
newsgroups_test = fetch_20newsgroups(subset="test")

# View list of class names for dataset
newsgroups_train.target_names

['alt.atheism',
 'comp.graphics',
 'comp.os.ms-windows.misc',
 'comp.sys.ibm.pc.hardware',
 'comp.sys.mac.hardware',
 'comp.windows.x',
 'misc.forsale',
 'rec.autos',
 'rec.motorcycles',
 'rec.sport.baseball',
 'rec.sport.hockey',
 'sci.crypt',
 'sci.electronics',
 'sci.med',
 'sci.space',
 'soc.religion.christian',
 'talk.politics.guns',
 'talk.politics.mideast',
 'talk.politics.misc',
 'talk.religion.misc']

In [9]:
print(newsgroups_train.data[0])

From: lerxst@wam.umd.edu (where's my thing)
Subject: WHAT car is this!?
Nntp-Posting-Host: rac3.wam.umd.edu
Organization: University of Maryland, College Park
Lines: 15

 I was wondering if anyone out there could enlighten me on this car I saw
the other day. It was a 2-door sports car, looked to be from the late 60s/
early 70s. It was called a Bricklin. The doors were really small. In addition,
the front bumper was separate from the rest of the body. This is 
all I know. If anyone can tellme a model name, engine specs, years
of production, where this car is made, history, or whatever info you
have on this funky looking car, please e-mail.

Thanks,
- IL
   ---- brought to you by your neighborhood Lerxst ----







In [11]:
import email
import re

import pandas as pd

"""
Start by preprocessing the data for this tutorial in a Pandas dataframe. To remove any sensitive information like 
names and email addresses, you will take only the subject and body of each message. This is an optional step that transforms 
the input data into more generic text, rather than email posts, so that it will work in other contexts.
"""

def preprocess_newsgroup_row(data):
    # Extract only the subject and body
    msg = email.message_from_string(data)
    text = f"{msg['Subject']}\n\n{msg.get_payload()}"
    # Strip any remaining email addresses
    text = re.sub(r"[\w\.-]+@[\w\.-]+", "", text)
    # Truncate each entry to 5,000 characters
    text = text[:5000]

    return text


def preprocess_newsgroup_data(newsgroup_dataset):
    # Put data points into dataframe
    df = pd.DataFrame(
        {"Text": newsgroup_dataset.data, "Label": newsgroup_dataset.target}
    )
    # Clean up the text
    df["Text"] = df["Text"].apply(preprocess_newsgroup_row)
    # Match label to target name index
    df["Class Name"] = df["Label"].map(lambda l: newsgroup_dataset.target_names[l])

    return df
    

# Apply preprocessing function to training and test datasets
df_train = preprocess_newsgroup_data(newsgroups_train)
df_test = preprocess_newsgroup_data(newsgroups_test)

df_train.head()

Unnamed: 0,Text,Label,Class Name
0,WHAT car is this!?\n\n I was wondering if anyo...,7,rec.autos
1,SI Clock Poll - Final Call\n\nA fair number of...,4,comp.sys.mac.hardware
2,"PB questions...\n\nwell folks, my mac plus fin...",4,comp.sys.mac.hardware
3,Re: Weitek P9000 ?\n\nRobert J.C. Kyanko () wr...,1,comp.graphics
4,Re: Shuttle Launch Question\n\nFrom article <>...,14,sci.space


In [12]:
def sample_data(df, num_samples, classes_to_keep):
    # Sample rows, selecting num_samples of each Label.
    df = (
        df.groupby("Label")[df.columns]
        .apply(lambda x: x.sample(num_samples))
        .reset_index(drop=True)
    )

    df = df[df["Class Name"].str.contains(classes_to_keep)]

    # We have fewer categories now, so re-calibrate the label encoding.
    df["Class Name"] = df["Class Name"].astype("category")
    df["Encoded Label"] = df["Class Name"].cat.codes

    return df
    

TRAIN_NUM_SAMPLES = 100
TEST_NUM_SAMPLES = 25
CLASSES_TO_KEEP = "sci"  # Class name should contain 'sci' to keep science categories

df_train = sample_data(df_train, TRAIN_NUM_SAMPLES, CLASSES_TO_KEEP)
df_test = sample_data(df_test, TEST_NUM_SAMPLES, CLASSES_TO_KEEP)

In [13]:
df_train.value_counts("Class Name")

Class Name
sci.crypt          100
sci.electronics    100
sci.med            100
sci.space          100
Name: count, dtype: int64

In [14]:
df_test.value_counts("Class Name")

Class Name
sci.crypt          25
sci.electronics    25
sci.med            25
sci.space          25
Name: count, dtype: int64

# 2. Create the embeddings

In this section, you will generate embeddings for each piece of text using the Gemini API embeddings endpoint. To learn more about embeddings, visit the embeddings guide.

In [31]:
from google.api_core import retry
from tqdm.rich import tqdm
import time


tqdm.pandas()

model = "models/text-embedding-004"

@retry.Retry(timeout=300.0)
def embed_fn(text: str) -> list[float]:
    # You will be performing classification, so set task_type accordingly.
    embed = client.embeddings.create(input=text,
                                     model=model)
    time.sleep(0.2)
    return embed.data[0].embedding


def create_embeddings(df):
    df["Embeddings"] = df["Text"].progress_apply(embed_fn)
    return df

In [23]:
df_train = create_embeddings(df_train)
df_test = create_embeddings(df_test)
df_train.head()

Output()

  t = cls(total=total, **tqdm_kwargs)


Output()

  t = cls(total=total, **tqdm_kwargs)


Unnamed: 0,Text,Label,Class Name,Encoded Label,Embeddings
1100,Re: Fifth Amendment and Passwords\n\nIn articl...,11,sci.crypt,0,"[-0.027680234983563423, 0.04840068146586418, -..."
1101,"Re: Once tapped, your code is no good any more...",11,sci.crypt,0,"[0.051765553653240204, 0.010603193193674088, -..."
1102,"Re: Once tapped, your code is no good any more...",11,sci.crypt,0,"[0.030571889132261276, 0.010716283693909645, -..."
1103,Re: Off the shelf cheap DES keyseach machine (...,11,sci.crypt,0,"[0.02225799486041069, 0.003658391535282135, -0..."
1104,Re: Secret algorithm [Re: Clipper Chip and cry...,11,sci.crypt,0,"[0.06348730623722076, 0.05079176649451256, -0...."


# 3. Build a classification model

In [32]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import accuracy_score


class ClassificationModel(nn.Module):
    def __init__(self, input_size: int, num_classes: int):
        super(ClassificationModel, self).__init__()
        self.hidden = nn.Linear(input_size, input_size)
        self.relu = nn.ReLU()
        self.output_probs = nn.Linear(input_size, num_classes)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.hidden(x)
        x = self.relu(x)
        x = self.output_probs(x)
        x = self.softmax(x)
        return x


# Extract embedding size and class count from the data
embedding_size = len(df_train["Embeddings"].iloc[0])
num_classes = len(df_train["Class Name"].unique())

# Convert the data into NumPy arrays and then to PyTorch tensors
y_train = torch.tensor(df_train["Encoded Label"].values, dtype=torch.long)
x_train = torch.tensor(np.stack(df_train["Embeddings"].values), dtype=torch.float32)
y_val = torch.tensor(df_test["Encoded Label"].values, dtype=torch.long)
x_val = torch.tensor(np.stack(df_test["Embeddings"].values), dtype=torch.float32)

# Create DataLoader for batching
train_dataset = TensorDataset(x_train, y_train)
val_dataset = TensorDataset(x_val, y_val)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Initialize the model, loss function, and optimizer
classifier = ClassificationModel(input_size=embedding_size, num_classes=num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=0.001)

# Training loop
NUM_EPOCHS = 20
early_stop_patience = 10
best_val_acc = 0
patience = 0

for epoch in range(NUM_EPOCHS):
    classifier.train()
    train_loss = 0.0
    train_correct = 0
    total_samples = 0

    for x_batch, y_batch in train_loader:
        optimizer.zero_grad()
        outputs = classifier(x_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        train_correct += (predicted == y_batch).sum().item()
        total_samples += y_batch.size(0)

    train_acc = train_correct / total_samples
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Train Loss: {train_loss/len(train_loader)}, Train Acc: {train_acc}")

    # Validation
    classifier.eval()
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for x_batch, y_batch in val_loader:
            outputs = classifier(x_batch)
            _, predicted = torch.max(outputs, 1)
            val_correct += (predicted == y_batch).sum().item()
            val_total += y_batch.size(0)

    val_acc = val_correct / val_total
    print(f"Validation Accuracy: {val_acc}")

    # Early stopping logic
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience = 0
    else:
        patience += 1
        if patience >= early_stop_patience:
            print("Early stopping triggered")
            break


Epoch 1/20, Train Loss: 1.3539885741013746, Train Acc: 0.7025
Validation Accuracy: 0.86
Epoch 2/20, Train Loss: 1.2009927401175866, Train Acc: 0.9075
Validation Accuracy: 0.89
Epoch 3/20, Train Loss: 0.972532318188594, Train Acc: 0.97
Validation Accuracy: 0.92
Epoch 4/20, Train Loss: 0.8394117997242854, Train Acc: 0.975
Validation Accuracy: 0.94
Epoch 5/20, Train Loss: 0.7929117862994854, Train Acc: 0.985
Validation Accuracy: 0.93
Epoch 6/20, Train Loss: 0.7766444820624131, Train Acc: 0.99
Validation Accuracy: 0.95
Epoch 7/20, Train Loss: 0.765714012659513, Train Acc: 0.995
Validation Accuracy: 0.94
Epoch 8/20, Train Loss: 0.7613680775348957, Train Acc: 0.9975
Validation Accuracy: 0.92
Epoch 9/20, Train Loss: 0.7552154614375188, Train Acc: 0.9975
Validation Accuracy: 0.93
Epoch 10/20, Train Loss: 0.7540987179829524, Train Acc: 0.9975
Validation Accuracy: 0.92
Epoch 11/20, Train Loss: 0.7510198171322162, Train Acc: 1.0
Validation Accuracy: 0.94
Epoch 12/20, Train Loss: 0.748965763128720

# 4. Try a custom prediction

In [33]:
# This example avoids any space-specific terminology to see if the model avoids
# biases towards specific jargon.
new_text = """
First-timer looking to get out of here.

Hi, I'm writing about my interest in travelling to the outer limits!

What kind of craft can I buy? What is easiest to access from this 3rd rock?

Let me know how to do that please.
"""
embedded = embed_fn(new_text)

inp = torch.Tensor([embedded])
with torch.no_grad():
    outputs = classifier(inp).numpy()[0]
    for idx, category in enumerate(df_test["Class Name"].cat.categories):
        print(f"{category}: {outputs[idx] * 100:0.2f}%")