# 感情分析

In [1]:
import torch

print(torch.__version__)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

2.0.1+cu117
cuda:0


In [2]:
from transformers import BertForSequenceClassification, BertTokenizerFast, BertJapaneseTokenizer, Trainer, TrainingArguments
from datasets import load_dataset

# 日本語版(東北大BERT-base)
model = BertForSequenceClassification.from_pretrained('cl-tohoku/bert-base-japanese-v3', num_labels=3)
tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-v3')

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at cl-tohoku/bert-base-japanese-v3 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
from datasets import load_dataset

dataset = load_dataset('dataset_loader.py', name='sentiment_dataset')

In [4]:
def tokenize(batch):
    return tokenizer(batch['text'], padding='max_length', truncation=True)

train_dataset, test_dataset, eval_dataset = dataset['train'].map(tokenize, batched=True), dataset['test'].map(tokenize, batched=True), dataset['validation'].map(tokenize, batched=True)
train_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
eval_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])

In [5]:
train_dataset, test_dataset, eval_dataset

(Dataset({
     features: ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
     num_rows: 649
 }),
 Dataset({
     features: ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
     num_rows: 324
 }),
 Dataset({
     features: ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
     num_rows: 325
 }))

In [60]:
# トレーニングの設定
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=2,  # accumulate gradients over 2 batches
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
)

# トレーナーの初期化とトレーニング開始
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset
)

# チェックポイントから学習を再開したいときと
# trainer.train(ignore_keys_for_eval=['last_hidden_state', 'hidden_states', 'attentions'],
            #   resume_from_checkpoint=True)

trainer.train()



  0%|          | 0/243 [00:00<?, ?it/s]

{'train_runtime': 65.1361, 'train_samples_per_second': 29.891, 'train_steps_per_second': 3.731, 'train_loss': 0.7231104344497492, 'epoch': 2.98}


TrainOutput(global_step=243, training_loss=0.7231104344497492, metrics={'train_runtime': 65.1361, 'train_samples_per_second': 29.891, 'train_steps_per_second': 3.731, 'train_loss': 0.7231104344497492, 'epoch': 2.98})

In [45]:
trainer.save_state()
trainer.save_model()

In [50]:
# 保存したモデルを読み込む
# model_path = 'results/'
# model = BertForSequenceClassification.from_pretrained(model_path)
# tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-v3')

# trainer = Trainer(
#     model=model,
#     args=training_args,
#     train_dataset=train_dataset,
#     eval_dataset=test_dataset
# )

In [61]:
predictions = trainer.predict(test_dataset)
predictions

  0%|          | 0/21 [00:00<?, ?it/s]

PredictionOutput(predictions=array([[-1.82261038e+00,  1.89144635e+00, -7.11811900e-01],
       [ 1.44777104e-01, -7.88384140e-01,  4.27118272e-01],
       [-2.26370430e+00,  3.52510715e+00, -1.42627966e+00],
       [-1.85360134e+00,  2.84092736e+00, -1.02738726e+00],
       [-2.10023594e+00,  1.60221970e+00,  3.03018242e-01],
       [ 5.85632861e-01, -1.43055356e+00,  4.91691381e-01],
       [-9.69299853e-01,  2.26190853e+00, -1.26744282e+00],
       [ 2.03745937e+00, -2.26675677e+00,  4.76884633e-01],
       [-1.54400408e+00,  1.88790500e+00, -3.71967182e-02],
       [-2.30509591e+00,  3.01106286e+00, -1.34506440e+00],
       [ 1.41407359e+00, -1.49115944e+00, -4.00320381e-01],
       [-1.66096985e+00,  2.25473571e+00, -6.53830171e-01],
       [-2.03043103e+00,  3.29224420e+00, -1.39032519e+00],
       [-8.92628312e-01, -1.43843031e+00,  2.57924032e+00],
       [ 2.37966239e-01, -1.02401578e+00,  9.38977301e-01],
       [-2.35551310e+00,  3.05863667e+00, -8.95049751e-01],
       [ 1.