BERT（東北のbase）ファインチューニング

In [None]:
import random
import glob
from tqdm import tqdm

import torch
import pandas as pd
import numpy as np 
from torch.utils.data import TensorDataset, random_split
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from mlflow import log_metric, log_param, log_artifact
from transformers import BertJapaneseTokenizer, BertForSequenceClassification
import pytorch_lightning as pl

In [None]:
#データの読み込み
df = pd.read_csv("ファイル名", delimiter='\t', names=["label", "t1", "t2"])
mapping = {
        'neutral': 2,
        'contradiction': 0,
        'entailment': 1
    }
df.label = df.label.map(mapping)

t1 = df.t1.values
t2 = df.t2.values
labels = df.label.values

In [None]:
# 1. BERT Tokenizerを用いて単語分割・IDへ変換
## Tokenizerの準備
tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-whole-word-masking')

In [None]:
# 最大単語数の確認
max_len = []

# 1文づつ処理
for sent1, sent2 in zip(t1, t2):
    token_words_1 = tokenizer.tokenize(sent1)
    token_words_2 = tokenizer.tokenize(sent2)
    token_words_1.extend(token_words_2)
    # 文章数を取得してリストへ格納
    max_len.append(len(token_words_1))
    
max_length = max(max_len) +3 # 最大単語数にSpecial token（[CLS], [SEP]）の+2をした値が最大単語数

# 最大の値を確認
print('最大単語数: ', max_length)

In [None]:
dataset_for_loader = []

end_term = "[SEP]"

# 1文づつ処理
for x , y, label in zip(t1, t2, labels):
    #sent= x  + end_term + y

    encoding = tokenizer(
            x,
            y,
            max_length=max_length, 
            padding='max_length',
            truncation=True
        )
    
    encoding['labels'] = label # ラベルを追加
    encoding = { k: torch.tensor(v) for k, v in encoding.items() }
    dataset_for_loader.append(encoding)

In [None]:
# 80%地点のIDを取得
train_size = int(0.8 * len(dataset_for_loader))
val_size = len(dataset_for_loader) - train_size

# データセットを分割
train_dataset, val_dataset = random_split(dataset_for_loader, [train_size, val_size])

In [None]:
# データローダの作成
dataloader_train = DataLoader(
    train_dataset, batch_size=16, shuffle=True
)
dataloader_val = DataLoader(val_dataset, batch_size=16)

In [None]:
MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking'
bert_sc = BertForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3)
bert_sc = bert_sc.cuda(2)


In [None]:
class BertForSequenceClassification_pl(pl.LightningModule):
    
    def __init__(self, model_name, num_labels, lr):
        super().__init__()
        self.save_hyperparameters()
        
        #BERTのロード
        self.bert_sc = BertForSequenceClassification.from_pretrained(
            model_name,
            num_labels = num_labels
        )
        
    def training_step(self, batch, batch_idx):
        output = self.bert_sc(**batch)
        loss = output.loss
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        output = self.bert_sc(**batch)
        val_loss = output.loss
        self.log('val_loss', val_loss)
        
    def test_step(self, batch, batch_idx):
        labels = batch.pop('labels')
        output = self.bert_sc(**batch)
        labels_predicted = output.logits.argmax(-1)
        num_correct = (labels_predicted == labels).sum().item()
        accuracy = num_correct / labels.size(0)
        self.log('accuracy', accuracy)
    
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)
        

In [None]:
checkpoint = pl.callbacks.ModelCheckpoint(
    monitor = 'val_loss',
    mode = 'min',
    save_top_k = 1,
    save_weights_only = True,
    dirpath  = 'model/'
)

early_stopping = pl.callbacks.EarlyStopping(
    monitor = 'val_loss',
    mode = 'min',
    patience = 10
)

trainer = pl.Trainer(
    accelerator = 'gpu',
    devices = 2,
    #gpus = [2],
    max_epochs = 5,
    callbacks = [checkpoint, early_stopping]
)

In [None]:
model = BertForSequenceClassification_pl(
    MODEL_NAME, num_labels=3, lr=2e-5
)

trainer.fit(model, dataloader_train, dataloader_val)

In [None]:
# 6-17
best_model_path = checkpoint.best_model_path # ベストモデルのファイル
print('ベストモデルのファイル: ', checkpoint.best_model_path)
print('ベストモデルの検証データに対する損失: ', checkpoint.best_model_score)

In [None]:
# 6-20
# PyTorch Lightningモデルのロード
model = BertForSequenceClassification_pl.load_from_checkpoint(
    best_model_path
) 

# Transformers対応のモデルを./model_transformesに保存
model.bert_sc.save_pretrained('./model_transformers') 

In [None]:
# 6-21
bert_sc = BertForSequenceClassification.from_pretrained(
    './model_transformers'
)

bert_sc.cuda(2)

テスト

In [None]:
df = pd.read_csv("ファイル名", delimiter='\t', names=["label", "t1", "t2"])
mapping = {
        'neutral': 2,
        'contradiction': 0,
        'entailment': 1
    }
df.label = df.label.map(mapping)

t1_test = df.t1.values
t2_test = df.t2.values
labels_test = df.label.values

In [None]:
predicted = []
correct_labels = []
wrong = []
i = 0

for x , y, label in zip(t1_test, t2_test, labels_test):
    
    correct_labels.append(label)
    correct = label
    
    encoding = tokenizer(
            x,
            y,
            max_length=max_length, 
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
    
    encoding = { k: v.cuda(2) for k, v in encoding.items() }

    with torch.no_grad():
        output = bert_sc.forward(**encoding)
        scores = output.logits
        labels_predicted = scores[0].argmax(-1).cpu().numpy().tolist()
        predicted.append(labels_predicted)
        
    if labels_predicted == correct:
        wrong.append(i)
    
    i+= 1
    
#print(predicted)

In [None]:
test_num = len(predicted)
num_correct = 0
tp = 0
fp = 0
fn = 0
tn = 0

for i in range(test_num):
    if predicted[i] == correct_labels[i]:
        num_correct += 1
    
    if predicted[i] == 1 and correct_labels[i] == 1:
        tp += 1
    
    if (predicted[i] == 1 and correct_labels[i] == 0) or (predicted[i] == 1 and correct_labels[i] == 2):
        fp += 1
    
    if (predicted[i] == 0 and correct_labels[i] == 1) or (predicted[i] == 2 and correct_labels[i] == 1):
        fn += 1
    
    if (predicted[i] == 0 and correct_labels[i] == 1) or (predicted[i] == 2 and correct_labels[i] == 1):
        tn += 1

accuracy = num_correct / test_num
recall = tp /(tp + fn)
precision = tp /(tp + fp)
f_value = 2*recall*precision / (precision + recall)
print("accuracy: " + str(accuracy))
print("recall: " + str(recall))
print("precision: " + str(precision))
print("f_value: " + str(f_value))