#### IMDb(Internet Movie Database)を使った感想のPositive,Negative分類

##### 使用するアーキテクチャ
- RNN
- LSTM

In [132]:
import random
import numpy as np
import string
import re
from collections import Counter
from typing import List
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import torchtext
from torchtext import data
from torchtext import datasets
from torchtext.vocab import vocab
from torchtext.data.utils import get_tokenizer
from sklearn.metrics import f1_score
from tqdm import tqdm

seed = 1234
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

In [133]:
class Config():
    batch_size=128
    n_epoch=20
    device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    emb_dim = 100
    hid_dim = 50

    checkpoint_path='model/IMDB.pt'
config=Config()

print(config.device)

cuda:0


In [134]:
def torch_log(x):
    return torch.log(torch.clamp(x,min=1e-10))

In [135]:
train_iter = datasets.IMDB(split='train')

train_iter, valid_iter = train_iter.random_split(
    weights={"train": 0.8, "valid": 0.2},
    seed=seed,
    total_length=len(list(train_iter)),
)

In [136]:
#tokenizerでインスタンス化
tokenizer = get_tokenizer("basic_english")

#counterでインスタンス化
counter = Counter()

#train_iterのlineをtokenizerを使ってcounterに加える
for label, line in train_iter:
    counter.update(tokenizer(line))

#vocabメソッドでvocabularyをインスタンス化
vocabulary = vocab(
    counter,
    min_freq=25,
    specials=('<unk>', '<PAD>', '<BOS>', '<EOS>')
)
# <unk>をデフォルトに設定することにより，min_freq回以上出てこない単語は<unk>になる
vocabulary.set_default_index(vocabulary['<unk>'])

word_num = len(vocabulary)

print(f"単語種数: {word_num}")
print(*vocabulary.get_itos()[:100], sep=', ')

単語種数: 9937
<unk>, <PAD>, <BOS>, <EOS>, i, am, curious, yellow, is, a, and, pretentious, steaming, pile, ., it, doesn, ', t, matter, what, one, s, political, views, are, because, this, film, can, hardly, be, taken, seriously, on, any, level, as, for, the, claim, that, frontal, male, nudity, an, automatic, ,, isn, true, ve, seen, films, with, granted, they, only, offer, some, fleeting, but, where, ?, nowhere, don, exist, same, goes, those, crappy, cable, shows, swinging, in, not, sight, indie, movies, like, brown, bunny, which, we, re, treated, to, site, of, vincent, johnson, trace, pink, visible, chloe, before, crying, (, or, ), matters


In [137]:
def text_transform(_text,max_length=256):
    text=[vocabulary[token] for token in tokenizer(_text)][:max_length-2]
    text=[vocabulary['<BOS>']]+text+[vocabulary['<EOS>']]

    return text,len(text)

def collate_batch(batch):
    label_list,text_list,len_seq_list=[],[],[]

    for _label,_text in batch:
    # torchtext==0.15.1からはnegativeは1，positiveは2なので，-1して{0, 1}にする
        label_list.append(_label-1)

        processed_test,len_seq=text_transform(_text)
        text_list.append(torch.tensor(processed_test))
        len_seq_list.append(len_seq)

    return torch.tensor(label_list),pad_sequence(text_list, padding_value=1).T, torch.tensor(len_seq_list)



In [138]:
train_dataloader=DataLoader(
    list(train_iter),
    batch_size=config.batch_size,
    shuffle=True,
    collate_fn=collate_batch
)

valid_dataloader=DataLoader(
    list(valid_iter),
    batch_size=config.batch_size,
    shuffle=False,
    collate_fn=collate_batch
)

#### LSTMとRNNで同じ訓練ループなので関数化する

In [139]:
class EarlyStopper:
    def __init__(self, verbose=True, path=config.checkpoint_path, patience=1):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.__early_stop = False
        self.val_f1_score = -np.Inf
        self.path = path
        
        
    @property
    def early_stop(self):
        return self.__early_stop

    def update(self, val_f1_score, model):
        if self.best_score is None:
            self.best_score = val_f1_score
            self.save_checkpoint(model, val_f1_score)
        elif val_f1_score < self.best_score:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.__early_stop = True
        else:
            self.best_score = val_f1_score
            self.save_checkpoint(model, val_f1_score)
            self.counter = 0
    
    def save_checkpoint(self, model, val_f1_score):
        if self.verbose:
            print(f'Validation f1score increased ({self.val_f1_score:.6f} --> {val_f1_score:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_f1_score = val_f1_score
        
    def load_checkpoint(self, model):
        if self.verbose:
            print(f'Loading model from last checkpoint with validation f1score {self.val_f1_score:.6f}')
        model.load_state_dict(torch.load(self.path))
        return model

early_stopping=EarlyStopper(patience=7,verbose=True)

In [140]:
from collections import OrderedDict
def train(net,optimizer,n_epoch=config.n_epoch):
    for epoch in range(n_epoch):
        losses_train=[]
        losses_valid=[]

        net.train()
        n_train=0

        with tqdm(train_dataloader) as pbar_epoch:
            for label,line,len_seq in pbar_epoch:
                net.zero_grad()

                t=label.to(config.device)
                x=line.to(config.device)
                len_seq.to(config.device)

                h=net(x,torch.max(len_seq),len_seq)
                y=torch.sigmoid(h).squeeze()#テンソル配列からサイズが１の次元を消去
                loss=-torch.mean(t*torch_log(y)+(1-t)*torch_log(1-y))

                loss.backward()  # 誤差の逆伝播

                optimizer.step()  # パラメータの更新

                losses_train.append(loss.tolist())

                n_train += t.size()[0]

                pbar_epoch.set_postfix(OrderedDict(train_loss=losses_train))
            # Valid
            t_valid = []
            y_pred = []
            net.eval()
        with tqdm(valid_dataloader) as pbar_epoch:
            for label, line, len_seq in pbar_epoch:

                t = label.to(config.device) # テンソルをGPUに移動
                x = line.to(config.device)
                len_seq.to(config.device)

                h = net(x, torch.max(len_seq), len_seq)
                y = torch.sigmoid(h).squeeze()

                loss = -torch.mean(t*torch_log(y) + (1 - t)*torch_log(1 - y))

                pred = y.round().squeeze()  # 0.5以上の値を持つ要素を正ラベルと予測する

                t_valid.extend(t.tolist())
                y_pred.extend(pred.tolist())

                losses_valid.append(loss.tolist())
                val_f1_score=f1_score(t_valid, y_pred, average='macro')

            early_stopping.update(val_f1_score,model=net)

            print('EPOCH: {}, Train Loss: {:.3f}, Valid Loss: {:.3f}, Validation F1: {:.3f}'.format(
                epoch+1,
                np.mean(losses_train),
                np.mean(losses_valid),
                f1_score(t_valid, y_pred, average='macro')
            ))
            pbar_epoch.set_postfix(OrderedDict(valid_loss=losses_valid,f1_score=f1_score(t_valid, y_pred, average='macro')))
            if early_stopping.early_stop:
                print('Early Stopping!')
                break

    net=early_stopping.load_checkpoint(net)

#### Embedding層の実装
- 単語を離散的なIDから連続的なベクトルに変換する
- 埋め込み行列も学習に使われるパラメータで単語間の類似度を示す
- 埋め込み行列はランダムに初期化する

以下の仮想的な状況を考える<br>
辞書: {'猫': 0, 'は': 1, '魚': 2, 'が': 3, '好き': 4, 'です': 5}

V (埋め込み行列):<br>
0: [0.1, 0.3, 0.2]  # 「猫」の埋め込みベクトル<br>
1: [0.2, 0.1, 0.3]  # 「は」の埋め込みベクトル<br>
2: [0.3, 0.2, 0.1]  # 「魚」の埋め込みベクトル<br>
3: [0.1, 0.2, 0.3]  # 「が」の埋め込みベクトル<br>
4: [0.2, 0.3, 0.1]  # 「好き」の埋め込みベクトル<br>
5: [0.3, 0.1, 0.2]  # 「です」の埋め込みベクトル<br>

これに対して以下の文があったとすると<br>
x (単語ID):<br>
文1: [0, 1, 2, 3, 4, 5]  # 「猫は魚が好きです」<br>
文2: [2, 1, 0, 3, 4, 5]  # 「魚は猫が好きです」<br>

出力のテンソルは以下のようになる<br>

出力テンソル:<br>
文1: [[0.1, 0.3, 0.2], [0.2, 0.1, 0.3], [0.3, 0.2, 0.1], [0.1, 0.2, 0.3], [0.2, 0.3, 0.1], [0.3, 0.1, 0.2]]  # 「猫は魚が好きです」<br>
文2: [[0.3, 0.2, 0.1], [0.2, 0.1, 0.3], [0.1, 0.3, 0.2], [0.1, 0.2, 0.3], [0.2, 0.3, 0.1], [0.3, 0.1, 0.2]]  # 「魚は猫が好きです」

In [141]:
class Embedding(nn.Module):
    def __init__(self,emb_dim,vocab_size):
        super().__init__()
        self.embedding_matrix=nn.Parameter(torch.rand((vocab_size,emb_dim),
                                                    dtype=torch.float))
#Embeddingの実行
    def forward(self,x):
        return F.embedding(x,self.embedding_matrix)

In [142]:
class RNN(nn.Module):
    def __init__(self, in_dim, hid_dim):
        super().__init__()
        self.hid_dim = hid_dim
        glorot = 6 / (in_dim + hid_dim*2)
        self.W = nn.Parameter(torch.tensor(np.random.uniform(
                        low=-np.sqrt(glorot),
                        high=np.sqrt(glorot),
                        size=(in_dim + hid_dim, hid_dim)
                    ).astype('float32')))
        self.b = nn.Parameter(torch.tensor(np.zeros([hid_dim]).astype('float32')))

    def function(self, h, x):
        return torch.tanh(torch.matmul(torch.cat([h, x], dim=1), self.W) + self.b)

    def forward(self, x, len_seq_max=0, init_state=None):
        x = x.transpose(0, 1)  # 系列のバッチ処理のため、次元の順番を「系列、バッチ」の順に入れ替える
        state = init_state

        if init_state is None:  # 初期値を設定しない場合は0で初期化する
            state = torch.zeros((x[0].size()[0], self.hid_dim)).to(x.device)

        size = list(state.unsqueeze(0).size())
        size[0] = 0
        output = torch.empty(size, dtype=torch.float).to(x.device)  # 一旦空テンソルを定義して順次出力を追加する

        if len_seq_max == 0:
            len_seq_max = x.size(0)
        for i in range(len_seq_max):
            state = self.function(state, x[i])
            output = torch.cat([output, state.unsqueeze(0)])  # 出力系列の追加
        return output

In [143]:
class SequenceTaggingNet(nn.Module):
    def __init__(self, word_num, emb_dim, hid_dim):
        super().__init__()
        self.emb = Embedding(emb_dim, word_num)
        self.rnn = RNN(emb_dim, hid_dim)
        self.linear = nn.Linear(hid_dim, 1)

    def forward(self, x, len_seq_max=0, len_seq=None, init_state=None):
        h = self.emb(x)
        h = self.rnn(h, len_seq_max, init_state)
        if len_seq is not None:
            # 系列が終わった時点での出力を取る必要があるので len_seq を元に集約する
            h = h[len_seq - 1, list(range(len(x))), :]
        else:
            h = h[-1]
        y = self.linear(h)
        return y
    
# RNNのモジュールを使った場合のモデルを構築する
    
class SeqenceTaggingNet2(nn.Module):
    def __init__(self,word_num,emb_dim,hid_dim):
        self.emb=Embedding(emb_dim,word_num)
        self.rnn=nn.RNN(emb_dim,hid_dim,1,batch_first=True)
        self.linear=nn.Linear(hid_dim,1)

    def forward(self, x, len_seq_max=0, len_seq=None, init_state=None):
        h = self.emb(x)
        if len_seq_max > 0:
            h, _ = self.rnn(h[:, 0:len_seq_max, :], init_state)
        else:
            h, _ = self.rnn(h, init_state)
        h = h.transpose(0, 1)
        if len_seq is not None:
            # 系列が終わった時点での出力を取る必要があるので len_seq を元に集約する
            h = h[len_seq - 1, list(range(len(x))), :]
        else:
            h = h[-1]
        y = self.linear(h)

        return y

In [144]:
net = SequenceTaggingNet(word_num, config.emb_dim, config.hid_dim)
net.to(config.device)

optimizer = optim.Adam(net.parameters())

train(net, optimizer)

  1%|▏         | 2/157 [00:00<00:16,  9.67it/s, train_loss=[0.7773711681365967, 0.7266910076141357]]

100%|██████████| 157/157 [00:18<00:00,  8.69it/s, train_loss=[0.7773711681365967, 0.7266910076141357, 0.6985199451446533, 0.7040643692016602, 0.7098147869110107, 0.696237325668335, 0.6894975900650024, 0.6836028695106506, 0.6911406517028809, 0.7015177011489868, 0.7033514976501465, 0.6876899600028992, 0.692599892616272, 0.6902141571044922, 0.6902273893356323, 0.6943600177764893, 0.6947373151779175, 0.6907326579093933, 0.691839873790741, 0.6896409392356873, 0.6962161064147949, 0.6905859708786011, 0.6916601657867432, 0.6941063404083252, 0.6852084398269653, 0.6860990524291992, 0.6978064775466919, 0.677242636680603, 0.6982951164245605, 0.7019615173339844, 0.6973644495010376, 0.6887200474739075, 0.6959929466247559, 0.6985212564468384, 0.6910181045532227, 0.7050608396530151, 0.6922539472579956, 0.6885624527931213, 0.6890206336975098, 0.6941676139831543, 0.6854497194290161, 0.6889085173606873, 0.6893584728240967, 0.6926817297935486, 0.6886997222900391, 0.6929694414138794, 0.6879715919494629, 0.

Validation f1score increased (-inf --> 0.462919).  Saving model ...
EPOCH: 1, Train Loss: 0.692, Valid Loss: 0.687, Validation F1: 0.463


100%|██████████| 157/157 [00:17<00:00,  8.75it/s, train_loss=[0.6932771801948547, 0.6553357839584351, 0.6666536331176758, 0.6777725219726562, 0.6985442042350769, 0.676143229007721, 0.6632679104804993, 0.674257755279541, 0.6850392818450928, 0.6649065017700195, 0.6585918664932251, 0.675749659538269, 0.6676524877548218, 0.6720970869064331, 0.6985001564025879, 0.683413565158844, 0.6842633485794067, 0.6856293678283691, 0.6797950863838196, 0.6966791152954102, 0.6754579544067383, 0.6908379793167114, 0.6784708499908447, 0.6740370392799377, 0.6555649042129517, 0.6669313907623291, 0.6809806227684021, 0.6692136526107788, 0.68584144115448, 0.6635119915008545, 0.6540162563323975, 0.6650574207305908, 0.6587514877319336, 0.6916971206665039, 0.697352945804596, 0.664478063583374, 0.6601776480674744, 0.6608054637908936, 0.6431913375854492, 0.6556723713874817, 0.6751958727836609, 0.6462963819503784, 0.6440573930740356, 0.6756724119186401, 0.652511715888977, 0.6478677988052368, 0.632858157157898, 0.661856

Validation f1score increased (0.462919 --> 0.596929).  Saving model ...
EPOCH: 2, Train Loss: 0.656, Valid Loss: 0.638, Validation F1: 0.597


100%|██████████| 157/157 [00:17<00:00,  8.75it/s, train_loss=[0.5630350112915039, 0.5241871476173401, 0.629736065864563, 0.5956842303276062, 0.6103527545928955, 0.5958457589149475, 0.6206311583518982, 0.5689160823822021, 0.5740084052085876, 0.633751392364502, 0.5596684813499451, 0.6153867244720459, 0.5730315446853638, 0.5940850973129272, 0.5989795923233032, 0.5603944063186646, 0.5450886487960815, 0.5557183623313904, 0.6020739078521729, 0.5651879906654358, 0.6323451995849609, 0.5865999460220337, 0.5451043844223022, 0.5626718401908875, 0.5612818002700806, 0.5438638925552368, 0.5920727252960205, 0.49862605333328247, 0.5365114212036133, 0.5437295436859131, 0.5845812559127808, 0.4910160005092621, 0.5091699361801147, 0.572566032409668, 0.5909469723701477, 0.5384236574172974, 0.6586071252822876, 0.6684675216674805, 0.6370872259140015, 0.5252100229263306, 0.5664170384407043, 0.559807300567627, 0.6580307483673096, 0.5687152147293091, 0.6064207553863525, 0.6310041546821594, 0.6019335389137268, 0

EarlyStopping counter: 1 out of 7
EPOCH: 3, Train Loss: 0.636, Valid Loss: 0.684, Validation F1: 0.532


100%|██████████| 157/157 [00:18<00:00,  8.52it/s, train_loss=[0.6607320308685303, 0.6752121448516846, 0.6530832052230835, 0.6716729402542114, 0.6692419052124023, 0.6641315221786499, 0.6654626131057739, 0.6790928840637207, 0.6798938512802124, 0.676298975944519, 0.6492323875427246, 0.6438759565353394, 0.6724892258644104, 0.6478465795516968, 0.6441630125045776, 0.6669182777404785, 0.6553927063941956, 0.6616431474685669, 0.663818359375, 0.6875261068344116, 0.6460745334625244, 0.6759970188140869, 0.636277437210083, 0.6674977540969849, 0.667823076248169, 0.6550562381744385, 0.6462736129760742, 0.6456701159477234, 0.6522848606109619, 0.6492360234260559, 0.666826605796814, 0.655573844909668, 0.6318085789680481, 0.6448764801025391, 0.6564043760299683, 0.6424844264984131, 0.6534214019775391, 0.640613317489624, 0.624362587928772, 0.6386486291885376, 0.6448438167572021, 0.6337491273880005, 0.6589888334274292, 0.602719247341156, 0.5918943881988525, 0.6862702965736389, 0.6475563049316406, 0.62986886

Validation f1score increased (0.596929 --> 0.606960).  Saving model ...
EPOCH: 4, Train Loss: 0.625, Valid Loss: 0.665, Validation F1: 0.607


100%|██████████| 157/157 [00:19<00:00,  8.13it/s, train_loss=[0.5461822152137756, 0.5349159240722656, 0.5710569620132446, 0.5286579132080078, 0.5359379053115845, 0.5237665772438049, 0.5184533596038818, 0.5299915075302124, 0.5864918231964111, 0.5639288425445557, 0.4778372645378113, 0.5285179615020752, 0.5149267911911011, 0.5581369400024414, 0.6725971698760986, 0.5109881162643433, 0.5160168409347534, 0.537002444267273, 0.5139050483703613, 0.5099262595176697, 0.5270183682441711, 0.47869443893432617, 0.5431281328201294, 0.4659014046192169, 0.5520938634872437, 0.5218560695648193, 0.4961414337158203, 0.5502879619598389, 0.49086645245552063, 0.49992233514785767, 0.5986135005950928, 0.5250354409217834, 0.4713640511035919, 0.4445295035839081, 0.49806657433509827, 0.5013097524642944, 0.5423562526702881, 0.468226820230484, 0.5664151310920715, 0.5044514536857605, 0.5516500473022461, 0.5059163570404053, 0.48656386137008667, 0.5826298594474792, 0.5303933620452881, 0.4834469258785248, 0.4468919336795

Validation f1score increased (0.606960 --> 0.624342).  Saving model ...
EPOCH: 5, Train Loss: 0.537, Valid Loss: 0.658, Validation F1: 0.624


100%|██████████| 157/157 [00:18<00:00,  8.27it/s, train_loss=[0.49728909134864807, 0.4802243113517761, 0.5637224316596985, 0.5441957712173462, 0.6120895147323608, 0.5550075173377991, 0.5214131474494934, 0.5317763090133667, 0.5322127342224121, 0.5952542424201965, 0.5079452991485596, 0.5187685489654541, 0.4537237286567688, 0.4507439136505127, 0.4459677040576935, 0.4697801172733307, 0.5008552074432373, 0.5308931469917297, 0.46265995502471924, 0.5258580446243286, 0.4993318021297455, 0.5460420846939087, 0.48065242171287537, 0.4673210680484772, 0.48246458172798157, 0.4823335111141205, 0.43435829877853394, 0.3892406225204468, 0.48503828048706055, 0.4745340347290039, 0.4937714636325836, 0.4568933844566345, 0.47350046038627625, 0.4858398735523224, 0.44126150012016296, 0.4584062993526459, 0.4296499490737915, 0.4587002098560333, 0.459960401058197, 0.38161003589630127, 0.4549712538719177, 0.42256277799606323, 0.4754537343978882, 0.41249507665634155, 0.4561234414577484, 0.40799853205680847, 0.53414

EarlyStopping counter: 1 out of 7
EPOCH: 6, Train Loss: 0.478, Valid Loss: 0.720, Validation F1: 0.589


100%|██████████| 157/157 [00:19<00:00,  8.11it/s, train_loss=[0.44129613041877747, 0.4366747736930847, 0.4269894063472748, 0.37686020135879517, 0.4381883144378662, 0.47146105766296387, 0.5458182096481323, 0.4702723026275635, 0.45232370495796204, 0.44945213198661804, 0.4071561098098755, 0.45051872730255127, 0.41509151458740234, 0.4061104655265808, 0.3826141357421875, 0.4567878842353821, 0.4222868084907532, 0.3736279308795929, 0.31702589988708496, 0.43975692987442017, 0.3498768210411072, 0.4758588671684265, 0.37794560194015503, 0.3615012764930725, 0.3967489004135132, 0.3550674319267273, 0.3416396379470825, 0.3902430534362793, 0.32885199785232544, 0.3531142473220825, 0.4011421203613281, 0.3773566484451294, 0.34887468814849854, 0.3567718267440796, 0.42403239011764526, 0.4112381339073181, 0.31813836097717285, 0.34725967049598694, 0.4037395119667053, 0.34891557693481445, 0.32391357421875, 0.3438565135002136, 0.343812495470047, 0.418401837348938, 0.3466445803642273, 0.44449687004089355, 0.471

Validation f1score increased (0.624342 --> 0.657565).  Saving model ...
EPOCH: 7, Train Loss: 0.401, Valid Loss: 0.692, Validation F1: 0.658


100%|██████████| 157/157 [00:19<00:00,  8.23it/s, train_loss=[0.3195885419845581, 0.325825959444046, 0.2835841774940491, 0.3621329069137573, 0.3102928698062897, 0.35718756914138794, 0.24911624193191528, 0.2836964726448059, 0.3460092544555664, 0.29754799604415894, 0.28317946195602417, 0.3179391026496887, 0.31777238845825195, 0.2998524308204651, 0.33708012104034424, 0.32614666223526, 0.3332980275154114, 0.2714913487434387, 0.2895217537879944, 0.4696432948112488, 0.3400498032569885, 0.2745872139930725, 0.23661547899246216, 0.3299432396888733, 0.2339407354593277, 0.2866974472999573, 0.25831907987594604, 0.2666015028953552, 0.3117761015892029, 0.2917984127998352, 0.22725391387939453, 0.3004572093486786, 0.28310248255729675, 0.268072247505188, 0.29786068201065063, 0.3265193700790405, 0.3093796968460083, 0.4085089862346649, 0.34553706645965576, 0.3110625147819519, 0.21718262135982513, 0.27464330196380615, 0.3051760196685791, 0.24745985865592957, 0.3059212565422058, 0.27084267139434814, 0.3009

EarlyStopping counter: 1 out of 7
EPOCH: 8, Train Loss: 0.319, Valid Loss: 0.778, Validation F1: 0.638


100%|██████████| 157/157 [00:18<00:00,  8.28it/s, train_loss=[0.3031945526599884, 0.30386245250701904, 0.3106968402862549, 0.3235756754875183, 0.21890321373939514, 0.3188316822052002, 0.34805506467819214, 0.22100821137428284, 0.28372687101364136, 0.26216959953308105, 0.27071839570999146, 0.1725735068321228, 0.19555959105491638, 0.23100662231445312, 0.18887510895729065, 0.2862181067466736, 0.3383558392524719, 0.20013464987277985, 0.18038882315158844, 0.23424556851387024, 0.2693289518356323, 0.2647663354873657, 0.30481261014938354, 0.2775879502296448, 0.21578532457351685, 0.28452810645103455, 0.2645101547241211, 0.19234395027160645, 0.2951121926307678, 0.21679899096488953, 0.2032168060541153, 0.26867109537124634, 0.20180436968803406, 0.168643981218338, 0.21534910798072815, 0.2648758292198181, 0.23678074777126312, 0.2779906094074249, 0.23149113357067108, 0.20460514724254608, 0.27042078971862793, 0.234024316072464, 0.23017506301403046, 0.22348880767822266, 0.3286520838737488, 0.32654336094

EarlyStopping counter: 2 out of 7
EPOCH: 9, Train Loss: 0.266, Valid Loss: 0.847, Validation F1: 0.651


100%|██████████| 157/157 [00:19<00:00,  8.19it/s, train_loss=[0.21733208000659943, 0.26378047466278076, 0.33291351795196533, 0.29843443632125854, 0.21620520949363708, 0.23912829160690308, 0.2001212239265442, 0.16360700130462646, 0.2510431408882141, 0.20386190712451935, 0.2531743347644806, 0.22249647974967957, 0.1686115562915802, 0.2010621279478073, 0.2352772355079651, 0.15620727837085724, 0.19418221712112427, 0.26923155784606934, 0.1957111656665802, 0.2043933868408203, 0.26026493310928345, 0.2444716989994049, 0.19588662683963776, 0.16855618357658386, 0.22464638948440552, 0.2245205193758011, 0.26685282588005066, 0.17435681819915771, 0.24570411443710327, 0.23839694261550903, 0.13057753443717957, 0.25888824462890625, 0.2698751986026764, 0.15266898274421692, 0.1636943817138672, 0.16094309091567993, 0.2127777487039566, 0.14107000827789307, 0.18581676483154297, 0.1680508404970169, 0.15201853215694427, 0.21204952895641327, 0.1996566653251648, 0.18841874599456787, 0.1716928482055664, 0.1550276

Validation f1score increased (0.657565 --> 0.726960).  Saving model ...
EPOCH: 10, Train Loss: 0.214, Valid Loss: 0.752, Validation F1: 0.727


100%|██████████| 157/157 [00:18<00:00,  8.43it/s, train_loss=[0.18854737281799316, 0.1580294519662857, 0.12909270823001862, 0.170453280210495, 0.18063293397426605, 0.09869541227817535, 0.13003839552402496, 0.1535877287387848, 0.1625174731016159, 0.1389649510383606, 0.16508899629116058, 0.21594107151031494, 0.14990249276161194, 0.13702884316444397, 0.1299055814743042, 0.09917865693569183, 0.20006324350833893, 0.136514350771904, 0.20717783272266388, 0.23413079977035522, 0.12039853632450104, 0.10353884845972061, 0.15206179022789001, 0.09176735579967499, 0.12783929705619812, 0.11562863737344742, 0.1184997484087944, 0.216262087225914, 0.29365983605384827, 0.18568825721740723, 0.1992916762828827, 0.189474955201149, 0.18686458468437195, 0.2064509093761444, 0.14199046790599823, 0.1850806474685669, 0.29770198464393616, 0.29107898473739624, 0.24366001784801483, 0.15134724974632263, 0.22121760249137878, 0.1985258013010025, 0.20993061363697052, 0.16122300922870636, 0.2462310492992401, 0.1846741735

EarlyStopping counter: 1 out of 7
EPOCH: 11, Train Loss: 0.304, Valid Loss: 0.915, Validation F1: 0.575


100%|██████████| 157/157 [00:18<00:00,  8.30it/s, train_loss=[0.4509819746017456, 0.30308011174201965, 0.2898770272731781, 0.31036877632141113, 0.33233073353767395, 0.3349110782146454, 0.25880712270736694, 0.3222244083881378, 0.32139551639556885, 0.29454463720321655, 0.33593785762786865, 0.3025300204753876, 0.3797544240951538, 0.3157956898212433, 0.29411518573760986, 0.35997143387794495, 0.2958695888519287, 0.279832124710083, 0.39898842573165894, 0.3078464865684509, 0.2778915762901306, 0.2710028290748596, 0.30272889137268066, 0.276279091835022, 0.31925076246261597, 0.3419024348258972, 0.3207859694957733, 0.22496634721755981, 0.26431286334991455, 0.35651326179504395, 0.33250272274017334, 0.24924568831920624, 0.32655495405197144, 0.23820456862449646, 0.30387744307518005, 0.27400118112564087, 0.3187780976295471, 0.2903355360031128, 0.28533250093460083, 0.3021695911884308, 0.34322023391723633, 0.2644239068031311, 0.2293204963207245, 0.321410596370697, 0.26026251912117004, 0.342945456504821

EarlyStopping counter: 2 out of 7
EPOCH: 12, Train Loss: 0.302, Valid Loss: 0.860, Validation F1: 0.653


100%|██████████| 157/157 [00:20<00:00,  7.78it/s, train_loss=[0.2016838788986206, 0.13150857388973236, 0.13916870951652527, 0.1807960867881775, 0.24917155504226685, 0.21489062905311584, 0.16071277856826782, 0.2080216109752655, 0.23716789484024048, 0.14459377527236938, 0.12239202111959457, 0.24123868346214294, 0.1776370108127594, 0.20835599303245544, 0.17259293794631958, 0.1420171558856964, 0.27260899543762207, 0.22417788207530975, 0.15157108008861542, 0.3556612730026245, 0.2893933653831482, 0.21640026569366455, 0.20639535784721375, 0.24449732899665833, 0.2653694450855255, 0.15491077303886414, 0.29234832525253296, 0.34495651721954346, 0.3210882544517517, 0.4263342618942261, 0.33692285418510437, 0.3055422902107239, 0.33636265993118286, 0.29400384426116943, 0.3506644070148468, 0.2796742618083954, 0.3281111717224121, 0.259827584028244, 0.301851361989975, 0.3451244831085205, 0.3479507267475128, 0.2472432553768158, 0.31998252868652344, 0.26308420300483704, 0.2911165952682495, 0.3566944897174

EarlyStopping counter: 3 out of 7
EPOCH: 13, Train Loss: 0.266, Valid Loss: 1.107, Validation F1: 0.600


100%|██████████| 157/157 [00:19<00:00,  8.06it/s, train_loss=[0.21136587858200073, 0.2764263451099396, 0.24601958692073822, 0.24573761224746704, 0.23168964684009552, 0.2864004075527191, 0.25284451246261597, 0.2589685022830963, 0.18570329248905182, 0.1948947310447693, 0.15161150693893433, 0.13712261617183685, 0.17509785294532776, 0.2441333830356598, 0.21294300258159637, 0.1903153359889984, 0.21462517976760864, 0.2207510769367218, 0.2652394771575928, 0.12075112760066986, 0.13858070969581604, 0.141974538564682, 0.1867958903312683, 0.16068045794963837, 0.14060208201408386, 0.17179298400878906, 0.22145861387252808, 0.1854582130908966, 0.1482173204421997, 0.20349639654159546, 0.11144928634166718, 0.1392236053943634, 0.10620464384555817, 0.16880889236927032, 0.22475531697273254, 0.12844964861869812, 0.2055942714214325, 0.21418803930282593, 0.12257325649261475, 0.17718863487243652, 0.17419877648353577, 0.18359749019145966, 0.22769811749458313, 0.11761102825403214, 0.15885306894779205, 0.239018

EarlyStopping counter: 4 out of 7
EPOCH: 14, Train Loss: 0.192, Valid Loss: 1.055, Validation F1: 0.638


100%|██████████| 157/157 [00:19<00:00,  8.22it/s, train_loss=[0.17925819754600525, 0.14192244410514832, 0.11400061845779419, 0.1363220065832138, 0.10994052141904831, 0.1226305291056633, 0.175301194190979, 0.09774549305438995, 0.08642429858446121, 0.10551296174526215, 0.14810240268707275, 0.1132887601852417, 0.11133237183094025, 0.15577229857444763, 0.1300571858882904, 0.10001590847969055, 0.14152005314826965, 0.1690429300069809, 0.14197930693626404, 0.12465986609458923, 0.07026758044958115, 0.11953255534172058, 0.15473046898841858, 0.1300182342529297, 0.09839567542076111, 0.12750646471977234, 0.18259581923484802, 0.144185870885849, 0.23999759554862976, 0.1783418357372284, 0.11249538511037827, 0.18657232820987701, 0.3638489246368408, 0.32243672013282776, 0.2266806960105896, 0.16040928661823273, 0.09416324645280838, 0.14565350115299225, 0.19983696937561035, 0.2674548327922821, 0.1426994800567627, 0.180955171585083, 0.20847564935684204, 0.17400982975959778, 0.20074231922626495, 0.13404726

EarlyStopping counter: 5 out of 7
EPOCH: 15, Train Loss: 0.146, Valid Loss: 1.111, Validation F1: 0.640


100%|██████████| 157/157 [00:18<00:00,  8.68it/s, train_loss=[0.1190328299999237, 0.19095031917095184, 0.06827342510223389, 0.08685432374477386, 0.13925287127494812, 0.08377478271722794, 0.07908552885055542, 0.10248889774084091, 0.13851036131381989, 0.0939319059252739, 0.11402024328708649, 0.16211852431297302, 0.19203394651412964, 0.16785213351249695, 0.18779706954956055, 0.14278572797775269, 0.06735105812549591, 0.06565693020820618, 0.22068414092063904, 0.21883746981620789, 0.10935014486312866, 0.0713667944073677, 0.07079748809337616, 0.16292545199394226, 0.18652303516864777, 0.15458446741104126, 0.25169092416763306, 0.13351601362228394, 0.09409520030021667, 0.08960872143507004, 0.14156392216682434, 0.23259907960891724, 0.1847418248653412, 0.1002504974603653, 0.13508783280849457, 0.10265891253948212, 0.12965483963489532, 0.08338938653469086, 0.23771797120571136, 0.11080894619226456, 0.07872457802295685, 0.08660173416137695, 0.06586308032274246, 0.16691647469997406, 0.16323524713516235

EarlyStopping counter: 6 out of 7
EPOCH: 16, Train Loss: 0.122, Valid Loss: 1.187, Validation F1: 0.642


100%|██████████| 157/157 [00:18<00:00,  8.72it/s, train_loss=[0.12594513595104218, 0.08535489439964294, 0.16008292138576508, 0.09861099720001221, 0.07156838476657867, 0.08432057499885559, 0.05961477383971214, 0.13416185975074768, 0.20459331572055817, 0.18334564566612244, 0.043547727167606354, 0.07637998461723328, 0.16194242238998413, 0.13841721415519714, 0.18351392447948456, 0.11682888865470886, 0.09305907040834427, 0.11410879343748093, 0.10508385300636292, 0.23490068316459656, 0.17295435070991516, 0.18747764825820923, 0.06838235259056091, 0.07359451055526733, 0.08884681761264801, 0.15003186464309692, 0.062111712992191315, 0.07760965079069138, 0.07516790181398392, 0.06014012172818184, 0.10500684380531311, 0.131808802485466, 0.12141262739896774, 0.09479940682649612, 0.07499152421951294, 0.089419424533844, 0.04571559652686119, 0.05667939782142639, 0.07369452714920044, 0.11084583401679993, 0.17621377110481262, 0.0401175394654274, 0.08132348954677582, 0.1238490641117096, 0.0784286558628082

EarlyStopping counter: 7 out of 7
EPOCH: 17, Train Loss: 0.101, Valid Loss: 1.320, Validation F1: 0.631
Early Stopping!
Loading model from last checkpoint with validation f1score 0.726960





### LSTMを用いた実装

- 入力ゲート: $\hspace{20mm}\boldsymbol{i}_t = \mathrm{\sigma} \left(\boldsymbol{W}_i \left[\begin{array}{c} \boldsymbol{x}_t \\ \boldsymbol{h}_{t-1} \end{array}\right] + \boldsymbol{b}_i\right)$
- 忘却ゲート: $\hspace{20mm}\boldsymbol{f}_t = \mathrm{\sigma} \left(\boldsymbol{W}_f \left[\begin{array}{c} \boldsymbol{x}_t \\ \boldsymbol{h}_{t-1} \end{array}\right] + \boldsymbol{b}_f\right)$  
- 出力ゲート: $\hspace{20mm}\boldsymbol{o}_t = \mathrm{\sigma} \left(\boldsymbol{W}_o \left[\begin{array}{c} \boldsymbol{x}_t \\ \boldsymbol{h}_{t-1} \end{array}\right] + \boldsymbol{b}_o\right)$  
- セル:　　　 $\hspace{20mm}\boldsymbol{c}_t = \boldsymbol{f}_t \odot \boldsymbol{c}_{t-1} + \boldsymbol{i}_t \odot \tanh \left(\boldsymbol{W}_c \left[\begin{array}{c} \boldsymbol{x}_t \\ \boldsymbol{h}_{t-1} \end{array}\right] + \boldsymbol{b}_c\right)$
- 隠れ状態: 　$\hspace{20mm}\boldsymbol{h}_t = \boldsymbol{o}_t \odot \tanh \left(\boldsymbol{c}_t \right)$

In [145]:
class SequenceTaggingNet4(nn.Module):
    def __init__(self, word_num, emb_dim, hid_dim):
        super().__init__()
        self.emb = nn.Embedding(word_num, emb_dim)
        self.lstm = nn.LSTM(emb_dim, hid_dim, 1, batch_first=True)  # nn.LSTMの使用
        self.linear = nn.Linear(hid_dim, 1)
    
    def forward(self, x, len_seq_max=0, len_seq=None, init_state=None):
        h = self.emb(x)
        if len_seq_max > 0:
            h, _ = self.lstm(h[:, 0:len_seq_max, :], init_state)
        else:
            h, _ = self.lstm(h, init_state)
        h = h.transpose(0, 1)
        if len_seq is not None:
            h = h[len_seq - 1, list(range(len(x))), :]
        else:
            h = h[-1]
        y = self.linear(h)
        
        return y

In [146]:

early_stopping=EarlyStopper(patience=7,verbose=True,path='./model/IMBD_LSTM.pt')
net = SequenceTaggingNet4(word_num, config.emb_dim, config.hid_dim)
net.to(config.device)
optimizer = optim.Adam(net.parameters())

train(net, optimizer)

100%|██████████| 157/157 [00:05<00:00, 29.46it/s, train_loss=[0.698729395866394, 0.69236159324646, 0.6948516964912415, 0.6941598653793335, 0.6915837526321411, 0.6916572451591492, 0.695105791091919, 0.6923444271087646, 0.693397581577301, 0.6929824352264404, 0.6925323605537415, 0.695821225643158, 0.6913526058197021, 0.6946244835853577, 0.6820111870765686, 0.6821329593658447, 0.6947636604309082, 0.6897933483123779, 0.6928807497024536, 0.7040039896965027, 0.6910910606384277, 0.7030431628227234, 0.6881298422813416, 0.6914405822753906, 0.6949384212493896, 0.685042142868042, 0.6884819269180298, 0.6966086030006409, 0.6988370418548584, 0.686540961265564, 0.6900820136070251, 0.6863418817520142, 0.6941786408424377, 0.6828898191452026, 0.6957176923751831, 0.6930797696113586, 0.6960475444793701, 0.6846505999565125, 0.7019962072372437, 0.6980888843536377, 0.6930607557296753, 0.6872420310974121, 0.6806379556655884, 0.6852893829345703, 0.6875762939453125, 0.6945504546165466, 0.6893776059150696, 0.6940

Validation f1score increased (-inf --> 0.618530).  Saving model ...
EPOCH: 1, Train Loss: 0.674, Valid Loss: 0.635, Validation F1: 0.619


100%|██████████| 157/157 [00:05<00:00, 29.28it/s, train_loss=[0.6231216192245483, 0.531691312789917, 0.6251899600028992, 0.5958290100097656, 0.49257728457450867, 0.6170104146003723, 0.5878537893295288, 0.6336519122123718, 0.6445177793502808, 0.6853920817375183, 0.6634631752967834, 0.6703441739082336, 0.6430200338363647, 0.6143584251403809, 0.6247613430023193, 0.6119644641876221, 0.6167840957641602, 0.6253660321235657, 0.6049641966819763, 0.5657706260681152, 0.5639371871948242, 0.6088065505027771, 0.6339655518531799, 0.6392848491668701, 0.6250555515289307, 0.642001748085022, 0.5642144083976746, 0.5647358894348145, 0.5980889797210693, 0.575366735458374, 0.561823844909668, 0.6403914093971252, 0.5697743892669678, 0.6176519393920898, 0.5504074096679688, 0.713493824005127, 0.6137531995773315, 0.6155554056167603, 0.5681366324424744, 0.6174235939979553, 0.5614286661148071, 0.5475415587425232, 0.5962712168693542, 0.5776998996734619, 0.5962190628051758, 0.5572700500488281, 0.5584982633590698, 0.

Validation f1score increased (0.618530 --> 0.710577).  Saving model ...
EPOCH: 2, Train Loss: 0.599, Valid Loss: 0.585, Validation F1: 0.711


100%|██████████| 157/157 [00:05<00:00, 29.49it/s, train_loss=[0.5158020257949829, 0.5605247616767883, 0.6384954452514648, 0.5680780410766602, 0.533652663230896, 0.5277923941612244, 0.5930191874504089, 0.5198675394058228, 0.5698422789573669, 0.4568105936050415, 0.5726643800735474, 0.5480008125305176, 0.5547996759414673, 0.551186203956604, 0.6438713669776917, 0.5180065631866455, 0.6609342098236084, 0.6225119233131409, 0.5051686763763428, 0.6304517984390259, 0.5378227829933167, 0.5102876424789429, 0.49635419249534607, 0.5896368026733398, 0.5949442982673645, 0.6382412314414978, 0.6074955463409424, 0.5873839259147644, 0.6252151727676392, 0.5053228139877319, 0.5580157041549683, 0.4482951760292053, 0.5719900131225586, 0.5900362730026245, 0.5571698546409607, 0.5793827772140503, 0.6133642196655273, 0.570102334022522, 0.5539973974227905, 0.5430811643600464, 0.4532851576805115, 0.48410049080848694, 0.482596755027771, 0.5175492763519287, 0.5347731113433838, 0.5517439842224121, 0.5045017004013062, 

Validation f1score increased (0.710577 --> 0.766393).  Saving model ...
EPOCH: 3, Train Loss: 0.530, Valid Loss: 0.514, Validation F1: 0.766


100%|██████████| 157/157 [00:05<00:00, 29.39it/s, train_loss=[0.5617698431015015, 0.5293222665786743, 0.5467472076416016, 0.4572775661945343, 0.44520318508148193, 0.4889025390148163, 0.5060163140296936, 0.4226808547973633, 0.3977445363998413, 0.563907265663147, 0.438960999250412, 0.5483431816101074, 0.502405047416687, 0.5211468935012817, 0.49309849739074707, 0.5028129816055298, 0.4985419809818268, 0.5167187452316284, 0.4707568287849426, 0.5225836038589478, 0.5630233287811279, 0.5062241554260254, 0.5390112400054932, 0.49851423501968384, 0.44030117988586426, 0.3922829329967499, 0.5624959468841553, 0.45108652114868164, 0.47109296917915344, 0.5079076290130615, 0.4439154863357544, 0.5058572292327881, 0.46074551343917847, 0.5271002054214478, 0.5388810634613037, 0.4933779835700989, 0.4080665409564972, 0.5191569328308105, 0.44077634811401367, 0.406653493642807, 0.4904279410839081, 0.5393408536911011, 0.43017706274986267, 0.4687937796115875, 0.535585880279541, 0.3809807002544403, 0.442286074161

Validation f1score increased (0.766393 --> 0.795515).  Saving model ...
EPOCH: 4, Train Loss: 0.463, Valid Loss: 0.477, Validation F1: 0.796


100%|██████████| 157/157 [00:05<00:00, 29.76it/s, train_loss=[0.5328482389450073, 0.44458532333374023, 0.3821716904640198, 0.3707275092601776, 0.38371413946151733, 0.5477564334869385, 0.5364640355110168, 0.3599517345428467, 0.4418720602989197, 0.38077446818351746, 0.36073601245880127, 0.45753419399261475, 0.41742056608200073, 0.359405517578125, 0.45648786425590515, 0.4168592393398285, 0.39328283071517944, 0.4364064931869507, 0.44591307640075684, 0.48632165789604187, 0.39262253046035767, 0.38416117429733276, 0.3987298309803009, 0.40836745500564575, 0.3508427143096924, 0.4078233242034912, 0.5075550079345703, 0.40276646614074707, 0.35330164432525635, 0.39952895045280457, 0.48533394932746887, 0.4163054823875427, 0.423323392868042, 0.4311296045780182, 0.444754034280777, 0.32599717378616333, 0.35281121730804443, 0.3969133198261261, 0.36945414543151855, 0.3994481563568115, 0.35208773612976074, 0.3774513900279999, 0.3368452191352844, 0.4408305287361145, 0.40772396326065063, 0.40800005197525024

Validation f1score increased (0.795515 --> 0.803995).  Saving model ...
EPOCH: 5, Train Loss: 0.398, Valid Loss: 0.447, Validation F1: 0.804


100%|██████████| 157/157 [00:05<00:00, 29.67it/s, train_loss=[0.3536042273044586, 0.31991109251976013, 0.3662112355232239, 0.3359900116920471, 0.3653064966201782, 0.3404502272605896, 0.35128462314605713, 0.3316754102706909, 0.4039442837238312, 0.2794460654258728, 0.26820239424705505, 0.3588349223136902, 0.466321736574173, 0.35809755325317383, 0.32329148054122925, 0.38292133808135986, 0.28054386377334595, 0.32953941822052, 0.3531590700149536, 0.3532572090625763, 0.3044546842575073, 0.456038236618042, 0.3006269335746765, 0.4043506383895874, 0.27549970149993896, 0.4171064496040344, 0.40703701972961426, 0.36343586444854736, 0.4019814133644104, 0.5178364515304565, 0.41959869861602783, 0.4649195075035095, 0.3380431532859802, 0.4043262302875519, 0.4912799596786499, 0.3316957950592041, 0.3868902325630188, 0.4262259006500244, 0.3792470097541809, 0.44472724199295044, 0.3836967349052429, 0.39085501432418823, 0.3097766041755676, 0.3843565583229065, 0.39522799849510193, 0.34941619634628296, 0.38591

Validation f1score increased (0.803995 --> 0.825185).  Saving model ...
EPOCH: 6, Train Loss: 0.372, Valid Loss: 0.420, Validation F1: 0.825


100%|██████████| 157/157 [00:05<00:00, 29.59it/s, train_loss=[0.2818051874637604, 0.3601857125759125, 0.4435563385486603, 0.39016294479370117, 0.32196587324142456, 0.43340396881103516, 0.31476372480392456, 0.3832433223724365, 0.2815614938735962, 0.32230427861213684, 0.3458278775215149, 0.3434099853038788, 0.3390394151210785, 0.33290594816207886, 0.3200798034667969, 0.22669407725334167, 0.2737419605255127, 0.3224413990974426, 0.2938989996910095, 0.21783016622066498, 0.35729146003723145, 0.34391623735427856, 0.3979162871837616, 0.36213964223861694, 0.30549725890159607, 0.3848404586315155, 0.24401021003723145, 0.19342730939388275, 0.2847858667373657, 0.2791464030742645, 0.3063986897468567, 0.31804367899894714, 0.32655274868011475, 0.3037671148777008, 0.39947959780693054, 0.23402778804302216, 0.3512642979621887, 0.2701064646244049, 0.2765745222568512, 0.3177591860294342, 0.2544547915458679, 0.26468440890312195, 0.31139951944351196, 0.3499857485294342, 0.33873897790908813, 0.349993079900741

EarlyStopping counter: 1 out of 7
EPOCH: 7, Train Loss: 0.322, Valid Loss: 0.429, Validation F1: 0.816


100%|██████████| 157/157 [00:05<00:00, 29.76it/s, train_loss=[0.3274475336074829, 0.22702202200889587, 0.2956578731536865, 0.24409961700439453, 0.2559770941734314, 0.35201752185821533, 0.32790273427963257, 0.3094385862350464, 0.3806529939174652, 0.22904352843761444, 0.2927340567111969, 0.23669973015785217, 0.2680363357067108, 0.23616720736026764, 0.2253655046224594, 0.25885194540023804, 0.2778559923171997, 0.24663569033145905, 0.21167504787445068, 0.17318373918533325, 0.33595776557922363, 0.18871194124221802, 0.28979283571243286, 0.256747305393219, 0.31855887174606323, 0.2699474096298218, 0.32021307945251465, 0.3132345378398895, 0.2372947335243225, 0.17635495960712433, 0.2976287305355072, 0.33035728335380554, 0.19304263591766357, 0.2987405061721802, 0.23546230792999268, 0.2575387954711914, 0.3846251368522644, 0.3229849338531494, 0.18796025216579437, 0.19918707013130188, 0.23704665899276733, 0.2363780438899994, 0.27567559480667114, 0.28624531626701355, 0.24959921836853027, 0.31100171804

EarlyStopping counter: 2 out of 7
EPOCH: 8, Train Loss: 0.288, Valid Loss: 0.446, Validation F1: 0.814


100%|██████████| 157/157 [00:05<00:00, 29.51it/s, train_loss=[0.24971459805965424, 0.33276939392089844, 0.25930601358413696, 0.2629486918449402, 0.2821865379810333, 0.2353093922138214, 0.24614563584327698, 0.336934894323349, 0.3416843116283417, 0.2522696852684021, 0.36190229654312134, 0.32037603855133057, 0.22865688800811768, 0.25086814165115356, 0.32658714056015015, 0.20207875967025757, 0.21636807918548584, 0.20230214297771454, 0.24986925721168518, 0.3655511140823364, 0.3370947241783142, 0.2818242907524109, 0.24435780942440033, 0.2597353458404541, 0.24315574765205383, 0.2713271975517273, 0.257932186126709, 0.22584575414657593, 0.28088003396987915, 0.2310309112071991, 0.2446824014186859, 0.29542112350463867, 0.28821298480033875, 0.2860214412212372, 0.34154778718948364, 0.252049058675766, 0.19803354144096375, 0.1774243861436844, 0.26861685514450073, 0.25101238489151, 0.21413850784301758, 0.2562398910522461, 0.23588509857654572, 0.2922671139240265, 0.2543303370475769, 0.3419836759567261,

Validation f1score increased (0.825185 --> 0.828163).  Saving model ...
EPOCH: 9, Train Loss: 0.257, Valid Loss: 0.442, Validation F1: 0.828


100%|██████████| 157/157 [00:05<00:00, 29.20it/s, train_loss=[0.3520541191101074, 0.28472310304641724, 0.21034102141857147, 0.19360414147377014, 0.22222431004047394, 0.2600303888320923, 0.2534201741218567, 0.1941109597682953, 0.19157381355762482, 0.20275646448135376, 0.22692915797233582, 0.13759645819664001, 0.3537750840187073, 0.20868930220603943, 0.13945436477661133, 0.26295459270477295, 0.15793967247009277, 0.16210800409317017, 0.3203541040420532, 0.1898963898420334, 0.17565718293190002, 0.22177374362945557, 0.21992528438568115, 0.19286739826202393, 0.21696841716766357, 0.29739004373550415, 0.20507119596004486, 0.18844759464263916, 0.19448065757751465, 0.24081343412399292, 0.27439942955970764, 0.17242072522640228, 0.21628384292125702, 0.17735174298286438, 0.2133726179599762, 0.23468884825706482, 0.2501606047153473, 0.28306323289871216, 0.1850375235080719, 0.25552472472190857, 0.12238326668739319, 0.28686439990997314, 0.25876718759536743, 0.21952496469020844, 0.19629138708114624, 0.2

EarlyStopping counter: 1 out of 7
EPOCH: 10, Train Loss: 0.225, Valid Loss: 0.484, Validation F1: 0.821


100%|██████████| 157/157 [00:05<00:00, 28.61it/s, train_loss=[0.27875417470932007, 0.1870516687631607, 0.26466190814971924, 0.28158873319625854, 0.18080519139766693, 0.2290969341993332, 0.2502976059913635, 0.155122309923172, 0.20677265524864197, 0.2341071218252182, 0.27452075481414795, 0.13959965109825134, 0.10956093668937683, 0.2002142071723938, 0.12318764626979828, 0.18543994426727295, 0.17272531986236572, 0.20786748826503754, 0.1224055290222168, 0.28136295080184937, 0.20277363061904907, 0.2627274990081787, 0.18321481347084045, 0.05764368176460266, 0.23934143781661987, 0.1733270138502121, 0.1775008589029312, 0.20409303903579712, 0.177435964345932, 0.1535767912864685, 0.18811877071857452, 0.198513001203537, 0.1792159080505371, 0.21047443151474, 0.2765468657016754, 0.28177186846733093, 0.21953816711902618, 0.16915278136730194, 0.16015620529651642, 0.18687456846237183, 0.25049296021461487, 0.14876113831996918, 0.21423932909965515, 0.17576324939727783, 0.1861770898103714, 0.2693227529525

EarlyStopping counter: 2 out of 7
EPOCH: 11, Train Loss: 0.208, Valid Loss: 0.421, Validation F1: 0.825


100%|██████████| 157/157 [00:05<00:00, 28.96it/s, train_loss=[0.2019236981868744, 0.2724367678165436, 0.1815301775932312, 0.23027408123016357, 0.11817224323749542, 0.1955944299697876, 0.13048750162124634, 0.19202451407909393, 0.2464059442281723, 0.22735311090946198, 0.19104337692260742, 0.19625107944011688, 0.24911724030971527, 0.24004217982292175, 0.2049640268087387, 0.16667479276657104, 0.24038565158843994, 0.19177967309951782, 0.1991950273513794, 0.17971986532211304, 0.15867845714092255, 0.14464208483695984, 0.1700211763381958, 0.16409866511821747, 0.2192482203245163, 0.1326335221529007, 0.14833660423755646, 0.14178603887557983, 0.07444504648447037, 0.21012139320373535, 0.11074043065309525, 0.21814540028572083, 0.21880191564559937, 0.2005433887243271, 0.17992262542247772, 0.10198397934436798, 0.11770623922348022, 0.1606675535440445, 0.1820896863937378, 0.1481921374797821, 0.2018158733844757, 0.21914272010326385, 0.16041147708892822, 0.16833987832069397, 0.1084442213177681, 0.1474713

Validation f1score increased (0.828163 --> 0.851991).  Saving model ...
EPOCH: 12, Train Loss: 0.177, Valid Loss: 0.431, Validation F1: 0.852


100%|██████████| 157/157 [00:05<00:00, 29.03it/s, train_loss=[0.11317026615142822, 0.09505587071180344, 0.17427632212638855, 0.11189384758472443, 0.11086718738079071, 0.14291653037071228, 0.17018260061740875, 0.1911158263683319, 0.19810029864311218, 0.10254812240600586, 0.24205408990383148, 0.22362260520458221, 0.12208673357963562, 0.14849939942359924, 0.18004979193210602, 0.15391451120376587, 0.09257045388221741, 0.13212527334690094, 0.13516269624233246, 0.21433639526367188, 0.18443551659584045, 0.12025352567434311, 0.1601017415523529, 0.16895455121994019, 0.18490926921367645, 0.1856241226196289, 0.1581544429063797, 0.10655446350574493, 0.09676893800497055, 0.12056506425142288, 0.12489227205514908, 0.11075092852115631, 0.18819370865821838, 0.1579146385192871, 0.12333135306835175, 0.09842238575220108, 0.13511979579925537, 0.08028963953256607, 0.20793388783931732, 0.15209351480007172, 0.19461141526699066, 0.10312013328075409, 0.18825232982635498, 0.24312573671340942, 0.17474856972694397

EarlyStopping counter: 1 out of 7
EPOCH: 13, Train Loss: 0.158, Valid Loss: 0.463, Validation F1: 0.841


100%|██████████| 157/157 [00:05<00:00, 28.77it/s, train_loss=[0.09035438299179077, 0.21934227645397186, 0.20966282486915588, 0.12289993464946747, 0.11037224531173706, 0.2327779233455658, 0.2769705355167389, 0.1627320945262909, 0.12844112515449524, 0.21467216312885284, 0.21446242928504944, 0.15180149674415588, 0.1051727756857872, 0.15913304686546326, 0.07254470884799957, 0.12344728410243988, 0.13591653108596802, 0.24653738737106323, 0.1108095571398735, 0.18230268359184265, 0.10738661885261536, 0.10541650652885437, 0.10132826864719391, 0.16823211312294006, 0.16245189309120178, 0.12401063740253448, 0.13304026424884796, 0.11221115291118622, 0.10239963233470917, 0.15545597672462463, 0.1869630515575409, 0.18145698308944702, 0.13662658631801605, 0.10409190505743027, 0.12484123557806015, 0.1286880075931549, 0.1266421377658844, 0.11370182037353516, 0.19994854927062988, 0.14102357625961304, 0.12312189489603043, 0.1207023411989212, 0.14028504490852356, 0.22077934443950653, 0.17914117872714996, 0.

EarlyStopping counter: 2 out of 7
EPOCH: 14, Train Loss: 0.181, Valid Loss: 0.562, Validation F1: 0.706


100%|██████████| 157/157 [00:05<00:00, 28.85it/s, train_loss=[0.47995877265930176, 0.4671536684036255, 0.4890919029712677, 0.45131146907806396, 0.4982001781463623, 0.4723905920982361, 0.43996191024780273, 0.4304419755935669, 0.5086848735809326, 0.48983222246170044, 0.5131378173828125, 0.4821571707725525, 0.42711102962493896, 0.46529334783554077, 0.4195075035095215, 0.45180943608283997, 0.34699133038520813, 0.41328662633895874, 0.4231872260570526, 0.36900147795677185, 0.45512938499450684, 0.3054716885089874, 0.33362406492233276, 0.3798002600669861, 0.3619500398635864, 0.2744405269622803, 0.350924015045166, 0.35675856471061707, 0.38274475932121277, 0.26874327659606934, 0.35067394375801086, 0.32396602630615234, 0.22026005387306213, 0.2723258137702942, 0.2343113273382187, 0.24399948120117188, 0.21477636694908142, 0.269814133644104, 0.14749443531036377, 0.38125699758529663, 0.3462243974208832, 0.1634332537651062, 0.11249875277280807, 0.3354458212852478, 0.23285481333732605, 0.28438779711723

EarlyStopping counter: 3 out of 7
EPOCH: 15, Train Loss: 0.485, Valid Loss: 0.655, Validation F1: 0.600


100%|██████████| 157/157 [00:05<00:00, 29.08it/s, train_loss=[0.6395754814147949, 0.6471556425094604, 0.6348226070404053, 0.6444286108016968, 0.6578325033187866, 0.621428370475769, 0.6335731744766235, 0.6318620443344116, 0.6600674390792847, 0.6583366990089417, 0.6343588829040527, 0.6427829265594482, 0.6435447931289673, 0.6487576961517334, 0.6017832159996033, 0.6487360000610352, 0.6091230511665344, 0.6616056561470032, 0.6572167873382568, 0.648544430732727, 0.6307101249694824, 0.6119900345802307, 0.6262170672416687, 0.6417970657348633, 0.6663821339607239, 0.6078526973724365, 0.6149547100067139, 0.5939285159111023, 0.6318396329879761, 0.6033446788787842, 0.6190108060836792, 0.6282562017440796, 0.6348931789398193, 0.6416423320770264, 0.630833625793457, 0.6231962442398071, 0.6200498342514038, 0.601387083530426, 0.6170440316200256, 0.5950775146484375, 0.6277980208396912, 0.6214437484741211, 0.5799522399902344, 0.6052670478820801, 0.6258320212364197, 0.6263021230697632, 0.6370069980621338, 0.

EarlyStopping counter: 4 out of 7
EPOCH: 16, Train Loss: 0.604, Valid Loss: 0.579, Validation F1: 0.690


100%|██████████| 157/157 [00:05<00:00, 29.15it/s, train_loss=[0.5105963945388794, 0.48414716124534607, 0.5445696115493774, 0.5799206495285034, 0.5533887147903442, 0.546892523765564, 0.6244048476219177, 0.5343009233474731, 0.5829444527626038, 0.5653732419013977, 0.583213210105896, 0.6123428344726562, 0.5667098760604858, 0.5155577659606934, 0.4964222311973572, 0.49372440576553345, 0.4790865182876587, 0.5239182114601135, 0.48584747314453125, 0.4156247675418854, 0.4655701816082001, 0.4496644139289856, 0.5222388505935669, 0.4532836079597473, 0.41432660818099976, 0.45208632946014404, 0.42170250415802, 0.4558994174003601, 0.43608731031417847, 0.3710898160934448, 0.4988820254802704, 0.6299285292625427, 0.599706768989563, 0.3932884931564331, 0.4231507182121277, 0.43070298433303833, 0.48634347319602966, 0.4174114465713501, 0.5201988816261292, 0.43982017040252686, 0.4217791259288788, 0.515681266784668, 0.4435793161392212, 0.47308382391929626, 0.38994431495666504, 0.36621129512786865, 0.4214681982

EarlyStopping counter: 5 out of 7
EPOCH: 17, Train Loss: 0.404, Valid Loss: 0.432, Validation F1: 0.808


100%|██████████| 157/157 [00:05<00:00, 28.78it/s, train_loss=[0.3550812900066376, 0.3061400055885315, 0.38735657930374146, 0.2626553177833557, 0.3475789427757263, 0.30413490533828735, 0.3067575693130493, 0.3578227162361145, 0.3208737075328827, 0.2794742286205292, 0.32502084970474243, 0.33017581701278687, 0.3744429349899292, 0.3325144648551941, 0.2722526788711548, 0.31418871879577637, 0.3491418957710266, 0.3367350101470947, 0.2478785365819931, 0.3996313512325287, 0.20005351305007935, 0.358268678188324, 0.4734545350074768, 0.3035343289375305, 0.24411579966545105, 0.33553236722946167, 0.2906672954559326, 0.2593798339366913, 0.3312363624572754, 0.32127898931503296, 0.39075174927711487, 0.4531291425228119, 0.3855685591697693, 0.30332493782043457, 0.2419230341911316, 0.4161282479763031, 0.34321117401123047, 0.38762661814689636, 0.47415482997894287, 0.3485196530818939, 0.34354305267333984, 0.28322553634643555, 0.30616986751556396, 0.4947284758090973, 0.3482034206390381, 0.30515140295028687, 0

EarlyStopping counter: 6 out of 7
EPOCH: 18, Train Loss: 0.301, Valid Loss: 0.402, Validation F1: 0.839


100%|██████████| 157/157 [00:05<00:00, 28.85it/s, train_loss=[0.2239537537097931, 0.2757214605808258, 0.31779396533966064, 0.2685692310333252, 0.27614787220954895, 0.19396042823791504, 0.2075130045413971, 0.21510444581508636, 0.22341278195381165, 0.3184662461280823, 0.39359867572784424, 0.2981640100479126, 0.21447807550430298, 0.22589808702468872, 0.20788021385669708, 0.22795692086219788, 0.2910388708114624, 0.37914127111434937, 0.3167056739330292, 0.24496227502822876, 0.29274192452430725, 0.3557252287864685, 0.3640207052230835, 0.24155524373054504, 0.23567074537277222, 0.22090908885002136, 0.2670629620552063, 0.2655462622642517, 0.3632611334323883, 0.34909558296203613, 0.31158989667892456, 0.30317625403404236, 0.28374117612838745, 0.23384001851081848, 0.28796178102493286, 0.25833049416542053, 0.2472342848777771, 0.19470423460006714, 0.3033245801925659, 0.2665923237800598, 0.21212485432624817, 0.26442334055900574, 0.28014248609542847, 0.29124388098716736, 0.37434816360473633, 0.2512272

EarlyStopping counter: 7 out of 7
EPOCH: 19, Train Loss: 0.249, Valid Loss: 0.430, Validation F1: 0.834
Early Stopping!
Loading model from last checkpoint with validation f1score 0.851991



