# Encoder-Decoder with Attention

In [1]:
import numpy as np
import chainer
from chainer import cuda, Function, gradient_check, report, training, utils, Variable
from chainer import datasets, iterators, optimizers, serializers
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
from chainer.training import extensions

In [2]:
# 可変長単語id系列を可変長単語ベクトル系列へ写像する関数
def sequence_embed(embed, xs):
    x_len = [len(x) for x in xs]
    x_section = np.cumsum(x_len[:-1])
    ex = embed(F.concat(xs, axis=0))
    exs = F.split_axis(ex, x_section, 0)
    return exs

In [354]:
class Encoder(Chain):
    def __init__(self, n_layers, n_source_vocab, n_units, dropout):
        super(Encoder, self).__init__()
        with self.init_scope():
            self.source_embed = L.EmbedID(in_size=n_source_vocab, out_size=n_units)
            self.encoder_lstm = L.NStepBiLSTM(n_layers=n_layers, in_size=n_units,
                                              out_size=n_units, dropout=dropout)
            self.n_source_vocab = n_source_vocab
            self.n_units = n_units
    
    def __call__(self, source_xs):
        # 単語の系列を単語ベクトルへ
        exs = sequence_embed(self.source_embed, source_xs)
        
        # lstmの初期状態
        hx = None
        cx = None
        
        # lstmで各系列をエンコード
        hy, cy, ys = self.encoder_lstm(hx, cx, exs)
        return hy, cy, ys
        

Encoderで最後のh, y, あと各タイムステップのhが帰ってくる。

Deocodingにおいて、最後の層以外は普通にLSTMしたい。

ひとまず、最後の層のみのDecoder + ATTを書く。

encoderのysと一つ前の隠れ状態からattentionを計算する。

encoderのysは(系列数 , n_units)  

Deocder側ではinput(系列arrayのリスト)を処理していく。  
forで回すことになるだろう。

In [355]:
class Attention(Chain):
    def __init__(self, n_units):
        super(Attention, self).__init__()
        with self.init_scope():
            # Encoder(BiLSTM)の中間層への線形変換
            self.eW = L.Linear(n_units*2)
            
            # 一つ前のdecoder中間層の線形変換
            self.dW = L.Linear(n_units*2)
            
            self.aW = L.Linear(1)
    
    def __call__(self, ehs, dh):
        # (len, n_units)
        encoder_hidden = self.eW(ehs)
        decoder_hidden = F.broadcast_to(self.dW(dh), encoder_hidden.shape)
        attention_hidden = F.tanh(encoder_hidden + decoder_hidden)
        scores = F.softmax(self.aW(attention_hidden), axis=0)
        context = F.matmul(F.transpose(scores), ehs)
        # (1 , n_units*2)
        return context

Decoder  
target言語の系列のリストと、一つ前の隠れ状態、セルを受け取る。  
各タイムステップの隠れ状態を返す。

入力: xs = [np.array(1,3,4...), np.array(...),...], h, c  
出力: os = [(len, n_units), (len, n_units), ...]

h = (1, n_units by 2)
c = (1, n_units by 2)

In [357]:
class Decoder(Chain):
    
    def __init__(self, n_layers, n_target_vocab, n_units, attention, dropout):
        super(Decoder, self).__init__()
        with self.init_scope():
            self.n_target_vocab = n_target_vocab
            self.target_embed = L.EmbedID(n_target_vocab, n_units, dropout)
            self.n_layers = n_layers
            if self.n_layers > 1:
                self.pre_lstm = L.NStepLSTM(self.n_layers -1, n_units, n_units*2, dropout)
            
            self.Att = attention
            self.dropout = dropout
            
            # topのLSTM
            # 中間層がn_units*2なので、それを４倍にする。
            self.lstm_input = L.Linear(n_units * 8)
            self.lstm_previous = L.Linear(n_units * 8)
            self.lstm_context = L.Linear(n_units * 8)
            
        
    def __call__(self, hy, cy, ys, target_xs):
        # attention以外の部分の計算
        batch_size = len(ys)
        
        exs = sequence_embed(self.target_embed, target_xs)
        hy = F.reshape(hy, (self.n_layers, batch_size, -1))
        cy = F.reshape(cy, (self.n_layers, batch_size, -1))
        if self.n_layers > 1:
            unatt_n_layer = self.n_layers - 1
            pre_hy = hy[:unatt_n_layer]
            pre_cy = cy[:unatt_n_layer]
            after_h, after_c, pre_os = self.pre_lstm(pre_hy, pre_cy, exs)
        
        else:
            pre_os = exs
            
        # 最終層の計算
        high_hy = hy[self.n_layers - 1]
        high_cy = cy[self.n_layers - 1]
        os = []
        last_h = []
        last_c = []
        for i, pre_eos in enumerate(pre_os):
            h = F.reshape(high_hy[i], (1,-1))
            c = F.reshape(high_hy[i], (1,-1))
            now_ys = ys[i]
            temp_os = []
            
            pre_eos = F.dropout(pre_eos, self.dropout)
            for x in pre_eos:
                x = F.reshape(x, (1,-1))
                context = self.Att(now_ys, h)
                
                c, h = F.lstm(c, self.lstm_input(x) + self.lstm_previous(h) + self.lstm_context(context))
                temp_os.append(h)
            last_h.append(h)
            last_c.append(c)
            os.append(F.concat(temp_os, axis=0))
        
        last_h = F.reshape(F.concat(last_h, axis=0), (1, batch_size, -1))
        last_c = F.reshape(F.concat(last_c, axis=0), (1, batch_size, -1))
        if self.n_layers > 1:
            ho = F.concat([after_h, last_h], axis=0)
            co = F.concat([after_c, last_c], axis=0)
        else:
            ho = last_h
            co = last_c
            
        
        return ho, co, os

        

        
                

In [353]:
UNK = 0
EOS = 1

In [402]:
class Encoder_Decoder_withAttention(Chain):
    
    def __init__(self, encoder, decoder):
        super(Encoder_Decoder_withAttention, self).__init__()
        with self.init_scope():
            self.encoder = encoder
            self.decoder = decoder
        
            self.W = L.Linear(self.decoder.n_target_vocab)
            
    def __call__(self, xs, ys):
        eos = self.xp.array([EOS], 'i')
        ys_in = [F.concat([eos, y], axis=0) for y in ys]
        ys_out = [F.concat([y, eos], axis=0) for y in ys]
        
        hy, cy, ys = self.encoder(xs)
        _, _, os = self.decoder(hy, cy, ys, ys_in)
        
        # loss calculation
        batch_size = len(xs)
        concat_os = F.concat(os, axis=0)
        concat_ys_out = F.concat(ys_out, axis=0)
        loss = F.sum(F.softmax_cross_entropy(
            self.W(concat_os), concat_ys_out, reduce='no')) / batch_size
        
        return loss
    
    def translate(self, xs, max_length = 100):
        batch_size = len(xs)
        
        with chainer.no_backprop_mode(), chainer.using_config('train', False):
            # Encode
            hy, cy, ys = self.encoder(xs)
            
            # decode時のinput用にbatch_size分のEOS=1を用意
            target_xs = self.xp.full((batch_size,1), 1, 'i')
            result = []
            
            ho = hy
            co = cy
            for i in range(max_length):
                ho, co, os = self.decoder(ho, co, ys, target_xs) 
                concat_os = F.concat(os, axis=0)
                wy = self.W(concat_os)
                target_xs = self.xp.argmax(wy.data, axis=1).astype('i')
                result.append(target_xs)
                target_xs = F.reshape(target_xs, (-1, 1)).data
            
            result = self.xp.stack(result).T
            # Remove EOS tags
            outs = []
            for y in result:
                inds = np.argwhere(y == EOS)
                if len(inds) > 0:
                    y = y[:inds[0, 0]]
                outs.append(y)
        return outs
            
            
            
            

In [403]:
xs = [np.random.randint(0, 50, size=np.random.randint(1, 15, 1)).astype('i')
      for i in range(20)]

In [404]:
xs

[array([ 9, 34, 45, 35, 43,  5, 13, 45, 33], dtype=int32),
 array([21,  8,  8, 12, 42], dtype=int32),
 array([20, 24,  9, 15], dtype=int32),
 array([0], dtype=int32),
 array([38, 20, 12, 28, 32, 36,  5, 40,  0,  8, 42, 23,  3], dtype=int32),
 array([22, 13, 14, 23, 27, 48, 40,  8, 29, 39], dtype=int32),
 array([38, 31, 46, 18, 26, 16, 31, 40, 19,  4,  5, 28, 43], dtype=int32),
 array([ 7, 10,  1, 11, 16, 23, 24], dtype=int32),
 array([ 7, 47, 31,  0, 48, 25,  7, 22, 16, 31, 44, 19, 24, 27], dtype=int32),
 array([44], dtype=int32),
 array([49, 18, 40, 19, 17, 27, 28, 20, 48, 37, 18], dtype=int32),
 array([41, 49], dtype=int32),
 array([36, 32, 32, 28, 48,  1, 25,  6, 37, 41,  3, 32], dtype=int32),
 array([15, 18, 16,  7, 45, 36, 17, 45, 12, 41, 25], dtype=int32),
 array([24, 23, 25,  0, 10], dtype=int32),
 array([ 4, 23, 13, 13], dtype=int32),
 array([ 4, 20, 32, 47, 37, 33, 34], dtype=int32),
 array([24, 45, 49, 20, 49,  5, 45, 17,  9,  6, 21], dtype=int32),
 array([46, 22,  9, 45, 18,

In [405]:
encoder = Encoder(2, 50, 100, 0)

In [406]:
hy, cy, ys = encoder(xs)

In [407]:
att = Attention(100)

In [408]:
decoder = Decoder(n_layers=2, n_target_vocab=100, n_units=100, attention=att, dropout=0)

In [409]:
decoder(hy, cy, ys, xs)[1].shape

(2, 20, 200)

In [410]:
eda = Encoder_Decoder_withAttention(encoder, decoder)

In [413]:
eda.translate(xs)

[array([23, 21, 21,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,
         7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,
         7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,
         7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,
         7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,
         7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7], dtype=int32),
 array([49, 49, 49, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73,
        73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73,
        73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73,
        73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73,
        73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73,
        73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73, 73], dtype=int32),
 array([17,  8, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30