#設定

In [None]:
%%capture
!pip install transformers fugashi ipadic

In [None]:
import tensorflow as tf
import os
import datetime
import random
import unicodedata
import pandas as pd
import numpy as np

from transformers import BertJapaneseTokenizer, TFBertForMaskedLM

In [None]:
MODEL_NAME = "cl-tohoku/bert-base-japanese-whole-word-masking"
MAX_LENGTH = 128

#トークナイザ

In [None]:
class SC_tokenizer(BertJapaneseTokenizer):

    def encode_plus_tagged(self, wrong_text, correct_text, max_length=128):
        encoding = self(
            wrong_text, max_length=max_length, 
            padding="max_length", truncation=True, return_tensors="tf"
        )
        encoding_correct = self(
            correct_text, max_length=max_length, 
            padding="max_length", truncation=True, return_tensors="tf"
        )
        encoding["label"] = encoding_correct["input_ids"]
        return encoding

    def encode_plus_untagged(self, text, max_length=None):
        tokens = []
        tokens_original = []
        words = self.word_tokenizer.tokenize(text)
        for word in words:
            tokens_word = self.subword_tokenizer.tokenize(word)
            tokens.extend(tokens_word)
            if tokens_word[0] == "[UNK]":
                tokens_original.append(word)
            else:
                tokens_original.extend([
                                        token.replace("##", "") for token in tokens_word
                ])

        position = 0
        spans = []
        for token in tokens_original:
            length = len(token)
            while 1:
                if token != text[position:position+length]:
                    position += 1
                else:
                    spans.append([position, position+length])
                    position += length
                    break

        input_ids = self.convert_tokens_to_ids(tokens)
        encoding = self.prepare_for_model(
            input_ids, max_length=max_length,
            padding="max_length" if max_length else False,
            trunncation=True if max_length else False,
            return_tensors="tf"
        )
        sequence_length = len(encoding["input_ids"])
        spans = [[-1, -1]] + spans[:sequence_length-2]
        spans = spans + [[-1, -1]] * (sequence_length - len(spans))

        return encoding, spans
    
    def convert_bert_output_to_text(self, text, labels, spans):
        assert len(spans) == len(labels)

        labels = [label for label, span in zip(labels, spans) if span[0] != -1]
        spans = [span for span in spans if span[0] != -1]

        predicted_text = ""
        position = 0
        for label, span in zip(labels, spans):
            start, end = span
            if position != start:
                predicted_text += text[position:start]
            predicted_token = self.convert_ids_to_tokens(label)
            predicted_token = predicted_token.replace("##", "")
            predicted_token = unicodedata.normalize("NFKC", predicted_token)
            predicted_text += predicted_token
            position = end
        
        return predicted_text

In [None]:
tokenizer = SC_tokenizer.from_pretrained(MODEL_NAME)

In [None]:
wrong_text = "優勝トロフィーを変換した"
correct_text = "優勝トロフィーを返還した"
encoding = tokenizer.encode_plus_tagged(
    wrong_text, correct_text, max_length=12
)
print(encoding)

In [None]:
encoding, spans = tokenizer.encode_plus_untagged(wrong_text)
print("# encoding")
print(encoding)
print("# spans")
print(spans)

In [None]:
predicted_labels = [2, 759, 18204, 11, 8274, 15, 10, 3]
predicted_text = tokenizer.convert_bert_output_to_text(
    wrong_text, predicted_labels, spans
)
print(predicted_text)

#BertForMaskedMLM

In [None]:
bert_mlm = TFBertForMaskedLM.from_pretrained(MODEL_NAME)

In [None]:
text = "優勝トロフィーを変換した。"

encoding, spans = tokenizer.encode_plus_untagged(text)

output = bert_mlm(
    tf.reshape(encoding["input_ids"], (1, 9)),
    tf.reshape(encoding["attention_mask"], (1, 9)),
    tf.reshape(encoding["token_type_ids"], (1, 9))
)
scores = output.logits
labels_predicted = tf.argmax(scores[0], axis=-1).numpy().tolist()

predicted_text = tokenizer.convert_bert_output_to_text(
    text, labels_predicted, spans
)
print(predicted_text)

In [None]:
wrong_texts = ["優勝トロフィーを変換した。", "人と森は強制している。"]
correct_texts = ["優勝トロフィーを返還した。" , "人と森は共生している。"]

max_length = 32
input_shape = (len(wrong_texts), max_length)
dataset = {
    "input_ids" : np.zeros(input_shape, dtype=np.int32), 
    "attention_mask" : np.zeros(input_shape, dtype=np.int32),
    "token_type_ids" : np.zeros(input_shape, dtype=np.int32)
}

for i in range(len(wrong_texts)):
    encoding = tokenizer.encode_plus_tagged(wrong_texts[i], correct_texts[i], max_length)
    for k in dataset.keys():
        dataset[k][i] = encoding[k].numpy()[0]

output = bert_mlm(**dataset)
loss = output.loss

#データセット

漢字の誤変換だけを対象にする

In [None]:
!curl -L -o JWTD.tar.gz https://nlp.ist.i.kyoto-u.ac.jp/DLcounter/lime.cgi?down=https://nlp.ist.i.kyoto-u.ac.jp/nl-resource/JWTD/jwtd_v2.0.tar.gz&name=JWTDv2.0.tar.gz
!tar zxvf JWTD.tar.gz

In [None]:
train_df = pd.read_json("jwtd_v2.0/train.jsonl",  orient='records', lines=True)
train_df = train_df.sample(frac=1, ignore_index=True)

In [None]:
test_df = pd.read_json("jwtd_v2.0/test.jsonl",  orient='records', lines=True)

In [None]:
train_df.head()

In [None]:
test_df.head(1)

In [None]:
category_types = set()
for i in range(len(train_df)):
    if isinstance(train_df.loc[i, "diffs"][0]["category"], list):
        for category in train_df.loc[i, "diffs"][0]["category"]:
            if isinstance(category, str):
                category_types.add(category)
    else:
        category_types.add(train_df.loc[i, "diffs"][0]["category"])

In [None]:
print(category_types)

In [None]:
def pickup_kanji_conversion(dataset, tokenizer):
    correct_texts = []
    wrong_texts = []

    for i in range(len(dataset)):
        if "kanji-conversion_a" in dataset.loc[i, "diffs"][0]["category"] or\
         "kanji-conversion_b" in dataset.loc[i, "diffs"][0]["category"]:
            correct_tokens = tokenizer.tokenize(dataset.loc[i, "post_text"])
            wrong_tokens = tokenizer.tokenize(dataset.loc[i, "pre_text"])

            if len(correct_tokens) == len(wrong_tokens):
                correct_texts.append(
                    unicodedata.normalize("NFKC", dataset.loc[i, "post_text"])
                )
                wrong_texts.append(
                    unicodedata.normalize("NFKC", dataset.loc[i, "pre_text"])
                )

    return wrong_texts, correct_texts

In [None]:
wrong_texts, correct_texts = pickup_kanji_conversion(train_df, tokenizer)
wrong_texts_test, correct_texts_test = pickup_kanji_conversion(test_df, tokenizer)

In [None]:
num_train = int(len(wrong_texts) * 0.8)

In [None]:
wrong_texts_train = wrong_texts[:num_train]
wrong_texts_valid = wrong_texts[num_train:]
correct_texts_train = correct_texts[:num_train]
correct_texts_valid = correct_texts[num_train:]

In [None]:
def to_input(texts, tokenizer, max_length):
    
    encodings = tokenizer(
        texts, max_length=max_length,
        padding="max_length", truncation=True,
        return_tensors="tf"
    )

    return [encodings["input_ids"], encodings["attention_mask"], encodings["token_type_ids"]]

In [None]:
def to_output(texts, tokenizer, max_length):
    input_shape = (len(texts), max_length)
    encodings = np.zeros(input_shape, dtype=np.int32)

    for i in range(len(texts)):
        encodings[i] = tokenizer.encode(
            texts[i],
            max_length=max_length, padding="max_length",
            truncation=True, return_tensors="tf"
        ).numpy()

    return encodings

In [None]:
X_train = to_input(wrong_texts_train, tokenizer, MAX_LENGTH)
y_train = to_output(correct_texts_train, tokenizer, MAX_LENGTH)
X_valid = to_input(wrong_texts_valid, tokenizer, MAX_LENGTH)
y_valid = to_output(correct_texts_valid, tokenizer, MAX_LENGTH)
X_test = to_input(wrong_texts_test, tokenizer, MAX_LENGTH)
y_test = to_output(correct_texts_test, tokenizer, MAX_LENGTH)

In [None]:
X_valid[0].shape

In [None]:
y_valid.shape

#ファインチューニング

In [None]:
!rm -rf logs

In [None]:
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = os.path.join('logs/', current_time)
ckpt_dir = os.path.join('ckpt/', current_time)

In [None]:
bert_mlm = TFBertForMaskedLM.from_pretrained(MODEL_NAME)

In [None]:
bert_mlm.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
    metrics=["accuracy"]
)

In [None]:
EPOCHS = 1

callbacks = [
             tf.keras.callbacks.EarlyStopping(
                 monitor="val_loss", mode="min",
                 patience=5
             ),
             tf.keras.callbacks.TensorBoard(
                 log_dir=log_dir,
                 histogram_freq=1
             ),
             tf.keras.callbacks.ModelCheckpoint(
                 ckpt_dir,
                 save_best_only=True, save_weights_only=True
             )
]


history = bert_mlm.fit(
    X_train, y_train, 
    epochs=EPOCHS, batch_size=64,
    callbacks=callbacks,
    validation_data=(X_valid, y_valid), 
    validation_batch_size=32
)

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir {log_dir}

#性能評価

In [None]:
_, spans = tokenizer.encode_plus_untagged(
    wrong_texts_test[0], max_length=MAX_LENGTH
)

In [None]:
output = bert_mlm([
                   tf.reshape(X_test[0][0], (1,128)), 
                   tf.reshape(X_test[1][0], (1,128)), 
                   tf.reshape(X_test[2][0], (1,128))
]).logits
predicted_labels = tf.argmax(output, axis=2)[0].numpy().tolist()

In [None]:
print(f"入力：{wrong_texts_test[0]}")
print(f"予想：{tokenizer.convert_bert_output_to_text(wrong_texts_test[0], predicted_labels, spans)}")
print(f"正解：{correct_texts_test[0]}")