In [12]:
import os
import time
import torch
import torch.nn as nn

from vocabulary import Vocab
from utils import *
from embedding import PositionalEncoding, Embeddings
from layers import *
from criterion import KLLossMasked
from optimizer import NoamOpt
from bert import Bert, Generator

In [13]:
log_every_iter = 100
validate_every_iter = 10000

In [14]:
#create_data()

In [15]:
directory = 'model/'
if not os.path.isdir(directory):
    os.mkdir(directory)
model_save_path = 'bert.checkpoint'
model_save_path = os.path.join(directory, model_save_path)

In [16]:
small_size = False
use_checkpoint = False
use_cuda = True
device = torch.device("cuda:0" if use_cuda else "cpu")


vocab = Vocab()
V = len(vocab.char2id)
d_model = 256
d_ff = 1024
h = 4
n_encoders = 4

batch_size = 256

In [17]:
self_attn = MultiHeadedAttention(h=h, d_model=d_model, d_k=d_model // h, d_v=d_model // h, dropout=0.1)
feed_forward = FullyConnectedFeedForward(d_model=d_model, d_ff=d_ff)
position = PositionalEncoding(d_model, dropout=0.1)
embedding = nn.Sequential(Embeddings(d_model=d_model, vocab=V), position)

encoder = Encoder(self_attn=self_attn, feed_forward=feed_forward, size=d_model, dropout=0.1)
generator = Generator(d_model=d_model, vocab_size=V)
model = Bert(encoder=encoder, embedding=embedding, generator=generator, n_layers=n_encoders)
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

model = model.to(device)

In [18]:
opt = torch.optim.Adam(model.parameters(), lr=1e-9, betas=(0.9, 0.98), eps=1e-9)
model_opt = NoamOpt(d_model, 2, 4000, opt)
criterion = nn.KLDivLoss(reduction="sum")
#criterion = nn.CrossEntropyLoss()
if use_cuda:
    criterion.cuda(device=device)

In [19]:
train_data = read_train_data(filepath="./pairs_train.txt", small = small_size)
dev_data = read_dev_data(filepath="./pairs_valid.txt", small = small_size)

In [20]:
hist_valid_scores = []

if use_checkpoint:
    checkpoint = torch.load(model_save_path)
    current_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    opt.load_state_dict(checkpoint['optimizer_state_dict'])

    step = checkpoint['_step']
    rate = checkpoint['_rate']
    current_train_iter = checkpoint['train_iter']
    model_opt._step = step
    model_opt._rate = rate
    print(f'reading checkpoint from epoch {current_epoch}, iter {current_train_iter}')
else:
    current_epoch = 0
    current_train_iter = 0

In [21]:
def run():
    train_iter = report_loss = cum_loss = valid_num = 0
    report_samples = cum_samples = 0
    for epoch in range(current_epoch, 1000):
        print("=" * 30)
        model.train()
        #loss_compute = KLLossMasked(model.generator, criterion, opt=model_opt)

        start = time.time()
        train_data_iter = create_words_batch(train_data, vocab, mini_batch=batch_size, shuffle=False, device=model.device)
        for i, batch in enumerate(train_data_iter):
            if use_checkpoint and train_iter <= current_train_iter:
                train_iter = current_train_iter
                continue
            out = model.forward(batch.src, batch.src_mask)
            generator_mask = torch.zeros(batch.src.shape[0], V, device=model.device)
            generator_mask = generator_mask.scatter_(1, batch.src, mask_token)

            #batch_loss = loss_compute(out, batch.tgt, generator_mask)
            x = model.generator(out, generator_mask)
            x = x.masked_fill(generator_mask == mask_token, -1e9)
            x = nn.LogSoftmax(dim=1)(x)
            y = batch.tgt.masked_fill(generator_mask == mask_token, 0)
            y = y/torch.sum(y, dim=1, keepdim=True)
            #y = nn.Softmax(dim=1)(y)
            batch_loss = criterion(x, y)
            loss = batch_loss
            loss.backward()

            model_opt.step()
            model_opt.optimizer.zero_grad()

            batch_loss_val = batch_loss.item()
            report_loss += batch_loss_val
            cum_loss += batch_loss_val
            report_samples += batch_size
            cum_samples += batch_size

            train_iter += 1

            if train_iter % log_every_iter == 0:
                elapsed = time.time() - start
                print(f'epoch {epoch}, iter {train_iter}, avg. loss {report_loss / report_samples:.2f} time elapsed {elapsed:.2f}sec')
                start = time.time()
                report_loss = report_samples = 0

            if train_iter % validate_every_iter == 0:
                print(f'epoch {epoch}, iter {train_iter}, cum. loss {cum_loss / cum_samples:.2f} examples {cum_samples}')
                cum_samples = cum_loss = 0.

                print('begin evaluation...')
                valid_num += 1
                acc = evaluate_acc(model, vocab, dev_data, device=model.device)
                print(f'validation: iter {train_iter}, dev. acc {acc:.4f}')

                valid_metric = acc

                is_better = len(hist_valid_scores) == 0 or valid_metric > max(hist_valid_scores)
                hist_valid_scores.append(valid_metric)

                if is_better:
                    print('save currently the best model to [%s]' % model_save_path)
                    torch.save({'epoch': epoch,
                                'model_state_dict': model.state_dict(),
                                'optimizer_state_dict': model_opt.optimizer.state_dict(),
                                'loss': cum_loss,
                                '_rate': model_opt._rate,
                                '_step': model_opt._step,
                                'train_iter': train_iter,
                                'hist_valid_scores': hist_valid_scores,
                                }, model_save_path)

        torch.save({'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': model_opt.optimizer.state_dict(),
                    'loss': cum_loss,
                    '_rate': model_opt._rate,
                    '_step': model_opt._step,
                    'train_iter': train_iter,
                    'hist_valid_scores': hist_valid_scores,
                    }, os.path.join(directory, f'real_model_{epoch}.checkpoint'))


In [22]:
run()

epoch 0, iter 100, avg. loss 1.72 time elapsed 1.76sec
epoch 0, iter 200, avg. loss 1.45 time elapsed 1.74sec
epoch 0, iter 300, avg. loss 1.43 time elapsed 1.72sec
epoch 0, iter 400, avg. loss 1.41 time elapsed 1.72sec
epoch 0, iter 500, avg. loss 1.41 time elapsed 1.71sec
epoch 0, iter 600, avg. loss 1.40 time elapsed 1.73sec
epoch 0, iter 700, avg. loss 1.41 time elapsed 1.77sec
epoch 0, iter 800, avg. loss 1.40 time elapsed 1.70sec
epoch 0, iter 900, avg. loss 1.39 time elapsed 1.70sec
epoch 0, iter 1000, avg. loss 1.39 time elapsed 1.74sec
epoch 0, iter 1100, avg. loss 1.39 time elapsed 1.68sec
epoch 0, iter 1200, avg. loss 1.38 time elapsed 1.62sec
epoch 0, iter 1300, avg. loss 1.37 time elapsed 1.59sec
epoch 0, iter 1400, avg. loss 1.38 time elapsed 1.58sec
epoch 0, iter 1500, avg. loss 1.37 time elapsed 1.59sec
epoch 0, iter 1600, avg. loss 1.38 time elapsed 1.60sec
epoch 0, iter 1700, avg. loss 1.37 time elapsed 1.61sec
epoch 0, iter 1800, avg. loss 1.37 time elapsed 1.83sec
e