In [None]:
import os
from google.colab import drive
drive.mount('/content/gdrive/', force_remount=True)

# Setting

In [None]:
%cd "/content/gdrive/My Drive/ML"
%ls

In [None]:
!pip uninstall --y torch torchtext
!pip install --pre torchtext -f https://download.pytorch.org/whl/nightly/cu101/torch_nightly.html

In [None]:
!pip install spacy  # spacy tokenizer
!python -m spacy download en # src language
!python -m spacy download de # trg language
!pip install GPUtil # for print GPU usage

In [None]:
import copy
import math
import time
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import torchtext

In [None]:
print(torch.backends.cudnn.enabled)
print(torch.cuda.is_available())

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

# Model

In [None]:
class Transformer(nn.Module):
    
    def __init__(self, src_embed, trg_embed, encoder, decoder, fc_layer):
        super(Transformer, self).__init__()
        self.src_embed = src_embed
        self.trg_embed = trg_embed
        self.encoder = encoder
        self.decoder = decoder
        self.fc_layer = fc_layer
        
    def forward(self, src, trg, src_mask, trg_mask):
        out = self.encode(src, src_mask)
        out = self.decode(trg, trg_mask, out, src_mask) # Decoder's src: Encoder's output
        out = self.fc_layer(out)
        out = F.log_softmax(out, dim=-1)
        return out
    
    def encode(self, x, mask):
        out = self.encoder(self.src_embed(x), mask)
        return out
    
    def decode(self, x, mask, encoder_output, encoder_mask):
        out = self.decoder(self.trg_embed(x), mask, encoder_output, encoder_mask)
        return out

In [None]:
class Encoder(nn.Module):
    
    def __init__(self, sub_layer, n_layer):
        super(Encoder, self).__init__()
        self.layers = []
        for i in range(n_layer):
            self.layers.append(copy.deepcopy(sub_layer))
    
    def forward(self, x, mask):
        out = x
        for layer in self.layers: 
            out = layer(out, mask)
        return out

In [None]:
class Decoder(nn.Module):
    
    def __init__(self, sub_layer, n_layer):
        super(Decoder, self).__init__()
        self.layers = []
        for i in range(n_layer):
            self.layers.append(copy.deepcopy(sub_layer))
    
    def forward(self, x, mask, encoder_output, encoder_mask):
        out = x
        for layer in self.layers: 
            out = layer(x, mask, encoder_output, encoder_mask)
        return out

In [None]:
class EncoderLayer(nn.Module):
    
    def __init__(self, multi_head_attention_layer, position_wise_feed_forward_layer, norm_layer, dropout_rate):
        super(EncoderLayer, self).__init__()
        self.multi_head_attention_layer = ResidualConnectionLayer(multi_head_attention_layer, copy.deepcopy(norm_layer), dropout_rate)
        self.position_wise_feed_forward_layer = ResidualConnectionLayer(position_wise_feed_forward_layer, copy.deepcopy(norm_layer), dropout_rate)
    
    def forward(self, x, mask):
        out = self.multi_head_attention_layer(query=x, key=x, value=x, mask=mask)
        out = self.position_wise_feed_forward_layer(x=out)
        return out

In [None]:
class DecoderLayer(nn.Module):
    
    def __init__(self, masked_multi_head_attention_layer, multi_head_attention_layer, position_wise_feed_forward_layer, norm_layer, dropout_rate):
        super(DecoderLayer, self).__init__()
        self.masked_multi_head_attention_layer = ResidualConnectionLayer(multi_head_attention_layer, copy.deepcopy(norm_layer), dropout_rate)
        self.multi_head_attention_layer = ResidualConnectionLayer(masked_multi_head_attention_layer, copy.deepcopy(norm_layer), dropout_rate)
        self.position_wise_feed_forward_layer = ResidualConnectionLayer(position_wise_feed_forward_layer, copy.deepcopy(norm_layer), dropout_rate)
    
    def forward(self, x, mask, encoder_output, encoder_mask):
        out = self.masked_multi_head_attention_layer(query=x, key=x, value=x, mask=mask)
        out = self.multi_head_attention_layer(query=out, key=encoder_output, value=encoder_output, mask=encoder_mask)
        out = self.position_wise_feed_forward_layer(x=out)
        return out

In [None]:
class MultiHeadAttentionLayer(nn.Module):
    
    def __init__(self, d_model, n_head, qkv_fc_layer, fc_layer, dropout_rate):
        super(MultiHeadAttentionLayer, self).__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.query_fc_layer = copy.deepcopy(qkv_fc_layer)
        self.key_fc_layer = copy.deepcopy(qkv_fc_layer)
        self.value_fc_layer = copy.deepcopy(qkv_fc_layer)
        self.fc_layer = fc_layer
        self.dropout = nn.Dropout(dropout_rate)
    
    def forward(self, query, key, value, mask=None):
        # query, key, value's shape: (n_batch, seq_len, d_embed)
        n_batch = query.shape[0]
        
        # reshape (n_batch, seq_len, d_embed) to (n_batch, n_head, seq_len, d_k)
        def transform(x, fc_layer):
            # x's shape: (n_batch, seq_len, d_embed)
            out = fc_layer(x) # d_embed -> d_model, out's shape: (n_batch, seq_len, d_model)
            out = out.view(n_batch, -1, self.n_head, self.d_model//self.n_head) # out's shape: (n_batch, seq_len, n_head, d_k) notice: d_k == d_model//n_head
            out = out.transpose(1, 2) # out's shape: (n_batch, n_head, seq_len, d_k)
            return out
        
        query = transform(query, self.query_fc_layer)      # query, key, value's shape: (n_batch, n_head, seq_len, d_k)
        key = transform(key, self.key_fc_layer)
        value = transform(value, self.value_fc_layer)
        
        if mask is not None:
            mask = mask.unsqueeze(1)
            
        out = self.calculate_attention(query, key, value, mask, self.dropout) # out's shape: (n_batch, n_head, seq_len, d_k)
        out = out.transpose(1, 2)  # out's shape: (n_batch, seq_len, n_head, d_k)
        out = out.contiguous().view(n_batch, -1, self.d_model)  # out's shape: (n_batch, seq_len, d_model)
        out = self.fc_layer(out)  # d-model -> d_embed, out's shape: (n_batch, seq_len, d_embed)
        return out
    
    def calculate_attention(self, query, key, value, mask, dropout=None): 
        d_k = key.size(-1) # query, key, value's shape: (n_batch, n_head, seq_len, d_k)
        score = torch.matmul(query, key.transpose(-2, -1)) # Q x K^T
        score = score / math.sqrt(d_k)  # scaling
        if mask is not None:
            score = score.masked_fill(mask==0, -1e9)  # masking (Decoder's Masked Multi-Attention Layer)
        out = F.softmax(score, dim = -1) # get softmax score
        if dropout is not None:
            out = dropout(out)
        out = torch.matmul(out, value) # score x V
        return out

In [None]:
class PositionWiseFeedForwardLayer(nn.Module):
    def __init__(self, first_fc_layer, second_fc_layer, dropout_rate):
        super(PositionWiseFeedForwardLayer, self).__init__()
        self.first_fc_layer = first_fc_layer
        self.second_fc_layer = second_fc_layer
        self.dropout = nn.Dropout(dropout_rate)
    
    def forward(self, **kwargs):
        x = kwargs['x']
        out = self.first_fc_layer(x)
        out = F.relu(out)
        out = self.dropout(out)
        out = self.second_fc_layer(out)
        return out

In [None]:
class ResidualConnectionLayer(nn.Module):
    def __init__(self, sub_layer, norm_layer, dropout_rate):
        super(ResidualConnectionLayer, self).__init__()
        self.sub_layer = sub_layer
        self.norm_layer = norm_layer
        self.dropout = nn.Dropout(dropout_rate)
    
    def forward(self, **kwargs):
        if 'x' in kwargs.keys():
            x = kwargs['x']
        elif 'query' in kwargs.keys():
            x = kwargs['query']
        out = x + self.dropout(self.sub_layer(**kwargs))
        out = self.norm_layer(out)
        return out

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_embed, dropout_rate, max_seq_len=5000):
        super(PositionalEncoding, self).__init__()
        encoding = torch.zeros(max_seq_len, d_embed)
        position = torch.arange(0, max_seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_embed, 2) * -(math.log(10000.0) / d_embed))
        encoding[:, 0::2] = torch.sin(position * div_term)
        encoding[:, 1::2] = torch.cos(position * div_term)
        encoding = encoding.unsqueeze(0)
        self.encoding = encoding
        self.dropout = nn.Dropout(dropout_rate)
        
    def forward(self, x):
        out = x + Variable(self.encoding[:, :x.size(1)], requires_grad=False).to(device)
        out = self.dropout(out)
        return out

In [None]:
class Embedding(nn.Module):
    def __init__(self, d_embed, vocab):
        super(Embedding, self).__init__()
        self.embedding = nn.Embedding(len(vocab), d_embed)
        self.vocab = vocab
        self.d_embed = d_embed
    
    def forward(self, x):
        out = self.embedding(x) * math.sqrt(self.d_embed)
        return out

In [None]:
class TransformerEmbedding(nn.Module):
    def __init__(self, embedding, positional_encoding):
        super(TransformerEmbedding, self).__init__()
        self.embedding = nn.Sequential(embedding, positional_encoding)
    
    def forward(self, x):
        out = self.embedding(x)
        return out

In [None]:
def make_model(
    src_vocab, 
    trg_vocab, 
    d_embed = 512, 
    n_layer = 6, 
    d_model = 512, 
    n_head = 8, 
    d_ff = 2048,
    dropout_rate = 0.1):

    cp = lambda x : copy.deepcopy(x).to(device)

    multi_head_attention_layer = MultiHeadAttentionLayer(
                                    d_model = d_model,
                                    n_head = n_head,
                                    qkv_fc_layer = nn.Linear(d_embed, d_model),
                                    fc_layer = nn.Linear(d_model, d_embed),
                                    dropout_rate = dropout_rate)
    
    position_wise_feed_forward_layer = PositionWiseFeedForwardLayer(
                                         first_fc_layer = nn.Linear(d_embed, d_ff),
                                         second_fc_layer = nn.Linear(d_ff, d_embed),
                                         dropout_rate = dropout_rate)
    
    norm_layer = nn.LayerNorm(d_embed, eps=1e-6)

    model = Transformer(
                src_embed = TransformerEmbedding(
                                embedding = Embedding(
                                                d_embed = d_embed, 
                                                vocab = src_vocab), 
                                positional_encoding = PositionalEncoding(
                                                d_embed = d_embed,
                                                dropout_rate = dropout_rate)), 
                trg_embed = TransformerEmbedding(
                                embedding = Embedding(
                                                d_embed = d_embed, 
                                                vocab = trg_vocab), 
                                positional_encoding = PositionalEncoding(
                                                d_embed = d_embed,
                                                dropout_rate = dropout_rate)),
                encoder = Encoder(
                                sub_layer = EncoderLayer(
                                                multi_head_attention_layer = cp(multi_head_attention_layer),
                                                position_wise_feed_forward_layer = cp(position_wise_feed_forward_layer),
                                                norm_layer = cp(norm_layer),
                                                dropout_rate = dropout_rate),
                                n_layer = n_layer),
                decoder = Decoder(
                                sub_layer = DecoderLayer(
                                                masked_multi_head_attention_layer = cp(multi_head_attention_layer),
                                                multi_head_attention_layer = cp(multi_head_attention_layer),
                                                position_wise_feed_forward_layer = cp(position_wise_feed_forward_layer),
                                                norm_layer = cp(norm_layer),
                                                dropout_rate = dropout_rate),
                                n_layer = n_layer),
                fc_layer = nn.Linear(d_model, len(trg_vocab)).to(device))
    
    return model

# Vocab

In [None]:
from torchtext.data import Field
import spacy

SRC = Field(tokenize = 'spacy',
            tokenizer_language='en',
            init_token = '<sos>',
            pad_token = '<pad>',
            eos_token = '<eos>',
            unk_token = '<unk>',
            lower=True
            )

TRG = Field(tokenize = 'spacy',
            tokenizer_language='de',
            init_token = '<sos>',
            pad_token = '<pad>',
            eos_token = '<eos>',
            unk_token = '<unk>',
            lower=True
            )  

In [None]:
import pickle

def save_vocab(vocab, path):
    with open(path, 'wb') as f:
        pickle.dump(vocab, f)

def load_vocab(path):
    with open(path, 'rb') as f:
        vocab = pickle.load(f)
    return vocab

In [None]:
SRC.vocab = load_vocab('wmt14/src.vocab')
TRG.vocab = load_vocab('wmt14/trg.vocab')

# Batch & Traniner

In [None]:
class Batch:
    
    "Object for holding a batch of data with masking during training." 
    def __init__(self, src, trg=None, pad=1):
        self.src = src.T
        self.src_mask = (self.src != pad).unsqueeze(-2)  # source mask, <pad>: False, other tokens: True
        if trg is not None:
            self.trg = trg.T[:, :-1]  # target sentence 0 ~ -1
            self.trg_y = trg.T[:, 1:]  # target sentence 1 ~ end
            self.trg_mask = self.make_std_mask(self.trg, pad) # target mask
            self.lengths = torch.sum((self.trg_y != pad), dim=1)
            self.ntokens = (self.trg_y != pad).data.sum() # number of tokens
    
    def __len__(self):
        return len(self.src)

    def subsequent_mask(self, size):
        "Mask out subsequent positions."
        attn_shape = (1, size, size)
        mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')  # masking with upper triangle
        return torch.from_numpy(mask) == 0 # reverse (masking=False, non-masking=True)
    
    def make_std_mask(self, tgt, pad):
        "Create a mask to hide padding and future words."
        tgt_mask = (tgt != pad).unsqueeze(-2) # <pad>: False, other tokens: True, reshape (batch_size, seq_len) -> (batch_size, 1, seq_len)
        tgt_mask = tgt_mask & Variable(self.subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)) # not <pad> && non-masking: True, others: False
        return tgt_mask

In [None]:
import json
import GPUtil
from torchtext import data
from torchtext.data import Dataset, BucketIterator
import nltk.translate.bleu_score as bleu

class Trainer:
    def __init__(self,
                 device,
                 model,
                 src_field,
                 trg_field,
                 batch_size,
                 criterion,
                 optimizer,
                 train_dataset_path,
                 valid_dataset_path,
                 test_dataset_path,
                 check_point_path='./'):
       self.model = model
       self.device = device
       self.batch_size = batch_size
       self.criterion = criterion
       self.optimizer = optimizer
       self.src_field = src_field
       self.trg_field = trg_field
       self.src_vocab = self.src_field.vocab
       self.trg_vocab = self.trg_field.vocab
       self.train_dataset_path = [train_dataset_path] if type(train_dataset_path) is str else train_dataset_path
       self.valid_dataset_path = valid_dataset_path
       self.test_dataset_path = test_dataset_path
       self.check_point_path = check_point_path
       self.now_epoch = 1
       self.now_train_split_num = 1
       self.train_split_loss = 0
       self.train_split_bleu_score = 0
       self.train_loss = []
       self.train_bleu_score = []
       self.valid_loss = []
       self.valid_bleu_score = []
       self.test_loss = 0
       self.test_bleu_score = 0
       
    def itos(self, field, batch):  # batch에서 원본 sentence 얻는 함수
        with torch.cuda.device_of(batch):
            batch = batch.tolist()
        batch = [[field.vocab.itos[ind] for ind in ex] for ex in batch]  # denumericalize
        
        def trim(s, t):  # 현재 token ~ <EOS> token 사이의 문장 return
            sentence = []
            for w in s:
                if w == t:
                    break
                sentence.append(w)
            return sentence

        batch = [trim(ex, field.eos_token) for ex in batch]  # batch를 문장으로 
        
        def filter_special(tok):
            return tok not in (field.init_token, field.pad_token)

        batch = [' '.join(list(filter(filter_special, ex))) for ex in batch]
        return batch
    
    def load_dataset(self, filename, print_time=False):
        def load_examples(filename):
            examples = []
            if print_time:
                start = time.time()
            with open(filename, 'r') as f:
                # Read num. elements (not really need it)
                total = json.loads(f.readline())
                # Save elements
                for i in range(total):
                    line = f.readline()
                    example = json.loads(line)
                    # example = data.Example().fromlist(example, fields)  # Create Example obj. (you can do it here or later)
                    examples.append(example)
            if print_time:
                end = time.time()
                print(end - start)
            return examples

        fields = [('src', self.src_field), ('trg', self.trg_field)]
        data_list = load_examples(filename)
        data_list = [data.Example().fromlist(d, fields) for d in data_list]
        return Dataset(data_list, fields)

    def save_status(self):
        epoch = self.now_epoch
        train_split_num = self.now_train_split_num-1
        if self.now_train_split_num == 1:
            epoch -= 1
            train_split_num = len(self.train_dataset_path)
        check_point_path = "%s/%d_%d.pt" % (self.check_point_path, epoch, train_split_num)
        torch.save({
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'now_epoch': self.now_epoch,
            'now_train_split_num': self.now_train_split_num,
            'train_loss': self.train_loss,
            'train_bleu_score': self.train_bleu_score,
            'valid_loss': self.valid_loss,
            'valid_bleu_score': self.valid_bleu_score
        }, check_point_path)
        print("%s is saved" % check_point_path)

    def load_status(self, check_point_path):
        check_point = torch.load(check_point_path)
        self.model.load_state_dict(check_point['model'])
        self.optimizer.load_state_dict(check_point['optimizer'])
        self.now_epoch = check_point['now_epoch']
        self.now_train_split_num = check_point['now_train_split_num']
        self.train_loss = check_point['train_loss']
        self.train_bleu_score = check_point['train_bleu_score']
        self.valid_loss = check_point['valid_loss']
        self.valid_bleu_score = check_point['valid_bleu_score']
        print("%s is loaded" % check_point_path)
    
    def get_batch_bleu_score(self, pred, label):
        assert len(pred) == len(label) # same batch size
        score = 0
        cnt = 0
        for (x, y) in zip(pred, label):
            if len(x) > 1 and len(y) > 1:
                score += bleu.sentence_bleu([y.split()], x.split(), smoothing_function=bleu.SmoothingFunction().method3)
                cnt += 1
        score /= cnt
        return score
    
    def train(self, n_epoch, valid=True, save=False, log_interval=200):
        for epoch in range(self.now_epoch, self.now_epoch+n_epoch):
            self.model.train()
            print("\n\n******[Train Epoch: %d Start]*****\n\n" % (epoch))
            total_loss = self.train_split_loss
            total_bleu_score = self.train_split_bleu_score
            pad_index = self.trg_vocab['<pad>']
            batch_cnt = 0

            for idx in range(self.now_train_split_num, len(self.train_dataset_path)+1):
                file_path = self.train_dataset_path[idx-1]
                train_dataset = self.load_dataset(file_path)
                train_iterator = BucketIterator(train_dataset, batch_size=self.batch_size, device=self.device)
                print('\n[%s is loaded]\n' % (file_path))

                for i, batch_without_mask in enumerate(train_iterator):
                    batch_cnt += 1
                    # mask 적용

                    batch = Batch(batch_without_mask.src, batch_without_mask.trg, pad_index)

                    batch.src = batch.src.to(device)
                    batch.trg = batch.trg.to(device)
                    batch.src_mask = batch.src_mask.to(device)
                    batch.trg_mask = batch.trg_mask.to(device)

                    out = self.model.forward(batch.src, batch.trg, batch.src_mask, batch.trg_mask)

                    loss = self.criterion(out.contiguous().transpose(-2, -1), batch.trg_y.contiguous())

                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()
                    total_loss += float(loss)

                    out_sentences = self.itos(self.trg_field, torch.argmax(F.log_softmax(out, dim=-1),dim=-1))
                    label_sentences = self.itos(self.trg_field, batch.trg_y)
                    total_bleu_score += float(self.get_batch_bleu_score(out_sentences, label_sentences))

                    if batch_cnt % log_interval == 0:
                        print("[%d'th Batch]\nLoss: %f\tBLEU Score: %f" % (batch_cnt, total_loss / batch_cnt, total_bleu_score / batch_cnt))
                        GPUtil.showUtilization()
                        print()
                self.now_train_split_num += 1
                self.train_split_loss = total_loss
                self.train_split_bleu_score = total_bleu_score
                if self.now_train_split_num <= len(self.train_dataset_path):
                    if save:
                        self.save_status()

            loss_avr = total_loss / batch_cnt
            bleu_avr = total_bleu_score / batch_cnt
            self.train_loss.append(loss_avr)
            self.train_bleu_score.append(bleu_avr)
            print("[Train Epoch: %d]\nLoss: %f\tBLEU Score: %f" % (epoch, loss_avr, bleu_avr))
            if valid:
                del out
                del loss
                del total_loss
                del train_dataset
                del train_iterator
                del out_sentences
                del label_sentences
                del total_bleu_score
                self.evaluation(is_valid=True)
            self.now_epoch += 1
            self.now_train_split_num = 1
            self.train_split_loss = 0
            self.train_split_bleu_score = 0
            if save:
                self.save_status()

        return loss_avr, bleu_avr

    def evaluation(self, is_valid=True, log_interval=200):
        self.model.eval()
        valid_dataset = self.load_dataset(valid_dataset_path, self.src_vocab, self.trg_vocab)
        valid_iterator = BucketIterator(valid_dataset, batch_size = self.batch_size, device = self.device)
        for epoch in range(n_epoch):
            if is_valid:
                print("\n\n******[Validation Start]*****\n\n")
            else:
                print("\n\n******[Test Start]*****\n\n")
            total_loss = 0
            pad_index = self.trg_vocab['<pad>']
            total_bleu_score = 0
            batch_cnt = 0

            file_path = self.validation_dataset_path if is_valid else self.test_dataset_path
            eval_dataset = self.load_dataset(file_path)
            eval_iterator = BucketIterator(eval_dataset, batch_size=self.batch_size, device=self.device)

            for i, batch_without_mask in enumerate(eval_iterator):
                batch_cnt += 1
                # mask 적용

                batch = Batch(batch_without_mask.src, batch_without_mask.trg, pad_index)

                batch.src = batch.src.to(device)
                batch.trg = batch.trg.to(device)
                batch.src_mask = batch.src_mask.to(device)
                batch.trg_mask = batch.trg_mask.to(device)

                out = self.model.forward(batch.src, batch.trg, batch.src_mask, batch.trg_mask)

                loss = self.criterion(out.contiguous().transpose(-2, -1), batch.trg_y.contiguous())

                total_loss += float(loss)

                out_sentences = self.itos(TRG, torch.argmax(F.log_softmax(out, dim=-1),dim=-1))
                label_sentences = self.itos(TRG, batch.trg_y)
                total_bleu_score += float(self.get_batch_bleu_score(out_sentences, label_sentences))

                if batch_cnt % log_interval == 0:
                    print("[%d'st Batch]\nLoss: %f\tBLEU Score: %f" % (batch_cnt, total_loss / batch_cnt, total_bleu_score / batch_cnt))
                    GPUtil.showUtilization()
                    print()

        loss_avr = total_loss / batch_cnt
        bleu_avr = total_bleu_score / batch_cnt
        if is_valid:
            self.valid_loss.append(loss_avr)
            self.valid_bleu_score.append(bleu_avr)
        else:
            self.test_loss = loss_avr
            self.test_bleu_score = bleu_avr
        return loss_avr, bleu_avr

In [None]:
class LabelSmoothing(nn.Module):
    "Implement label smoothing."
    def __init__(self, size, padding_idx, smoothing=0.0):
        super(LabelSmoothing, self).__init__()
        self.criterion = nn.KLDivLoss(reduction='sum')
        self.padding_idx = padding_idx
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.size = size # vocab size
        self.true_dist = None
        
    def forward(self, x, target):
        # x: (n_batch, n_vocab, seq), target: (n_batch, seq)
        x = x.cpu()
        target = target.cpu()
        assert x.size(1) == self.size # vocab size correct
        true_dist = x.data.clone() # (n_batch, n_vocab, seq)
        true_dist.fill_(self.smoothing / (self.size - 2)) # fill with smoothing (2: except correct, pad)
        confidence = x.data.clone()
        confidence = confidence.fill_(self.confidence)
        true_dist.scatter_(dim=1, index=target.data.unsqueeze(1), src=confidence) # scatter confidence with target(per dim=1)
        true_dist[:, self.padding_idx] = 0 # pad index's prob: 0
        mask = torch.BoolTensor(target.data != self.padding_idx) # padding_index: False, else: True
        true_dist = true_dist * mask.unsqueeze(1) # pad masking to true_dist
        self.true_dist = true_dist
        avr_loss = self.criterion(x, Variable(true_dist, requires_grad=False)) / float(mask.data.sum())
        return avr_loss

# Train

In [None]:
model = make_model(SRC.vocab, TRG.vocab, d_embed=512, n_layer=6, d_model=512, n_head=8, d_ff=2048)
model.to(device)
print(model)

In [None]:
batch_size = 32
learning_rate  = 0.01
optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.98), eps=1e-9)
criterion = LabelSmoothing(size=len(TRG.vocab), padding_idx=TRG.vocab['<pad>'], smoothing=0.1).to(device)

dataset_dir_path = './wmt14/'

train_dataset_path = ['%strain_%02d.json' % (dataset_dir_path, i+1) for i in range(9)]

trainer = Trainer(device = device,
                 model = model,
                 src_field = SRC,
                 trg_field = TRG,
                 batch_size = batch_size,
                 criterion = criterion,
                 optimizer = optimizer,
                 train_dataset_path = train_dataset_path,
                 valid_dataset_path = dataset_dir_path + 'valid.json',
                 test_dataset_path = dataset_dir_path + 'test.json',
                 check_point_path='./save')

In [None]:
import warnings
warnings.filterwarnings('ignore')

trainer.train(n_epoch=5, valid=True, save=True, log_interval=200)