#### Positional Encoding
$$PE_{(pos,2i)}=sin(pos/10000^{2i/d_{model}})$$
$$PE_{(pos,2i+1)}=cos(pos/10000^{2i/d_{model}})$$

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self,max_len,dim_model,device):
        super().__init__()
        self.encoding = torch.zeros(max_len,dim_model,device=device,dtype=torch.float32)
        self.encoding.requires_grad=False
        pos = torch.arange(0,max_len,device=device,dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0,dim_model,2,device=device,dtype=torch.float32)*(math.log(10000)/dim_model))
        self.encoding[:,0::2] = torch.sin(pos/div_term)
        self.encoding[:,1::2] = torch.cos(pos/div_term)
    
    def forward(self,x):
        _, seq_len = x.size()
        
        return self.encoding[:seq_len,:]

#### Embedding

In [None]:
class TokenEmbedding(nn.Embedding):
    def __init__(self,num_embeddings,embedding_dim,padding_idx):
        super(TokenEmbedding,self).__init__(num_embeddings,embedding_dim,padding_idx)
#相当于nn.Embedding(num_embeddings=num_embeddings,embedding_dim=embedding_dim,padding_idx=padding_idx,device=device)

### TransformerEmbedding

<img src="./data/Transformer_figure/PosEmb.png" width="500" height="200">

In [None]:
class TransformerEmbedding(nn.Module):
    def __init__(self,max_len,dim_model,vocab_size,padding_idx,drop_prob,device):
        super().__init__()
        self.pos = PositionalEncoding(max_len=max_len,dim_model=dim_model,device=device)
        self.emb = TokenEmbedding(num_embeddings=vocab_size,embedding_dim=dim_model,
                                  padding_idx=padding_idx)
        self.drop = nn.Dropout(p=drop_prob)
    
    def forward(self,x):
        x_pos = self.pos(x)
        x_emb = self.emb(x)
        
        return self.drop(x_pos+x_emb)

##### ScaledDotProductAttention
<img src="./data/Transformer_figure/SDPA.png" width="200" height="200">
$$\mathrm{Attention}(Q,K,V)=\mathrm{softmax}(\frac {QK^{T}} {\sqrt{d_{k}}})V$$

In [None]:
class ScaledDotProductAttention(nn.Module):
    """
    input q,k,v shape=[batch_size,num_heads,seq_len,split_dim_model]
    output v shape=[batch_size,num_heads,seq_len,split_dim_model]
           score shape=[batch_size,num_heads,seq_len,seq_len]
    """
    def __init__(self):
        super().__init__()
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self,q,k,v,mask=None):
        d_k = k.size(-1)
        k_t = k.transpose(2,3)
        score = (q@k_t)/math.sqrt(d_k)
        if mask is not None:
            mask = mask.to(torch.float32)
            score = score.masked_fill(mask==0,-1e6)
        score = self.softmax(score)
        v = score@v
        
        return v,score 

#### MultiheadAttention
<img src="./data/Transformer_figure/MHA.png" width="200" height="200">
$$\mathrm{MultiHead}(Q,K,V)=\mathrm{Concat}(\mathrm{head_1},...,\mathrm{head_n})W^O$$
$$\mathrm{head_i}=\mathrm{Attention}(QW_i^Q,KW_i^K,VW_i^V)$$

In [None]:
class MultiheadAttention(nn.Module):
    def __init__(self,dim_model,num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.attention = ScaledDotProductAttention()
        self.w_q = nn.Linear(dim_model,dim_model)
        self.w_k = nn.Linear(dim_model,dim_model)
        self.w_v = nn.Linear(dim_model,dim_model)
        self.w_o = nn.Linear(dim_model,dim_model)
    
    def forward(self,q,k,v,mask):
        q = self.w_q(q)
        k = self.w_k(k)
        v = self.w_v(v)
        q = self.split(q)
        k = self.split(k)
        v = self.split(v)
        out, attention = self.attention(q,k,v,mask)
        out = self.concat(out)
        out = self.w_o(out)
        
        return out
    
    def split(self,x):
        batch_size, seq_len, dim_model = x.size()
        split_dim_model = dim_model//self.num_heads
        
        return x.reshape(batch_size,self.num_heads,seq_len,split_dim_model)
    
    def concat(self,x):
        batch_size, num_heads, seq_len, split_dim_model = x.size()
        x = x.transpose(1,2)
        
        return x.reshape(batch_size,seq_len,num_heads*split_dim_model)

#### LayerNorm
$$y=\frac{x-\mathrm{E}[x]}{\sqrt{\mathrm{Var}[x]+\epsilon}}*\gamma+\beta$$

In [None]:
class LayerNorm(nn.Module):
    def __init__(self,dim_model,epsilon=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim_model))
        self.beta = nn.Parameter(torch.zeros(dim_model))
        self.epsilon = epsilon
    
    def forward(self,x):
        mean = x.mean(-1,keepdim=True)
        var = x.var(-1,unbiased=False,keepdim=True)
        
        return (x-mean)/torch.sqrt(var+self.epsilon)*self.gamma+self.beta

#### PWFFN
$$\mathrm{FFN}(x)=\mathrm{max}(0,xW_1+b_1)W_2+b_2$$

In [None]:
class PWFFN(nn.Module):
    def __init__(self,dim_model,ffn_hidden,drop_prob=0.1):
        super().__init__()
        self.fc1 = nn.Linear(dim_model,ffn_hidden)
        self.fc2 = nn.Linear(ffn_hidden,dim_model)
        self.relu = nn.ReLU()
        self.drop = nn.Dropout(p=drop_prob)
    
    def forward(self,x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.drop(x)
        x = self.fc2(x)
        
        return x

### TransformerEncoderLayer
<img src="./data/Transformer_figure/EncoderLayer.png" width="200" height="200">
$$\mathrm{LayerNorm}(x+\mathrm{Sublayer}(x))$$

In [None]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self,dim_model,num_heads,ffn_hidden,drop_prob):
        super().__init__()
        self.attention =  MultiheadAttention(dim_model=dim_model,num_heads=num_heads)
        self.pwffn = PWFFN(dim_model=dim_model,ffn_hidden=ffn_hidden,drop_prob=drop_prob)
        self.norm1 = LayerNorm(dim_model=dim_model)
        self.norm2 = LayerNorm(dim_model=dim_model)
        self.drop1 = nn.Dropout(p=drop_prob)
        self.drop2 = nn.Dropout(p=drop_prob)
    
    def forward(self,src,src_mask):
        # layer1
        x1 = src
        q,k,v = src,src,src
        out1 = self.attention(q,k,v,src_mask)
        out1 = self.drop1(out1)
        out1 = self.norm1(x1+out1)
        
        # layer2
        x2 = out1
        out = self.pwffn(x2)
        out = self.drop2(out)
        out = self.norm2(x2+out)
        
        return out

## TransformerEncoder
<img src="./data/Transformer_figure/Encoder.png" width="200" height="200">

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self,max_len,dim_model,num_heads,num_layers,ffn_hidden,vocab_size,padding_idx,drop_prob,device):
        super().__init__()
        self.posemb = TransformerEmbedding(max_len=max_len,dim_model=dim_model,vocab_size=vocab_size,
                                           padding_idx=padding_idx,drop_prob=drop_prob,device=device)
        self.layers = nn.ModuleList([TransformerEncoderLayer(dim_model=dim_model,num_heads=num_heads,
                                     ffn_hidden=ffn_hidden,drop_prob=drop_prob) for _ in range(num_layers)])
        
    def forward(self,src,src_mask):
        src = self.posemb(src)
        for layer in self.layers:
            src = layer(src,src_mask)
        
        return src

### TransformerDecoderLayer
<img src="./data/Transformer_figure/DecoderLayer.png" width="200" height="200">
$$\mathrm{LayerNorm}(x+\mathrm{Sublayer}(x))$$

In [None]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self,dim_model,num_heads,ffn_hidden,drop_prob):
        super().__init__()
        self.attention1 = MultiheadAttention(dim_model=dim_model,num_heads=num_heads)
        self.attention2 = MultiheadAttention(dim_model=dim_model,num_heads=num_heads)
        self.pwffn = PWFFN(dim_model=dim_model,ffn_hidden=ffn_hidden,drop_prob=drop_prob)
        self.norm1 = LayerNorm(dim_model=dim_model)
        self.norm2 = LayerNorm(dim_model=dim_model)
        self.norm3 = LayerNorm(dim_model=dim_model)
        self.drop1 = nn.Dropout(p=drop_prob)
        self.drop2 = nn.Dropout(p=drop_prob)
        self.drop3 = nn.Dropout(p=drop_prob)
        
    def forward(self,enc_src,tgt,src_mask,tgt_mask):
        # layer1
        x1 = tgt
        q,k,v = tgt,tgt,tgt
        out1 = self.attention1(q,k,v,tgt_mask)
        out1 = self.drop1(out1)
        out1 = self.norm1(x1+out1)
        
        # layer2
        x2 = out1
        q = out1
        k,v = enc_src,enc_src
        out2 = self.attention2(q,k,v,src_mask)
        out2 = self.drop2(out2)
        out2 = self.norm2(x2+out2)
        
        # layer3
        x3 = out2
        out = self.pwffn(out2)
        out = self.drop3(out)
        out = self.norm3(x3+out)
        
        return out

## TransformerDecoder
<img src="./data/Transformer_figure/Decoder.png" width="200" height="200">

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self,max_len,dim_model,num_heads,num_layers,ffn_hidden,vocab_size,padding_idx,drop_prob,device):
        super().__init__()
        self.posemb = TransformerEmbedding(max_len=max_len,dim_model=dim_model,vocab_size=vocab_size,
                                           padding_idx=padding_idx,drop_prob=drop_prob,device=device)
        self.layers = nn.ModuleList(TransformerDecoderLayer(dim_model=dim_model,num_heads=num_heads,
                                    ffn_hidden=ffn_hidden,drop_prob=drop_prob) for _ in range(num_layers))
        self.generator = nn.Linear(dim_model,vocab_size)
        
    def forward(self,enc_src,tgt,src_mask,tgt_mask):
        tgt = self.posemb(tgt)
        for layer in self.layers:
            tgt = layer(enc_src,tgt,src_mask,tgt_mask)
        out = self.generator(tgt)
        
        return out

# Transformer
<img src="./data/Transformer_figure/Model.png" width="400" height="200">

In [None]:
class Transformer(nn.Module):
    def __init__(self,max_len,dim_model,num_heads,num_layers,ffn_hidden,src_vocab_size,src_padding_idx,tgt_vocab_size,tgt_padding_idx,drop_prob,device):
        super().__init__()
        self.src_padding_idx = src_padding_idx
        self.tgt_padding_idx = tgt_padding_idx
        self.encoder = TransformerEncoder(max_len=max_len,dim_model=dim_model,num_heads=num_heads,
                       num_layers=num_layers,ffn_hidden=ffn_hidden,vocab_size=src_vocab_size,
                       padding_idx=src_padding_idx,drop_prob=drop_prob,device=device)
        self.decoder = TransformerDecoder(max_len=max_len,dim_model=dim_model,num_heads=num_heads,
                       num_layers=num_layers,ffn_hidden=ffn_hidden,vocab_size=tgt_vocab_size,
                       padding_idx=tgt_padding_idx,drop_prob=drop_prob,device=device)
        
    def forward(self,src,tgt):
        src_mask = self.make_src_mask(src)
        tgt_mask = self.make_tgt_mask(tgt)
        enc_src = self.encoder(src,src_mask)
        out = self.decoder(enc_src,tgt,src_mask,tgt_mask)
        
        return out
    # make mask
    def make_src_mask(self,src):
        src_mask = (src!=self.src_padding_idx).unsqueeze(1).unsqueeze(2)
        
        return src_mask
    
    def make_tgt_mask(self,tgt):
        tgt_pad_mask = (tgt!=self.tgt_padding_idx).unsqueeze(1).unsqueeze(2)
        tgt_seq_len = tgt.size(1)
        tgt_seq_mask = torch.tril(torch.ones(tgt_seq_len,tgt_seq_len)).type(torch.ByteTensor).to(device)
        tgt_mask = tgt_pad_mask & tgt_seq_mask
        
        return tgt_mask

In [None]:
import os
import math
import torch
import torch.nn as nn
import torch.optim as optim

#### test

In [None]:
# make mask
def make_src_mask(src,src_pad_idx):
    src_mask = (src!=src_pad_idx).unsqueeze(1).unsqueeze(2)
    return src_mask
def make_tgt_mask(tgt,tgt_pad_idx):
    tgt_pad_mask = (tgt!=tgt_pad_idx).unsqueeze(1).unsqueeze(2)
    tgt_seq_len = tgt.size(1)
    tgt_seq_mask = torch.tril(torch.ones(tgt_seq_len,tgt_seq_len)).type(torch.ByteTensor).to(device)
    tgt_mask = tgt_pad_mask & tgt_seq_mask
    return tgt_mask

In [None]:
virtual_src = torch.Tensor([
    [9,4,3,2,5,6,1,1],
    [5,3,2,6,8,4,1,1],
    [2,3,4,5,6,7,8,9],
    [5,6,7,4,8,6,2,1]
])
virtual_tgt = torch.Tensor([
    [10,6,4,3,9,7,8,5,1,1],
    [5,7,4,2,8,6,9,3,1,1],
    [3,5,2,6,7,4,11,9,8,10],
    [5,7,2,6,4,9,3,5,6,1]
])
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
virtual_src,virtual_tgt = virtual_src.long().to(device),virtual_tgt.long().to(device)
virtual_src_voc_size = 10
virtual_tgt_voc_size = 11
virtual_batch_size = 4
virtual_dmodel = 6
virtual_num_layers = 3
virtual_nheads = 2
virtual_ffn_hidden = 24
virtual_max_len = 100
virtual_drop_prob=0.1
virtual_src_pad_idx=1
virtual_tgt_pad_idx=1
virtual_tgt_sos_idx=2
#virtual_src_mask = make_src_mask(virtual_src,virtual_src_pad_idx)
#virtual_tgt_mask = make_tgt_mask(virtual_tgt,virtual_tgt_pad_idx)

In [None]:
test = Transformer(virtual_max_len,virtual_dmodel,virtual_nheads,virtual_num_layers,virtual_ffn_hidden,
                  virtual_src_voc_size,virtual_src_pad_idx,virtual_tgt_voc_size,virtual_tgt_pad_idx,
                  virtual_drop_prob,device).to(device)

In [None]:
test(virtual_src,virtual_tgt)

#### test

## data

In [None]:
#config
# model parameter setting
batch_size = 128
max_len = 256
dim_model = 512
num_layers = 6
num_heads = 8
ffn_hidden = 2048
drop_prob = 0.1
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
# optimizer parameter setting
init_lr = 1e-5
factor = 0.9
adam_eps = 5e-9
patience = 10
warmup = 100
epoch = 1000
clip = 1.0
weight_decay = 5e-4
inf = float('inf')


In [None]:
import spacy
import collections
import time
import sacrebleu
from torch.utils.data import Dataset,DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split

In [None]:
raw_data_path = '../data/fra-eng/fra.txt'

In [None]:
with open(raw_data_path,'r',encoding='utf-8')as f:
    raw_data = f.readlines()

def get_char_set(raw_data):
    src_char_list,tgt_char_list=[],[]
    for item in raw_data:
        item_list = item.strip('\n').split('\t')
        src_char_list+=list(item_list[0])
        tgt_char_list+=list(item_list[1])
        
    return set(src_char_list),set(tgt_char_list)    

get_char_set(raw_data)
#'\xa0','\xad','\u2009','\u200b','\u202f'

In [None]:
raw_train,raw_test = train_test_split(raw_data,test_size=0.1,train_size=0.9,random_state=42,shuffle=True)

In [None]:
class Tokenizer():
    def __init__(self,token_model):
        super().__init__()
        self.tokenizer = spacy.load(token_model)
    def tokenize(self,x):
        return [token.text for token in self.tokenizer.tokenizer(x)]

In [None]:
class Vocab():
    def __init__(self,data_list,token_model,special_tokens,min_freq):
        super().__init__()
        self.tokenizer = Tokenizer(token_model)
        token_list = [self.tokenizer.tokenize(item) for item in data_list]
        token_list1d = [token for item in token_list for token in item]
        counter = collections.Counter(token_list1d)
        sort_counter = sorted(counter.items(),key=lambda x:x[1],reverse=True)
        self.idx2token = special_tokens+[token for token,idx in sort_counter if idx>min_freq]
        self.token2idx = {token:idx for idx, token in enumerate(self.idx2token)}
    
    def __len__(self):
        return len(self.idx2token)
    
    def __getitem__(self,tokens):
        if not isinstance(tokens,(list,tuple)):
            return self.token2idx.get(tokens,self.unk)
        return [self.__getitem__(token) for token in tokens]
        
    def get_idx2token(self,idx):
        if not isinstance(idx,list):
            return self.idx2token[int(idx)]
        return [self.get_idx2token(i) for i in idx]
    
    @property
    def unk(self):
        return self.token2idx.get('<unk>')
        

In [None]:
class EnFrDataset(Dataset):
    def __init__(self,raw_data,src_tokenizer_model,tgt_tokenizer_model,special_tokens,min_freq):
        super().__init__()
        self.src_list, self.tgt_list = self.get_clean_data(raw_data)
        self.src_vocab = Vocab(data_list=self.src_list,token_model=src_tokenizer_model,
                               special_tokens=special_tokens,min_freq=min_freq)
        self.tgt_vocab = Vocab(data_list=self.tgt_list,token_model=tgt_tokenizer_model,
                               special_tokens=special_tokens,min_freq=min_freq)
        self.src_pad_idx = self.src_vocab['<pad>']
        self.tgt_pad_idx = self.tgt_vocab['<pad>']
        self.tgt_sos_idx = self.tgt_vocab['<sos>']
        self.tgt_eos_idx = self.tgt_vocab['<eos>']
        
    def __len__(self):
        assert len(self.src_list)==len(self.tgt_list),'length is not equal!'
        return len(self.src_list)
    
    def __getitem__(self,idx):
        return self.src_list[idx],self.tgt_list[idx]
    
    def collate_fn(self,batch):
        src_idx_list = [torch.tensor(self.src_vocab[self.src_vocab.tokenizer.tokenize(item[0])]) for item in batch]
        tgt_idx_list = [torch.tensor([self.tgt_sos_idx]+self.tgt_vocab[self.tgt_vocab.tokenizer.tokenize(item[1])]
                                     +[self.tgt_eos_idx]) for item in batch]
       
        src_padded = pad_sequence(sequences=src_idx_list,batch_first=True,padding_value=self.src_pad_idx)
        tgt_padded = pad_sequence(sequences=tgt_idx_list,batch_first=True,padding_value=self.tgt_pad_idx)
    
        return src_padded,tgt_padded
    
    def get_vocab(self):
        return self.src_vocab, self.tgt_vocab
        
    def get_clean_data(self, raw_data):
        src_list = [raw_sentence.replace('\u202f', ' ').replace('\u2009', ' ').replace('\u200b', ' ')
                    .replace('\xad', ' ').replace('\xa0', ' ').lower().strip('\n').split('\t')[0] 
                    for raw_sentence in raw_data]
        tgt_list = [raw_sentence.replace('\u202f', ' ').replace('\u2009', ' ').replace('\u200b', ' ')
                    .replace('\xad', ' ').replace('\xa0', ' ').lower().strip('\n').split('\t')[1] 
                    for raw_sentence in raw_data]
        
        return src_list, tgt_list

In [None]:
src_tokenizer_model = 'en_core_web_sm'
tgt_tokenizer_model = 'fr_core_news_sm'
special_tokens = ['<pad>','<unk>','<sos>','<eos>']
min_freq = 2

In [None]:
train_dataset = EnFrDataset(raw_data=raw_train,src_tokenizer_model=src_tokenizer_model,min_freq=min_freq,
                            tgt_tokenizer_model=tgt_tokenizer_model,special_tokens=special_tokens)
test_dataset = EnFrDataset(raw_data=raw_test,src_tokenizer_model=src_tokenizer_model,min_freq=min_freq,
                            tgt_tokenizer_model=tgt_tokenizer_model,special_tokens=special_tokens)

In [None]:
train_dataloader = DataLoader(dataset=train_dataset,batch_size=batch_size,collate_fn=train_dataset.collate_fn,shuffle=True)
test_dataloader = DataLoader(dataset=test_dataset,batch_size=batch_size,collate_fn=test_dataset.collate_fn,shuffle=False)

In [None]:
# data parameters
src_voc,tgt_voc = train_dataset.get_vocab()
src_pad_idx = src_voc['<pad>']
tgt_pad_idx = tgt_voc['<pad>']
tgt_sos_idx = tgt_voc['<sos>']
enc_voc_size, dec_voc_size = len(src_voc),len(tgt_voc)
len_train_dataloader = len(train_dataloader)
len_test_dataloader = len(test_dataloader)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.kaiming_uniform(m.weight.data)


model = Transformer(max_len,dim_model,num_heads,num_layers,ffn_hidden,enc_voc_size,src_pad_idx,dec_voc_size,
                    tgt_pad_idx,drop_prob,device).to(device)

print(f'The model has {count_parameters(model):,} trainable parameters')
model.apply(initialize_weights)
optimizer = optim.Adam(params=model.parameters(),
                 lr=init_lr,
                 weight_decay=weight_decay,
                 eps=adam_eps)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                                                 verbose=True,
                                                 factor=factor,
                                                 patience=patience)

criterion = nn.CrossEntropyLoss(ignore_index=tgt_pad_idx)


def train(model, iterator, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0
    for i, (src, tgt) in enumerate(iterator):
        src,tgt = src.to(device),tgt.to(device)
        optimizer.zero_grad()
        output = model(src, tgt[:, :-1])
        output_reshape = output.contiguous().view(-1, output.shape[-1])
        tgt = tgt[:, 1:].contiguous().view(-1)

        loss = criterion(output_reshape, tgt)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        epoch_loss += loss.item()
        
        if (i+1)%100==0:
            print(f'step : {(i+1)}/{len_train_dataloader}, {round((i / len_train_dataloader) * 100, 2)}% , loss : {loss.item()}')
    return epoch_loss / len_train_dataloader


def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    batch_bleu = []
    with torch.no_grad():
        for i, (src, tgt) in enumerate(iterator):
            src,tgt = src.to(device),tgt.to(device)
            output = model(src, tgt[:, :-1])
            total_bleu = []
            for j in range(tgt.shape[0]):
                tgt_sentence = ''.join(tgt_voc.get_idx2token(tgt[j].tolist()))
                output_sentence = output[j].argmax(1)
                output_sentence =''.join(tgt_voc.get_idx2token(output_sentence.tolist()))
                bleu = sacrebleu.sentence_bleu(output_sentence,[tgt_sentence])
                total_bleu.append(bleu.score)

            total_bleu = sum(total_bleu) / len(total_bleu)
            batch_bleu.append(total_bleu)
            
            output_reshape = output.contiguous().view(-1, output.shape[-1])
            tgt = tgt[:, 1:].contiguous().view(-1)

            loss = criterion(output_reshape, tgt)
            epoch_loss += loss.item()
            
    batch_bleu = sum(batch_bleu) / len(batch_bleu)
    return epoch_loss / len_test_dataloader, batch_bleu

In [None]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
def run(total_epoch, best_loss):
    train_losses, test_losses, bleus = [], [], []
    for step in range(total_epoch):
        start_time = time.time()
        train_loss = train(model, train_dataloader, optimizer, criterion, clip)
        valid_loss, bleu = evaluate(model, test_dataloader, criterion)
        end_time = time.time()

        if step > warmup:
            scheduler.step(valid_loss)

        train_losses.append(train_loss)
        test_losses.append(valid_loss)
        bleus.append(bleu)
        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        if valid_loss < best_loss:
            best_loss = valid_loss
            torch.save(model.state_dict(), f'saved/model_{step+1}-{valid_loss}.pt')

        #f = open('saved/train_loss.txt', 'w')
        #f.write(str(train_losses))
        #f.close()

        #f = open('saved/bleu.txt', 'w')
        #f.write(str(bleus))
        #f.close()

        #f = open('saved/test_loss.txt', 'w')
        #f.write(str(test_losses))
        #f.close()

        print(f'Epoch: {step + 1} | Time: {epoch_mins}m {epoch_secs}s')
        print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
        print(f'\tVal Loss: {valid_loss:.3f} |  Val PPL: {math.exp(valid_loss):7.3f}')
        print(f'\tBLEU Score: {bleu:.3f}')


if __name__ == '__main__':
    run(total_epoch=epoch, best_loss=inf)
