<a href="https://colab.research.google.com/github/machine-perception-robotics-group/MPRGDeepLearningLectureNotebook/blob/master/13_rnn/05_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Attention Seq2seq
前セクションでは，Seq2seqについて学びました．本セクションでは，計算機がどの数字，演算記号に着目したかをAttention機構を導入することで調査します．

# Attention機構
Seq2Seqはエンコーダが長期のパターンを学習する際に直近の情報に強く影響されるため，過去の特徴をうまく捉えることが難しいとされています．そこで，各時刻のエンコーダの出力を保持し，デコーダ側へ情報を伝搬するAttention機構を導入することで，この問題を解決します．Attention機構は保持したエンコーダの出力をデコーダの出力に対して重みづけすることで，どの時刻のエンコーダに着目してデコーダが文字等の情報を生成したかを可視化することもできます．

<img src="https://github.com/himidev/Lecture/blob/main/13_rnn/05_Attention/Atten_Seq2seq.png?raw=true" width = 100%>


###データローダの作成
まず，データローダを用意します．データは0から9までの数字と加算記号，開始，終了のフラグです．また，３桁の数字の足し算を行うため，各桁の値を１つずつランダムに生成して連結しています．


In [1]:
import sys
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

word2id = {str(i): i for i in range(10)}
word2id.update({"<pad>": 10, "+": 11, "<eos>": 12})
id2word = {v: k for k, v in word2id.items()}

class CalcDataset(torch.utils.data.Dataset):

    def transform(self, string, seq_len=7):
        tmp = []
        for i, c in enumerate(string):
            try:
                tmp.append(word2id[c])
            except:
                tmp += [word2id["<pad>"]] * (seq_len - i)
                break
        return tmp

    def __init__(self, data_num, train=True):
        super().__init__()
        self.data_num = data_num
        self.train = train
        self.data = []
        self.label = []

        for _ in range(data_num):
            x = int("".join([random.choice(list("0123456789")) for _ in range(random.randint(1, 3))] ))
            y = int("".join([random.choice(list("0123456789")) for _ in range(random.randint(1, 3))] ))
            left = ("{:*<7s}".format(str(x) + "+" + str(y))).replace("*", "<pad>")
            self.data.append(self.transform(left))

            z = x + y
            right = ("{:*<6s}".format(str(z))).replace("*", "<pad>")
            right = self.transform(right, seq_len=5)
            right = [12] + right
            right[right.index(10)] = 12
            self.label.append(right)
        


        self.data = np.asarray(self.data)
        self.label = np.asarray(self.label)

    def __getitem__(self, item):
        d = self.data[item]
        l = self.label[item]
        return d, l

    def __len__(self):
        return self.data.shape[0]

### エンコーダ・デコーダの作成
基本的な構造は前セクションのエンコーダ・デコーダ構造と同様です．ただし，エンコーダは各時刻の出力値を保持しておきます．デコーダでは，保持したエンコーダの出力値とデコーダの出力値で内積計算します．この内積計算によって，各時刻のエンコーダの出力値に重み付けすることができます．これにより，どの時刻のエンコーダの出力に着目したかをデコーダ側が自動で決定することができます．

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


embedding_dim = 16
hidden_dim = 128
vocab_size = len(word2id)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, batch_size=30):
        super(Encoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size

        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=word2id["<pad>"])
        self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True)

    def forward(self, indices):
        embedding = self.word_embeddings(indices)
        if embedding.dim() == 2:
            embedding = torch.unsqueeze(embedding, 1)
        hs, state = self.gru(embedding, torch.zeros(1, self.batch_size, self.hidden_dim, device=device))

        return hs, state


class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, batch_size=100):
        super(Decoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size

        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=word2id["<pad>"])
        self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True)
        self.output = nn.Linear(hidden_dim * 2, vocab_size)

        self.softmax = nn.Softmax(dim=1)

    def forward(self, index, hs, state):
        embedding = self.word_embeddings(index)
        if embedding.dim() == 2:
            embedding = torch.unsqueeze(embedding, 1)
        gruout, state = self.gru(embedding, state)

        t_output = torch.transpose(gruout, 1, 2)
        s = torch.bmm(hs, t_output)
        attention_weight = self.softmax(s)

        c = torch.zeros(self.batch_size, 1, self.hidden_dim, device=device)

        # attention weight
        for i in range(attention_weight.size()[2]):
          unsq_weight = attention_weight[:,:,i].unsqueeze(2)
          weighted_hs = hs * unsq_weight
          weight_sum = torch.sum(weighted_hs, axis=1).unsqueeze(1)
          c = torch.cat([c, weight_sum], dim=1)
        c = c[:,1:,:]
        gruout = torch.cat([gruout, c], dim=2)
        output = self.output(gruout)
        return output, state, attention_weight


encoder = Encoder(vocab_size, embedding_dim, hidden_dim, batch_size=100).to(device)
decoder = Decoder(vocab_size, embedding_dim, hidden_dim, batch_size=100).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=word2id["<pad>"])

# Initialize opotimizers
encoder_optimizer = optim.Adam(encoder.parameters(), lr=0.001)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=0.001)

###学習
前セクション同様の条件で学習を行います．


In [12]:
import numpy as np
from time import time

# GPUの確認
use_cuda = torch.cuda.is_available()
print('Use CUDA:', use_cuda)

batch_size=100
epoch_num = 200

train_data = CalcDataset(data_num = 20000)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

all_losses = []
start = time()
for epoch in range(1, epoch_num+1):
    epoch_loss = 0
    for data, label in train_loader:
        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()

        if use_cuda:
            data = data.cuda()
            label = label.cuda()

        hs, encoder_hidden = encoder(data)
        source = label[:, :-1]
        target = label[:, 1:]
        decoder_hidden = encoder_hidden


        loss = 0
        decoder_output, _, attention_weight = decoder(source, hs, decoder_hidden)
        decoder_output = torch.squeeze(decoder_output)
        for j in range(decoder_output.size()[1]):
            loss += criterion(decoder_output[:, j, :], target[:, j])
        #for i in range(source.size(1)):
        #    decoder_output, decoder_hidden = decoder(source[:, i], hs, decoder_hidden)
        #    decoder_output = torch.squeeze(decoder_output)
        #    loss += criterion(decoder_output, target[:, i])

        # Perform backpropagation
        loss.backward()
        epoch_loss += loss.item()
        
        # Adjust model weights
        encoder_optimizer.step()
        decoder_optimizer.step()

  
    elapsed_time = time() - start
    all_losses.append(epoch_loss)
    if epoch % 10 == 0:
        print("epoch: {}, mean loss: {}, elapsed_time: {}".format(epoch, loss.item(), elapsed_time))
        
model_name = "seq2seq_calculator_v{}.pt".format(epoch)
torch.save({
    'encoder_model': encoder.state_dict(),
    'decoder_model': decoder.state_dict(),
}, model_name)


Use CUDA: True
epoch: 10, mean loss: 3.172910451889038, elapsed_time: 30.177438735961914
epoch: 20, mean loss: 1.3510839939117432, elapsed_time: 60.827641010284424
epoch: 30, mean loss: 0.9135118126869202, elapsed_time: 90.97140645980835
epoch: 40, mean loss: 0.47546422481536865, elapsed_time: 121.44973468780518
epoch: 50, mean loss: 0.40520748496055603, elapsed_time: 151.65692138671875
epoch: 60, mean loss: 0.13845081627368927, elapsed_time: 181.8737678527832
epoch: 70, mean loss: 0.16695518791675568, elapsed_time: 212.7437777519226
epoch: 80, mean loss: 0.24229463934898376, elapsed_time: 242.9985547065735
epoch: 90, mean loss: 0.10309318453073502, elapsed_time: 275.9088475704193
epoch: 100, mean loss: 0.04000218212604523, elapsed_time: 310.8333098888397
epoch: 110, mean loss: 0.010310840792953968, elapsed_time: 341.2810893058777
epoch: 120, mean loss: 0.02676214836537838, elapsed_time: 371.54746437072754
epoch: 130, mean loss: 0.06269810348749161, elapsed_time: 401.7816140651703
epoc

###評価
こちらも前セクション同様に学習モデルを評価します．



In [13]:
batch_size = 1
test_data = CalcDataset(data_num = 2000)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)

encoder = Encoder(vocab_size, embedding_dim, hidden_dim, batch_size=1).to(device)
decoder = Decoder(vocab_size, embedding_dim, hidden_dim, batch_size=1).to(device)

model_name = "seq2seq_calculator_v{}.pt".format(epoch)
checkpoint = torch.load(model_name)
encoder.load_state_dict(checkpoint["encoder_model"])
decoder.load_state_dict(checkpoint["decoder_model"])

accuracy = 0
        
# 評価の実行   
with torch.no_grad():
    for data, label in test_loader:
        if use_cuda:
            data = data.cuda()

        hs, state = encoder(data)

        right = []
        token = "<eos>"
        for _ in range(7):
            index = word2id[token]
            input_tensor = torch.tensor([index], device=device)
            output, state, _ = decoder(input_tensor, hs, state)
            prob = F.softmax(torch.squeeze(output))
            index = torch.argmax(prob.cpu().detach()).item()
            token = id2word[index]
            if token == "<eos>":
                break
            right.append(token)
        right = "".join(right)
        
        x = list(data[0].to('cpu').detach().numpy() )
        try:
            padded_idx_x = x.index(word2id["<pad>"])
        except ValueError:
            padded_idx_x = len(x)
        left = "".join(map(lambda c: str(id2word[c]), x[:padded_idx_x]))

        flag = ["F", "T"][eval(left) == int(right)]
        print("{:>7s} = {:>4s} :{}".format(left, right, flag))
        if flag == "T":
            accuracy += 1
print("Accuracy: {:.2f}".format(accuracy / len(test_loader)))



478+767 = 1335 :F
 35+665 =  700 :T
  599+9 =  608 :T
  8+288 =  296 :T
  263+8 =  271 :T
174+492 =  676 :F
291+777 = 1089 :F
  551+1 =  552 :T
  26+45 =   71 :T
 32+947 =  979 :T
 50+962 = 1012 :T
 26+385 =  411 :T
 716+41 =  757 :T
   4+92 =   96 :T
  5+471 =  476 :T
  85+34 =  119 :T
    5+7 =   12 :T
  573+0 =  573 :T




 893+49 =  942 :T
   3+36 =   39 :T
  854+2 =  856 :T
  6+837 =  843 :T
426+841 = 1287 :F
   4+95 =   99 :T
   96+5 =  101 :T
 219+11 =  230 :T
 166+20 =  176 :F
 71+559 =  630 :T
   69+9 =   78 :T
  12+62 =   74 :T
 685+90 =  785 :F
   96+4 =  100 :T
  108+7 =  115 :T
  19+82 =   91 :F
  31+65 =   96 :T
   0+60 =   60 :T
982+361 = 1333 :F
  833+3 =  836 :T
  934+5 =  938 :F
 55+263 =  318 :T
  789+0 =  790 :F
824+101 = 1094 :F
721+270 = 1001 :F
    7+9 =   16 :T
   5+82 =   87 :T
  68+50 =  118 :T
 62+847 =  908 :F
867+535 = 1412 :F
763+902 = 1665 :T
481+619 = 1090 :F
 645+46 =  691 :T
  30+25 =   55 :T
  932+2 =  934 :T
  178+8 =  186 :T
  495+6 =  501 :T
 932+16 =  948 :T
  4+830 =  834 :T
   1+19 =   20 :T
    0+3 =    3 :T
  796+7 =  803 :T
 58+334 =  392 :T
  9+495 =  504 :T
  885+3 =  888 :T
  864+7 =  871 :T
  3+317 =  310 :F
    1+6 =    7 :T
 60+354 =  414 :T
  6+201 =  207 :T
 25+471 =  496 :T
 97+979 = 1067 :F
    1+2 =    3 :T
  263+7 =  270 :T
   9+47 =   56 :T
690+750 = 

  615+2 =  617 :T
  58+99 =  157 :T
621+728 = 1369 :F
    2+8 =   10 :T
    7+5 =   12 :T
 75+104 =  179 :T
 225+54 =  279 :T
    2+3 =    5 :T
  136+0 =  136 :T
 624+75 =  709 :F
   2+44 =   46 :T
  8+575 =  583 :T
 23+133 =  146 :F
   71+5 =   76 :T
   36+7 =   43 :T
417+312 =  739 :F
   93+0 =   93 :T
    8+3 =   11 :T
336+902 = 1238 :T
  824+3 =  827 :T
 670+47 =  717 :T
  384+2 =  386 :T
   1+19 =   20 :T
910+558 = 1468 :T
   82+8 =   90 :T
    6+2 =    8 :T
 98+896 =  994 :T
 398+70 =  468 :T
   64+7 =   71 :T
    4+9 =   13 :T
  99+18 =  117 :T
   1+63 =   64 :T
   8+67 =   75 :T
374+298 =  842 :F
  1+300 =  301 :T
  5+618 =  623 :T
  93+39 =  132 :T
  5+323 =  328 :T
   0+85 =   85 :T
 745+88 =  833 :T
   52+8 =   60 :T
  76+65 =  141 :T
   1+65 =   66 :T
  78+15 =   93 :T
    6+9 =   15 :T
   0+25 =   25 :T
 16+672 =  688 :T
840+705 = 1555 :F
  6+336 =  342 :T
  10+95 =  105 :T
   66+4 =   70 :T
   87+2 =   89 :T
  22+70 =   92 :T
500+480 =  980 :T
  66+47 =  113 :T
622+886 = 

500+493 =  993 :T
    3+8 =   11 :T
  6+949 =  955 :T
   95+7 =  102 :T
 82+920 = 1012 :F
   9+42 =   51 :T
   93+1 =   94 :T
    8+9 =   17 :T
   62+5 =   67 :T
 423+60 =  473 :F
  401+1 =  402 :T
989+280 = 1269 :T
  833+7 =  840 :T
194+304 =  607 :F
   7+12 =   19 :T
  9+422 =  431 :T
859+392 = 1251 :T
   56+6 =   62 :T
   1+65 =   66 :T
   8+23 =   31 :T
  252+0 =  251 :F
    8+9 =   17 :T
555+564 = 1209 :F
    9+4 =   13 :T
  786+8 =  794 :T
  9+956 =  965 :T
  381+3 =  384 :T
  112+9 =  121 :T
  9+683 =  692 :T
  106+5 =  111 :T
  6+166 =  172 :T
144+314 =  468 :F
    0+2 =    2 :T
   4+99 =  103 :T
    9+7 =   16 :T
 803+97 =  890 :F
   20+6 =   26 :T
 122+46 =  168 :T
   62+6 =   68 :T
  1+920 =  921 :T
  68+72 =  130 :F
    9+3 =   12 :T
  53+21 =   74 :T
 756+48 =  793 :F
   5+54 =   59 :T
  5+573 =  578 :T
    3+8 =   11 :T
  7+719 =  726 :T
444+823 = 1266 :F
  0+912 =  912 :T
443+267 =  790 :F
  531+4 =  535 :T
477+323 =  810 :F
 88+219 =  397 :F
  650+4 =  654 :T
913+289 = 

536+947 = 1473 :F
   9+74 =   83 :T
 70+438 =  498 :F
 657+54 =  711 :T
812+229 = 1041 :T
   49+7 =   56 :T
  6+896 =  902 :T
 473+92 =  565 :T
    8+1 =    9 :T
  67+97 =  164 :T
534+309 =  843 :T
 61+389 =  440 :F
  89+77 =  166 :T
  220+8 =  228 :T
    5+0 =    5 :T
285+293 =  868 :F
  47+41 =   88 :T
  43+35 =   78 :T
    7+0 =    7 :T
 23+632 =  655 :T
  42+39 =   71 :F
    3+1 =    4 :T
676+320 = 1005 :F
885+600 = 1495 :F
  81+19 =  100 :T
  425+0 =  425 :T
  341+4 =  345 :T
  2+547 =  549 :T
 381+29 =  410 :T
  295+1 =  296 :T
 81+181 =  241 :F
 451+29 =  470 :F
    2+0 =    2 :T
  37+69 =  116 :F
  176+2 =  178 :T
 13+332 =  345 :T
   5+53 =   58 :T
  8+542 =  550 :T
   29+9 =   38 :T
  578+8 =  586 :T
    4+6 =   10 :T
  50+98 =  148 :T
 742+68 =  810 :T
  274+3 =  277 :T
  202+6 =  208 :T
  27+38 =   65 :T
   91+4 =   95 :T
 52+482 =  535 :F
  98+78 =  176 :T
  954+9 =  963 :T
  670+5 =  675 :T
  101+9 =  119 :F
  76+46 =  122 :T
   5+40 =   45 :T
   5+72 =   77 :T
  97+77 = 

 49+662 =  711 :T
328+973 = 1201 :F
    1+3 =    4 :T
   1+70 =   71 :T
  87+76 =  163 :T
  94+21 =  115 :T
  44+51 =   95 :T
199+998 = 1118 :F
   3+40 =   43 :T
 408+70 =  478 :T
  66+22 =   88 :T
   51+8 =   59 :T
  566+2 =  568 :T
165+991 = 1106 :F
  9+734 =  743 :T
   1+15 =   16 :T
   52+5 =   57 :T
    0+5 =    5 :T
 27+651 =  678 :T
  301+1 =  302 :T
   47+1 =   48 :T
  752+5 =  757 :T
  2+189 =  191 :T
373+295 =  698 :F
  18+67 =   85 :T
916+576 = 1482 :F
  752+9 =  761 :T
  550+5 =  555 :T
703+749 = 1453 :F
 957+76 = 1033 :T
 40+799 =  839 :T
 880+70 =  940 :F
   28+5 =   33 :T
 93+118 =  211 :T
 362+73 =  435 :T
   51+7 =   58 :T
  650+5 =  655 :T
190+421 =  721 :F
    9+6 =   15 :T
 303+16 =  319 :T
   3+29 =   32 :T
 286+70 =  366 :F
   0+89 =   89 :T
 650+97 =  757 :F
    7+7 =   14 :T
  646+6 =  652 :T
   22+2 =   24 :T
  5+866 =  871 :T
    1+7 =    8 :T
    3+4 =    7 :T
 302+31 =  333 :T
   50+3 =   53 :T
  97+20 =  117 :T
 743+60 =  713 :F
  154+4 =  158 :T
359+113 = 

# Attentionの可視化
Decoder内のAttention weightの可視化をします．Attention weightを見ることで，デコーダがどのエンコーダの入力に着目したかを確認することができます．Attention　weightの可視化にはヒートマップがよく用いられるので，ヒートマップで可視化してみます．ただし，全ての評価サンプルを確認すると時間もかかるので，今回は5サンプルを実行するごとにランダム表示します．ヒートマップは縦軸がエンコーダの入力，横軸がデコーダの出力を表しています．１数字ずつ見たとき、左に並んでいるボックスの色が一番明るいところの文字が最も着目して生成された数値を表しています．プロット毎に数値をランダムにしているので，各自ヒートマップの結果を考察してみてください．

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

batch_size = 1
test_data = CalcDataset(data_num = 2000)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True)

encoder = Encoder(vocab_size, embedding_dim, hidden_dim, batch_size=1).to(device)
decoder = Decoder(vocab_size, embedding_dim, hidden_dim, batch_size=1).to(device)

model_name = "seq2seq_calculator_v{}.pt".format(epoch)
checkpoint = torch.load(model_name)
encoder.load_state_dict(checkpoint["encoder_model"])
decoder.load_state_dict(checkpoint["decoder_model"])

accuracy = 0
        
# 評価の実行   
with torch.no_grad():
    for ind, (data, label) in enumerate(test_loader):
        if use_cuda:
            data = data.cuda()

        hs, state = encoder(data)

        right = []
        Atten = []
        token = "<eos>"
        for _ in range(7):
            index = word2id[token]
            input_tensor = torch.tensor([index], device=device)
            output, state, attention_weight = decoder(input_tensor, hs, state)
            prob = F.softmax(torch.squeeze(output))
            index = torch.argmax(prob.cpu().detach()).item()
            token = id2word[index]
            if token == "<eos>":
                break
            right.append(token)
            Atten.append(attention_weight.cpu().detach().numpy())
        str_right = right
        right = "".join(right)
        
        x = list(data[0].to('cpu').detach().numpy() )
        try:
            padded_idx_x = x.index(word2id["<pad>"])
        except ValueError:
            padded_idx_x = len(x)
        left = "".join(map(lambda c: str(id2word[c]), x[:padded_idx_x]))
        str_left = []
        for s in range(len(x)):
          if str(x[s]) == '11':
            str_left.append('+')
          elif str(x[s]) == '10':
            str_left.append('=')
          else:
            str_left.append(str(x[s]))

        flag = ["F", "T"][eval(left) == int(right)]
        print("{:>7s} = {:>4s} :{}".format(left, right, flag))
        Atten = np.concatenate(Atten, axis=0)
        Atten = Atten[:, :, 0].transpose(1, 0)
        df = pd.DataFrame(Atten, index=str_left, columns=str_right)
        plt.figure(figsize=[12, 8])
        sns.heatmap(df)
        if ind == 4:
          sys.exit()


# 課題
* 四則演算を変えてみよう