In [None]:
import pandas as pd
import numpy as np
from google import genai
from google.genai import types
from google.api_core.exceptions import ResourceExhausted, ServiceUnavailable, DeadlineExceeded
from typing import Sequence
from typing import Tuple
import time
import os
from dotenv import load_dotenv
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, classification_report
from xgboost import XGBClassifier

In [32]:
load_dotenv()

True

In [34]:
data = pd.read_csv("./data/IMDB Dataset.csv")

In [35]:
client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))

In [40]:
def make_batches(
    texts: Sequence[str],
    token_cap: int = 2400,
) -> list[list[str]]:
    batches: list[list[str]] = []
    current_batch: list[str] = []
    current_tokens = 0

    for text in texts:
        tokens = max(1, len(text or "") // 4)
        if tokens > token_cap:
            msg = (
                f"Single text exceeds token cap ({tokens} > {token_cap})."
            )
            raise ValueError(msg)

        if current_batch and current_tokens + tokens > token_cap:
            batches.append(current_batch)
            current_batch = [text]
            current_tokens = tokens
        else:
            current_batch.append(text)
            current_tokens += tokens

    if current_batch:
        batches.append(current_batch)

    return batches


def request_embeddings(
    client: genai.Client,
    texts: Sequence[str],
    max_retries: int = 3,
    base_delay: float = 1.0,
    backoff_factor: float = 2.0,
) -> np.ndarray:

    attempt = 0
    while True:
        try:
            result = client.models.embed_content(
                model="gemini-embedding-001",
                contents=texts,
                config=types.EmbedContentConfig(task_type="CLASSIFICATION")
            ).embeddings
            if not result:
                raise ValueError("Received empty embeddings from the API.")
            embeddings = np.array([embedding.values for embedding in result])
            return embeddings

        except (ServiceUnavailable, DeadlineExceeded, ResourceExhausted) as exc:
            attempt += 1
            if attempt > max_retries:
                msg = (
                    f"Request failed after {max_retries} retries "
                    f"(batch_size={len(texts)})."
                )
                raise RuntimeError(msg) from exc
            
            retry_delay = base_delay * (backoff_factor ** (attempt - 1))
            time.sleep(retry_delay)

        except Exception as exc:
            raise RuntimeError("Embedding request failed") from exc


def create_embeddings(
    client: genai.Client,
    data: pd.DataFrame,
    token_cap: int = 2400,
):
    texts = data["review"].astype(str).tolist()
    texts = [text[:(token_cap * 4)] for text in texts]
    batches = make_batches(texts, token_cap=token_cap)

    all_embeddings: list[np.ndarray] = []
    for batch in batches:
        batch_embeddings = request_embeddings(client, batch)
        if batch_embeddings.shape[0] != len(batch):
            msg = (
                "Embedding batch size mismatch: "
                f"expected {len(batch)}, got {batch_embeddings.shape[0]}"
            )
            raise ValueError(msg)
        all_embeddings.append(batch_embeddings)

    embeddings_array = np.vstack(all_embeddings)
    if embeddings_array.shape[0] != len(data):
        msg = (
            "Total embeddings count does not match dataframe length: "
            f"{embeddings_array.shape[0]} vs {len(data)}"
        )
        raise ValueError(msg)

    result = data.copy()
    result["embedding"] = list(embeddings_array)
    return result

In [41]:
embeddings_df = create_embeddings(client, data)
embeddings_df.head()

Unnamed: 0,review,sentiment,embedding
0,One of the other reviewers has mentioned that ...,positive,"[0.008090536, -0.032237276, 0.023127945, -0.05..."
1,A wonderful little production. <br /><br />The...,positive,"[-0.020639218, -0.0028484424, 0.02097729, -0.0..."
2,I thought this was a wonderful way to spend ti...,positive,"[-0.003580236, -0.014770088, 0.00019785065, -0..."
3,Basically there's a family where a little boy ...,negative,"[0.013992016, -0.013159083, -0.0051524853, -0...."
4,"Petter Mattei's ""Love in the Time of Money"" is...",positive,"[-0.0059860116, -0.024320487, 0.0072585535, -0..."


In [None]:
def train_xgb_on_embeddings(
    df: pd.DataFrame,
    embedding_column: str = "embedding",
    label_column: str = "sentiment",
    test_size: float = 0.2,
    random_state: int = 42,
) -> Tuple[XGBClassifier, float]:
    """Train an XGBoost classifier on precomputed embeddings and return model and test accuracy.

    Args:
        df: DataFrame containing embeddings and labels.
        embedding_column: Name of the column with embedding vectors.
        label_column: Name of the column with class labels.
        test_size: Proportion of data to use for the test split.
        random_state: Random seed for reproducibility.

    Returns:
        A tuple of (fitted XGBClassifier, test accuracy).
    """
    X = np.array([np.array(emb) for emb in df[embedding_column]])
    y_raw = df[label_column].to_numpy()

    encoder = LabelEncoder()
    y = encoder.fit_transform(y_raw)

    X_train, X_test, y_train, y_test = train_test_split(
        X,
        y,
        test_size=test_size,
        random_state=random_state,
        stratify=y,
    )

    model = XGBClassifier(
        n_estimators=300,
        max_depth=6,
        learning_rate=0.1,
        subsample=0.9,
        colsample_bytree=0.9,
        objective="binary:logistic",
        eval_metric="logloss",
        n_jobs=-1,
        tree_method="hist",
    )

    model.fit(X_train, y_train)

    y_pred = model.predict(X_test)
    acc = accuracy_score(y_test, y_pred)

    print(f"Test accuracy: {acc:.4f}")
    print("\nClassification report:")
    print(classification_report(y_test, y_pred, target_names=encoder.classes_))

    return model, float(acc)


In [51]:
model, acc = train_xgb_on_embeddings(embeddings_df)


Test accuracy: 0.9657

Classification report:
              precision    recall  f1-score   support

    negative       0.97      0.96      0.97      5000
    positive       0.96      0.97      0.97      5000

    accuracy                           0.97     10000
   macro avg       0.97      0.97      0.97     10000
weighted avg       0.97      0.97      0.97     10000

