In [1]:
import pandas as pd
import torch
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
from transformers import (
    BertForSequenceClassification,
    BertTokenizer,
    Trainer,
    TrainingArguments,
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# データの読み込み
data = pd.read_csv("data/data_long_texts.tsv", sep="\t")

# 入力データの表示
print(data.head())

     ラベル 満足度                                                 文章        会員ID
0  移動・交通  満足                             鉄道路線が充実している。どこにいくにも便利。  1226930387
1   自然景観  満足                            海も近く県立公園もあり、リラックスできるので。  1233151099
2  移動・交通  満足  今居住している武蔵小杉はJR横須賀線、南武線、東急東横線、目黒線と選択の余地がたくさんあり、...  1229140632
3  買物・飲食  満足                          スーパーマーケットもたくさんあり、選択の余地が広い  1229140632
4  遊び・娯楽  不満                       かなり発展した街なのに映画館がない。ライブハウスがない。  1229140632


In [3]:
# モデルとトークナイザーのロード
model_name = "cl-tohoku/bert-base-japanese-v3"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(
    model_name, num_labels=2
)  # 2クラス分類（満足、不満）

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BertJapaneseTokenizer'. 
The class this function is called from is 'BertTokenizer'.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at cl-tohoku/bert-base-japanese-v3 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
class CustomDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

In [7]:
def preprocess_data(data: pd.DataFrame):
    encodings = tokenizer(
        data["文章"].tolist(), truncation=True, padding=True, max_length=128
    )
    labels = [1 if label == "満足" else 0 for label in data["満足度"]]
    return encodings, labels


# 評価指標の計算関数
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc}

In [6]:
encodings, labels = preprocess_data(data)
dataset = CustomDataset(encodings, labels)

In [8]:
# モデルのトレーニング
train_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="./logs",
)

trainer = Trainer(
    model=model,
    args=train_args,
    train_dataset=dataset,
    eval_dataset=dataset,
    compute_metrics=compute_metrics,
)

In [9]:
trainer.train()

Step,Training Loss


TrainOutput(global_step=276, training_loss=0.5009553881659023, metrics={'train_runtime': 27.775, 'train_samples_per_second': 78.956, 'train_steps_per_second': 9.937, 'total_flos': 144250636101120.0, 'train_loss': 0.5009553881659023, 'epoch': 3.0})

In [10]:
# 評価と正解率の計算
result = trainer.evaluate()
print(result)

{'eval_loss': 0.3837265372276306, 'eval_accuracy': 0.8727770177838577, 'eval_runtime': 2.3022, 'eval_samples_per_second': 317.521, 'eval_steps_per_second': 39.962, 'epoch': 3.0}


In [13]:
# デバイスの設定
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [14]:
def classify_text(text: str):
    inputs = tokenizer(
        text, return_tensors="pt", truncation=True, padding=True, max_length=128
    ).to(device)
    outputs = model(**inputs)
    preds = torch.argmax(outputs.logits, dim=1)
    return "満足" if preds.item() == 1 else "不満"

In [18]:
# テスト文章の分類
test_text = "子供が喜んで遊べる公園が近くにはない"
result = classify_text(test_text)
print(f"分類結果: {result}")

分類結果: 不満
