# Transformerを使った翻訳

In [3]:
from pathlib import Path
import math
import io
import time
from tqdm import tqdm
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import (
    TransformerEncoder, TransformerDecoder,
    TransformerEncoderLayer, TransformerDecoderLayer
)
from load_data import *

In [3]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## パラメータの設定

In [5]:
torch.manual_seed(0)
torch.use_deterministic_algorithms(True)

model_dir_path = Path('model')
if not model_dir_path.exists():
    model_dir_path.mkdir(parents=True)

## 関数定義

### Transoformerの設定

In [7]:
class Seq2SeqTransformer(nn.Module):
    
    def __init__(
        self, num_encoder_layers: int, num_decoder_layers: int,
        embedding_size: int, vocab_size_src: int, vocab_size_tgt: int,
        dim_feedforward:int = 512, dropout:float = 0.1, nhead:int = 8
    ):
        
        super(Seq2SeqTransformer, self).__init__()

        self.token_embedding_src = TokenEmbedding(vocab_size_src, embedding_size)
        self.positional_encoding = PositionalEncoding(embedding_size, dropout=dropout)
        encoder_layer = TransformerEncoderLayer(
            d_model=embedding_size, nhead=nhead, dim_feedforward=dim_feedforward
        )
        self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        
        self.token_embedding_tgt = TokenEmbedding(vocab_size_tgt, embedding_size)
        decoder_layer = TransformerDecoderLayer(
            d_model=embedding_size, nhead=nhead, dim_feedforward=dim_feedforward
        )
        self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
        
        self.output = nn.Linear(embedding_size, vocab_size_tgt)

    def forward(
        self, src: Tensor, tgt: Tensor,
        mask_src: Tensor, mask_tgt: Tensor,
        padding_mask_src: Tensor, padding_mask_tgt: Tensor,
        memory_key_padding_mask: Tensor
    ):
        
        embedding_src = self.positional_encoding(self.token_embedding_src(src))
        memory = self.transformer_encoder(embedding_src, mask_src, padding_mask_src)
        embedding_tgt = self.positional_encoding(self.token_embedding_tgt(tgt))
        outs = self.transformer_decoder(
            embedding_tgt, memory, mask_tgt, None,
            padding_mask_tgt, memory_key_padding_mask
        )
        return self.output(outs)

    def encode(self, src: Tensor, mask_src: Tensor):
        return self.transformer_encoder(self.positional_encoding(self.token_embedding_src(src)), mask_src)

    def decode(self, tgt: Tensor, memory: Tensor, mask_tgt: Tensor):
        return self.transformer_decoder(self.positional_encoding(self.token_embedding_tgt(tgt)), memory, mask_tgt)

In [8]:
class PositionalEncoding(nn.Module):
    
    def __init__(self, embedding_size: int, dropout: float, maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        
        den = torch.exp(-torch.arange(0, embedding_size, 2) * math.log(10000) / embedding_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        embedding_pos = torch.zeros((maxlen, embedding_size))
        embedding_pos[:, 0::2] = torch.sin(pos * den)
        embedding_pos[:, 1::2] = torch.cos(pos * den)
        embedding_pos = embedding_pos.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('embedding_pos', embedding_pos)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.embedding_pos[: token_embedding.size(0), :])

    
class TokenEmbedding(nn.Module):
    
    def __init__(self, vocab_size, embedding_size):
        
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.embedding_size = embedding_size
        
    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.embedding_size)

In [9]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt, PAD_IDX):
    
    seq_len_src = src.shape[0]
    seq_len_tgt = tgt.shape[0]

    mask_tgt = generate_square_subsequent_mask(seq_len_tgt)
    mask_src = torch.zeros((seq_len_src, seq_len_src), device=device).type(torch.bool)

    padding_mask_src = (src == PAD_IDX).transpose(0, 1)
    padding_mask_tgt = (tgt == PAD_IDX).transpose(0, 1)
    
    return mask_src, mask_tgt, padding_mask_src, padding_mask_tgt

### Transformerの学習

In [10]:
def train(model, data, optimizer, criterion, PAD_IDX):
    
    model.train()
    losses = 0
    for src, tgt in tqdm(data):
        
        src = src.to(device)
        tgt = tgt.to(device)

        input_tgt = tgt[:-1, :]

        mask_src, mask_tgt, padding_mask_src, padding_mask_tgt = create_mask(src, input_tgt, PAD_IDX)

        logits = model(
            src=src, tgt=input_tgt,
            mask_src=mask_src, mask_tgt=mask_tgt,
            padding_mask_src=padding_mask_src, padding_mask_tgt=padding_mask_tgt,
            memory_key_padding_mask=padding_mask_src
        )

        optimizer.zero_grad()

        output_tgt = tgt[1:, :]
        loss = criterion(logits.reshape(-1, logits.shape[-1]), output_tgt.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()
        
    return losses / len(data)

In [11]:
def evaluate(model, data, criterion, PAD_IDX):
    
    model.eval()
    losses = 0
    for src, tgt in data:
        
        src = src.to(device)
        tgt = tgt.to(device)

        input_tgt = tgt[:-1, :]

        mask_src, mask_tgt, padding_mask_src, padding_mask_tgt = create_mask(src, input_tgt, PAD_IDX)

        logits = model(
            src=src, tgt=input_tgt,
            mask_src=mask_src, mask_tgt=mask_tgt,
            padding_mask_src=padding_mask_src, padding_mask_tgt=padding_mask_tgt,
            memory_key_padding_mask=padding_mask_src
        )
        
        output_tgt = tgt[1:, :]
        loss = criterion(logits.reshape(-1, logits.shape[-1]), output_tgt.reshape(-1))
        losses += loss.item()
        
    return losses / len(data)

### 翻訳の実行

In [None]:
def translate(
    model, text, vocab_src, vocab_tgt, tokenizer_src, seq_len_tgt,
    START_IDX, END_IDX
):
    
    model.eval()
    tokens = [START_IDX] + [vocab_src.stoi[token] for token in tokenizer_src(text)] + [END_IDX]
    num_tokens = len(tokens)
    src = torch.LongTensor(tokens).reshape(num_tokens, 1)
    mask_src = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    
    predicts = greedy_decode(
        model=model, src=src,
        mask_src=mask_src, seq_len_tgt=seq_len_tgt,
        START_IDX=START_IDX, END_IDX=END_IDX
    ).flatten()
    
    return ' '.join([vocab_tgt.itos[token] for token in predicts]).replace("<start>", "").replace("<end>", "")


def greedy_decode(model, src, mask_src, seq_len_tgt, START_IDX, END_IDX):
    
    src = src.to(device)
    mask_src = mask_src.to(device)

    memory = model.encode(src, mask_src)
    memory = model.transformer_encoder(model.positional_encoding(model.token_embedding_src(src)), mask_src)
    ys = torch.ones(1, 1).fill_(START_IDX).type(torch.long).to(device)
    
    for i in range(seq_len_tgt - 1):
        
        memory = memory.to(device)
        memory_mask = torch.zeros(ys.shape[0], memory.shape[0]).to(device).type(torch.bool)
        mask_tgt = (generate_square_subsequent_mask(ys.size(0)).type(torch.bool)).to(device)
        
        output = model.decode(ys, memory, mask_tgt)
        output = output.transpose(0, 1)
        output = model.output(output[:, -1])
        _, next_word = torch.max(output, dim = 1)
        next_word = next_word.item()

        ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == END_IDX:
            break
            
    return ys

## Transfomerを実行する

In [14]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

vocab_size_src = len(vocab_src)
vocab_size_tgt = len(vocab_tgt)
embedding_size = 240
nhead = 8
dim_feedforward = 100
num_encoder_layers = 2
num_decoder_layers = 2
dropout = 0.1

model = Seq2SeqTransformer(
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers,
    embedding_size=embedding_size,
    vocab_size_src=vocab_size_src, vocab_size_tgt=vocab_size_tgt,
    dim_feedforward=dim_feedforward,
    dropout=dropout, nhead=nhead
)

for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

model = model.to(device)

criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(model.parameters())

In [16]:
epoch = 100
best_loss = float('Inf')
best_model = None
patience = 10
counter = 0

for loop in range(1, epoch + 1):
    
    start_time = time.time()
    
    loss_train = train(
        model=model, data=train_iter, optimizer=optimizer,
        criterion=criterion, PAD_IDX=PAD_IDX
    )
    
    elapsed_time = time.time() - start_time
    
    loss_valid = evaluate(
        model=model, data=valid_iter, criterion=criterion, PAD_IDX=PAD_IDX
    )
    
    print('[{}/{}] train loss: {:.2f}, valid loss: {:.2f}  [{}{:.0f}s] count: {}, {}'.format(
        loop, epoch,
        loss_train, loss_valid,
        str(int(math.floor(elapsed_time / 60))) + 'm' if math.floor(elapsed_time / 60) > 0 else '',
        elapsed_time % 60,
        counter,
        '**' if best_loss > loss_valid else ''
    ))
    
    if best_loss > loss_valid:
        best_loss = loss_valid
        best_model = model
        counter = 0
        
    if counter > patience:
        break
    
    counter += 1

100%|██████████| 227/227 [08:51<00:00,  2.34s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[1/100] train loss: 4.16, valid loss: 2.78  [8m52s] **


100%|██████████| 227/227 [07:23<00:00,  1.96s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[2/100] train loss: 2.41, valid loss: 2.13  [7m24s] **


100%|██████████| 227/227 [07:17<00:00,  1.93s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[3/100] train loss: 1.75, valid loss: 1.90  [7m17s] **


100%|██████████| 227/227 [06:45<00:00,  1.79s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[4/100] train loss: 1.34, valid loss: 1.88  [6m45s] **


100%|██████████| 227/227 [06:27<00:00,  1.71s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[5/100] train loss: 1.06, valid loss: 1.89  [6m27s] 


100%|██████████| 227/227 [06:26<00:00,  1.70s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[6/100] train loss: 0.86, valid loss: 1.97  [6m27s] 


100%|██████████| 227/227 [06:28<00:00,  1.71s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[7/100] train loss: 0.72, valid loss: 2.02  [6m28s] 


100%|██████████| 227/227 [06:28<00:00,  1.71s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[8/100] train loss: 0.62, valid loss: 2.08  [6m29s] 


100%|██████████| 227/227 [06:45<00:00,  1.79s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[9/100] train loss: 0.54, valid loss: 2.18  [6m46s] 


100%|██████████| 227/227 [09:43<00:00,  2.57s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[10/100] train loss: 0.48, valid loss: 2.25  [9m43s] 


100%|██████████| 227/227 [10:08<00:00,  2.68s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[11/100] train loss: 0.43, valid loss: 2.32  [10m9s] 


100%|██████████| 227/227 [10:04<00:00,  2.66s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[12/100] train loss: 0.39, valid loss: 2.38  [10m4s] 


100%|██████████| 227/227 [09:53<00:00,  2.61s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[13/100] train loss: 0.36, valid loss: 2.45  [9m53s] 


100%|██████████| 227/227 [09:44<00:00,  2.57s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[14/100] train loss: 0.33, valid loss: 2.51  [9m44s] 


100%|██████████| 227/227 [09:57<00:00,  2.63s/it]


[15/100] train loss: 0.31, valid loss: 2.59  [9m58s] 


In [21]:
torch.save(best_model.state_dict(), model_dir_path.joinpath('translation_transfomer.pth'))

## 学習したモデルを使って翻訳をする

In [18]:
seq_len_tgt = max([len(x[1]) for x in train_data])

In [20]:
text = 'Eine Gruppe von Menschen steht vor einem Iglu .'

translate(
    model=best_model, text=text, vocab_src=vocab_src, vocab_tgt=vocab_tgt,
    tokenizer_src=tokenizer_src, seq_len_tgt=seq_len_tgt,
    START_IDX=START_IDX, END_IDX=END_IDX
)

' A group of people stand in front of an igloo . '