# 京大BERTファインチューニング
[京大BERT](https://nlp.ist.i.kyoto-u.ac.jp/?ku_bert_japanese)をベースにして、[ストックマーク株式会社が公開しているner-wikipedia-dataset](https://github.com/stockmarkteam/ner-wikipedia-dataset)を使って固有表現抽出タスク向けにファインチューニングを行う例です  
PyTorch+transformersです(not Tensorflow)

## 準備
学習に必要なものを用意します
主に必要になるもの
- [京大BERTモデル](https://nlp.ist.i.kyoto-u.ac.jp/?ku_bert_japanese)
- [Juman++](https://nlp.ist.i.kyoto-u.ac.jp/?JUMAN%2B%2B)
- [pyknp](https://github.com/ku-nlp/pyknp)

In [2]:
!wget "http://nlp.ist.i.kyoto-u.ac.jp/DLcounter/lime.cgi?down=http://lotus.kuee.kyoto-u.ac.jp/nl-resource/JapaneseBertPretrainedModel/Japanese_L-24_H-1024_A-16_E-30_BPE_WWM_transformers.zip&name=Japanese_L-24_H-1024_A-16_E-30_BPE_WWM_transformers.zip" -O bert.zip
!unzip bert.zip

In [3]:
!mkdir kyoto
!mv Japanese_L-24_H-1024_A-16_E-30_BPE_WWM_transformers kyoto/bert

※Juman++のインストールは大きめのインスタンスでないと時間がかかるorフリーズするかもしれません

In [2]:
!wget "https://github.com/ku-nlp/jumanpp/releases/download/v2.0.0-rc2/jumanpp-2.0.0-rc2.tar.xz"
!tar xvf jumanpp-2.0.0-rc2.tar.xz
!apt-get update -y
!apt-get install -y cmake gcc build-essential
%cd jumanpp-2.0.0-rc2
!mkdir bld
%cd bld
!cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=/usr/local
!make install -j
%cd ../..

In [3]:
!pip install --upgrade pip
!pip install transformers["ja"] numpy noyaki sklearn pyknp
!pip install -U jupyter ipywidgets

In [4]:
!pip3 install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html

In [22]:
!mkdir outputs
!mkdir ckpt

## 動作確認
Juman++が動いていることを確認します

In [5]:
!echo "こんにちは" | jumanpp

## 学習データのダウンロード
今回は冒頭でも述べたとおり[ストックマーク株式会社が公開しているner-wikipedia-dataset](https://github.com/stockmarkteam/ner-wikipedia-dataset)を利用させていただきます

In [7]:
!wget "https://github.com/stockmarkteam/ner-wikipedia-dataset/raw/main/ner.json"

## 学習データの確認
ダウンロードしてきた`ner.json`がどのようになっているか軽く確認してみましょう

In [16]:
!head -15 ner.json

## 学習
実際に学習をしてみます

In [6]:
from transformers import (
    BertForTokenClassification, BertTokenizer, BertConfig,
    TrainingArguments, Trainer,
    EarlyStoppingCallback
)
from pyknp import Juman
from sklearn.model_selection import train_test_split

import torch
import noyaki
import os
import numpy as np
import argparse
import re
import json

関数を定義しておきます

In [24]:
def load_from_json(path: str) -> list:
    jumanpp = Juman()
    json_dict = json.load(open(path, "r"))
    features = []
    for unit in json_dict:
        result = jumanpp.analysis(unit["text"])
        tokenized_text = [mrph.midasi for mrph in result.mrph_list()]
        spans = []
        for entity in unit["entities"]:
            span_list = []
            span_list.extend(entity["span"])
            span_list.append(entity["type"])
            spans.append(span_list)
        label = noyaki.convert(tokenized_text, spans)
        features.append({"x": tokenized_text, "y": label})
    return features

In [25]:
def create_label_vocab(features: list) -> tuple:
    labels = [f["y"] for f in features]
    unique_labels = list(set(sum(labels, [])))
    label2id = {}
    for i, label in enumerate(unique_labels):
        label2id[label] = i
    id2label = {v: k for k, v in label2id.items()}
    return label2id, id2label

In [26]:
def data_collator(features: list) -> dict:
    x = [f["x"] for f in features]
    y = [f["y"] for f in features]
    inputs = tokenizer(x, return_tensors=None, padding='max_length', truncation=True, max_length=64, is_split_into_words=True)
    input_labels = []
    for labels in y:
        pad_list = [-100] * 64
        for i, label in enumerate(labels):
            pad_list.insert(i, label2id[label])
        input_labels.append(pad_list[:64])
    inputs['labels'] = input_labels
    batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in inputs.items()}
    return batch

変数を定義しておきます

In [27]:
model_output_dir = "./outputs"
ckpt_dir = "./ckpt"
training_data_directory = "./"
base_model_directory = "./kyoto/bert"
batch_size = 32
epochs = 3
learning_rate = 3e-5
save_freq = 200

In [28]:
tokenizer = BertTokenizer.from_pretrained(base_model_directory, tokenize_chinese_chars=False, do_lower_case=False)
features = load_from_json(os.path.join(training_data_directory, "ner.json"))
label2id, id2label = create_label_vocab(features)

train_data, val_data = train_test_split(features, test_size=0.2, random_state=123)
train_data, test_data = train_test_split(train_data, test_size=0.1, random_state=123)

`features`の中身を確認してみます

In [29]:
print(features[:10])

`label2id`と`id2label`の中身を確認してみます

In [30]:
print(label2id)
print(id2label)

モデルの用意をします

In [31]:
config = BertConfig.from_pretrained(base_model_directory, label2id=label2id, id2label=id2label)
model = BertForTokenClassification.from_pretrained(base_model_directory, config=config)
print(model)

学習の設定をつくります

In [32]:
args = TrainingArguments(output_dir=ckpt_dir,
                         do_train=True,
                         do_eval=True,
                         do_predict=True,
                         per_device_train_batch_size=batch_size,
                         per_device_eval_batch_size=batch_size,
                         learning_rate=learning_rate,
                         num_train_epochs=epochs,
                         evaluation_strategy="steps",
                         eval_steps=save_freq,
                         save_strategy="steps",
                         save_steps=save_freq,
                         load_best_model_at_end=True,
                        )

Trainerをつくります

In [33]:
trainer = Trainer(model=model,
                  args=args,
                  data_collator=data_collator,
                  train_dataset=train_data,
                  eval_dataset=val_data,
                  callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
                 )

学習を実行します

In [34]:
trainer.train()

テストしてみます

In [18]:
_, _, metrics = trainer.predict(test_data, metric_key_prefix="test")
print(metrics)

In [35]:
trainer.save_model(model_output_dir)

In [20]:
!ls outputs

## 推論
できあがったモデルを使って推論を行ってみます

In [36]:
text = "田中さんはhogehoge株式会社の社員です"
model = BertForTokenClassification.from_pretrained("outputs")

jumanpp = Juman()
result = jumanpp.analysis(text)
tokenized_text = [mrph.midasi for mrph in result.mrph_list()]
inputs = tokenizer(tokenized_text, return_tensors="pt", padding='max_length', truncation=True, max_length=64, is_split_into_words=True)
pred = model(**inputs).logits[0]
pred = np.argmax(pred.detach().numpy(), axis=-1)
labels = []
for i, label in enumerate(pred):
    if i + 1 > len(tokenized_text):
        continue
    labels.append(model.config.id2label[label])
    print(f"{tokenized_text[i]}: {model.config.id2label[label]}")

In [37]:
print(tokenized_text)
print(labels)