必要ModuleをImport

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

import math

import random
import time
import pickle
import tqdm
from collections import Counter

from torch.utils.data import Dataset
import numpy as np

from utils import GELU, PositionwiseFeedForward, LayerNorm, SublayerConnection, LayerNorm

import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
from ipywidgets import FloatProgress
from IPython.display import display, clear_output

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

read_dir = './data/merged_bert/'
output_dir = './output/merged_bert/'
processed_train_txt = read_dir + 'train_X.txt'
processed_valid_txt = read_dir + 'valid_X.txt'
processed_test_txt = read_dir + 'test_X.txt'

Attentionセルを定義する

In [3]:
class Attention(nn.Module):
    """
    Scaled Dot Product Attention
    """
    def forward(self, query, key, value, mask=None, dropout=None):
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        p_attn = F.softmax(scores, dim=-1)

        if dropout is not None:
            p_attn = dropout(p_attn)

        return torch.matmul(p_attn, value), p_attn


Multi Head Attentionを定義する

In [4]:
class MultiHeadedAttention(nn.Module):

    def __init__(self, h, d_model, dropout=0.1):
        super().__init__()
        assert d_model % h == 0

        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h

        self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
        self.output_linear = nn.Linear(d_model, d_model)
        self.attention = Attention()

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) for l, x in zip(self.linear_layers, (query, key, value))]

        x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)

        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)

        return self.output_linear(x)


Transformerを定義する

In [5]:
class TransformerBlock(nn.Module):
    def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):
        super().__init__()
        self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden, dropout=dropout)
        self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)
        self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, mask):
        x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))
        x = self.output_sublayer(x, self.feed_forward)
        return self.dropout(x)


BERTクラスを定義する

In [6]:
class BERT(nn.Module):

    def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):
        super().__init__()
        self.hidden = hidden
        self.n_layers = n_layers
        self.attn_heads = attn_heads
        self.feed_forward_hidden = hidden * 4
        # embedding for BERT
        self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden, dropout=dropout)

        self.transformer_blocks = nn.ModuleList([TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)])

    def forward(self, x):
        # xの中で0以上は1, 0未満は0として, maskテンソルを作る
        mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
        x = self.embedding(x)

        for transformer in self.transformer_blocks:
            x = transformer.forward(x, mask)
        return x


BERTのEmbedding層を定義する

In [7]:
class TokenEmbedding(nn.Embedding):
    def __init__(self, vocab_size, embed_size=512):
        super().__init__(vocab_size, embed_size, padding_idx=0)

class PositionalEmbedding(nn.Module):

    def __init__(self, d_model, max_len=512):
        super().__init__()

        pe = torch.zeros(max_len, d_model).float()
        pe.require_grad = False

        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = (torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)).float().exp()

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return self.pe[:, :x.size(1)]


class BERTEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_size, dropout=0.1):
        super().__init__()
        self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size)
        self.position = PositionalEmbedding(d_model=self.token.embedding_dim)
        self.dropout = nn.Dropout(p=dropout)
        self.embed_size = embed_size

    #def forward(self, sequence, segment_label):
    def forward(self, sequence):
        x = self.token(sequence) + self.position(sequence)
        return self.dropout(x)


学習用にマスク予測の層を追加する<br>
Next Sentence Prediction用のクラスは削除してある

In [8]:
class BERTLM(nn.Module):
    def __init__(self, bert: BERT, vocab_size):
        super().__init__()
        self.bert = bert
        self.mask_lm = MaskedLanguageModel(self.bert.hidden, vocab_size)

    def forward(self, x):
        x = self.bert(x)
        return self.mask_lm(x)

class MaskedLanguageModel(nn.Module):
    def __init__(self, hidden, vocab_size):
        super().__init__()
        self.linear = nn.Linear(hidden, vocab_size)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.linear(x))

In [9]:
class BERTESD(nn.Module):
    def __init__(self, bert: BERT):
        super().__init__()
        self.bert = bert
        self.pr = ErrorSentenceDetectionHead(self.bert.hidden)

    def forward(self, x):
        x = self.bert(x)
        return self.pr(x)

class ErrorSentenceDetectionHead(nn.Module):
    def __init__(self, hidden):
        super().__init__()
        self.linear = nn.Linear(hidden, 2)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.linear(x[:, 0]))

In [10]:
class BERTHybridESD(nn.Module):
    def __init__(self, epd):
        super().__init__()
        self.epd = epd
        self.esd = ErrorSentenceDetectionHead(9)

    def forward(self, x):
        x = self.epd(x)
        return self.esd(x)

class ErrorSentenceDetectionHead(nn.Module):
    def __init__(self, hidden):
        super().__init__()
        self.linear = nn.Linear(hidden, 2)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.linear(x[:, 0]))

In [11]:
class BERTEPD(nn.Module):
    def __init__(self, bert: BERT, head):
        super().__init__()
        self.bert = bert
        self.epd = ErrorPositionDetectionHead(self.bert.hidden, head)

    def forward(self, x):
        x = self.bert(x)
        return self.epd(x)

class ErrorPositionDetectionHead(nn.Module):
    def __init__(self, hidden, head):
        super().__init__()
        self.linear = nn.Linear(hidden, head)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.linear(x))

BERT用のVocabを生成するクラスを定義する

In [12]:
import pickle
import tqdm
from collections import Counter

class TorchVocab(object):
    """
    :property freqs: collections.Counter, コーパス中の単語の出現頻度を保持するオブジェクト
    :property stoi: collections.defaultdict, string → id の対応を示す辞書
    :property itos: collections.defaultdict, id → string の対応を示す辞書
    """
    def __init__(self, counter, max_size=None, min_freq=1, specials=['<pad>', '<oov>'],
                 vectors=None, unk_init=None, vectors_cache=None):
        """
        :param coutenr: collections.Counter, データ中に含まれる単語の頻度を計測するためのcounter
        :param max_size: int, vocabularyの最大のサイズ. Noneの場合は最大値なし. defaultはNone
        :param min_freq: int, vocabulary中の単語の最低出現頻度. この数以下の出現回数の単語はvocabularyに加えられない.
        :param specials: list of str, vocabularyにあらかじめ登録するtoken
        :param vecors: list of vectors, 事前学習済みのベクトル. ex)Vocab.load_vectors
        """
        self.freqs = counter
        counter = counter.copy()
        min_freq = max(min_freq, 1)

        self.itos = list(specials)
        # special tokensの出現頻度はvocabulary作成の際にカウントされない
        for tok in specials:
            del counter[tok]

        max_size = None if max_size is None else max_size + len(self.itos)

        # まず頻度でソートし、次に文字順で並び替える
        words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
        words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
        
        # 出現頻度がmin_freq未満のものはvocabに加えない
        for word, freq in words_and_frequencies:
            if freq < min_freq or len(self.itos) == max_size:
                break
            self.itos.append(word)

        # dictのk,vをいれかえてstoiを作成する
        self.stoi = {tok: i for i, tok in enumerate(self.itos)}

        self.vectors = None
        if vectors is not None:
            self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache)
        else:
            assert unk_init is None and vectors_cache is None

    def __eq__(self, other):
        if self.freqs != other.freqs:
            return False
        if self.stoi != other.stoi:
            return False
        if self.itos != other.itos:
            return False
        if self.vectors != other.vectors:
            return False
        return True

    def __len__(self):
        return len(self.itos)

    def vocab_rerank(self):
        self.stoi = {word: i for i, word in enumerate(self.itos)}

    def extend(self, v, sort=False):
        words = sorted(v.itos) if sort else v.itos
        for w in words:
            if w not in self.stoi:
                self.itos.append(w)
                self.stoi[w] = len(self.itos) - 1


class Vocab(TorchVocab):
    def __init__(self, counter, max_size=None, min_freq=1):
        self.pad_index = 0
        self.unk_index = 1
        self.eos_index = 2
        self.sos_index = 3
        self.mask_index = 4
        super().__init__(counter, specials=["<pad>", "<unk>", "<eos>", "<sos>", "<mask>"], max_size=max_size, min_freq=min_freq)

    # override用
    def to_seq(self, sentece, seq_len, with_eos=False, with_sos=False) -> list:
        pass

    # override用
    def from_seq(self, seq, join=False, with_pad=False):
        pass

    @staticmethod
    def load_vocab(vocab_path: str) -> 'Vocab':
        with open(vocab_path, "rb") as f:
            return pickle.load(f)

    def save_vocab(self, vocab_path):
        with open(vocab_path, "wb") as f:
            pickle.dump(self, f)


# テキストファイルからvocabを作成する
class WordVocab(Vocab):
    def __init__(self, texts, max_size=None, min_freq=1):
        print("Building Vocab")
        counter = Counter()
        for line in texts:
            if isinstance(line, list):
                words = line
            else:
                words = line.replace("\n", "").replace("\t", "").split()

            for word in words:
                counter[word] += 1
        super().__init__(counter, max_size=max_size, min_freq=min_freq)

    def to_seq(self, sentence, seq_len=None, with_eos=False, with_sos=False, with_len=False):
        if isinstance(sentence, str):
            sentence = sentence.split()

        seq = [self.stoi.get(word, self.unk_index) for word in sentence]

        if with_eos:
            seq += [self.eos_index]  # this would be index 1
        if with_sos:
            seq = [self.sos_index] + seq

        origin_seq_len = len(seq)

        if seq_len is None:
            pass
        elif len(seq) <= seq_len:
            seq += [self.pad_index for _ in range(seq_len - len(seq))]
        else:
            seq = seq[:seq_len]

        return (seq, origin_seq_len) if with_len else seq

    def from_seq(self, seq, join=False, with_pad=False):
        words = [self.itos[idx]
                 if idx < len(self.itos)
                 else "<%d>" % idx
                 for idx in seq
                 if not with_pad or idx != self.pad_index]

        return " ".join(words) if join else words

    @staticmethod
    def load_vocab(vocab_path: str) -> 'WordVocab':
        with open(vocab_path, "rb") as f:
            return pickle.load(f)


def build(corpus_path, output_path, vocab_size=None, encoding='utf-8', min_freq=1):
    with open(corpus_path, "r", encoding=encoding) as f:
        vocab = WordVocab(f, max_size=vocab_size, min_freq=min_freq)

    print("VOCAB SIZE:", len(vocab))
    vocab.save_vocab(output_path)

Dataloaderを定義する.<br>
ここで文章中の単語をMASKする処理を行う<br>
Windowsでの並列化処理のために外部ファイル(dataset.py)で定義する  
MACやLinuxならこのipynbに直接定義しても良いはず

Trainerクラスを定義する.  
Masked Language Model : 文章中の一部の単語をマスクして,予測を行うタスク.  
- `save_bert`はbert本体のみを保存する
- `save_pretrain`は事前学習用のHeadも含んだ全体を保存する
- `load_bert`と`load_pretrain`も同様

In [13]:
import datetime
dt_now = str(datetime.datetime.now()).replace(' ', '')
dt_now = dt_now.replace(':','_')

In [14]:
from torch.utils.data import DataLoader
class ESDTrainer:
    def __init__(self, bert: BERT, train_dataloader: DataLoader, 
                 valid_dataloader: DataLoader = None,
                 lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01,
                 with_cuda: bool = True, log_freq: int = 10):
        # GPU環境において、GPUを指定しているかのフラグ
        cuda_condition = torch.cuda.is_available() and with_cuda
        self.device = torch.device("cuda:0" if cuda_condition else "cpu")

        self.bert = bert
        self.model = BERTESD(bert).to(self.device)

        if torch.cuda.device_count() > 1:
            print("Using %d GPUS for BERT" % torch.cuda.device_count())
            self.model = nn.DataParallel(self.model)

        self.train_data = train_dataloader
        self.valid_data = valid_dataloader

        self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)

        # masked_token予測のためのLoss関数を設定
        self.criterion = nn.NLLLoss()
        self.log_freq = log_freq
        print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
        
        self.train_losses = []
        self.valid_losses = []
        self.train_accs = []
        self.valid_accs = []

    def train(self, epoch):
        self.iteration(epoch, self.train_data)

    def valid(self, epoch):
        self.iteration(epoch, self.valid_data, train=False)

    def iteration(self, epoch, data_loader, train=True):
        str_code = "train" if train else "valid"
        data_iter = tqdm.tqdm(enumerate(data_loader), desc="EP_%s:%d" % (str_code, epoch), total=len(data_loader), bar_format="{l_bar}{r_bar}")
        avg_loss = 0.0
        total_correct = 0
        total_element = 0

        for i, data in data_iter:
            data = {key: value.to(self.device) for key, value in data.items()}
            pr_output = self.model.forward(data["bert_input"])
            pr_loss = self.criterion(pr_output, data["label"])
            loss = pr_loss
            # training時のみ,backwardとoptimizer更新を行う
            if train:
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()
            
            correct = pr_output.argmax(dim=-1).eq(data["label"]).sum().item()
            avg_loss += loss.item()
            total_correct += correct
            total_element += data["label"].nelement()
            post_fix = {
                "epoch": epoch,
                "iter": i,
                "avg_loss": avg_loss / (i + 1),
                "avg_acc": total_correct / total_element * 100,
                "loss": loss.item()
            }

            if i % self.log_freq == 0:
                data_iter.write(str(post_fix))

        print("EP%d_%s, avg_loss=" % (epoch, str_code), avg_loss / len(data_iter),  "total_acc=", total_correct * 100.0 / total_element)
        if train:
            self.train_losses.append(avg_loss / len(data_iter))
            self.train_accs.append(total_correct * 100.0 / total_element)
        else:
            self.valid_losses.append(avg_loss / len(data_iter))
            self.valid_accs.append(total_correct * 100.0 / total_element)
        
    def save_bert(self, epoch, file_path=output_dir + "bert_model"):
        output_path = file_path + ".ep%d" % epoch
        torch.save(self.bert.state_dict(), output_path)
        self.bert.to(self.device)
        print("EP:%d Model Saved on:" % epoch, output_path)
        return output_path
    
    def save_pretrain(self, epoch, file_path=output_dir + "mlm_model"):
        state = {
            'epoch' : epoch,
            'state_dict' : self.model.module.state_dict(),
            'optimizer' : self.optim.state_dict()
        }
        output_path = file_path + ".ep%d" % epoch
        torch.save(state, output_path)
        self.bert.to(self.device)
        print("EP:%d Model Saved on:" % epoch, output_path)
        return output_path
 
    def load_bert(self, file_path):
        self.bert.state_dict(torch.load(file_path))
        
    def load_pretrain(self, file_path):
        state = torch.load(file_path)
        #state = self.fix_model_state_dict(state)
        self.model.module.load_state_dict(state['state_dict'])
        self.optim.load_state_dict(state['optimizer'])
        return state

In [15]:
class HybridESDTrainer:
    def __init__(self, bert: BERT,train_dataloader: DataLoader, 
                 valid_dataloader: DataLoader = None,
                 lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01,
                 with_cuda: bool = True, log_freq: int = 10, head = 9):
        # GPU環境において、GPUを指定しているかのフラグ
        cuda_condition = torch.cuda.is_available() and with_cuda
        self.device = torch.device("cuda:0" if cuda_condition else "cpu")

        self.bert = bert
        self.epd = BERTEPD(bert, head)
        self.model = BERTHybridESD(self.epd).to(self.device)

        if torch.cuda.device_count() > 1:
            print("Using %d GPUS for BERT" % torch.cuda.device_count())
            self.model = nn.DataParallel(self.model)

        self.train_data = train_dataloader
        self.valid_data = valid_dataloader

        self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)

        # masked_token予測のためのLoss関数を設定
        self.criterion = nn.NLLLoss()
        self.log_freq = log_freq
        print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
        
        self.train_losses = []
        self.valid_losses = []
        self.train_accs = []
        self.valid_accs = []

    def train(self, epoch):
        self.iteration(epoch, self.train_data)

    def valid(self, epoch):
        self.iteration(epoch, self.valid_data, train=False)

    def iteration(self, epoch, data_loader, train=True):
        str_code = "train" if train else "valid"
        data_iter = tqdm.tqdm(enumerate(data_loader), desc="EP_%s:%d" % (str_code, epoch), total=len(data_loader), bar_format="{l_bar}{r_bar}")
        avg_loss = 0.0
        total_correct = 0
        total_element = 0

        for i, data in data_iter:
            data = {key: value.to(self.device) for key, value in data.items()}
            pr_output = self.model.forward(data["bert_input"])
            pr_loss = self.criterion(pr_output, data["label"])
            loss = pr_loss
            # training時のみ,backwardとoptimizer更新を行う
            if train:
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()
            
            correct = pr_output.argmax(dim=-1).eq(data["label"]).sum().item()
            avg_loss += loss.item()
            total_correct += correct
            total_element += data["label"].nelement()
            post_fix = {
                "epoch": epoch,
                "iter": i,
                "avg_loss": avg_loss / (i + 1),
                "avg_acc": total_correct / total_element * 100,
                "loss": loss.item()
            }

            if i % self.log_freq == 0:
                data_iter.write(str(post_fix))

        print("EP%d_%s, avg_loss=" % (epoch, str_code), avg_loss / len(data_iter),  "total_acc=", total_correct * 100.0 / total_element)
        if train:
            self.train_losses.append(avg_loss / len(data_iter))
            self.train_accs.append(total_correct * 100.0 / total_element)
        else:
            self.valid_losses.append(avg_loss / len(data_iter))
            self.valid_accs.append(total_correct * 100.0 / total_element)
        
    def save_bert(self, epoch, file_path=output_dir + "bert_model"):
        output_path = file_path + ".ep%d" % epoch
        torch.save(self.bert.state_dict(), output_path)
        self.bert.to(self.device)
        print("EP:%d Model Saved on:" % epoch, output_path)
        return output_path
    
    def save_pretrain(self, epoch, file_path=output_dir + "mlm_model"):
        state = {
            'epoch' : epoch,
            'state_dict' : self.model.module.state_dict(),
            'optimizer' : self.optim.state_dict()
        }
        output_path = file_path + ".ep%d" % epoch
        torch.save(state, output_path)
        self.bert.to(self.device)
        print("EP:%d Model Saved on:" % epoch, output_path)
        return output_path
 
    def load_bert(self, file_path):
        self.bert.state_dict(torch.load(file_path))
    
    def load_epd(self, file_path):
        state = torch.load(file_path)
        self.epd.load_state_dict(state['state_dict'])
    
    def load_pretrain(self, file_path):
        state = torch.load(file_path)
        #state = self.fix_model_state_dict(state)
        self.model.module.load_state_dict(state['state_dict'])
        self.optim.load_state_dict(state['optimizer'])
        return state

In [16]:
class EPDTrainer:
    def __init__(self, bert: BERT, train_dataloader: DataLoader, 
                 valid_dataloader: DataLoader = None,
                 lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01,
                 with_cuda: bool = True, log_freq: int = 10, head = 2):
        # GPU環境において、GPUを指定しているかのフラグ
        cuda_condition = torch.cuda.is_available() and with_cuda
        self.device = torch.device("cuda:0" if cuda_condition else "cpu")

        self.bert = bert
        self.model = BERTEPD(bert, head).to(self.device)

        if torch.cuda.device_count() > 1:
            print("Using %d GPUS for BERT" % torch.cuda.device_count())
            self.model = nn.DataParallel(self.model)

        self.train_data = train_dataloader
        self.valid_data = valid_dataloader

        self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)

        # masked_token予測のためのLoss関数を設定
        self.criterion = nn.NLLLoss()
        self.log_freq = log_freq
        print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
        
        self.train_losses = []
        self.valid_losses = []
        self.train_accs = []
        self.valid_accs = []

    def train(self, epoch):
        self.iteration(epoch, self.train_data)

    def valid(self, epoch):
        self.iteration(epoch, self.valid_data, train=False)

    def iteration(self, epoch, data_loader, train=True):
        str_code = "train" if train else "valid"
        data_iter = tqdm.tqdm(enumerate(data_loader), desc="EP_%s:%d" % (str_code, epoch), total=len(data_loader), bar_format="{l_bar}{r_bar}")
        avg_loss = 0.0
        total_correct = 0
        total_element = 0

        for i, data in data_iter:
            data = {key: value.to(self.device) for key, value in data.items()}
            pr_output = self.model.forward(data["bert_input"])
            pr_loss = self.criterion(pr_output.transpose(1, 2), data["token_label"])
            loss = pr_loss
            # training時のみ,backwardとoptimizer更新を行う
            if train:
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()
            
            correct = pr_output.argmax(dim=-1).eq(data["token_label"]).sum().item()


            avg_loss += loss.item()
            total_correct += correct
            total_element += data["token_label"].nelement() * 128
            post_fix = {
                "epoch": epoch,
                "iter": i,
                "avg_loss": avg_loss / (i + 1),
                "avg_acc": total_correct / total_element * 100,
                "loss": loss.item()
            }

            if i % self.log_freq == 0:
                data_iter.write(str(post_fix))

        print("EP%d_%s, avg_loss=" % (epoch, str_code), avg_loss / len(data_iter),  "total_acc=", total_correct * 100.0 / total_element)
        if train:
            self.train_losses.append(avg_loss / len(data_iter))
            self.train_accs.append(total_correct * 100.0 / total_element)
        else:
            self.valid_losses.append(avg_loss / len(data_iter))
            self.valid_accs.append(total_correct * 100.0 / total_element)
        
    def save_bert(self, epoch, file_path=output_dir + "bert_model"):
        output_path = file_path + ".ep%d" % epoch
        torch.save(self.bert.state_dict(), output_path)
        self.bert.to(self.device)
        print("EP:%d Model Saved on:" % epoch, output_path)
        return output_path
    
    def save_pretrain(self, epoch, file_path=output_dir + "mlm_model"):
        state = {
            'epoch' : epoch,
            'state_dict' : self.model.module.state_dict(),
            'optimizer' : self.optim.state_dict()
        }
        output_path = file_path + ".ep%d" % epoch
        torch.save(state, output_path)
        self.bert.to(self.device)
        print("EP:%d Model Saved on:" % epoch, output_path)
        return output_path
 
    def load_bert(self, file_path):
        self.bert.state_dict(torch.load(file_path))
        
    def load_pretrain(self, file_path):
        state = torch.load(file_path)
        self.model.module.load_state_dict(state['state_dict'])
        self.optim.load_state_dict(state['optimizer'])
        return state

In [17]:
class MLMTrainer:
    def __init__(self, bert: BERT, vocab_size: int,
                 train_dataloader: DataLoader, valid_dataloader: DataLoader = None,
                 lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01,
                 with_cuda: bool = True, log_freq: int = 10):
        # GPU環境において、GPUを指定しているかのフラグ
        cuda_condition = torch.cuda.is_available() and with_cuda
        self.device = torch.device("cuda:0" if cuda_condition else "cpu")
        self.bert = bert
        self.model = BERTLM(bert, vocab_size).to(self.device)

        if torch.cuda.device_count() > 1:
            print("Using %d GPUS for BERT" % torch.cuda.device_count())
            self.model = nn.DataParallel(self.model)

        self.train_data = train_dataloader
        self.valid_data = valid_dataloader

        self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)

        # masked_token予測のためのLoss関数を設定
        self.criterion = nn.NLLLoss()
        self.log_freq = log_freq
        print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
        
        self.train_losses = []
        self.valid_losses = []
        self.train_accs = []
        self.valid_accs = []

    def train(self, epoch):
        self.iteration(epoch, self.train_data)

    def valid(self, epoch):
        self.iteration(epoch, self.valid_data, train=False)

    def iteration(self, epoch, data_loader, train=True):
        """
        :param epoch: 現在のepoch
        :param data_loader: torch.utils.data.DataLoader
        :param train: trainかtestかのbool値
        """
        str_code = "train" if train else "valid"
        data_iter = tqdm.tqdm(enumerate(data_loader), desc="EP_%s:%d" % (str_code, epoch), total=len(data_loader), bar_format="{l_bar}{r_bar}")
        avg_loss = 0.0
        total_correct = 0
        total_element = 0

        for i, data in data_iter:
            # 0. batch_dataはGPU or CPUに載せる
            data = {key: value.to(self.device) for key, value in data.items()}
            # 1. forward the next_sentence_prediction and masked_lm model
            mask_lm_output = self.model.forward(data["bert_input"])
            # 2-2. NLLLoss(negative log likelihood) : predicting masked token word
            mask_loss = self.criterion(mask_lm_output.transpose(1, 2), data["bert_label"])
            # 2-3. next_lossとmask_lossの合計をlossとする
            loss = mask_loss
            # 3. training時のみ,backwardとoptimizer更新を行う
            if train:
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()

            correct = mask_lm_output.argmax(dim=-1).eq(data["bert_label"]).sum().item()
            avg_loss += loss.item()
            total_element += data["bert_label"].nelement() * 128
            total_correct += correct
            post_fix = {
                "epoch": epoch,
                "iter": i,
                "avg_loss": avg_loss / (i + 1),
                "avg_acc": total_correct / total_element * 100,
                "loss": loss.item()
            }

            if i % self.log_freq == 0:
                data_iter.write(str(post_fix))
            
        if train:
            self.train_losses.append(avg_loss / len(data_iter))
            self.train_accs.append(total_correct * 100.0 / total_element)
        else:
            self.valid_losses.append(avg_loss / len(data_iter))
            self.valid_accs.append(total_correct * 100.0 / total_element)

        print("EP%d_%s, avg_loss=" % (epoch, str_code), avg_loss / len(data_iter))

        
    def save_bert(self, epoch, file_path=output_dir + "bert"):
        output_path = file_path + ".ep%d" % epoch
        torch.save(self.bert.state_dict(), output_path)
        self.bert.to(self.device)
        print("EP:%d Model Saved on:" % epoch, output_path)
        return output_path
    
    def save_pretrain(self, epoch, file_path=output_dir + "mlm"):
        state = {
            'epoch' : epoch,
            'state_dict' : self.model.module.state_dict(),
            'optimizer' : self.optim.state_dict(),
            'train_loss' : self.train_losses,
            'valid_loss' : self.valid_losses
        }
        output_path = file_path + ".ep%d" % epoch
        torch.save(state, output_path)
        self.bert.to(self.device)
        print("EP:%d Model Saved on:" % epoch, output_path)
        return output_path
 
    def load_bert(self, file_path):
        self.bert.state_dict(torch.load(file_path))
        
    def load_pretrain(self, file_path):
        state = torch.load(file_path)
        self.model.module.load_state_dict(state['state_dict'])
        self.optim.load_state_dict(state['optimizer'])
        self.train_losses = state['train_loss']
        self.valid_losses = state['valid_loss']
        return state

In [18]:
import matplotlib.pyplot as plt

import matplotlib.ticker as ticker
import numpy as np
plt.style.use('ggplot')

def showPlot(points):

    fig, ax = plt.subplots()
    plt.plot(points)
    plt.show()
    
def savePlot(points, figure_path):
    plt.switch_backend('Agg')
    plt.figure()
    fig, ax = plt.subplots()
    loc = ticker.MultipleLocator(base=0.2)
    ax.yaxis.set_major_locator(loc)
    plt.plot(points)
    plt.savefig(figure_path)

In [19]:
# 訓練用パラメタを定義する
vocab_path= read_dir + 'vocab.txt'

hidden=256 #768
layers=2 #12
attn_heads=4 #12
seq_len=128

batch_size=128
num_workers=0
with_cuda=True
log_freq=100
corpus_lines=None

lr=1e-3
adam_weight_decay=0.00
adam_beta1=0.9
adam_beta2=0.999

dropout=0.0

min_freq=7

label_path=None

In [20]:
#corpus_path=processed_train_txt
#build(corpus_path, vocab_path, min_freq=min_freq)

In [21]:

print("Loading Vocab", vocab_path)
vocab = WordVocab.load_vocab(vocab_path)

from dataset import MixDataset
print("Loading Train Dataset", processed_train_txt)
train_dataset = MixDataset(processed_train_txt, vocab, seq_len=seq_len, label_path=label_path, corpus_lines=corpus_lines)

print("Loading Valid Dataset", processed_valid_txt)
valid_dataset = MixDataset(processed_valid_txt, vocab, seq_len=seq_len, label_path=label_path) if processed_valid_txt is not None else None

print("Creating Dataloader")
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers)
valid_data_loader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=num_workers) if processed_valid_txt is not None else None

Loading Vocab ./data/merged_bert/vocab.txt
Loading Train Dataset ./data/merged_bert/train_X.txt
Loading Valid Dataset ./data/merged_bert/valid_X.txt
Creating Dataloader


In [22]:
print("Building BERT model")
epd_bert = BERT(len(vocab), hidden=hidden, n_layers=layers, attn_heads=attn_heads, dropout=dropout)
epd_trainer = EPDTrainer(epd_bert, train_dataloader=train_data_loader, valid_dataloader=valid_data_loader,
                 lr=lr, betas=(adam_beta1, adam_beta2), weight_decay=adam_weight_decay,
                 with_cuda=with_cuda, log_freq=log_freq, head=9)

esd_bert = BERT(len(vocab), hidden=hidden, n_layers=layers, attn_heads=attn_heads, dropout=dropout)
esd_trainer = HybridESDTrainer(esd_bert, train_dataloader=train_data_loader, valid_dataloader=valid_data_loader,
                                  lr=lr, betas=(adam_beta1, adam_beta2), weight_decay=adam_weight_decay,
                                  with_cuda=with_cuda, log_freq=log_freq, head=9)

layers=4 #12
attn_heads=8 #12
mlm_bert = BERT(len(vocab), hidden=hidden, n_layers=layers, attn_heads=attn_heads, dropout=dropout)
mlm_trainer = MLMTrainer(mlm_bert, len(vocab), train_dataloader=train_data_loader, valid_dataloader=valid_data_loader,
                  lr=lr, betas=(adam_beta1, adam_beta2), weight_decay=adam_weight_decay,
                  with_cuda=with_cuda, log_freq=log_freq)


Building BERT model
Using 2 GPUS for BERT
Total Parameters: 2761481
Using 2 GPUS for BERT
Total Parameters: 2761501
Using 2 GPUS for BERT
Total Parameters: 5522944


In [23]:
esd_model_path = output_dir + 'mix/esd/small/small.ep15'
epd_model_path = output_dir + 'mix/epd/small/small.ep34'
mlm_model_path = output_dir + 'mlm/mid/.ep16'
_ = epd_trainer.load_pretrain(epd_model_path)

_ = esd_trainer.load_pretrain(esd_model_path)


_ = mlm_trainer.load_pretrain(mlm_model_path)


In [26]:
test_dataset = MixDataset(processed_test_txt, vocab, seq_len=seq_len, label_path=label_path, is_train=False)
test_loader = DataLoader(test_dataset, batch_size=1, num_workers=num_workers, shuffle=True)

In [83]:
def test_all():
    esd_trainer.model.eval()
    epd_trainer.model.eval()
    mlm_trainer.model.eval()
    with torch.no_grad():
        for batch in test_loader:
            input_line = []
            text = batch['bert_input'][0]
            original = batch['original'][0]
            esd_output = esd_trainer.model(batch['bert_input'])
            org_text = [vocab.itos[i] for i in original if i!=0]
            print("".join(org_text))
            processed_text = [vocab.itos[i] for i in text if i!=0]
            processed_text = processed_text[1:-1]
            print("".join(processed_text))
            print(batch['label'].item(), esd_output.argmax(dim=-1).item())
            
            processed_index = [vocab.sos_index] +[vocab.stoi[i] for i in processed_text] + [vocab.eos_index]
            processed_batch = torch.tensor(processed_index)
            processed_batch = processed_batch.unsqueeze(0)
            esd_output = esd_trainer.model(processed_batch)
            #NG文だけ次のステップへ
            if esd_output.argmax(dim=-1).item()==1:
                epd_output = epd_trainer.model(processed_batch)[0]
                predict_list = []
                masked_list = []
                masked_label = []
                text = processed_batch[0]
                for j in range(len(processed_batch[0])):
                    predict = epd_output[j].argmax(dim=-1).item()
                    if predict == 0:
                        # 正しい場合
                        masked_list.append(text[j])
                        masked_label.append(-1)
                        predict_list.append(0)
                    elif predict == 1:
                        # 脱字(1文字)
                        masked_list.extend([text[j], vocab.mask_index])
                        masked_label.extend([-1,-2])
                        predict_list.extend([1,-1])
                    elif predict == 2:
                        # 脱字(2文字)
                        masked_list.extend([text[j],vocab.mask_index, vocab.mask_index])
                        masked_label.extend([-1,-2,-2])
                        predict_list.extend([2,-1,-1])
                    elif predict == 3:
                        # 誤字
                        masked_list.append(vocab.mask_index)
                        masked_label.append(text[j])
                        predict_list.append(3)
                    elif predict == 4:
                        # 誤字かつ脱字(1文字)
                        masked_list.extend([vocab.mask_index, vocab.mask_index])
                        masked_label.extend([text[j],-2])
                        predict_list.extend([4,-1])
                    elif predict == 5:
                        # 誤字かつ脱字(2文字)
                        masked_list.extend([vocab.mask_index, vocab.mask_index, vocab.mask_index])
                        masked_label.extend([text[j],-2,-2])
                        predict_list.extend([5,-1,-1])
                    elif predict == 6:
                        # 衍字
                        predict_list.append(6)
                    elif predict == 7:
                        # 衍字かつ脱字(1文字)
                        masked_list.append(vocab.mask_index)
                        masked_label.append(-2)
                        predict_list.extend([7,-1])
                    elif predict == 8:
                        # 衍字かつ脱字(2文字)
                        masked_list.extend([vocab.mask_index, vocab.mask_index])
                        masked_label.extend(-2,-2)
                        predict_list.extend([8,-1,-1])
                no_pad_text = [vocab.itos[i] for i in masked_list if i!=0]
                no_pad_text = no_pad_text[1:-1]
                print("".join(no_pad_text))
                masked_list = masked_list[:128]
                masked_label = masked_label[:128]
                predict_list = predict_list[1:129]
                masked_text = [vocab.itos[i] for i in masked_list]
                masked_batch = torch.tensor(masked_list)
                masked_batch = masked_batch.unsqueeze(0)
                mlm_output = mlm_trainer.model(masked_batch)
                print(len(epd_output[0]))
                print(len(masked_text))
                enji_flag = False
                for j in range(len(masked_text)):
                    _, topi = mlm_output[0, j].topk(3)
                    predict = [vocab.itos[index.item()] for index in topi]
                    if masked_list[j] in [0,2,3]:
                        continue

                    if predict_list[j] in [6,7,8]:
                        print(predict_list[j],":<del>")

                    print(predict_list[j],":",end="")
                    if masked_label[j] == -1:
                        print(masked_text[j])
                    elif masked_label[j] == -2:
                        print(masked_text[j],':脱:',predict)
                    else:
                        print(masked_text[j],':誤:',predict)

test_all()

野中広務官房長官と同会議の樋口広太郎議長が１９日、首相官邸で会談して決めた。
野中広務官房長官と同会議の樋口広太郎議長や１９日、首相官邸で会談して決めっ。
1 1
野中広務官房長官と同会議の樋口広太郎議長<mask>１９日、首相官邸で会談して決め<mask>。
9
40
0 :野
0 :中
0 :広
0 :務
0 :官
0 :房
0 :長
0 :官
0 :と
0 :同
0 :会
0 :議
0 :の
0 :樋
0 :口
0 :広
0 :太
0 :郎
0 :議
3 :長
0 :<mask> :誤: ['は', 'が', 'も']
0 :１
0 :９
0 :日
0 :、
0 :首
0 :相
0 :官
0 :邸
0 :で
0 :会
0 :談
0 :し
0 :て
0 :決
3 :め
0 :<mask> :誤: ['た', 'る', 'て']
0 :。
文部省の現職幹部は明かす。
文部省の現職幹部は明かす。
0 1
文部省の現職幹部は明かす。
9
15
0 :文
0 :部
0 :省
0 :の
0 :現
0 :職
0 :幹
0 :部
0 :は
0 :明
0 :か
0 :す
0 :。
ドナーとなり得る患者が発生した場合、同センターではこの原則を家族に説明する。
ドナーとなり得る患者が発生した場合、同センターではこの原則を家族に説明する。
0 0
一方、人権、台湾問題では双方が原則的立場を主張し、協議は平行線に終わった。クリントン大統領は「政治的意見を表明しようとした人々を逮捕するなどダライ・ラマとの対話が進んでいないことも遺憾だ」と述べた。
一方、人権、台湾問題ででは双方原則的立場を主張、協議は平行線に終わった。クリントン大統領は「政治的意見をを表明しｙうとしした人々を逮捕するながダライ・ラマとの対話進んでいないことも遺憾だ」と述べべた。
1 1
一方、人権、台湾問題では双方<mask><mask>原則的立場を主張、協議は平行線に終わった。クリントン大統領は「政治的意見を表明し<mask>うとした人々を逮捕するながダライ・ラマとの対話<mask><mask>進んでいないことも遺憾だ」と述べた。
9
102
0 :一
0 :方
0 :、
0 :人
0 :権
0 :、
0 :台
0 :湾
0 :問
0 :題
6 :<del>
6 :で
0 :は
0 :双
2 :方
-1 :<mask>

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



 :教
0 :育
0 :の
0 :性
0 :質
0 :上
0 :、
0 :学
0 :校
0 :内
0 :で
0 :あ
0 :る
0 :種
0 :の
0 :強
0 :制
0 :が
0 :が
0 :働
0 :く
6 :<del>
6 :の
0 :は
0 :当
0 :然
0 :で
0 :あ
0 :り
0 :、
0 :入
0 :学
0 :式
0 :や
6 :<del>
6 :卒
0 :業
0 :式
0 :児
0 :童
0 :・
0 :生
0 :徒
0 :全
0 :員
0 :に
0 :国
0 :歌
0 :を
0 :斉
0 :唱
0 :さ
0 :<mask> :誤: ['せ', 'れ', 'る']
3 :<mask> :誤: ['て', 'に', 'か']
3 :も
0 :問
0 :題
0 :は
0 :は
0 :な
2 :<mask> :脱: ['い', 'な', 'っ']
-1 :<mask> :脱: ['か', 'た', 'い']
-1 :。
0 :ど
0 :う
1 :<mask> :脱: ['し', 'み', 'っ']
-1 :て
0 :も
0 :国
0 :歌
0 :を
0 :斉
0 :唱
0 :し
0 :た
6 :<del>
6 :く
0 :け
0 :い
0 :生
0 :徒
0 :<mask> :誤: ['の', 'が', 'は']
0 :力
3 :ず
0 :く
0 :で
0 :歌
0 :わ
0 :た
0 :<mask> :脱: ['り', 'く', 'れ']
1 :、
-1 :国
0 :歌
0 :を
0 :斉
0 :唱
0 :し
0 :な
0 :い
0 :か
0 :ら
0 :と
0 :い
0 :<mask> :誤: ['っ', 'う', 'わ']
0 :<mask> :誤: ['た', 'て', 'れ']
3 :停
3 :学
0 :・
0 :退
0 :学
0 :処
0 :分
0 :を
0 :課
0 :す
0 :と
0 :い
6 :<del>
6 :う
0 :こ
0 :と
6 :<del>
6 :が
0 :あ
0 :れ
0 :ば
6 :<del>
6 :当
0 :然
自宅ではバラを栽培。
自宅でｈバラを栽培。
1 1
自宅で<mask>バラを栽培。
9
12
0 :自
0 :宅
3 :で
0 :<mask> :誤: ['は', 'も', 'の']
0 :バ
0 :ラ
0

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



 :参
0 :加
0 :し
0 :た
0 :約
0 :１
0 :５
0 :０
0 :０
0 :人
0 :が
0 :華
0 :道
0 :精
0 :進
0 :を
6 :<del>
6 :誓
0 :っ
0 :た
0 :。
ほかに大正時代の流行歌や童謡なども。
ほかに大正時代の流行歌で童謡なども。
1 0
広い敷地を一つの「家」と見立て、さまざまな趣向を凝らした七つのオープンスペースを設ける。
広ぇ敷地を一つの「家」と見立て、さまざまな趣向をを凝らをた七ｔのオープンスペースをを設けるる。
1 1
広<mask>敷地を一つの「家」と見立て、さまざまな趣向を凝ら<mask>た七<mask>のオープンスペースを設ける。
9
46
3 :広
0 :<mask> :誤: ['い', 'く', 'の']
0 :敷
0 :地
0 :を
0 :一
0 :つ
0 :の
0 :「
0 :家
0 :」
0 :と
0 :見
0 :立
0 :て
0 :、
0 :さ
0 :ま
0 :ざ
0 :ま
0 :な
0 :趣
0 :向
6 :<del>
6 :を
0 :凝
0 :ら
3 :<mask> :誤: ['し', 'せ', 'す']
0 :た
0 :七
3 :<mask> :誤: ['つ', 'て', 'め']
0 :の
0 :オ
0 :ー
0 :プ
0 :ン
0 :ス
0 :ペ
0 :ー
0 :ス
0 :を
6 :<del>
6 :設
0 :け
0 :る
0 :。
この映画は、すごい発明をしたジョー
この映画は、すごい発明をしたジョー
0 0
意外だったのはアルバニア系代表団が３年間の暫定期間後に独立を問う住民投票実施に固執した点だ。
意外だったたのアルバニア系代表団が３年間の暫定期間後に独立を問う住民投票実施に固執しした点だ。
1 1
意外だったのアルバニア系代表団が３年間の暫定期間後に独立を問う住民投票実施に固執した点だ。
9
47
0 :意
0 :外
0 :だ
0 :っ
6 :<del>
6 :た
0 :の
0 :ア
0 :ル
0 :バ
0 :ニ
0 :ア
0 :系
0 :代
0 :表
0 :団
0 :が
0 :３
0 :年
0 :間
0 :の
0 :暫
0 :定
0 :期
0 :間
0 :後
0 :に
0 :独
0 :立
0 :を
0 :問
0 :う
0 :住
0 :民
0 :投
0 :票
0

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



た
0 :あ
0 :と
0 :、
0 :昭
0 :和
0 :２
0 :９
0 :年
0 :氏
0 :子
0 :さ
0 :ん
0 :ち
0 :の
0 :要
0 :望
0 :で
0 :現
0 :在
0 :の
0 :形
0 :に
0 :<mask> :誤: ['な', 'あ', 'と']
3 :り
0 :ま
0 :し
6 :<del>
6 :た
0 :。
0 :１
6 :<del>
6 :０
0 :月
0 :の
0 :秋
0 :祭
0 :り
0 :は
0 :熱
0 :心
0 :に
0 :や
0 :っ
0 :て
0 :い
0 :ま
伊吹文明労相が２８日の会見で明らかにした。
伊吹文明労相がが２８日の会見でで明らにした。
1 1
伊吹文明労相が２８日の会見で明ら<mask>にした。
9
23
0 :伊
0 :吹
0 :文
0 :明
0 :労
0 :相
6 :<del>
6 :が
0 :２
0 :８
0 :日
0 :の
0 :会
0 :見
0 :で
6 :<del>
6 :明
0 :ら
1 :<mask> :脱: ['か', 'ぎ', 'み']
-1 :に
0 :し
0 :た
0 :。
シャスのスウィサ内相はラジオで、新政権で内相ポストを狙う移民党を「シャスの内相を怖がっている。豚肉を売る店が閉められ、売春婦が入国できなくなるからだ」と非難した。
シャスｎスウィサ内相はラジオ、新政権で内相ポストを狙うう移民党を「シャスの内相をを怖がっていいる。豚肉を売る店が閉められ、売春婦は入国できくなるからだ」と非難した。
1 1
シャス<mask>スウィサ内相はラジオ、新政権で内相ポストを狙う移民党を「シャスの内相を怖がっている。豚肉を売る店が閉められ、売春婦は入国でき<mask><mask>くなるからだ」と非難した。
9
83
0 :シ
0 :ャ
3 :ス
0 :<mask> :誤: ['の', 'と', 'や']
0 :ス
0 :ウ
0 :ィ
0 :サ
0 :内
0 :相
0 :は
0 :ラ
0 :ジ
0 :オ
0 :、
0 :新
0 :政
0 :権
0 :で
0 :内
0 :相
0 :ポ
0 :ス
0 :ト
0 :を
0 :狙
6 :<del>
6 :う
0 :移
0 :民
0 :党
0 :を
0 :「
0 :シ
0 :ャ
0 :ス
0 :の
0 :内
0 :相
0 :を
6 :<d

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



 :事
0 :の
0 :合
0 :間
3 :を
0 :<mask> :誤: ['も', 'や', 'と']
0 :っ
0 :て
0 :体
0 :力
0 :維
0 :持
0 :に
0 :努
0 :め
0 :、
0 :戦
0 :術
0 :を
0 :勉
0 :強
0 :し
0 :て
0 :い
0 :る
0 :。
調べでは、金容疑者は１０日午前９時ごろ、大阪市中央区宗右衛門町の雑居ビル地下１階で、兵頭さんを殴ったりけったりして、キャッシュカードなどが入った財布を奪った疑い。兵頭さんは同日夜、意識不明で倒れているのを発見されたが、１２日に死亡した。
調べは、金容疑者は１０日午前９時ご、大阪市中央区宗右衛門町の雑居ビル地下１階で、兵頭ささんを殴ったりりけったたりしｔ、キャッシュカードなどどが入った財布を奪っｔ疑い。兵頭さん同日夜、意識不明でで倒れてちるるのを発見されたが、１２日に死亡しｔ。
1 1
調べ<mask>は、金容疑者は１０日午前９時ご<mask><mask>、大阪市中央区宗右衛門町の雑居ビル地下１階で、兵頭さんを殴ったりけったりし<mask>、キャッシュカードなどが入った財布を奪っ<mask>疑い。兵頭さん<mask><mask>同日夜、意識不明で倒れて<mask>るのを発見されたが、１２日に死亡し<mask>。
9
122
1 :調
-1 :べ
0 :<mask> :脱: ['で', 'に', 'ら']
0 :は
0 :、
0 :金
0 :容
0 :疑
0 :者
0 :は
0 :１
0 :０
0 :日
0 :午
0 :前
0 :９
2 :時
-1 :ご
-1 :<mask> :脱: ['ろ', 'ご', 'と']
0 :<mask> :脱: ['に', 'ろ', 'で']
0 :、
0 :大
0 :阪
0 :市
0 :中
0 :央
0 :区
0 :宗
0 :右
0 :衛
0 :門
0 :町
0 :の
0 :雑
0 :居
0 :ビ
0 :ル
0 :地
0 :下
0 :１
0 :階
0 :で
0 :、
0 :兵
0 :頭
6 :<del>
6 :さ
0 :ん
0 :を
0 :殴
0 :っ
0 :た
0 :り
6 :<del>
6 :け
0 :っ
0 :た
0 :り
6 :<del>
6 :し
0 :<mask> :誤: ['て', 'た', 'く']
0 :、
3 :キ

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [122]:
def test_input(original_text ,processed_text):
    esd_trainer.model.eval()
    epd_trainer.model.eval()
    mlm_trainer.model.eval()
    with torch.no_grad():
        input_line = []
        original = original_text
        processed_index = [vocab.sos_index] +[vocab.stoi[i] for i in processed_text] + [vocab.eos_index]
        processed_batch = torch.tensor(processed_index)
        processed_batch = processed_batch.unsqueeze(0)
        esd_output = esd_trainer.model(processed_batch)
        print(original_text)
        print(processed_text)
        #NG文だけ次のステップへ
        #if esd_output.argmax(dim=-1).item()==1:
        epd_output = epd_trainer.model(processed_batch)[0]
        predict_list = []
        masked_list = []
        masked_label = []
        text = processed_batch[0]
        for j in range(len(processed_batch[0])):
            predict = epd_output[j].argmax(dim=-1).item()
            if predict == 0:
                # 正しい場合
                masked_list.append(text[j])
                masked_label.append(-1)
                predict_list.append(0)
            elif predict == 1:
                # 脱字(1文字)
                masked_list.extend([text[j], vocab.mask_index])
                masked_label.extend([-1,-2])
                predict_list.extend([1,-1])
            elif predict == 2:
                # 脱字(2文字)
                masked_list.extend([text[j],vocab.mask_index, vocab.mask_index])
                masked_label.extend([-1,-2,-2])
                predict_list.extend([2,-1,-1])
            elif predict == 3:
                # 誤字
                masked_list.append(vocab.mask_index)
                masked_label.append(text[j])
                predict_list.append(3)
            elif predict == 4:
                # 誤字かつ脱字(1文字)
                masked_list.extend([vocab.mask_index, vocab.mask_index])
                masked_label.extend([text[j],-2])
                predict_list.extend([4,-1])
            elif predict == 5:
                # 誤字かつ脱字(2文字)
                masked_list.extend([vocab.mask_index, vocab.mask_index, vocab.mask_index])
                masked_label.extend([text[j],-2,-2])
                predict_list.extend([5,-1,-1])
            elif predict == 6:
                # 衍字
                predict_list.append(6)
            elif predict == 7:
                # 衍字かつ脱字(1文字)
                masked_list.append(vocab.mask_index)
                masked_label.append(-2)
                predict_list.extend([7,-1])
            elif predict == 8:
                # 衍字かつ脱字(2文字)
                masked_list.extend([vocab.mask_index, vocab.mask_index])
                masked_label.extend(-2,-2)
                predict_list.extend([8,-1,-1])
        
        no_pad_text = [vocab.itos[i] for i in masked_list if i!=0]
        no_pad_text = no_pad_text[1:-1]
        print("".join(no_pad_text))
        masked_list = masked_list[:128]
        masked_label = masked_label[:128]
        predict_list = predict_list[1:129]
        masked_text = [vocab.itos[i] for i in masked_list]
        masked_batch = torch.tensor(masked_list)
        masked_batch = masked_batch.unsqueeze(0)
        mlm_output = mlm_trainer.model(masked_batch)
        print(len(epd_output[0]))
        print(len(masked_text))
        enji_flag = False
        for j in range(len(masked_text)):
            _, topi = mlm_output[0, j].topk(3)
            predict = [vocab.itos[index.item()] for index in topi]
            if masked_list[j] in [0,2,3]:
                continue

            if predict_list[j] in [6,7,8]:
                print(predict_list[j],":<del>")
                
            print(predict_list[j],":",end="")
            if masked_label[j] == -1:
                print(masked_text[j])
            elif masked_label[j] == -2:
                print(masked_text[j],':脱:',predict)
            else:
                print(masked_text[j],':誤:',predict)

test_input("","大分空港ケ岳の山中に墜落ているのが発見さた。")


大分空港ケ岳の山中に墜落ているのが発見さた。
大分空港ケ岳の山中に墜落<mask><mask>ているのが発見さ<mask><mask>た。
9
28
0 :大
0 :分
0 :空
0 :港
0 :ケ
0 :岳
0 :の
0 :山
0 :中
0 :に
2 :墜
-1 :落
-1 :<mask> :脱: ['さ', 'を', 'と']
0 :<mask> :脱: ['せ', 'し', 'み']
0 :て
0 :い
0 :る
0 :の
0 :が
0 :発
2 :見
-1 :さ
-1 :<mask> :脱: ['れ', 'え', 'ま']
0 :<mask> :脱: ['っ', 'れ', 'た']
0 :た
0 :。


In [None]:
def test_esd():
    esd_trainer.model.eval()
    with torch.no_grad():
        n_correct = 0
        true_positive = 0
        true_negative = 0
        false_positive = 0
        false_negative = 0
        for i, batch in enumerate(tqdm.tqdm(test_loader)):
            if i==10000:
                break
            input_line = []
            esd_output = esd_trainer.model(batch['bert_input'])
            if batch['label'][0] == 1:
                if esd_output.argmax(dim=-1).item() == 1:
                    #NG文をNGと判定したとき
                    true_positive += 1
                    n_correct += 1
                else:
                    #NG文をOKと判定したとき
                    false_negative += 1
            else:
                if esd_output.argmax(dim=-1).item() == 1:
                    #OK文をNGと判定したとき
                    false_positive += 1
                else:
                    #OK文をOKと判定したとき
                    n_correct += 1
                    true_negative += 1
            print("検出率："+ str(n_correct/(i+1)))

    with open("./esd.txt","w", encoding='utf-8') as f:
        f.write("検出率："+ str(n_correct/i)+"\n")
        f.write("再現率："+ str(true_positive/(true_positive+false_negative))+"\n")
        f.write("適合率："+ str(true_positive/(true_positive+false_positive))+"\n")
        f.write("特異度："+ str(true_negative/(false_positive+true_negative))+"\n")
        f.write("F値："+ str(2*true_positive/(2*true_positive+false_negative+false_positive))+"\n")

test_esd()

In [None]:
def test_epd():
    epd_trainer.model.eval()
    with torch.no_grad():
        all_num = 0
        n_correct = 0
        p_correct = 0
        true_positive = 0
        true_negative = 0
        false_positive = 0
        false_negative = 0
        for i, batch in enumerate(tqdm.tqdm(test_loader)):
            if i==10000:
                break
            output = epd_trainer.model(batch['bert_input'])
            for j in range(128):
                all_num += 1
                true =  batch["token_label"][0,j].item()
                predict = output[0,j].argmax(dim=-1).item()
                if true != 0:
                    if predict != 0:
                        true_positive += 1
                        n_correct += 1
                        if predict == true:
                            p_correct += 1
                    else:
                        false_negative += 1
                else:
                    if predict != 0:
                        false_positive += 1
                    else:
                        true_negative += 1
                        n_correct += 1
            print("検出率：", n_correct/all_num)
            #print("完全検出率：", p_correct/all_num)
        print("再現率：", true_positive/(true_positive+false_negative))
        print("適合率：", true_positive/(true_positive+false_positive))
        print("特異度：", true_negative/(false_positive+true_negative))
        print("F値：", 2*true_positive/(2*true_positive+false_negative+false_positive))
    
    with open("./epd.txt","w", encoding='utf-8') as f:
        f.write("検出率："+ str(n_correct/all_num)+"\n")
        f.write("再現率："+ str(true_positive/(true_positive+false_negative))+"\n")
        f.write("適合率："+ str(true_positive/(true_positive+false_positive))+"\n")
        f.write("特異度："+ str(true_negative/(false_positive+true_negative))+"\n")
        f.write("F値："+ str(2*true_positive/(2*true_positive+false_negative+false_positive))+"\n")
test_epd()

In [75]:
import copy
def test_mlm_proofread(original_text ,processed_text):
    mlm_trainer.model.eval()
    processed_list = [vocab.stoi[i] for i in processed_text]
    error_list = []
    for i in range(len(processed_list)):
        masked_list = copy.copy(processed_list)
        masked_list[i] = vocab.mask_index 
        masked_batch = torch.tensor(masked_list)
        masked_batch = masked_batch.unsqueeze(0)
        mlm_output = mlm_trainer.model(masked_batch)

        _, topi = mlm_output[0, i].topk(3)
        topi = [vocab.itos[j.item()] for j in topi]
        print(processed_text[i],topi)
        #if processed_list[i] in topi:
        #    error_list.append(1)
        # else:
        #    error_list.append(0)
    print(error_list)
test_mlm_proofread("","待てども結果が出ｒことはなかった。")

待 ['あ', 'か', 'ん']
て ['<pad>', 'ふ', 'ゆ']
ど ['ぬ', 'に', 'は']
も ['<pad>', 'に', 'く']
結 ['<pad>', 'ぱ', 'は']
果 ['ぽ', 'ぞ', 'ぐ']
が ['ぞ', 'ぽ', 'ぐ']
出 ['<pad>', 'は', 'ぱ']
ｒ ['<pad>', 'に', 'く']
こ ['ぬ', 'に', 'は']
と ['<pad>', 'ふ', 'ゆ']
は ['あ', 'か', 'ん']
な ['ぞ', 'ぽ', 'ぐ']
か ['あ', '<pad>', 'ぞ']
っ ['<pad>', 'で', 'が']
た ['ぬ', 'に', 'は']
。 ['<pad>', 'ふ', 'お']
[]


In [None]:
import matplotlib.pyplot as plt
plt.switch_backend('Agg')
import matplotlib.ticker as ticker
import numpy as np
plt.style.use('ggplot')

def showPlot(points, figure_path):
    plt.figure()
    fig, ax = plt.subplots()
    # this locator puts ticks at regular intervals
    loc = ticker.MultipleLocator(base=0.2)
    ax.yaxis.set_major_locator(loc)
    plt.plot(points)
    plt.savefig(figure_path)

In [None]:
showPlot(trainer.train_losses, "./results/"+str(epochs)+"_train.png")
showPlot(trainer.valid_losses, "./results/"+str(epochs)+"_valid.png")