## Fine-tuning

In [1]:
import torch

print(torch.__version__)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

2.0.1+cu117
cuda:0


In [2]:
from transformers import BertForSequenceClassification, BertTokenizerFast, BertJapaneseTokenizer, Trainer, TrainingArguments
from datasets import load_dataset

# BERTモデルとトークナイザの準備
# model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
# tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

# 日本語版(東北大BERT-base)
model = BertForSequenceClassification.from_pretrained('cl-tohoku/bert-base-japanese-v3', num_labels=3)
tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-v3')

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at cl-tohoku/bert-base-japanese-v3 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
# データセットのダウンロードと前処理
dataset = load_dataset('imdb')

def tokenize(batch):
    return tokenizer(batch['text'], padding=True, truncation=True)

train_dataset, test_dataset = dataset['train'].map(tokenize, batched=True), dataset['test'].map(tokenize, batched=True)
train_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

In [6]:
# トレーニングの設定
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=2,  # accumulate gradients over 2 batches
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
)

# トレーナーの初期化とトレーニング開始
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset
)

trainer.train()



  0%|          | 0/9375 [00:00<?, ?it/s]

KeyboardInterrupt: 