<a href="https://colab.research.google.com/github/honomax/modernbert-text-classification/blob/main/%E9%95%B7%E6%96%87%E3%83%86%E3%82%AD%E3%82%B9%E3%83%88%E5%88%86%E9%A1%9E%E3%83%97%E3%83%AD%E3%82%B0%E3%83%A9%E3%83%A0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Driveのマウント

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# ライブラリのインポート

In [None]:
import os
import torch
import random
import collections
import numpy as np
import pandas as pd
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import StratifiedKFold, train_test_split
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix

# CSVデータの読み込み

In [None]:
base_dir = "/content/drive/MyDrive/modernbert-text-classification"
data_path = os.path.join(base_dir, "data", "sample.csv")

df = pd.read_csv(data_path)
texts = df["text"].tolist()
labels = df["label"].tolist()

print(df.head())

# シードの固定

In [None]:
def set_seed(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)  # シード値は任意の整数

# データセットクラスの定義

In [None]:
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=2048): #トークン上限の指定
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        encoding = self.tokenizer(
            self.texts[idx],
            max_length=self.max_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        item = {k: v.squeeze(0) for k, v in encoding.items()}
        item["labels"] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item

# モデルとトークナイザの準備

In [None]:
model_name = "answerdotai/ModernBERT-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# データの分割

In [None]:
texts_train_val, texts_test, labels_train_val, labels_test = train_test_split(
    texts, labels, test_size=0.2, stratify=labels, random_state=42
)

print(f"Train+Val size: {len(texts_train_val)}, Test size: {len(texts_test)}")
print("Train+Val label dist:", collections.Counter(labels_train_val))
print("Test label dist:", collections.Counter(labels_test))

# 層化K分割交差検証の準備

In [None]:
skf = StratifiedKFold(n_splits=3, shuffle=True, random_state=42) #分割数の設定
val_accs = []
best_val_acc = 0
best_model_state = None

# モデルの学習

In [None]:
for fold, (train_idx, val_idx) in enumerate(skf.split(texts_train_val, labels_train_val)):
    print(f"\n[Fold {fold+1}]")

    # Foldごとに乱数シードを再設定して初期化状態を揃える
    set_seed(42)

    # 各Foldでモデルを新規に初期化
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
    model.to(device)
    optimizer = AdamW(model.parameters(), lr=2e-5)

    train_dataset = Subset(TextDataset(texts_train_val, labels_train_val, tokenizer), train_idx)
    val_dataset = Subset(TextDataset(texts_train_val, labels_train_val, tokenizer), val_idx)
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

    model.train()
    for batch in train_loader:
        inputs = {k: v.to(device) for k, v in batch.items() if k != "labels"}
        labels_batch = batch["labels"].to(device)

        outputs = model(**inputs)
        loss = torch.nn.functional.cross_entropy(outputs.logits, labels_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()
    val_preds, val_labels = [], []

    with torch.no_grad():
        for batch in val_loader:
            inputs = {k: v.to(device) for k, v in batch.items() if k != "labels"}
            labels_batch = batch["labels"].to(device)

            outputs = model(**inputs)
            preds = torch.argmax(outputs.logits, dim=1)

            val_preds.extend(preds.cpu().tolist())
            val_labels.extend(labels_batch.cpu().tolist())

    val_acc = accuracy_score(val_labels, val_preds)
    val_accs.append(val_acc)
    print(f"Fold {fold+1} Accuracy: {val_acc:.4f}")

    # 最良モデル状態の保存
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_state = model.state_dict()

# 交差検証結果とモデルの保存

In [None]:
# 交差検証平均精度の出力
average_val_acc = sum(val_accs) / len(val_accs)
print(f"Average Validation Accuracy across folds: {average_val_acc:.4f}")

# 最良モデルの保存
save_path = os.path.join(base_dir,"model","best_model.pth")
torch.save(best_model_state, save_path)
print("Best model saved with val accuracy:", best_val_acc)

# テストデータ評価

In [None]:
load_path = os.path.join(base_dir,"model","best_model.pth")
model.load_state_dict(torch.load(load_path))
model.eval()

test_dataset = TextDataset(texts_test, labels_test, tokenizer)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

test_preds, test_labels = [], []

with torch.no_grad():
    for batch in test_loader:
        inputs = {k: v.to(device) for k, v in batch.items() if k != "labels"}
        labels_batch = batch["labels"].to(device)

        outputs = model(**inputs)
        preds = torch.argmax(outputs.logits, dim=1)

        test_preds.extend(preds.cpu().tolist())
        test_labels.extend(labels_batch.cpu().tolist())

 # 評価指標の出力

In [None]:
test_acc = accuracy_score(test_labels, test_preds)
precision, recall, f1, _ = precision_recall_fscore_support(test_labels, test_preds, average='weighted')
conf_mat = confusion_matrix(test_labels, test_preds)

print(f"Test Accuracy: {test_acc:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")
print("Confusion Matrix:\n", conf_mat)