# Gemini APIを用いたテキスト分類

In [None]:
!pip install -q datasets==2.20.0 google-genai==1.18.0 scikit-learn==1.5.0 tqdm==4.66.5

## 準備

In [2]:
import os
import time

import tqdm
from datasets import load_dataset, Dataset
from google import genai
from google.genai import types
from sklearn.metrics import classification_report

In [3]:
dataset = load_dataset("rotten_tomatoes")
dataset

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading readme:   0%|          | 0.00/7.46k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/699k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/90.0k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/92.2k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/8530 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1066 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1066 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 8530
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 1066
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 1066
    })
})

In [4]:
def evaluate_performance(y_true, y_pred):
    """Create and print the classification report"""
    performance = classification_report(
        y_true, y_pred,
        target_names=["Negative Review", "Positive Review"]
    )
    print(performance)

## Gemini APIを用いたテキスト分類

以下のページを参考に、APIキーを取得しましょう。

- [Gemini API キーを取得する](https://ai.google.dev/gemini-api/docs/api-key?hl=ja)

APIにはレート制限があります。詳細については、以下のページを参照してください。

- [レート制限](https://ai.google.dev/gemini-api/docs/rate-limits?hl=ja)

取得したAPIキーを[Google Gen AI SDK](https://github.com/googleapis/python-genai)のクライアントに設定します。

In [8]:
client = genai.Client(api_key="APIキーを設定")

In [6]:
def predict(prompt: str, document: str, model: str = "gemini-2.0-flash-lite") -> str:
    """Generate an output based on a prompt and an input document."""
    response = client.models.generate_content(
        model=model,
        contents=prompt.replace("[DOCUMENT]", document),
        config=types.GenerateContentConfig(
            max_output_tokens=1, # 最大出力トークン数の指定
            temperature=0.0,
        ),
    )
    return response.text

In [7]:
prompt = """Predict whether the following document is a positive or negative movie review:

[DOCUMENT]

If it is positive return 1 and if it is negative return 0. Do not give any other answers.
"""

document = "unpretentious , charming , quirky , original"
predict(prompt, document)

'1'

Gemini 2.0 Flash-Liteのレート制限を見ると、無料枠の場合は1分間あたり30リクエスト、1日あたり1500リクエストの制限があることがわかります。今回の評価用データの件数は1066件なので、リクエスト数的には無料枠の範囲に収まりますが、全件を評価すると時間がかかるため、以下では100件だけ評価しています。

In [41]:
max_predictions = 100
num_labels = len(set(dataset["test"]["label"]))
max_predictions_per_label = max_predictions // num_labels

# pandasに変換
df = dataset["test"].to_pandas()

# ラベルごとに50件ずつサンプリング
sampled_df = df.groupby("label", group_keys=False).sample(n=max_predictions_per_label, random_state=42)

# Hugging FaceのDatasetに戻す
sampled_dataset = Dataset.from_pandas(sampled_df, preserve_index=False)

In [44]:
delay_in_seconds = 2
predictions = []
for doc in tqdm.tqdm(sampled_dataset["text"]):
    predictions.append(predict(prompt, doc))
    time.sleep(delay_in_seconds)

100%|██████████| 100/100 [03:45<00:00,  2.26s/it]


In [46]:
y_pred = [int(pred) for pred in predictions]
evaluate_performance(sampled_dataset["label"], y_pred)

                 precision    recall  f1-score   support

Negative Review       0.83      0.98      0.90        50
Positive Review       0.98      0.80      0.88        50

       accuracy                           0.89       100
      macro avg       0.90      0.89      0.89       100
   weighted avg       0.90      0.89      0.89       100

