# Pytorchを使ってSeq2Seqを実装する

In [8]:
from pathlib import Path
import time
import math
import pandas as pd
import numpy as np
from torch import optim
import random
import requests
import zipfile
import urllib
import codecs
import re
import neologdn
import unicodedata
import MeCab
from collections import Counter
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torchtext.vocab import Vocab
from load_data import *

## モデルの設定

In [9]:
class Encoder(nn.Module):
    
    def __init__(
        self, vocab_size_src, embedding_dim, hidden_dim_enc, dropout, PAD_IDX
    ):
        
        super().__init__()

        self.vocab_size_src = vocab_size_src
        
        self.embedding = nn.Embedding(vocab_size_src, embedding_dim, padding_idx=PAD_IDX)
        self.dropout = nn.Dropout(dropout)
        self.rnn = nn.LSTM(embedding_dim, hidden_dim_enc)

    def forward(self, src):

        embedded = self.dropout(self.embedding(src))
        _, state = self.rnn(embedded)
        
        return state

In [10]:
class Decoder(nn.Module):
    
    def __init__(
        self, vocab_size_trg, embedding_dim, hidden_dim_dec, dropout, PAD_IDX
    ):
        
        super().__init__()

        self.vocab_size_trg = vocab_size_trg

        self.embedding = nn.Embedding(vocab_size_trg, embedding_dim, padding_idx=PAD_IDX)
        self.rnn = nn.LSTM(embedding_dim, hidden_dim_dec)
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(hidden_dim_dec, vocab_size_trg)

    def forward(self, x, encoder_state):

        x = x.unsqueeze(0)
        embedded = self.dropout(self.embedding(x))
        output, state = self.rnn(embedded, encoder_state)
        output = self.out(output[0])

        return output, state

In [11]:
class Seq2Seq(nn.Module):
    
    def __init__(self, encoder, decoder):
        
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, src, trg, teacher_forcing_ratio=0.5):

        batch_size = src.shape[1]
        seq_len = trg.shape[0]
        vocab_size_trg = self.decoder.vocab_size_trg
        
        state = self.encoder(src)
        
        outputs = torch.zeros(seq_len, batch_size, vocab_size_trg)
        output = trg[0, :]
        for t in range(1, seq_len):
            output, state = self.decoder(output, state)
            outputs[t] = output
            teacher_force = random.random() < teacher_forcing_ratio
            output = trg[t] if teacher_force else output.max(1)[1]

        return outputs

In [12]:
def init_weights(m):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.02)
        else:
            nn.init.constant_(param.data, 0)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


In [13]:
def train(model, data, optimizer, criterion, clip):

    model.train()

    losses = []
    for src, trg in tqdm(data):

        optimizer.zero_grad()

        output = model(src, trg)

        output = output[1:].view(-1, output.shape[-1])
        trg = trg[1:].view(-1)

        loss = criterion(output, trg)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()

        losses.append(loss.item())

    return np.mean(losses)


def evaluate(model, data, criterion):

    model.eval()

    losses = []
    with torch.no_grad():

        for src, trg in data:

            output = model(src, trg, 0)

            output = output[1:].view(-1, output.shape[-1])
            trg = trg[1:].view(-1)

            loss = criterion(output, trg)

            losses.append(loss.item())

    return np.mean(losses)

## SeqSeq2の実行

In [15]:
embedding_dim_enc = 32
embedding_dim_dec = 32
hidden_dim_enc = 64
hidden_dim_dec = 64
dropout_enc = 0.5
dropout_dec = 0.5
vocab_size_src = len(de_vocab)
vocab_size_trg = len(en_vocab)

encoder = Encoder(
    vocab_size_src=vocab_size_src, 
    embedding_dim=embedding_dim_enc, 
    hidden_dim_enc=hidden_dim_enc,
    dropout=dropout_enc,
    PAD_IDX=PAD_IDX
)

decoder = Decoder(
    vocab_size_trg=vocab_size_trg, 
    embedding_dim=embedding_dim_dec,
    hidden_dim_dec=hidden_dim_dec, 
    dropout=dropout_dec,
    PAD_IDX=PAD_IDX
)

model = Seq2Seq(encoder=encoder, decoder=decoder)
model.apply(init_weights)
optimizer = optim.Adam(model.parameters())

In [16]:
epochs = 100
clip = 1
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
patience = 10
best_loss = float('inf')
best_model = None

losses_train = []
losses_valid = []
counter = 0
for epoch in range(epochs):

    start_time = time.time()

    loss_train = train(
        model=model, 
        data=train_iter, 
        optimizer=optimizer, 
        criterion=criterion, 
        clip=clip
    )
    
    elapsed_time = time.time() - start_time
    
    loss_valid = evaluate(
        model=model, 
        data=valid_iter, 
        criterion=criterion
    )
    
    print('[{}/{}] train loss: {:.2f}, valid loss: {:.2f}  [{}{:.0f}s] {}'.format(
        epoch + 1, epochs,
        loss_train, loss_valid,
        str(int(math.floor(elapsed_time / 60))) + 'm' if math.floor(elapsed_time / 60) > 0 else '',
        elapsed_time % 60,
        '**' if best_loss > loss_valid else ''
    ))
    
    if best_loss > loss_valid:
        best_loss = loss_valid
        best_model = model
        counter = 0
        
    if counter > patience:
        break
        
    counter += 1

100%|██████████| 227/227 [09:14<00:00,  2.44s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[1/100] train loss: 6.04, valid loss: 5.35  [9m15s] **


100%|██████████| 227/227 [09:08<00:00,  2.41s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[2/100] train loss: 5.24, valid loss: 5.22  [9m8s] **


100%|██████████| 227/227 [09:03<00:00,  2.39s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[3/100] train loss: 5.09, valid loss: 5.13  [9m3s] **


100%|██████████| 227/227 [09:02<00:00,  2.39s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[4/100] train loss: 4.98, valid loss: 5.09  [9m3s] **


100%|██████████| 227/227 [08:56<00:00,  2.36s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[5/100] train loss: 4.90, valid loss: 5.08  [8m56s] **


100%|██████████| 227/227 [08:50<00:00,  2.34s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[6/100] train loss: 4.79, valid loss: 5.11  [8m51s] 


100%|██████████| 227/227 [08:55<00:00,  2.36s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[7/100] train loss: 4.72, valid loss: 5.04  [8m55s] **


100%|██████████| 227/227 [09:54<00:00,  2.62s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[8/100] train loss: 4.65, valid loss: 5.04  [9m55s] **


100%|██████████| 227/227 [08:38<00:00,  2.28s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[9/100] train loss: 4.60, valid loss: 5.01  [8m39s] **


100%|██████████| 227/227 [08:38<00:00,  2.29s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[10/100] train loss: 4.54, valid loss: 5.00  [8m39s] **


100%|██████████| 227/227 [08:41<00:00,  2.30s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[11/100] train loss: 4.49, valid loss: 4.95  [8m42s] **


100%|██████████| 227/227 [08:42<00:00,  2.30s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[12/100] train loss: 4.42, valid loss: 4.92  [8m43s] **


100%|██████████| 227/227 [08:41<00:00,  2.30s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[13/100] train loss: 4.39, valid loss: 4.92  [8m41s] 


100%|██████████| 227/227 [08:42<00:00,  2.30s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[14/100] train loss: 4.34, valid loss: 4.88  [8m43s] **


100%|██████████| 227/227 [08:44<00:00,  2.31s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[15/100] train loss: 4.28, valid loss: 4.87  [8m44s] **


100%|██████████| 227/227 [08:47<00:00,  2.32s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[16/100] train loss: 4.25, valid loss: 4.88  [8m47s] 


100%|██████████| 227/227 [08:40<00:00,  2.29s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[17/100] train loss: 4.21, valid loss: 4.86  [8m40s] **


100%|██████████| 227/227 [08:38<00:00,  2.28s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[18/100] train loss: 4.20, valid loss: 4.83  [8m38s] **


100%|██████████| 227/227 [08:35<00:00,  2.27s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[19/100] train loss: 4.17, valid loss: 4.86  [8m35s] 


100%|██████████| 227/227 [08:43<00:00,  2.30s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[20/100] train loss: 4.14, valid loss: 4.81  [8m43s] **


100%|██████████| 227/227 [08:41<00:00,  2.30s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[21/100] train loss: 4.12, valid loss: 4.80  [8m41s] **


100%|██████████| 227/227 [08:42<00:00,  2.30s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[22/100] train loss: 4.10, valid loss: 4.81  [8m42s] 


100%|██████████| 227/227 [08:43<00:00,  2.31s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[23/100] train loss: 4.08, valid loss: 4.83  [8m43s] 


100%|██████████| 227/227 [08:43<00:00,  2.31s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[24/100] train loss: 4.05, valid loss: 4.80  [8m44s] 


100%|██████████| 227/227 [08:44<00:00,  2.31s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[25/100] train loss: 4.03, valid loss: 4.76  [8m45s] **


100%|██████████| 227/227 [08:39<00:00,  2.29s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[26/100] train loss: 4.01, valid loss: 4.76  [8m40s] **


100%|██████████| 227/227 [09:03<00:00,  2.39s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[27/100] train loss: 3.98, valid loss: 4.73  [9m4s] **


100%|██████████| 227/227 [08:35<00:00,  2.27s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[28/100] train loss: 3.95, valid loss: 4.71  [8m36s] **


100%|██████████| 227/227 [08:41<00:00,  2.30s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[29/100] train loss: 3.93, valid loss: 4.71  [8m41s] 


100%|██████████| 227/227 [08:39<00:00,  2.29s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[30/100] train loss: 3.90, valid loss: 4.69  [8m39s] **


100%|██████████| 227/227 [08:49<00:00,  2.33s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[31/100] train loss: 3.87, valid loss: 4.67  [8m49s] **


100%|██████████| 227/227 [09:13<00:00,  2.44s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[32/100] train loss: 3.86, valid loss: 4.67  [9m13s] 


100%|██████████| 227/227 [08:56<00:00,  2.36s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[33/100] train loss: 3.82, valid loss: 4.63  [8m56s] **


100%|██████████| 227/227 [09:00<00:00,  2.38s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[34/100] train loss: 3.81, valid loss: 4.60  [9m0s] **


100%|██████████| 227/227 [08:54<00:00,  2.36s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[35/100] train loss: 3.78, valid loss: 4.60  [8m55s] **


100%|██████████| 227/227 [09:49<00:00,  2.60s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[36/100] train loss: 3.75, valid loss: 4.60  [9m50s] **


100%|██████████| 227/227 [09:14<00:00,  2.44s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[37/100] train loss: 3.72, valid loss: 4.57  [9m15s] **


100%|██████████| 227/227 [08:52<00:00,  2.34s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[38/100] train loss: 3.71, valid loss: 4.58  [8m52s] 


100%|██████████| 227/227 [09:02<00:00,  2.39s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[39/100] train loss: 3.68, valid loss: 4.58  [9m2s] 


100%|██████████| 227/227 [09:16<00:00,  2.45s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[40/100] train loss: 3.66, valid loss: 4.57  [9m16s] **


100%|██████████| 227/227 [09:21<00:00,  2.47s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[41/100] train loss: 3.66, valid loss: 4.54  [9m21s] **


100%|██████████| 227/227 [09:28<00:00,  2.50s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[42/100] train loss: 3.62, valid loss: 4.53  [9m28s] **


100%|██████████| 227/227 [09:41<00:00,  2.56s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[43/100] train loss: 3.59, valid loss: 4.54  [9m41s] 


100%|██████████| 227/227 [08:57<00:00,  2.37s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[44/100] train loss: 3.59, valid loss: 4.51  [8m57s] **


100%|██████████| 227/227 [09:04<00:00,  2.40s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[45/100] train loss: 3.58, valid loss: 4.46  [9m4s] **


100%|██████████| 227/227 [09:08<00:00,  2.42s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[46/100] train loss: 3.54, valid loss: 4.49  [9m9s] 


100%|██████████| 227/227 [09:20<00:00,  2.47s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[47/100] train loss: 3.55, valid loss: 4.49  [9m21s] 


100%|██████████| 227/227 [09:16<00:00,  2.45s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[48/100] train loss: 3.49, valid loss: 4.48  [9m17s] 


100%|██████████| 227/227 [09:44<00:00,  2.57s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[49/100] train loss: 3.49, valid loss: 4.48  [9m44s] 


100%|██████████| 227/227 [09:37<00:00,  2.54s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[50/100] train loss: 3.48, valid loss: 4.44  [9m37s] **


100%|██████████| 227/227 [09:50<00:00,  2.60s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[51/100] train loss: 3.46, valid loss: 4.45  [9m50s] 


100%|██████████| 227/227 [09:50<00:00,  2.60s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[52/100] train loss: 3.45, valid loss: 4.45  [9m51s] 


100%|██████████| 227/227 [10:00<00:00,  2.65s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[53/100] train loss: 3.42, valid loss: 4.45  [10m1s] 


100%|██████████| 227/227 [10:04<00:00,  2.66s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[54/100] train loss: 3.42, valid loss: 4.42  [10m4s] **


100%|██████████| 227/227 [10:09<00:00,  2.69s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[55/100] train loss: 3.41, valid loss: 4.40  [10m10s] **


100%|██████████| 227/227 [10:00<00:00,  2.65s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[56/100] train loss: 3.37, valid loss: 4.42  [10m1s] 


100%|██████████| 227/227 [10:05<00:00,  2.67s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[57/100] train loss: 3.40, valid loss: 4.39  [10m6s] **


100%|██████████| 227/227 [09:54<00:00,  2.62s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[58/100] train loss: 3.37, valid loss: 4.44  [9m55s] 


100%|██████████| 227/227 [09:51<00:00,  2.61s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[59/100] train loss: 3.36, valid loss: 4.40  [9m52s] 


100%|██████████| 227/227 [09:48<00:00,  2.59s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[60/100] train loss: 3.35, valid loss: 4.42  [9m49s] 


100%|██████████| 227/227 [09:38<00:00,  2.55s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[61/100] train loss: 3.34, valid loss: 4.37  [9m39s] **


100%|██████████| 227/227 [09:30<00:00,  2.51s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[62/100] train loss: 3.35, valid loss: 4.34  [9m31s] **


100%|██████████| 227/227 [09:21<00:00,  2.47s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[63/100] train loss: 3.30, valid loss: 4.40  [9m21s] 


100%|██████████| 227/227 [09:04<00:00,  2.40s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[64/100] train loss: 3.28, valid loss: 4.39  [9m5s] 


100%|██████████| 227/227 [09:10<00:00,  2.42s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[65/100] train loss: 3.30, valid loss: 4.37  [9m10s] 


100%|██████████| 227/227 [09:04<00:00,  2.40s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[66/100] train loss: 3.28, valid loss: 4.38  [9m4s] 


100%|██████████| 227/227 [09:13<00:00,  2.44s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[67/100] train loss: 3.29, valid loss: 4.38  [9m13s] 


100%|██████████| 227/227 [09:02<00:00,  2.39s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[68/100] train loss: 3.25, valid loss: 4.36  [9m2s] 


100%|██████████| 227/227 [08:55<00:00,  2.36s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[69/100] train loss: 3.24, valid loss: 4.35  [8m55s] 


100%|██████████| 227/227 [08:52<00:00,  2.34s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[70/100] train loss: 3.24, valid loss: 4.38  [8m52s] 


100%|██████████| 227/227 [08:56<00:00,  2.36s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[71/100] train loss: 3.24, valid loss: 4.36  [8m56s] 


100%|██████████| 227/227 [08:50<00:00,  2.34s/it]
  0%|          | 0/227 [00:00<?, ?it/s]

[72/100] train loss: 3.23, valid loss: 4.35  [8m50s] 


100%|██████████| 227/227 [08:53<00:00,  2.35s/it]


[73/100] train loss: 3.22, valid loss: 4.38  [8m53s] 


In [33]:
def translate(
    model, text, vocab_src, vocab_trg, tokenizer_src, max_seq_length_trg
):
    
    tokens = [BOS_IDX] + [vocab_src.stoi[tok] for tok in tokenizer_src(text)] + [EOS_IDX]
    num_tokens = len(tokens)
    src = (torch.LongTensor(tokens).reshape(num_tokens, 1) )
    
    model.eval()
    predict_words = greedy_decode(
        model=model, src=src, vocab_trg=vocab_trg,
        max_seq_length_trg=max_seq_length_trg
    )
    return ' '.join(predict_words)


def greedy_decode(model, src, vocab_trg, max_seq_length_trg):
    
    state = model.encoder(src)

    predict_words = []
    output = torch.tensor([BOS_IDX])
    for t in range(1, max_seq_length_trg):
        output, state = model.decoder(output, state)
        output = output.max(1)[1]
        if output == EOS_IDX:
            break
        predict_words.append(vocab_trg.itos[output])

    return predict_words

In [34]:
text = "Eine Gruppe von Menschen steht vor einem Iglu ."
translate(
    model=best_model, text=text, vocab_src=de_vocab,
    vocab_trg=en_vocab, tokenizer_src=de_tokenizer,
    max_seq_length_trg=100
)

'A crowd of a a a . .'

In [37]:
model=best_model
text=text
vocab_src=de_vocab
vocab_trg=en_vocab
tokenizer_src=de_tokenizer

In [38]:
tokens = [BOS_IDX] + [vocab_src.stoi[tok] for tok in tokenizer_src(text)] + [EOS_IDX]
num_tokens = len(tokens)
src = (torch.LongTensor(tokens).reshape(num_tokens, 1) )

In [40]:
model.eval()

Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(19215, 32, padding_idx=1)
    (dropout): Dropout(p=0.5, inplace=False)
    (rnn): LSTM(32, 64)
  )
  (decoder): Decoder(
    (embedding): Embedding(10838, 32, padding_idx=1)
    (rnn): LSTM(32, 64)
    (dropout): Dropout(p=0.5, inplace=False)
    (out): Linear(in_features=64, out_features=10838, bias=True)
  )
)

In [41]:
state = model.encoder(src)

In [43]:
output = torch.tensor([BOS_IDX])

In [45]:
output, state = model.decoder(output, state)

In [49]:
output = output.max(1)[1]

In [51]:
output, state = model.decoder(output, state)

In [56]:
output.max(1)[1]

tensor([94])

In [57]:
en_vocab.itos[94]

'crowd'