# BERT Fine-Tuning
PyTorch Lightning を用いて BERT モデルの Fine Tuning を行います。本ノートブックは簡易的に学習データのみを用いています。

## 必要なライブラリのインポート

In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split
from src import datasets, models

import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl

import numpy as np

pl.seed_everything(1234)
torch.manual_seed(1234)
np.random.seed(1234)

In [None]:
# CPU or GPU の検知
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

## データ準備
加工済みの Livedoor ニュースのデータを読み込み、学習データと検証データに分割します。

In [None]:
df = pd.read_csv("../data/processed/livedoor.tsv", delimiter='\t')
df = df.dropna()
df.head()

In [None]:
X_train, X_test = train_test_split(df, test_size=0.2, stratify=df['label_index'])
X_train.reset_index(drop=True, inplace=True)
X_test.reset_index(drop=True, inplace=True)

In [None]:
X_train.to_csv("../data/processed/livedoor-train.tsv", sep='\t', index=False)
X_train.to_csv("../data/processed/livedoor-test.tsv", sep='\t', index=False)

PyTorch の Data Loader を定義します。

In [None]:
train_dataset = datasets.LivedoorDataset(X_train)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

## Fine Tuning
BERT の Fine Tuning を実施します。

In [None]:
bert_model = models.LitBert()

# fix param
for param in bert_model.bert.bert.parameters():
    param.requires_grad = False

bert_model.to(device)
if device.type == "cpu":
    trainer = pl.Trainer(default_root_dir='pl-model', max_epochs=1)
else:
    trainer = pl.Trainer(gpus=1, default_root_dir='pl-model', max_epochs=30)

In [None]:
%%time
trainer.fit(bert_model, train_loader)

In [None]:
trainer.save_checkpoint("../models/bert-livedoor-epoch01.ckpt")