必要ModuleをImport

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

import math

import random
import time
import pickle
import tqdm
from collections import Counter

from torch.utils.data import Dataset
import numpy as np

from utils import GELU, PositionwiseFeedForward, LayerNorm, SublayerConnection, LayerNorm

import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
from ipywidgets import FloatProgress
from IPython.display import display, clear_output

read_dirでディレクトリを指定<br>
このディレクトリ内のコーパスを読み込む<br>
vocabファイルもここに作成される

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#read_dir = './data/single/'
read_dir = './data/merged_bert/'
#output_dir = './output/single/'
output_dir = './output/merged_bert/'
processed_train_txt = read_dir + 'train_X.txt'
processed_valid_txt = read_dir + 'valid_X.txt'
processed_test_txt = read_dir + 'test_X.txt'

Attentionセルを定義する

In [3]:
class Attention(nn.Module):
    """
    Scaled Dot Product Attention
    """
    def forward(self, query, key, value, mask=None, dropout=None):
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        p_attn = F.softmax(scores, dim=-1)

        if dropout is not None:
            p_attn = dropout(p_attn)

        return torch.matmul(p_attn, value), p_attn


Multi Head Attentionを定義する

In [4]:
class MultiHeadedAttention(nn.Module):

    def __init__(self, h, d_model, dropout=0.1):
        super().__init__()
        assert d_model % h == 0

        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h

        self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
        self.output_linear = nn.Linear(d_model, d_model)
        self.attention = Attention()

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) for l, x in zip(self.linear_layers, (query, key, value))]

        x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)

        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)

        return self.output_linear(x)


Transformerを定義する

In [5]:
class TransformerBlock(nn.Module):
    def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):
        super().__init__()
        self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden, dropout=dropout)
        self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)
        self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, mask):
        x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))
        x = self.output_sublayer(x, self.feed_forward)
        return self.dropout(x)


BERTクラスを定義する

In [6]:
class BERT(nn.Module):

    def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):
        super().__init__()
        self.hidden = hidden
        self.n_layers = n_layers
        self.attn_heads = attn_heads
        self.feed_forward_hidden = hidden * 4
        # embedding for BERT
        self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden, dropout=dropout)

        self.transformer_blocks = nn.ModuleList([TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)])

    def forward(self, x):
        # xの中で0以上は1, 0未満は0として, maskテンソルを作る
        mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
        x = self.embedding(x)

        for transformer in self.transformer_blocks:
            x = transformer.forward(x, mask)
        return x


BERTのEmbedding層を定義する

In [7]:
class TokenEmbedding(nn.Embedding):
    def __init__(self, vocab_size, embed_size=512):
        super().__init__(vocab_size, embed_size, padding_idx=0)

class PositionalEmbedding(nn.Module):

    def __init__(self, d_model, max_len=512):
        super().__init__()

        pe = torch.zeros(max_len, d_model).float()
        pe.require_grad = False

        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = (torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)).float().exp()

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return self.pe[:, :x.size(1)]


class BERTEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_size, dropout=0.1):
        super().__init__()
        self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size)
        self.position = PositionalEmbedding(d_model=self.token.embedding_dim)
        self.dropout = nn.Dropout(p=dropout)
        self.embed_size = embed_size

    #def forward(self, sequence, segment_label):
    def forward(self, sequence):
        x = self.token(sequence) + self.position(sequence)
        return self.dropout(x)


学習用にマスク予測の層を追加する<br>
Next Sentence Prediction用のクラスは削除してある

In [8]:
class BERTLM(nn.Module):
    def __init__(self, bert: BERT, vocab_size):
        super().__init__()
        self.bert = bert
        self.mask_lm = MaskedLanguageModel(self.bert.hidden, vocab_size)

    def forward(self, x):
        x = self.bert(x)
        return self.mask_lm(x)

class MaskedLanguageModel(nn.Module):
    def __init__(self, hidden, vocab_size):
        super().__init__()
        self.linear = nn.Linear(hidden, vocab_size)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.linear(x))

In [9]:
class BERTESD(nn.Module):
    def __init__(self, bert: BERT):
        super().__init__()
        self.bert = bert
        self.esd = ErrorSentenceDetectionHead(self.bert.hidden)

    def forward(self, x):
        x = self.bert(x)
        return self.esd(x)

class ErrorSentenceDetectionHead(nn.Module):
    def __init__(self, hidden):
        super().__init__()
        self.linear = nn.Linear(hidden, 2)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.linear(x[:, 0]))

In [10]:
class BERTEPD(nn.Module):
    def __init__(self, bert: BERT, head):
        super().__init__()
        self.bert = bert
        self.epd = ErrorPositionDetectionHead(self.bert.hidden, head)

    def forward(self, x):
        x = self.bert(x)
        return self.epd(x)

class ErrorPositionDetectionHead(nn.Module):
    def __init__(self, hidden, head):
        super().__init__()
        self.linear = nn.Linear(hidden, head)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.linear(x))

In [11]:
class BERTHybridESD(nn.Module):
    def __init__(self, epd):
        super().__init__()
        self.epd = epd
        self.esd = ErrorSentenceDetectionHead(9)

    def forward(self, x):
        x = self.epd(x)
        return self.esd(x)

class ErrorSentenceDetectionHead(nn.Module):
    def __init__(self, hidden):
        super().__init__()
        self.linear = nn.Linear(hidden, 2)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.linear(x[:, 0]))

BERT用のVocabを生成するクラスを定義する

In [12]:
import pickle
import tqdm
from collections import Counter

class TorchVocab(object):
    """
    :property freqs: collections.Counter, コーパス中の単語の出現頻度を保持するオブジェクト
    :property stoi: collections.defaultdict, string → id の対応を示す辞書
    :property itos: collections.defaultdict, id → string の対応を示す辞書
    """
    def __init__(self, counter, max_size=None, min_freq=1, specials=['<pad>', '<oov>'],
                 vectors=None, unk_init=None, vectors_cache=None):
        """
        :param coutenr: collections.Counter, データ中に含まれる単語の頻度を計測するためのcounter
        :param max_size: int, vocabularyの最大のサイズ. Noneの場合は最大値なし. defaultはNone
        :param min_freq: int, vocabulary中の単語の最低出現頻度. この数以下の出現回数の単語はvocabularyに加えられない.
        :param specials: list of str, vocabularyにあらかじめ登録するtoken
        :param vecors: list of vectors, 事前学習済みのベクトル. ex)Vocab.load_vectors
        """
        self.freqs = counter
        counter = counter.copy()
        min_freq = max(min_freq, 1)

        self.itos = list(specials)
        # special tokensの出現頻度はvocabulary作成の際にカウントされない
        for tok in specials:
            del counter[tok]

        max_size = None if max_size is None else max_size + len(self.itos)

        # まず頻度でソートし、次に文字順で並び替える
        words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
        words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
        
        # 出現頻度がmin_freq未満のものはvocabに加えない
        for word, freq in words_and_frequencies:
            if freq < min_freq or len(self.itos) == max_size:
                break
            self.itos.append(word)

        # dictのk,vをいれかえてstoiを作成する
        self.stoi = {tok: i for i, tok in enumerate(self.itos)}

        self.vectors = None
        if vectors is not None:
            self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache)
        else:
            assert unk_init is None and vectors_cache is None

    def __eq__(self, other):
        if self.freqs != other.freqs:
            return False
        if self.stoi != other.stoi:
            return False
        if self.itos != other.itos:
            return False
        if self.vectors != other.vectors:
            return False
        return True

    def __len__(self):
        return len(self.itos)

    def vocab_rerank(self):
        self.stoi = {word: i for i, word in enumerate(self.itos)}

    def extend(self, v, sort=False):
        words = sorted(v.itos) if sort else v.itos
        for w in words:
            if w not in self.stoi:
                self.itos.append(w)
                self.stoi[w] = len(self.itos) - 1


class Vocab(TorchVocab):
    def __init__(self, counter, max_size=None, min_freq=1):
        self.pad_index = 0
        self.unk_index = 1
        self.eos_index = 2
        self.sos_index = 3
        self.mask_index = 4
        super().__init__(counter, specials=["<pad>", "<unk>", "<eos>", "<sos>", "<mask>"], max_size=max_size, min_freq=min_freq)

    # override用
    def to_seq(self, sentece, seq_len, with_eos=False, with_sos=False) -> list:
        pass

    # override用
    def from_seq(self, seq, join=False, with_pad=False):
        pass

    @staticmethod
    def load_vocab(vocab_path: str) -> 'Vocab':
        with open(vocab_path, "rb") as f:
            return pickle.load(f)

    def save_vocab(self, vocab_path):
        with open(vocab_path, "wb") as f:
            pickle.dump(self, f)


# テキストファイルからvocabを作成する
class WordVocab(Vocab):
    def __init__(self, texts, max_size=None, min_freq=1):
        print("Building Vocab")
        counter = Counter()
        for line in texts:
            if isinstance(line, list):
                words = line
            else:
                words = line.replace("\n", "").replace("\t", "").split()

            for word in words:
                counter[word] += 1
        super().__init__(counter, max_size=max_size, min_freq=min_freq)

    def to_seq(self, sentence, seq_len=None, with_eos=False, with_sos=False, with_len=False):
        if isinstance(sentence, str):
            sentence = sentence.split()

        seq = [self.stoi.get(word, self.unk_index) for word in sentence]

        if with_eos:
            seq += [self.eos_index]  # this would be index 1
        if with_sos:
            seq = [self.sos_index] + seq

        origin_seq_len = len(seq)

        if seq_len is None:
            pass
        elif len(seq) <= seq_len:
            seq += [self.pad_index for _ in range(seq_len - len(seq))]
        else:
            seq = seq[:seq_len]

        return (seq, origin_seq_len) if with_len else seq

    def from_seq(self, seq, join=False, with_pad=False):
        words = [self.itos[idx]
                 if idx < len(self.itos)
                 else "<%d>" % idx
                 for idx in seq
                 if not with_pad or idx != self.pad_index]

        return " ".join(words) if join else words

    @staticmethod
    def load_vocab(vocab_path: str) -> 'WordVocab':
        with open(vocab_path, "rb") as f:
            return pickle.load(f)


def build(corpus_path, output_path, vocab_size=None, encoding='utf-8', min_freq=1):
    with open(corpus_path, "r", encoding=encoding) as f:
        vocab = WordVocab(f, max_size=vocab_size, min_freq=min_freq)

    print("VOCAB SIZE:", len(vocab))
    vocab.save_vocab(output_path)

Dataloaderを定義する.<br>
ここで文章中の単語をMASKする処理を行う<br>
Windowsでの並列化処理のために外部ファイル(dataset.py)で定義する  
MACやLinuxならこのipynbに直接定義しても良いはず

Trainerクラスを定義する.  
Masked Language Model : 文章中の一部の単語をマスクして,予測を行うタスク.  
- `save_bert`はbert本体のみを保存する
- `save_pretrain`は事前学習用のHeadも含んだ全体を保存する
- `load_bert`と`load_pretrain`も同様

In [13]:
from torch.utils.data import DataLoader
class ESDTrainer:
    def __init__(self, bert: BERT, train_dataloader: DataLoader, 
                 valid_dataloader: DataLoader = None,
                 lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01,
                 with_cuda: bool = True, log_freq: int = 10):
        # GPU環境において、GPUを指定しているかのフラグ
        cuda_condition = torch.cuda.is_available() and with_cuda
        self.device = torch.device("cuda:0" if cuda_condition else "cpu")

        self.bert = bert
        self.model = BERTESD(bert).to(self.device)

        if torch.cuda.device_count() > 1:
            print("Using %d GPUS for BERT" % torch.cuda.device_count())
            self.model = nn.DataParallel(self.model)

        self.train_data = train_dataloader
        self.valid_data = valid_dataloader

        self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)

        # masked_token予測のためのLoss関数を設定
        self.criterion = nn.NLLLoss()
        self.log_freq = log_freq
        print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
        
        self.train_losses = []
        self.valid_losses = []
        self.train_accs = []
        self.valid_accs = []

    def train(self, epoch):
        self.iteration(epoch, self.train_data)

    def valid(self, epoch):
        self.iteration(epoch, self.valid_data, train=False)

    def iteration(self, epoch, data_loader, train=True):
        str_code = "train" if train else "valid"
        data_iter = tqdm.tqdm(enumerate(data_loader), desc="EP_%s:%d" % (str_code, epoch), total=len(data_loader), bar_format="{l_bar}{r_bar}")
        avg_loss = 0.0
        total_correct = 0
        total_element = 0

        for i, data in data_iter:
            data = {key: value.to(self.device) for key, value in data.items()}
            pr_output = self.model.forward(data["bert_input"])
            pr_loss = self.criterion(pr_output, data["label"])
            loss = pr_loss
            # training時のみ,backwardとoptimizer更新を行う
            if train:
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()
            
            correct = pr_output.argmax(dim=-1).eq(data["label"]).sum().item()
            avg_loss += loss.item()
            total_correct += correct
            total_element += data["label"].nelement()
            post_fix = {
                "epoch": epoch,
                "iter": i,
                "avg_loss": avg_loss / (i + 1),
                "avg_acc": total_correct / total_element * 100,
                "loss": loss.item()
            }

            if i % self.log_freq == 0:
                data_iter.write(str(post_fix))

        print("EP%d_%s, avg_loss=" % (epoch, str_code), avg_loss / len(data_iter),  "total_acc=", total_correct * 100.0 / total_element)
        if train:
            self.train_losses.append(avg_loss / len(data_iter))
            self.train_accs.append(total_correct * 100.0 / total_element)
        else:
            self.valid_losses.append(avg_loss / len(data_iter))
            self.valid_accs.append(total_correct * 100.0 / total_element)
        
    def save_bert(self, epoch, file_path=output_dir + "bert_model"):
        output_path = file_path + ".ep%d" % epoch
        torch.save(self.bert.state_dict(), output_path)
        self.bert.to(self.device)
        print("EP:%d Model Saved on:" % epoch, output_path)
        return output_path
    
    def save_pretrain(self, epoch, file_path=output_dir + "mlm_model"):
        state = {
            'epoch' : epoch,
            'state_dict' : self.model.module.state_dict(),
            'optimizer' : self.optim.state_dict()
        }
        output_path = file_path + ".ep%d" % epoch
        torch.save(state, output_path)
        self.bert.to(self.device)
        print("EP:%d Model Saved on:" % epoch, output_path)
        return output_path
 
    def load_bert(self, file_path):
        self.bert.state_dict(torch.load(file_path))
        
    def load_pretrain(self, file_path):
        state = torch.load(file_path)
        #state = self.fix_model_state_dict(state)
        self.model.module.load_state_dict(state['state_dict'])
        self.optim.load_state_dict(state['optimizer'])
        return state

In [14]:
import datetime
dt_now = str(datetime.datetime.now()).replace(' ', '')
dt_now = dt_now.replace(':','_')

* 最初から学習する場合：`train_mode = 'first'`  
* 途中から学習を再開する場合：`train_mode = 'continue'`  
    - 保存してある学習済みモデルのパスを指定する
    - vocabファイルのパスを指定する

In [15]:
class EPDTrainer:
    def __init__(self, bert: BERT, train_dataloader: DataLoader,
                 valid_dataloader: DataLoader = None,
                 lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01,
                 with_cuda: bool = True, log_freq: int = 10, head = 2):
        # GPU環境において、GPUを指定しているかのフラグ
        cuda_condition = torch.cuda.is_available() and with_cuda
        self.device = torch.device("cuda:0" if cuda_condition else "cpu")

        self.bert = bert
        self.model = BERTEPD(bert, head).to(self.device)

        if torch.cuda.device_count() > 1:
            print("Using %d GPUS for BERT" % torch.cuda.device_count())
            self.model = nn.DataParallel(self.model)

        self.train_data = train_dataloader
        self.valid_data = valid_dataloader

        self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)

        # masked_token予測のためのLoss関数を設定
        self.criterion = nn.NLLLoss()
        self.log_freq = log_freq
        print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
        
        self.train_losses = []
        self.valid_losses = []
        self.train_accs = []
        self.valid_accs = []

    def train(self, epoch):
        self.iteration(epoch, self.train_data)

    def valid(self, epoch):
        self.iteration(epoch, self.valid_data, train=False)

    def iteration(self, epoch, data_loader, train=True):
        str_code = "train" if train else "valid"
        data_iter = tqdm.tqdm(enumerate(data_loader), desc="EP_%s:%d" % (str_code, epoch), total=len(data_loader), bar_format="{l_bar}{r_bar}")
        avg_loss = 0.0
        total_correct = 0
        total_element = 0

        for i, data in data_iter:
            data = {key: value.to(self.device) for key, value in data.items()}
            pr_output = self.model.forward(data["bert_input"])
            pr_loss = self.criterion(pr_output.transpose(1, 2), data["token_label"])
            loss = pr_loss
            # training時のみ,backwardとoptimizer更新を行う
            if train:
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()
            
            correct = pr_output.argmax(dim=-1).eq(data["token_label"]).sum().item()


            avg_loss += loss.item()
            total_correct += correct
            total_element += data["token_label"].nelement() * 128
            post_fix = {
                "epoch": epoch,
                "iter": i,
                "avg_loss": avg_loss / (i + 1),
                "avg_acc": total_correct / total_element * 100,
                "loss": loss.item()
            }

            if i % self.log_freq == 0:
                data_iter.write(str(post_fix))

        print("EP%d_%s, avg_loss=" % (epoch, str_code), avg_loss / len(data_iter),  "total_acc=", total_correct * 100.0 / total_element)
        if train:
            self.train_losses.append(avg_loss / len(data_iter))
            self.train_accs.append(total_correct * 100.0 / total_element)
        else:
            self.valid_losses.append(avg_loss / len(data_iter))
            self.valid_accs.append(total_correct * 100.0 / total_element)
        
    def save_bert(self, epoch, file_path=output_dir + "bert_model"):
        output_path = file_path + ".ep%d" % epoch
        torch.save(self.bert.state_dict(), output_path)
        self.bert.to(self.device)
        print("EP:%d Model Saved on:" % epoch, output_path)
        return output_path
    
    def save_pretrain(self, epoch, file_path=output_dir + "mlm_model"):
        state = {
            'epoch' : epoch,
            'state_dict' : self.model.module.state_dict(),
            'optimizer' : self.optim.state_dict()
        }
        output_path = file_path + ".ep%d" % epoch
        torch.save(state, output_path)
        self.bert.to(self.device)
        print("EP:%d Model Saved on:" % epoch, output_path)
        return output_path
 
    def load_bert(self, file_path):
        self.bert.state_dict(torch.load(file_path))
        
    def load_pretrain(self, file_path):
        state = torch.load(file_path)
        #state = self.fix_model_state_dict(state)
        self.model.module.load_state_dict(state['state_dict'])
        self.optim.load_state_dict(state['optimizer'])
        return state

In [16]:
class HybridESDTrainer:
    def __init__(self, bert: BERT, train_dataloader: DataLoader, 
                 valid_dataloader: DataLoader = None,
                 lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01,
                 with_cuda: bool = True, log_freq: int = 10, head = 9):
        # GPU環境において、GPUを指定しているかのフラグ
        cuda_condition = torch.cuda.is_available() and with_cuda
        self.device = torch.device("cuda:0" if cuda_condition else "cpu")

        self.bert = bert
        self.epd = BERTEPD(bert, head)
        self.model = BERTHybridESD(self.epd).to(self.device)

        if torch.cuda.device_count() > 1:
            print("Using %d GPUS for BERT" % torch.cuda.device_count())
            self.model = nn.DataParallel(self.model)

        self.train_data = train_dataloader
        self.valid_data = valid_dataloader

        self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)

        # masked_token予測のためのLoss関数を設定
        self.criterion = nn.NLLLoss()
        self.log_freq = log_freq
        print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
        
        self.train_losses = []
        self.valid_losses = []
        self.train_accs = []
        self.valid_accs = []

    def train(self, epoch):
        self.iteration(epoch, self.train_data)

    def valid(self, epoch):
        self.iteration(epoch, self.valid_data, train=False)

    def iteration(self, epoch, data_loader, train=True):
        str_code = "train" if train else "valid"
        data_iter = tqdm.tqdm(enumerate(data_loader), desc="EP_%s:%d" % (str_code, epoch), total=len(data_loader), bar_format="{l_bar}{r_bar}")
        avg_loss = 0.0
        total_correct = 0
        total_element = 0

        for i, data in data_iter:
            data = {key: value.to(self.device) for key, value in data.items()}
            pr_output = self.model.forward(data["bert_input"])
            pr_loss = self.criterion(pr_output, data["label"])
            loss = pr_loss
            # training時のみ,backwardとoptimizer更新を行う
            if train:
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()
            
            correct = pr_output.argmax(dim=-1).eq(data["label"]).sum().item()
            avg_loss += loss.item()
            total_correct += correct
            total_element += data["label"].nelement()
            post_fix = {
                "epoch": epoch,
                "iter": i,
                "avg_loss": avg_loss / (i + 1),
                "avg_acc": total_correct / total_element * 100,
                "loss": loss.item()
            }

            if i % self.log_freq == 0:
                data_iter.write(str(post_fix))

        print("EP%d_%s, avg_loss=" % (epoch, str_code), avg_loss / len(data_iter),  "total_acc=", total_correct * 100.0 / total_element)
        if train:
            self.train_losses.append(avg_loss / len(data_iter))
            self.train_accs.append(total_correct * 100.0 / total_element)
        else:
            self.valid_losses.append(avg_loss / len(data_iter))
            self.valid_accs.append(total_correct * 100.0 / total_element)
        
    def save_bert(self, epoch, file_path=output_dir + "bert_model"):
        output_path = file_path + ".ep%d" % epoch
        torch.save(self.bert.state_dict(), output_path)
        self.bert.to(self.device)
        print("EP:%d Model Saved on:" % epoch, output_path)
        return output_path
    
    def save_pretrain(self, epoch, file_path=output_dir + "mlm_model"):
        state = {
            'epoch' : epoch,
            'state_dict' : self.model.module.state_dict(),
            'optimizer' : self.optim.state_dict(),
            'train_acc': self.train_accs[-1],
            'valid_acc': self.valid_accs[-1],
        }
        output_path = file_path + ".ep%d" % epoch
        torch.save(state, output_path)
        self.bert.to(self.device)
        print("EP:%d Model Saved on:" % epoch, output_path)
        return output_path
 
    def load_bert(self, file_path):
        self.bert.state_dict(torch.load(file_path))
    
    def load_epd(self, file_path):
        state = torch.load(file_path)
        self.epd.load_state_dict(state['state_dict'])
    
    def load_pretrain(self, file_path):
        state = torch.load(file_path)
        #state = self.fix_model_state_dict(state)
        self.model.module.load_state_dict(state['state_dict'])
        self.optim.load_state_dict(state['optimizer'])
        return state

In [17]:
# 訓練用パラメタを定義する
vocab_path= read_dir + 'vocab.txt'


hidden=256 #768
layers=2 #12 2
attn_heads=4 #12 4
seq_len=128

batch_size=128
epochs=50
num_workers=0
with_cuda=True
log_freq=100
corpus_lines=None

lr=1e-3
adam_weight_decay=0.00
adam_beta1=0.9
adam_beta2=0.999

dropout=0.0

min_freq=7

corpus_path=processed_train_txt
label_path=None

# 最初から学習する場合は'first'、学習を続きから再開する場合は'continue'
#train_mode = 'first' 
train_mode = 'first' 

In [18]:
#if train_mode == 'first':
#build(corpus_path, vocab_path, min_freq=min_freq)


In [19]:
print("Loading Vocab", vocab_path)

vocab = WordVocab.load_vocab(vocab_path)
from dataset import ReplaceDataset, MixDataset 

print("Loading Train Dataset", processed_train_txt)
#train_dataset = ReplaceDataset(processed_train_txt, vocab, seq_len=seq_len, label_path=label_path, corpus_lines=corpus_lines)
#train_dataset = DeleteDataset(processed_train_txt, vocab, seq_len=seq_len, label_path=label_path, corpus_lines=corpus_lines)
train_dataset = MixDataset(processed_train_txt, vocab, seq_len=seq_len, label_path=label_path, corpus_lines=corpus_lines)

print("Loading Valid Dataset", processed_valid_txt)
#valid_dataset = ReplaceDataset(processed_valid_txt, vocab, seq_len=seq_len, label_path=label_path) if processed_valid_txt is not None else None
#valid_dataset = DeleteDataset(processed_valid_txt, vocab, seq_len=seq_len, label_path=label_path) if processed_valid_txt is not None else None
valid_dataset = MixDataset(processed_valid_txt, vocab, seq_len=seq_len, label_path=label_path) if processed_valid_txt is not None else None


print("Creating Dataloader")
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers)
valid_data_loader = DataLoader(valid_dataset, batch_size=batch_size, num_workers=num_workers) if processed_valid_txt is not None else None

Loading Vocab ./data/merged_bert/vocab.txt
Loading Train Dataset ./data/merged_bert/train_X.txt
Loading Valid Dataset ./data/merged_bert/valid_X.txt
Creating Dataloader


In [20]:
print("Building BERT model")
bert = BERT(len(vocab), hidden=hidden, n_layers=layers, attn_heads=attn_heads, dropout=dropout)
#trainer = ESDTrainer(bert, train_dataloader=train_data_loader, valid_dataloader=valid_data_loader,lr=lr, betas=(adam_beta1, adam_beta2), weight_decay=adam_weight_decay,with_cuda=with_cuda, log_freq=log_freq)
trainer = EPDTrainer(bert, train_dataloader=train_data_loader, valid_dataloader=valid_data_loader,lr=lr, betas=(adam_beta1, adam_beta2), weight_decay=adam_weight_decay,with_cuda=with_cuda, log_freq=log_freq, head=9)
#trainer = HybridESDTrainer(bert, train_dataloader=train_data_loader, valid_dataloader=valid_data_loader,lr=lr, betas=(adam_beta1, adam_beta2), weight_decay=adam_weight_decay,with_cuda=with_cuda, log_freq=log_freq, head=9)


output_model_path = output_dir + 'mix/epd/small/small'

if train_mode == 'first':
    trainer.load_bert(output_dir + 'bert/mid/bert.ep16')
    #trainer.load_epd(output_dir + 'mix/epd/small/small.ep34')
    trained_epoch = 0

else:
    state = trainer.load_pretrain(output_model_path + '.ep34')
    trained_epoch = state['epoch'] +1
    print(trained_epoch)

Building BERT model
Using 2 GPUS for BERT
Total Parameters: 2761225


In [21]:
print("Training Start")
for epoch in range(trained_epoch, epochs):
    epoch_start = time.time()
    trainer.train(epoch)
    # Model Save
    trainer.save_pretrain(epoch, output_model_path
                         )
    trainer.valid(epoch)
    print(time.time() - epoch_start)

Training Start


                                           

{'epoch': 0, 'iter': 0, 'avg_loss': 2.3968453407287598, 'avg_acc': 0.031757354736328125, 'loss': 2.3968453407287598}


                                                      

{'epoch': 0, 'iter': 100, 'avg_loss': 0.20024701377542892, 'avg_acc': 0.7458346905094562, 'loss': 0.16066795587539673}


                                                       

{'epoch': 0, 'iter': 200, 'avg_loss': 0.17212165532568793, 'avg_acc': 0.750356408494029, 'loss': 0.13187795877456665}


                                                     

{'epoch': 0, 'iter': 300, 'avg_loss': 0.16212434648378346, 'avg_acc': 0.7518560783411578, 'loss': 0.124681256711483}


                                                     

{'epoch': 0, 'iter': 400, 'avg_loss': 0.15698987262281694, 'avg_acc': 0.7525687800381249, 'loss': 0.13871435821056366}


                                                     

{'epoch': 0, 'iter': 500, 'avg_loss': 0.15335853090007862, 'avg_acc': 0.7530554088051924, 'loss': 0.13516204059123993}


                                                     

{'epoch': 0, 'iter': 600, 'avg_loss': 0.15080905173860254, 'avg_acc': 0.7533420143825639, 'loss': 0.13195250928401947}


                                                     

{'epoch': 0, 'iter': 700, 'avg_loss': 0.14861334003667517, 'avg_acc': 0.7535550121574021, 'loss': 0.14707638323307037}


                                                     

{'epoch': 0, 'iter': 800, 'avg_loss': 0.14677570079373065, 'avg_acc': 0.7537069689766149, 'loss': 0.1366102397441864}


                                                     

{'epoch': 0, 'iter': 900, 'avg_loss': 0.1450824712575739, 'avg_acc': 0.7538426597163362, 'loss': 0.1373373121023178}


                                                     

{'epoch': 0, 'iter': 1000, 'avg_loss': 0.14364935608057827, 'avg_acc': 0.7539282311926355, 'loss': 0.11972714960575104}


                                                      

{'epoch': 0, 'iter': 1100, 'avg_loss': 0.1421022433093956, 'avg_acc': 0.7540569426686411, 'loss': 0.12282160669565201}


                                                      

{'epoch': 0, 'iter': 1200, 'avg_loss': 0.14094105402595494, 'avg_acc': 0.7541144718039938, 'loss': 0.11693091690540314}


                                                      

{'epoch': 0, 'iter': 1300, 'avg_loss': 0.1398501885420629, 'avg_acc': 0.7541772679674177, 'loss': 0.12487543374300003}


                                                      

{'epoch': 0, 'iter': 1400, 'avg_loss': 0.1388213958147166, 'avg_acc': 0.7542463475512573, 'loss': 0.11512469500303268}


                                                        

{'epoch': 0, 'iter': 1500, 'avg_loss': 0.138078941857513, 'avg_acc': 0.7542694353564908, 'loss': 0.11803661286830902}


                                                      

{'epoch': 0, 'iter': 1600, 'avg_loss': 0.13724823778659237, 'avg_acc': 0.7543163847580766, 'loss': 0.13336512446403503}


                                                      

{'epoch': 0, 'iter': 1700, 'avg_loss': 0.13661392278824744, 'avg_acc': 0.7543335936477366, 'loss': 0.10741885006427765}


                                                      

{'epoch': 0, 'iter': 1800, 'avg_loss': 0.1360374567476926, 'avg_acc': 0.7543504271223967, 'loss': 0.13723847270011902}


                                                      

{'epoch': 0, 'iter': 1900, 'avg_loss': 0.13536472197255606, 'avg_acc': 0.754399226879459, 'loss': 0.1225067526102066}


                                                      

{'epoch': 0, 'iter': 2000, 'avg_loss': 0.1348439316483571, 'avg_acc': 0.7544179370199544, 'loss': 0.13906340301036835}


                                                      

{'epoch': 0, 'iter': 2100, 'avg_loss': 0.1341990492085455, 'avg_acc': 0.7544681834130558, 'loss': 0.11312884837388992}


                                                      

{'epoch': 0, 'iter': 2200, 'avg_loss': 0.13375572026832058, 'avg_acc': 0.7544801539586166, 'loss': 0.11215327680110931}


                                                      

{'epoch': 0, 'iter': 2300, 'avg_loss': 0.1333208946864898, 'avg_acc': 0.7544997877164904, 'loss': 0.13544219732284546}


                                                      

{'epoch': 0, 'iter': 2400, 'avg_loss': 0.1328808999915562, 'avg_acc': 0.7545295430937692, 'loss': 0.13052555918693542}


                                                      

{'epoch': 0, 'iter': 2500, 'avg_loss': 0.13247675590160513, 'avg_acc': 0.7545487016070037, 'loss': 0.11861809343099594}


                                                      

{'epoch': 0, 'iter': 2600, 'avg_loss': 0.13207969513765805, 'avg_acc': 0.7545692652108714, 'loss': 0.11904822289943695}


                                                      

{'epoch': 0, 'iter': 2700, 'avg_loss': 0.1316305161796866, 'avg_acc': 0.7546075138062383, 'loss': 0.12410170584917068}


                                                      

{'epoch': 0, 'iter': 2800, 'avg_loss': 0.13122643441525753, 'avg_acc': 0.7546381625223484, 'loss': 0.11399337649345398}


                                                      

{'epoch': 0, 'iter': 2900, 'avg_loss': 0.13088681586587564, 'avg_acc': 0.7546606494491982, 'loss': 0.11792612820863724}


                                                      

{'epoch': 0, 'iter': 3000, 'avg_loss': 0.13060081011551453, 'avg_acc': 0.7546687833232428, 'loss': 0.12652848660945892}


                                                      

{'epoch': 0, 'iter': 3100, 'avg_loss': 0.13027382789072164, 'avg_acc': 0.7546861876752522, 'loss': 0.1069667711853981}


                                                      

{'epoch': 0, 'iter': 3200, 'avg_loss': 0.12993182786104643, 'avg_acc': 0.7547074353385813, 'loss': 0.11970771849155426}


                                                      

{'epoch': 0, 'iter': 3300, 'avg_loss': 0.12963794164813242, 'avg_acc': 0.7547240010249691, 'loss': 0.12228089570999146}


                                                      

{'epoch': 0, 'iter': 3400, 'avg_loss': 0.1293012507411004, 'avg_acc': 0.7547505846166569, 'loss': 0.11957265436649323}


                                                      

{'epoch': 0, 'iter': 3500, 'avg_loss': 0.12905352180879548, 'avg_acc': 0.7547608037500237, 'loss': 0.0859709233045578}


                                                      

{'epoch': 0, 'iter': 3600, 'avg_loss': 0.1287773732686665, 'avg_acc': 0.7547789168205569, 'loss': 0.1309727281332016}


                                                      

{'epoch': 0, 'iter': 3700, 'avg_loss': 0.12850879207837262, 'avg_acc': 0.7547979707880235, 'loss': 0.12619277834892273}


                                                      

{'epoch': 0, 'iter': 3800, 'avg_loss': 0.12823687278183157, 'avg_acc': 0.7548195473365613, 'loss': 0.11586001515388489}


                                                      

{'epoch': 0, 'iter': 3900, 'avg_loss': 0.12798001199303696, 'avg_acc': 0.7548369129203399, 'loss': 0.1187761053442955}


                                                      

{'epoch': 0, 'iter': 4000, 'avg_loss': 0.127708505575611, 'avg_acc': 0.7548618721860673, 'loss': 0.12080484628677368}


                                                      

{'epoch': 0, 'iter': 4100, 'avg_loss': 0.12742567695704532, 'avg_acc': 0.7548899744755174, 'loss': 0.11934375017881393}


                                                      

{'epoch': 0, 'iter': 4200, 'avg_loss': 0.1271878767122799, 'avg_acc': 0.7549086572782847, 'loss': 0.12944120168685913}


                                                      

{'epoch': 0, 'iter': 4300, 'avg_loss': 0.1269489478227206, 'avg_acc': 0.7549270256484173, 'loss': 0.11861283332109451}


                                                      

{'epoch': 0, 'iter': 4400, 'avg_loss': 0.12671065752721866, 'avg_acc': 0.7549509409475641, 'loss': 0.1049809679389}


                                                      

{'epoch': 0, 'iter': 4500, 'avg_loss': 0.1264722993644151, 'avg_acc': 0.7549744398163255, 'loss': 0.1219373270869255}


                                                      

{'epoch': 0, 'iter': 4600, 'avg_loss': 0.12625887726531343, 'avg_acc': 0.7549934557172895, 'loss': 0.1244179829955101}


                                                      

{'epoch': 0, 'iter': 4700, 'avg_loss': 0.12605279321126345, 'avg_acc': 0.7550122712019073, 'loss': 0.13258056342601776}


                                                      

{'epoch': 0, 'iter': 4800, 'avg_loss': 0.12585542111397038, 'avg_acc': 0.7550324283274877, 'loss': 0.11600080132484436}


                                                      

{'epoch': 0, 'iter': 4900, 'avg_loss': 0.1256197716609796, 'avg_acc': 0.75505944909526, 'loss': 0.09763708710670471}


                                                      

{'epoch': 0, 'iter': 5000, 'avg_loss': 0.12540729915981316, 'avg_acc': 0.7550833964676792, 'loss': 0.11712465435266495}


                                                      

{'epoch': 0, 'iter': 5100, 'avg_loss': 0.12520363502287674, 'avg_acc': 0.7551062646928288, 'loss': 0.11308625340461731}


                                                      

{'epoch': 0, 'iter': 5200, 'avg_loss': 0.12497560177524832, 'avg_acc': 0.7551331218447371, 'loss': 0.11620582640171051}


                                                      

{'epoch': 0, 'iter': 5300, 'avg_loss': 0.12475565485719509, 'avg_acc': 0.7551628066737389, 'loss': 0.11228800565004349}


                                                      

{'epoch': 0, 'iter': 5400, 'avg_loss': 0.12452208688194676, 'avg_acc': 0.755194561765318, 'loss': 0.12277583032846451}


                                                      

{'epoch': 0, 'iter': 5500, 'avg_loss': 0.12430577484087996, 'avg_acc': 0.7552240354714448, 'loss': 0.10757356882095337}


                                                      

{'epoch': 0, 'iter': 5600, 'avg_loss': 0.12408934673136461, 'avg_acc': 0.7552535805057743, 'loss': 0.11079442501068115}


                                                      

{'epoch': 0, 'iter': 5700, 'avg_loss': 0.1238796689208297, 'avg_acc': 0.7552825908998798, 'loss': 0.11210567504167557}


                                                      

{'epoch': 0, 'iter': 5800, 'avg_loss': 0.1236773318675464, 'avg_acc': 0.7553118751945423, 'loss': 0.11918868124485016}


                                                      

{'epoch': 0, 'iter': 5900, 'avg_loss': 0.12345157178468248, 'avg_acc': 0.7553486920194976, 'loss': 0.11203055083751678}


                                                      

{'epoch': 0, 'iter': 6000, 'avg_loss': 0.12326451731044001, 'avg_acc': 0.7553758432101934, 'loss': 0.1166946068406105}


                                                      

{'epoch': 0, 'iter': 6100, 'avg_loss': 0.12305527556931224, 'avg_acc': 0.7554092791760052, 'loss': 0.10127455741167068}


                                                      

{'epoch': 0, 'iter': 6200, 'avg_loss': 0.12285890975818964, 'avg_acc': 0.7554394451769451, 'loss': 0.11115330457687378}


                                                      

{'epoch': 0, 'iter': 6300, 'avg_loss': 0.12268963573925307, 'avg_acc': 0.7554657779748697, 'loss': 0.1121760681271553}


                                                      

{'epoch': 0, 'iter': 6400, 'avg_loss': 0.12250809488101334, 'avg_acc': 0.7554952808887283, 'loss': 0.12331507354974747}


                                                      

{'epoch': 0, 'iter': 6500, 'avg_loss': 0.12229372759362107, 'avg_acc': 0.7555321938550357, 'loss': 0.11576683819293976}


                                                      

{'epoch': 0, 'iter': 6600, 'avg_loss': 0.1221098457310005, 'avg_acc': 0.7555638131111467, 'loss': 0.11332907527685165}


                                                      

{'epoch': 0, 'iter': 6700, 'avg_loss': 0.12191598674060337, 'avg_acc': 0.7555969436388197, 'loss': 0.11259106546640396}


                                                      

{'epoch': 0, 'iter': 6800, 'avg_loss': 0.12173336772682246, 'avg_acc': 0.7556275714234278, 'loss': 0.101250559091568}


                                                      

{'epoch': 0, 'iter': 6900, 'avg_loss': 0.12154051303915347, 'avg_acc': 0.7556632815400961, 'loss': 0.09096652269363403}


                                                      

{'epoch': 0, 'iter': 7000, 'avg_loss': 0.12133532890552726, 'avg_acc': 0.755698468715459, 'loss': 0.10560303926467896}


                                                      

{'epoch': 0, 'iter': 7100, 'avg_loss': 0.12120294191941058, 'avg_acc': 0.7557195838849791, 'loss': 0.11418378353118896}


                                                      

{'epoch': 0, 'iter': 7200, 'avg_loss': 0.12099243682501233, 'avg_acc': 0.7557605541575172, 'loss': 0.10313649475574493}


                                                      

{'epoch': 0, 'iter': 7300, 'avg_loss': 0.12080272215696218, 'avg_acc': 0.7557968883673764, 'loss': 0.10322147607803345}


                                                      

{'epoch': 0, 'iter': 7400, 'avg_loss': 0.12060899803194222, 'avg_acc': 0.7558345214819009, 'loss': 0.10917773097753525}


                                                      

{'epoch': 0, 'iter': 7500, 'avg_loss': 0.12042488446395852, 'avg_acc': 0.7558694729374625, 'loss': 0.11956381052732468}


                                                      

{'epoch': 0, 'iter': 7600, 'avg_loss': 0.12024274603494954, 'avg_acc': 0.7559054996634514, 'loss': 0.09100145101547241}


                                                      

{'epoch': 0, 'iter': 7700, 'avg_loss': 0.12007159875663995, 'avg_acc': 0.7559382254512972, 'loss': 0.1070605143904686}


                                                      

{'epoch': 0, 'iter': 7800, 'avg_loss': 0.1198866978706022, 'avg_acc': 0.7559760169117379, 'loss': 0.10087554156780243}


                                                      

{'epoch': 0, 'iter': 7900, 'avg_loss': 0.11970206975103284, 'avg_acc': 0.7560132017861045, 'loss': 0.10921834409236908}


                                                      

{'epoch': 0, 'iter': 8000, 'avg_loss': 0.11952022375922369, 'avg_acc': 0.7560496299866661, 'loss': 0.10029669851064682}


                                                      

{'epoch': 0, 'iter': 8100, 'avg_loss': 0.11936504572148, 'avg_acc': 0.7560801144044145, 'loss': 0.10553261637687683}


                                                      

{'epoch': 0, 'iter': 8200, 'avg_loss': 0.11920752830566277, 'avg_acc': 0.7561113380568883, 'loss': 0.08677563071250916}


                                                      

{'epoch': 0, 'iter': 8300, 'avg_loss': 0.1190501819937477, 'avg_acc': 0.7561436476096836, 'loss': 0.10695136338472366}


                                                      

{'epoch': 0, 'iter': 8400, 'avg_loss': 0.11889963208610935, 'avg_acc': 0.7561744784843409, 'loss': 0.1002187728881836}


                                                      

{'epoch': 0, 'iter': 8500, 'avg_loss': 0.11872146926481097, 'avg_acc': 0.7562111916366486, 'loss': 0.10920249670743942}


                                                      

{'epoch': 0, 'iter': 8600, 'avg_loss': 0.11854013040230421, 'avg_acc': 0.756250028207737, 'loss': 0.10404020547866821}


                                                      

{'epoch': 0, 'iter': 8700, 'avg_loss': 0.11836883545850094, 'avg_acc': 0.7562860868783015, 'loss': 0.10145155340433121}


                                                      

{'epoch': 0, 'iter': 8800, 'avg_loss': 0.1181897678792497, 'avg_acc': 0.7563258664010127, 'loss': 0.08674938231706619}


                                                      

{'epoch': 0, 'iter': 8900, 'avg_loss': 0.11803195735448238, 'avg_acc': 0.7563597324818336, 'loss': 0.1118120327591896}


                                                      

{'epoch': 0, 'iter': 9000, 'avg_loss': 0.11787625047852232, 'avg_acc': 0.7563936565998011, 'loss': 0.09754858911037445}


                                                      

{'epoch': 0, 'iter': 9100, 'avg_loss': 0.11770932570406008, 'avg_acc': 0.7564304608746474, 'loss': 0.09700588881969452}


                                                      

{'epoch': 0, 'iter': 9200, 'avg_loss': 0.11753673793430394, 'avg_acc': 0.7564682219938252, 'loss': 0.09833085536956787}


                                                      

{'epoch': 0, 'iter': 9300, 'avg_loss': 0.11737763172912925, 'avg_acc': 0.7565037971695909, 'loss': 0.10200539976358414}


                                                      

{'epoch': 0, 'iter': 9400, 'avg_loss': 0.1172204043650016, 'avg_acc': 0.7565384937745309, 'loss': 0.0940345972776413}


                                                      

{'epoch': 0, 'iter': 9500, 'avg_loss': 0.11705346751568782, 'avg_acc': 0.7565747234849174, 'loss': 0.11192771792411804}


                                                      

{'epoch': 0, 'iter': 9600, 'avg_loss': 0.11690064422097905, 'avg_acc': 0.7566089270549718, 'loss': 0.10689074546098709}


                                                      

{'epoch': 0, 'iter': 9700, 'avg_loss': 0.11674731566468265, 'avg_acc': 0.7566426466597858, 'loss': 0.10677316784858704}


                                                      

{'epoch': 0, 'iter': 9800, 'avg_loss': 0.11659220670856464, 'avg_acc': 0.7566773761305173, 'loss': 0.0922844260931015}


                                                      

{'epoch': 0, 'iter': 9900, 'avg_loss': 0.11644945448455805, 'avg_acc': 0.7567079943029005, 'loss': 0.10342154651880264}


                                                      

{'epoch': 0, 'iter': 10000, 'avg_loss': 0.11630429293963686, 'avg_acc': 0.7567399550099312, 'loss': 0.11023253202438354}


                                                       

{'epoch': 0, 'iter': 10100, 'avg_loss': 0.11615349675874735, 'avg_acc': 0.7567736243578529, 'loss': 0.08974181115627289}


                                                       

{'epoch': 0, 'iter': 10200, 'avg_loss': 0.11599181823482842, 'avg_acc': 0.7568079284003828, 'loss': 0.10534709692001343}


                                                       

{'epoch': 0, 'iter': 10300, 'avg_loss': 0.11583030361054002, 'avg_acc': 0.7568444271548699, 'loss': 0.09615160524845123}


                                                       

{'epoch': 0, 'iter': 10400, 'avg_loss': 0.11568423802492057, 'avg_acc': 0.7568757037293072, 'loss': 0.09323802590370178}


                                                       

{'epoch': 0, 'iter': 10500, 'avg_loss': 0.11551734767077299, 'avg_acc': 0.7569132958261323, 'loss': 0.11683249473571777}


                                                       

{'epoch': 0, 'iter': 10600, 'avg_loss': 0.11537326367092451, 'avg_acc': 0.7569457526342361, 'loss': 0.09174921363592148}


                                                       

{'epoch': 0, 'iter': 10700, 'avg_loss': 0.11525502371424304, 'avg_acc': 0.7569715872221757, 'loss': 0.09802934527397156}


                                                       

{'epoch': 0, 'iter': 10800, 'avg_loss': 0.11510166549845963, 'avg_acc': 0.7570058921345648, 'loss': 0.09211859107017517}


                                                       

{'epoch': 0, 'iter': 10900, 'avg_loss': 0.11494582238060412, 'avg_acc': 0.7570405956058697, 'loss': 0.08601070195436478}


                                                       

{'epoch': 0, 'iter': 11000, 'avg_loss': 0.11480809778378731, 'avg_acc': 0.7570714259643943, 'loss': 0.10339472442865372}


                                                       

{'epoch': 0, 'iter': 11100, 'avg_loss': 0.11468657595573185, 'avg_acc': 0.7570986940607103, 'loss': 0.08861679583787918}


                                                       

{'epoch': 0, 'iter': 11200, 'avg_loss': 0.11456527933445097, 'avg_acc': 0.7571258839512816, 'loss': 0.09444498270750046}


                                                       

{'epoch': 0, 'iter': 11300, 'avg_loss': 0.11442606024916811, 'avg_acc': 0.7571576601762116, 'loss': 0.08340657502412796}


                                                       

{'epoch': 0, 'iter': 11400, 'avg_loss': 0.11428739866798107, 'avg_acc': 0.757187879375173, 'loss': 0.10274304449558258}


                                                       

{'epoch': 0, 'iter': 11500, 'avg_loss': 0.11414679918338108, 'avg_acc': 0.7572192895330561, 'loss': 0.11727585643529892}


                                                       

{'epoch': 0, 'iter': 11600, 'avg_loss': 0.11401557072171177, 'avg_acc': 0.7572487688981174, 'loss': 0.10514870285987854}


                                                       

{'epoch': 0, 'iter': 11700, 'avg_loss': 0.11388999420909757, 'avg_acc': 0.7572771127321162, 'loss': 0.10918369889259338}


                                                       

{'epoch': 0, 'iter': 11800, 'avg_loss': 0.11375800098059956, 'avg_acc': 0.7573068874305148, 'loss': 0.11260082572698593}


                                                       

{'epoch': 0, 'iter': 11900, 'avg_loss': 0.11362225351293752, 'avg_acc': 0.7573376602610062, 'loss': 0.09459137916564941}


                                                       

{'epoch': 0, 'iter': 12000, 'avg_loss': 0.11350485602772918, 'avg_acc': 0.7573636966232021, 'loss': 0.10649165511131287}


                                                       

{'epoch': 0, 'iter': 12100, 'avg_loss': 0.11337445522894417, 'avg_acc': 0.757392202859161, 'loss': 0.09437194466590881}


                                                       

{'epoch': 0, 'iter': 12200, 'avg_loss': 0.11326569005689428, 'avg_acc': 0.7574155324600279, 'loss': 0.0935046523809433}


                                                       

{'epoch': 0, 'iter': 12300, 'avg_loss': 0.1131432636465089, 'avg_acc': 0.7574421420793942, 'loss': 0.0965500995516777}


                                                       

{'epoch': 0, 'iter': 12400, 'avg_loss': 0.11301784383360008, 'avg_acc': 0.7574700797799037, 'loss': 0.09984453022480011}


                                                       

{'epoch': 0, 'iter': 12500, 'avg_loss': 0.11290388936131451, 'avg_acc': 0.7574955679571025, 'loss': 0.09662211686372757}


                                                       

{'epoch': 0, 'iter': 12600, 'avg_loss': 0.11278307565374324, 'avg_acc': 0.7575228009732146, 'loss': 0.10882196575403214}


                                                       

{'epoch': 0, 'iter': 12700, 'avg_loss': 0.11266755509138127, 'avg_acc': 0.7575484150348241, 'loss': 0.0973634123802185}


                                                       

{'epoch': 0, 'iter': 12800, 'avg_loss': 0.1125414200076488, 'avg_acc': 0.7575774358572527, 'loss': 0.1076808050274849}


                                                       

{'epoch': 0, 'iter': 12900, 'avg_loss': 0.11241548048642318, 'avg_acc': 0.7576051899355776, 'loss': 0.0847194567322731}


                                                       

{'epoch': 0, 'iter': 13000, 'avg_loss': 0.11229828997968518, 'avg_acc': 0.7576312443704093, 'loss': 0.09761401265859604}


                                                       

{'epoch': 0, 'iter': 13100, 'avg_loss': 0.11219066897211398, 'avg_acc': 0.7576552741118637, 'loss': 0.08994071185588837}


                                                       

{'epoch': 0, 'iter': 13200, 'avg_loss': 0.11208979612173037, 'avg_acc': 0.7576767255587882, 'loss': 0.09575797617435455}


                                                       

{'epoch': 0, 'iter': 13300, 'avg_loss': 0.11197325915185574, 'avg_acc': 0.7577036119172398, 'loss': 0.0928146094083786}


                                                       

{'epoch': 0, 'iter': 13400, 'avg_loss': 0.11186156308036732, 'avg_acc': 0.7577290508998134, 'loss': 0.0972261130809784}


                                                       

{'epoch': 0, 'iter': 13500, 'avg_loss': 0.111749999159469, 'avg_acc': 0.7577536185753185, 'loss': 0.09467682987451553}


                                                       

{'epoch': 0, 'iter': 13600, 'avg_loss': 0.11163497955082498, 'avg_acc': 0.7577794201715353, 'loss': 0.08338374644517899}


                                                       

{'epoch': 0, 'iter': 13700, 'avg_loss': 0.11152381820917025, 'avg_acc': 0.7578033381557387, 'loss': 0.09637077897787094}


                                                       

{'epoch': 0, 'iter': 13800, 'avg_loss': 0.1114223650104973, 'avg_acc': 0.7578264672759132, 'loss': 0.09813173115253448}


                                                       

{'epoch': 0, 'iter': 13900, 'avg_loss': 0.11131860079917365, 'avg_acc': 0.7578486976379303, 'loss': 0.10534453392028809}


                                                       

{'epoch': 0, 'iter': 14000, 'avg_loss': 0.11122870429978865, 'avg_acc': 0.7578687577254091, 'loss': 0.10323213040828705}


                                                       

{'epoch': 0, 'iter': 14100, 'avg_loss': 0.11112240925204052, 'avg_acc': 0.7578918776793832, 'loss': 0.09827148914337158}


                                                       

{'epoch': 0, 'iter': 14200, 'avg_loss': 0.1110200745611491, 'avg_acc': 0.7579142355128129, 'loss': 0.10137403756380081}


                                                       

{'epoch': 0, 'iter': 14300, 'avg_loss': 0.11092056083269498, 'avg_acc': 0.7579363040111838, 'loss': 0.10083679854869843}


                                                       

{'epoch': 0, 'iter': 14400, 'avg_loss': 0.11080247790302206, 'avg_acc': 0.7579616089428757, 'loss': 0.08496630191802979}


                                                       

{'epoch': 0, 'iter': 14500, 'avg_loss': 0.11071457087574363, 'avg_acc': 0.7579806130324205, 'loss': 0.09495709091424942}


                                                       

{'epoch': 0, 'iter': 14600, 'avg_loss': 0.11060941815539571, 'avg_acc': 0.7580035076215818, 'loss': 0.09884414821863174}


                                                       

{'epoch': 0, 'iter': 14700, 'avg_loss': 0.1105194371493931, 'avg_acc': 0.7580233888474949, 'loss': 0.09674839675426483}


                                                       

{'epoch': 0, 'iter': 14800, 'avg_loss': 0.11042072824314796, 'avg_acc': 0.758045952462335, 'loss': 0.0900091677904129}


                                                       

{'epoch': 0, 'iter': 14900, 'avg_loss': 0.11033307669652916, 'avg_acc': 0.7580652723983745, 'loss': 0.08418430387973785}


                                                       

{'epoch': 0, 'iter': 15000, 'avg_loss': 0.11024309660794615, 'avg_acc': 0.7580850753901156, 'loss': 0.09867922961711884}


                                                       

{'epoch': 0, 'iter': 15100, 'avg_loss': 0.11014778841661622, 'avg_acc': 0.7581059833716323, 'loss': 0.07304386049509048}


                                                       

{'epoch': 0, 'iter': 15200, 'avg_loss': 0.11005633227435486, 'avg_acc': 0.7581264907910191, 'loss': 0.0846879854798317}


                                                       

{'epoch': 0, 'iter': 15300, 'avg_loss': 0.10996203406847627, 'avg_acc': 0.7581473316182405, 'loss': 0.09135600924491882}


                                                       

{'epoch': 0, 'iter': 15400, 'avg_loss': 0.1098745075454728, 'avg_acc': 0.7581666881144414, 'loss': 0.09456709772348404}


                                                       

{'epoch': 0, 'iter': 15500, 'avg_loss': 0.10978282045177179, 'avg_acc': 0.758187849747274, 'loss': 0.08565270900726318}


                                                       

{'epoch': 0, 'iter': 15600, 'avg_loss': 0.10970133048346788, 'avg_acc': 0.7582054482951439, 'loss': 0.09758547693490982}


                                                       

{'epoch': 0, 'iter': 15700, 'avg_loss': 0.10961079386070453, 'avg_acc': 0.7582259902482761, 'loss': 0.0987098217010498}


                                                       

{'epoch': 0, 'iter': 15800, 'avg_loss': 0.10953018119343323, 'avg_acc': 0.7582439998155696, 'loss': 0.10073285549879074}


                                                       

{'epoch': 0, 'iter': 15900, 'avg_loss': 0.10944371500912496, 'avg_acc': 0.7582632942504384, 'loss': 0.09407927095890045}


                                                       

{'epoch': 0, 'iter': 16000, 'avg_loss': 0.10935889492164694, 'avg_acc': 0.7582822819589324, 'loss': 0.08067578822374344}


                                                       

{'epoch': 0, 'iter': 16100, 'avg_loss': 0.10926829784568086, 'avg_acc': 0.7583029765785425, 'loss': 0.07839294523000717}


                                                       

{'epoch': 0, 'iter': 16200, 'avg_loss': 0.10917464057456695, 'avg_acc': 0.7583237012207754, 'loss': 0.09948913007974625}


                                                       

{'epoch': 0, 'iter': 16300, 'avg_loss': 0.10908911349676363, 'avg_acc': 0.7583426592591482, 'loss': 0.08889659494161606}


                                                       

{'epoch': 0, 'iter': 16400, 'avg_loss': 0.10901941580502669, 'avg_acc': 0.7583583653623698, 'loss': 0.1064402312040329}


                                                       

{'epoch': 0, 'iter': 16500, 'avg_loss': 0.10895981950406905, 'avg_acc': 0.7583705405525479, 'loss': 0.10431725531816483}


                                                       

{'epoch': 0, 'iter': 16600, 'avg_loss': 0.10887611068955287, 'avg_acc': 0.7583891294870123, 'loss': 0.10420001298189163}


                                                       

{'epoch': 0, 'iter': 16700, 'avg_loss': 0.1087947601416669, 'avg_acc': 0.7584073216492141, 'loss': 0.1055116280913353}


                                                       

{'epoch': 0, 'iter': 16800, 'avg_loss': 0.10870336380365185, 'avg_acc': 0.7584275933122586, 'loss': 0.08522643893957138}


                                                       

{'epoch': 0, 'iter': 16900, 'avg_loss': 0.10863295552404482, 'avg_acc': 0.7584423717257045, 'loss': 0.09842368215322495}


                                                       

{'epoch': 0, 'iter': 17000, 'avg_loss': 0.10855554723743481, 'avg_acc': 0.7584591948504336, 'loss': 0.11490710824728012}


                                                       

{'epoch': 0, 'iter': 17100, 'avg_loss': 0.10847911510796696, 'avg_acc': 0.7584761363094672, 'loss': 0.09376880526542664}


                                                       

{'epoch': 0, 'iter': 17200, 'avg_loss': 0.10839360460884162, 'avg_acc': 0.7584943223031452, 'loss': 0.0894496887922287}


                                                       

{'epoch': 0, 'iter': 17300, 'avg_loss': 0.10831845905762051, 'avg_acc': 0.7585109833947262, 'loss': 0.1039181724190712}


                                                       

{'epoch': 0, 'iter': 17400, 'avg_loss': 0.1082349288659523, 'avg_acc': 0.7585296644012748, 'loss': 0.08813481032848358}


                                                       

{'epoch': 0, 'iter': 17500, 'avg_loss': 0.10815746914839092, 'avg_acc': 0.7585473363314088, 'loss': 0.10345838963985443}


                                                       

{'epoch': 0, 'iter': 17600, 'avg_loss': 0.1080964913894511, 'avg_acc': 0.7585606462040732, 'loss': 0.11571861803531647}


                                                       

{'epoch': 0, 'iter': 17700, 'avg_loss': 0.10803521955766204, 'avg_acc': 0.7585736736928829, 'loss': 0.10280651599168777}


                                                       

{'epoch': 0, 'iter': 17800, 'avg_loss': 0.1079540190669226, 'avg_acc': 0.7585921961767915, 'loss': 0.08305365592241287}


                                                       

{'epoch': 0, 'iter': 17900, 'avg_loss': 0.10788275914034773, 'avg_acc': 0.7586077973600354, 'loss': 0.08392823487520218}


                                                       

{'epoch': 0, 'iter': 18000, 'avg_loss': 0.10781509218335278, 'avg_acc': 0.7586228755452714, 'loss': 0.0814346894621849}


                                                       

{'epoch': 0, 'iter': 18100, 'avg_loss': 0.10773492130482239, 'avg_acc': 0.758640763904632, 'loss': 0.09252245724201202}


                                                       

{'epoch': 0, 'iter': 18200, 'avg_loss': 0.10766583703129601, 'avg_acc': 0.7586554638416346, 'loss': 0.09237112104892731}


                                                       

{'epoch': 0, 'iter': 18300, 'avg_loss': 0.10758793176858714, 'avg_acc': 0.7586730125139179, 'loss': 0.10521000623703003}


                                                       

{'epoch': 0, 'iter': 18400, 'avg_loss': 0.10751427626444279, 'avg_acc': 0.7586890462623797, 'loss': 0.08927955478429794}


                                                       

{'epoch': 0, 'iter': 18500, 'avg_loss': 0.10743970007008781, 'avg_acc': 0.7587054247315153, 'loss': 0.09835158288478851}


                                                       

{'epoch': 0, 'iter': 18600, 'avg_loss': 0.10736035656939591, 'avg_acc': 0.7587231062386705, 'loss': 0.083841972053051}


                                                       

{'epoch': 0, 'iter': 18700, 'avg_loss': 0.10728912145611717, 'avg_acc': 0.7587388902862763, 'loss': 0.09694387018680573}


                                                       

{'epoch': 0, 'iter': 18800, 'avg_loss': 0.10720997917254893, 'avg_acc': 0.7587566774428243, 'loss': 0.09749281406402588}


                                                       

{'epoch': 0, 'iter': 18900, 'avg_loss': 0.10714641912867358, 'avg_acc': 0.7587704543215681, 'loss': 0.08654481917619705}


                                                       

{'epoch': 0, 'iter': 19000, 'avg_loss': 0.10707618905963925, 'avg_acc': 0.7587858202783542, 'loss': 0.10806410759687424}


                                                       

{'epoch': 0, 'iter': 19100, 'avg_loss': 0.10700366196621588, 'avg_acc': 0.758802153715806, 'loss': 0.09326416254043579}


                                                       

{'epoch': 0, 'iter': 19200, 'avg_loss': 0.10693733733052121, 'avg_acc': 0.7588165960276626, 'loss': 0.08837232738733292}


EP_train:0:  38%|| 19247/50252 [21:38<34:31, 14.97it/s]

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

import matplotlib.ticker as ticker
import numpy as np
plt.style.use('ggplot')

def showPlot(points):

    fig, ax = plt.subplots()
    # this locator puts ticks at regular intervals
    #loc = ticker.MultipleLocator(base=0.2)
    #ax.yaxis.set_major_locator(loc)
    plt.plot(points)
    plt.show()
    
def savePlot(points, figure_path):
    plt.switch_backend('Agg')
    plt.figure()
    fig, ax = plt.subplots()
    # this locator puts ticks at regular intervals
    loc = ticker.MultipleLocator(base=0.2)
    ax.yaxis.set_major_locator(loc)
    plt.plot(points)
    plt.savefig(figure_path)

In [None]:
def save_results():
    with open(output_model_path + "results.txt","w",encoding='utf-8') as f:
        f.write("----------Parameters list----------\n")
        f.write("hidden=" +str(hidden) + "\n")
        f.write("layers="+str(layers)+"\n")
        f.write("attn_heads="+str(attn_heads)+"\n")
        f.write("seq_len="+str(seq_len)+"\n")
        f.write("batch_size="+str(batch_size)+"\n")
        f.write("epochs="+str(epochs)+"\n")
        f.write("lr="+str(lr)+"\n")
        f.write("\n")
        
        f.write("-"*10+"Train Loss"+"-"*10+"\n")
        f.write(str(trainer.train_losses[-1]))
        f.write("\n")
        
        f.write("-"*10+"Valid Loss"+"-"*10+"\n")
        f.write(str(trainer.valid_losses[-1]))
        f.write("\n")
        
        f.write("-"*10+"Train Acc"+"-"*10+"\n")
        f.write(str(trainer.train_accs[-1]))
        f.write("\n")
        
        f.write("-"*10+"Valid Acc"+"-"*10+"\n")
        f.write(str(trainer.valid_accs[-1]))
    savePlot(trainer.train_accs, output_model_path + "train_accs")
    savePlot(trainer.valid_accs, output_model_path + "valid_accs")
    savePlot(trainer.train_losses, output_model_path + "train_losses")
    savePlot(trainer.valid_losses, output_model_path + "valid_losses")
save_results()

TestデータからDataLoaderを作成

In [None]:
test_dataset = MixDataset(processed_test_txt, vocab, seq_len=seq_len, label_path=label_path, is_train=True)
test_loader = DataLoader(test_dataset, batch_size=1, num_workers=num_workers, shuffle=False)

In [None]:
def test_esd():
    trainer.model.eval()
    with torch.no_grad():
        for batch in test_loader:
            #batch = next(iter(test_loader))
            input_line = []
            text = batch['bert_input'][0]
            output = trainer.model(batch['bert_input'])
            for j in range(len(text)):
                input_line.append(vocab.itos[text[j].item()])
            #print("".join(input_line))
            print(batch['label'].item(), output.argmax(dim=-1).item())
test_esd()

In [None]:
def test_epd():
    trainer.model.eval()
    with torch.no_grad():
        for batch in test_loader:
            #batch = next(iter(test_loader))
            input_line = []
            text = batch['bert_input'][0]
            output = trainer.model(batch['bert_input'])
            for j in range(len(text)):
                input_line.append(vocab.itos[text[j].item()])
            print("".join(input_line))
            #print(batch['token_label'][0], output.argmax(dim=-1))
            for j in range(128):
                predict = output[0,j].argmax(dim=-1).item()
                print(input_line[j], ':',predict)
            break

test_epd()

In [None]:
import matplotlib.pyplot as plt
plt.switch_backend('Agg')
import matplotlib.ticker as ticker
import numpy as np
plt.style.use('ggplot')

def showPlot(points, figure_path):
    plt.figure()
    fig, ax = plt.subplots()
    # this locator puts ticks at regular intervals
    loc = ticker.MultipleLocator(base=0.2)
    ax.yaxis.set_major_locator(loc)
    plt.plot(points)
    plt.savefig(figure_path)

In [None]:
showPlot(trainer.train_losses, "./results/"+str(epochs)+"_train.png")
showPlot(trainer.valid_losses, "./results/"+str(epochs)+"_valid.png")