<a href="https://colab.research.google.com/github/komazawa-deep-learning/komazawa-deep-learning.github.io/blob/master/2021notebooks/2021_1008seq2seq_attention_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# -*- coding: utf-8 -*-
# Google Colaboratory では GPU モードで実行してください

---
title: NLP From Scratch: Translation with a sequence to sequence network and attention
original: [Sean Robertson](https://github.com/spro/practical-pytorch)
URL: https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html

---

# Seq2seq モデル による翻訳デモ 注意付きリカレントニューラルネットワーク

<center>
<img src="https://komazawa-deep-learning.github.io/assets/2015Bahdanau_attention.jpg" width="25%">
<img src="https://komazawa-deep-learning.github.io/assets/2015Loung_fig2.svg" width="28%">
<img src="https://komazawa-deep-learning.github.io/assets/2015Loung_fig3.svg" width="28%"><br/>
左: Bahdanau (2014) Fig.1, 中，右: Loung (2015) Fig. 2, Fig. 3
<!--
    右: Itti & Koch (1998) Fig. 1
-->
<!--
<img src="https://komazawa-deep-learning.github.io/assets/1998IttiKoch_fig1.jpg" width="44%"><br/>
-->
</center>                                                                                
<!-- https://github.com/komazawa-deep-learning/komazawa-deep-learning.github.io/blob/master/assets/2015Bahdanau_attention.jpg  -->


Loung (2015) Fig. 2: 大域注意モデル - 各時間ステップ $t$ において， モデルは現在の標的状態 $h_t$ とすべてのソース状態 $h_s$ に基づいて， 可変長の配置重みベクトル $a_t$ を推論する。
大域文脈ベクトル $c_t$ は， 加重平均として計算される。
すべてのソースの状態に対する $a_t$ に基づく加重平均として計算される。

Loung (2015) Fig. 3: 局所的注意モデル - このモデルはまず， 現在の標的単語の単一の整列した位置 $p_t$ を予測する。
次に，ソースポジション $p_t$  を中心としたウィンドウを使用して， ウィンドウ内のソース隠れ状態の加重平均である文脈ベクトル $c_t$ を計算する。
重み $a_t$ は， 現在の標的状態 $h_t$ と窓の中のソースの状態 $h_s$ から推測される。

<center>
<img src="https://komazawa-deep-learning.github.io/assets/2015Greff_LSTM_ja.svg" width="25%"><br/>
</center>

- [オリジナル Sean Robertson](https://github.com/spro/practical-pytorch)

## 0.1 文献

- [Cho et. al., (2014) Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation](https://arxiv.org/abs/1406.1078)
- [Sutskever et. al.,(2014) Sequence to Sequence Learning with Neural Networks](https://arxiv.org/abs/1409.3215)
- [Bahdanau et. al., (2014) Neural Machine Translation by Jointly Learning to Align and Translate](https://arxiv.org/abs/1409.0473)
- [Vinyals (2015) A Neural Conversational Model](https://arxiv.org/abs/1506.05869)
- [Luong et. al (2015) Effective Approaches to Attention-based Neural Machine Translation](https://arxiv.org/bs/1508.04025)

## 0.2 翻訳データセット

- https://tatoeba.org/eng/downloads
- 言語ペアを個々のテキストファイルに分割する作業を以下のサイトで実施してくれている方がいます https://www.manythings.org/anki/


In [None]:
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import sys
import requests
url_base = 'https://komazawa-deep-learning.github.io/2021/'
jpntxt = 'eng-jpn.txt'
chntxt = 'eng-chn.txt'
for _lang in [jpntxt,chntxt]:
    r = requests.get(url_base+_lang, _lang)
    with open('___'+_lang, 'wb') as f:
        f.write(r.content)

!mkdir data        
!wget https://komazawa-deep-learning.github.io/2021/lang_dict.py -O data/lang_dict.py
!wget https://komazawa-deep-learning.github.io/2021/eng-jpn_normalized.txt -O data/eng-jpn_normalized.txt
!wget https://komazawa-deep-learning.github.io/2021/eng-chn_normalized.txt -O data/eng-chn_normalized.txt

%load_ext autoreload
%autoreload 2
import data.lang_dict

C2E_dict = data.lang_dict.C2E_dict
E2C_dict = data.lang_dict.E2C_dict
J2E_dict = data.lang_dict.J2E_dict
E2J_dict = data.lang_dict.E2J_dict
jpn = data.lang_dict.jpn
chn = data.lang_dict.chn
E2J_pairs = data.lang_dict.E2J_pairs
E2C_pairs = data.lang_dict.E2C_pairs
print(E2J_pairs[:5])

MAX_LENGTH = data.lang_dict.MAX_LENGTH
teacher_forcing_ratio = 0.5
SOS_token = 1
EOS_token = 0

In [None]:
class EncoderRNN(nn.Module):
    def __init__(self, n_inp, n_hid, dropout_p=0.1):
        super(EncoderRNN, self).__init__()
        self.n_hid = n_hid
        self.embed = nn.Embedding(n_inp, n_hid)
        #self.rnn = nn.LSTM(n_hid, n_hid)
        self.rnn = nn.GRU(n_hid, n_hid)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, inp, hid):
        embed = self.embed(inp).view(1, 1, -1)
        out = self.dropout(embed)
        #out, (C, hid) = self.rnn(out, hid)
        out, hid = self.rnn(out, hid)
        out = self.dropout(out)
        return out, hid

    def initHidden(self):
        weight = next(self.parameters())
        return weight.new_zeros(1, 1, self.n_hid)
    

In [None]:
class AttnDecoderRNN(nn.Module):
    def __init__(self, n_hid, output_size, dropout_p=0.1, max_length=MAX_LENGTH):
        super(AttnDecoderRNN, self).__init__()
        self.n_hid = n_hid
        self.output_size = output_size
        self.dropout_p = dropout_p
        self.max_length = max_length

        self.embedding = nn.Embedding(self.output_size, self.n_hid)
        self.attn = nn.Linear(self.n_hid * 2, self.max_length)
        self.attn_combine = nn.Linear(self.n_hid * 2, self.n_hid)
        self.dropout = nn.Dropout(self.dropout_p)
        self.gru = nn.GRU(self.n_hid, self.n_hid)
        self.out = nn.Linear(self.n_hid, self.output_size)

    def forward(self, input, hidden, encoder_outputs):
        embedded = self.embedding(input).view(1, 1, -1)
        embedded = self.dropout(embedded)

        attn_weights = F.softmax(
            self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
        attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                                 encoder_outputs.unsqueeze(0))

        output = torch.cat((embedded[0], attn_applied[0]), 1)
        output = self.attn_combine(output).unsqueeze(0)

        output = F.relu(output)
        output, hidden = self.gru(output, hidden)

        output = F.log_softmax(self.out(output[0]), dim=1)
        return output, hidden, attn_weights

    def initHidden(self):
        return torch.zeros(1, 1, self.n_hid, device=device)

In [None]:
def indexesFromSentence(lang, sentence):
    return [lang.wrd2idx[word] for word in sentence.split(' ')]


def tensorFromSentence(lang, sentence):
    indexes = indexesFromSentence(lang, sentence)
    indexes.append(EOS_token)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)


def tensorsFromPair(pair):
    input_tensor = tensorFromSentence(E2J_dict, pair[0])
    target_tensor = tensorFromSentence(J2E_dict, pair[1])
    return (input_tensor, target_tensor)

In [None]:
teacher_forcing_ratio = 0.5


def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH):
    encoder_hidden = encoder.initHidden()

    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    input_length = input_tensor.size(0)
    target_length = target_tensor.size(0)

    encoder_outputs = torch.zeros(max_length, encoder.n_hid, device=device)

    loss = 0

    for ei in range(input_length):
        encoder_output, encoder_hidden = encoder(
            input_tensor[ei], encoder_hidden)
        encoder_outputs[ei] = encoder_output[0, 0]

    decoder_input = torch.tensor([[SOS_token]], device=device)

    decoder_hidden = encoder_hidden

    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    if use_teacher_forcing:
        # Teacher forcing: Feed the target as the next input
        for di in range(target_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(
                decoder_input, decoder_hidden, encoder_outputs)
            loss += criterion(decoder_output, target_tensor[di])
            decoder_input = target_tensor[di]  # Teacher forcing

    else:
        # Without teacher forcing: use its own predictions as the next input
        for di in range(target_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(
                decoder_input, decoder_hidden, encoder_outputs)
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.squeeze().detach()  # detach from history as input

            loss += criterion(decoder_output, target_tensor[di])
            if decoder_input.item() == EOS_token:
                break

    loss.backward()

    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.item() / target_length

In [None]:
import time
import math

def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '経過時間:%s (残り時間:%s)' % (asMinutes(s), asMinutes(rs))


In [None]:
def trainIters(encoder, decoder, 
               SRC, DST,
               pairs=E2J_pairs,
               train_f=train,
               n_iters=75000, 
               print_interval=1000, plot_interval=100, 
               lr=0.01):
    """
    引数:
    encoder: nn.Module.RNNmodel
        符号化器。ソース言語
    decoder: nn.Module.RNNmodel
        復号化器。ターゲット言語
    n_iters: int
        反復回数
    """
    start = time.time()

    # 損失関数の値を初期化
    plot_losses = []
    print_loss_total = 0  # printする度にリセット
    plot_loss_total = 0   # plot_everyごとにリセット

    encoder_optimizer = optim.Adam(encoder.parameters(), lr=lr)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=lr)
    #encoder_optimizer = optim.SGD(encoder.parameters(), lr=lr)
    #decoder_optimizer = optim.SGD(decoder.parameters(), lr=lr)
    #training_pairs = [tensorsFromPair(random.choice(pairs), SRC, DST)
    training_pairs = [tensorsFromPair(random.choice(pairs))
                      for i in range(n_iters)]
    criterion = nn.NLLLoss()

    for iter in range(1, n_iters + 1):
        training_pair = training_pairs[iter - 1]
        input_tensor = training_pair[0]
        target_tensor = training_pair[1]

        loss = train(input_tensor, target_tensor, 
                     encoder, decoder, 
                     encoder_optimizer, decoder_optimizer, criterion)
        print_loss_total += loss
        plot_loss_total += loss

        if iter % print_interval == 0:
            print_loss_avg = print_loss_total / print_interval
            print_loss_total = 0
            print(f'平均損失: {print_loss_avg:.3f}',
                  f'(反復回数: {iter:5d} {iter/n_iters * 100}%)',
                  f'{timeSince(start, iter / n_iters)}'
                 )

        if iter % plot_interval == 0:
            plot_loss_avg = plot_loss_total / plot_interval
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0

    showPlot(plot_losses)

In [None]:
import matplotlib.pyplot as plt
# 目盛りの調整のため matplotlib.ticker を使用
import matplotlib.ticker as ticker
!pip install japanize_matplotlib
import japanize_matplotlib
import numpy as np

def showPlot(points):
    plt.figure()
    fig, ax = plt.subplots()
    loc = ticker.MultipleLocator(base=0.25) # 等間隔で tic を設定
    ax.yaxis.set_major_locator(loc)
    plt.plot(points)

In [None]:
def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH, 
             input_lang=E2J_dict, output_lang=J2E_dict):
    with torch.no_grad():
        input_tensor = tensorFromSentence(input_lang, sentence)
        input_length = input_tensor.size()[0]
        encoder_hidden = encoder.initHidden()

        encoder_outputs = torch.zeros(max_length, encoder.n_hid, device=device)

        for ei in range(input_length):
            encoder_output, encoder_hidden = encoder(input_tensor[ei],
                                                     encoder_hidden)
            encoder_outputs[ei] += encoder_output[0, 0]

        decoder_input = torch.tensor([[SOS_token]], device=device)  # SOS

        decoder_hidden = encoder_hidden

        decoded_words = []
        decoder_attentions = torch.zeros(max_length, max_length)

        for di in range(max_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(
                decoder_input, decoder_hidden, encoder_outputs)
            decoder_attentions[di] = decoder_attention.data
            topv, topi = decoder_output.data.topk(1)
            if topi.item() == EOS_token:
                decoded_words.append('<EOS>')
                break
            else:
                decoded_words.append(output_lang.idx2wrd[topi.item()])

            decoder_input = topi.squeeze().detach()

        return decoded_words, decoder_attentions[:di + 1]

In [None]:
def evaluateRandomly(encoder, decoder, pairs=E2J_pairs, n=10, 
                     input_lang=E2J_dict, output_lang=J2E_dict):
    for i in range(n):
        pair = random.choice(pairs)
        print('>', pair[0])
        print('=', pair[1])
        output_words, attentions = evaluate(encoder, decoder, pair[0], input_lang=input_lang)
        output_sentence = ' '.join(output_words)
        print('<', output_sentence)
        print('')

In [None]:
n_hid = 64

encoder = EncoderRNN(E2J_dict.n_words, n_hid).to(device)
attn_decoder = AttnDecoderRNN(n_hid, J2E_dict.n_words, dropout_p=0.1).to(device)

#trainIters(encoder, attn_decoder, 15000, print_every=2500)
_ = trainIters(encoder, attn_decoder, E2J_dict, J2E_dict, pairs=E2J_pairs, n_iters=1000, 
               print_interval=250)

In [None]:
_ = trainIters(encoder, attn_decoder, E2J_dict, J2E_dict, pairs=E2J_pairs, n_iters=20000, print_interval=250)

In [None]:
evaluateRandomly(encoder, attn_decoder, input_lang=E2J_dict, output_lang=J2E_dict)

In [None]:
def showAttention(input_sentence, output_words, attentions):
    # Set up figure with colorbar
    fig = plt.figure()
    ax = fig.add_subplot(111)
    cax = ax.matshow(attentions.numpy(), cmap='bone')
    fig.colorbar(cax)

    # Set up axes
    ax.set_xticklabels([''] + input_sentence.split(' ') +
                       ['<EOS>'], rotation=90)
    ax.set_yticklabels([''] + output_words)

    # Show label at every tick
    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    plt.show()


def evaluateAndShowAttention(SRC, DST, input_sentence, 
                             encoder=encoder,
                             decoder=attn_decoder):
    output_words, attentions = evaluate(
        encoder, decoder, input_sentence)
    print('input =', input_sentence)
    print('output =', ' '.join(output_words))
    showAttention(input_sentence, output_words, attentions)


evaluateAndShowAttention(E2J_dict, J2E_dict, "are you serious ?", 
                         encoder=encoder, decoder=attn_decoder)