# <center>Transformer from scratch２：モデルの学習</center>
## <center>最終更新日(超未完成版)：2021.5/8</center>

自己アテンションに基づくTransformerアーキテクチャをtf.Kerasで実装するノートブックです。同じような内容は

https://www.tensorflow.org/tutorials/text/transformer?hl=ja

にあります。また、スクラッチ実装をしなくても多くのライブラリでは簡単にSelf-Attention層が利用できます。しかしちゃんとTransformerを理解したりTransformerベースのモデルを開発したいなら、スクラッチ実装の経験がないといけないと思いますので、勉強していきましょう！

[参考文献]

https://arxiv.org/pdf/1706.03762.pdf

このノートブックでは学習の実装を扱います。

---


前回のノートブックでモデルの準備ができましたので、いよいよ訓練の実装を行い、本物のデータで訓練してみましょう。

# 0. 準備

## 0.1 実装したモジュール等

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers

# configuration
source_vocab_size = 100000
max_seq_len = 100
d_model = 512
do_rate = 0.2
d_ff = 2048

utils

In [None]:
def positional_encoding_function(t, i, d):
    theta = t / 10000**(2*(i//2)/d) + np.pi/2 * (i%2)
    return np.sin(theta)

def positional_encoding(T, d):
    dd, TT = np.meshgrid(np.arange(d), np.arange(T))
    encoding = positional_encoding_function(TT, dd, d)
    encoding = encoding[np.newaxis,:,:]
    return encoding

layers

In [None]:
class PositionalEncoding(layers.Layer):
    def __init__(self, max_seq_len, d_model):
        super(PositionalEncoding, self).__init__()
        self.d_model = d_model
        self.pos_encoding = positional_encoding(max_seq_len, d_model)

    def call(self, inputs):
        input_dtype = inputs.dtype
        pos_encoding = self.pos_encoding
        pos_encoding = tf.cast(pos_encoding, dtype=input_dtype)
        inputs *= tf.math.sqrt(tf.cast(self.d_model, dtype=input_dtype))

        return inputs + pos_encoding

class MultiHeadSelfAttention(layers.Layer):
    def __init__(self, n_heads, d_model):
        super(MultiHeadSelfAttention, self).__init__()
        self.n_heads = n_heads
        self.d_model = d_model

        assert self.d_model % self.n_heads == 0, 'n_headsはd_modelの因数'

        self.d = self.d_model // self.n_heads

        self.dense_q = layers.Dense(d_model, use_bias=False)
        self.dense_k = layers.Dense(d_model, use_bias=False)
        self.dense_v = layers.Dense(d_model, use_bias=False)
        self.dense_o = layers.Dense(d_model, use_bias=False)

    def call(self, x_q, x_k, x_v, mask):
        
        q = self.dense_q(x_q)
        k = self.dense_k(x_k)
        v = self.dense_v(x_v)

        max_seq_len = tf.shape(k)[1]
        
        q = self.split_to_heads(q, max_seq_len)
        k = self.split_to_heads(k, max_seq_len)
        v = self.split_to_heads(v, max_seq_len)

        logit = tf.matmul(q, k, transpose_b=True)

        d_k = tf.shape(k)[-1]
        k_dtype = k.dtype
        scale = 1/tf.math.sqrt(tf.cast(d_k, k_dtype))
        logit *= scale

        if mask is not None:
            logit += mask*k_dtype.min
        
        attention_weight = tf.nn.softmax(logit, axis=-1)
        attention_output = tf.einsum('nhst,nhtd->nshd', attention_weight, v)
        attention_output = tf.reshape(attention_output, (-1, max_seq_len, self.n_heads*self.d))
        attention_output = self.dense_o(attention_output)

        return attention_output, attention_weight

    def split_to_heads(self, x, max_seq_len):
        x = tf.reshape(x, (-1, max_seq_len, self.n_heads, self.d))
        x = tf.transpose(x, perm=[0, 2, 1, 3])
        return x

class MHSAModule(layers.Layer):
    def __init__(self, source_vocab_size, max_seq_len, d_model, n_heads, do_rate):
        super(MHSAModule, self).__init__()
        self.mhsa = MultiHeadSelfAttention(n_heads, d_model)
        self.dropout = layers.Dropout(do_rate)
        self.add = layers.Add()
        self.norm = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs, training, mask):
        x, att = self.mhsa(inputs, inputs, inputs, mask)
        x = self.dropout(x, training=training)
        x = self.add([x,inputs])
        x = self.norm(x)
        return x, att

class PointWiseFeedForwardModule(layers.Layer):
    def __init__(self, d_model, d_ff, do_rate):
        super(PointWiseFeedForwardModule, self).__init__()
        self.pwff_1 = layers.Dense(d_ff, activation='relu')
        self.pwff_2 = layers.Dense(d_model)
        self.dropout = layers.Dropout(do_rate)
        self.add = layers.Add()
        self.norm = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs, training):
        x = self.pwff_1(inputs)
        x = self.pwff_2(x)
        x = self.dropout(x, training=training)
        x = self.add([x,inputs])
        x = self.norm(x)
        return x

class EncoderModule(layers.Layer):
    def __init__(self, source_vocab_size, max_seq_len, d_model, n_heads, do_rate):
        super(EncoderModule, self).__init__()
        self.mhsa = MHSAModule(source_vocab_size, max_seq_len, d_model, n_heads, do_rate)
        self.pwff = PointWiseFeedForwardModule(d_model, d_ff, do_rate)

    def call(self, inputs, training, mask):
        x, att = self.mhsa(inputs, mask, training)
        x = self.pwff(x, training)
        return x, att

class Encoder(layers.Layer):
    def __init__(self, n_layers, source_vocab_size, max_seq_len, d_model, n_heads, do_rate):
        super(Encoder, self).__init__()
        self.source_vocab_size = source_vocab_size
        self.max_seq_len = max_seq_len
        self.d_model = d_model
        self.n_heads = n_heads
        self.do_rate = do_rate
        self.do = layers.Dropout(do_rate)
        self.embedding = layers.Embedding(source_vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(max_seq_len, d_model)
        self.mhsa_modules = [EncoderModule(source_vocab_size, max_seq_len, d_model, n_heads, do_rate) for _ in range(n_layers)]

    def call(self, inputs, training, mask_enc):
        x = self.embedding(inputs)
        x = self.pos_encoding(x)
        x = self.do(x)
        attention_weights = []
        for module in self.mhsa_modules:
            x, att = module(x, training, mask_enc)
            attention_weights.append(att)
        return x, attention_weights

class MHSAModuleDec(layers.Layer):
    def __init__(self, source_vocab_size, max_seq_len, d_model, n_heads, do_rate):
        super(MHSAModuleDec, self).__init__()
        self.mhsa = MultiHeadSelfAttention(n_heads, d_model)
        self.dropout = layers.Dropout(do_rate)
        self.add = layers.Add()
        self.norm = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs, context, training, mask):
        x, att = self.mhsa(inputs, context, context, mask)
        x = self.dropout(x, training=training)
        x = self.add([x,inputs])
        x = self.norm(x)
        return x, att

class DecoderModule(layers.Layer):
    def __init__(self, source_vocab_size, max_seq_len, d_model, n_heads, do_rate):
        super(DecoderModule, self).__init__()
        self.mhsa1 = MHSAModule(source_vocab_size, max_seq_len, d_model, n_heads, do_rate)
        self.mhsa2 = MHSAModuleDec(source_vocab_size, max_seq_len, d_model, n_heads, do_rate)
        self.pwff = PointWiseFeedForwardModule(d_model, d_ff, do_rate)

    def call(self, inputs, enc_outputs, training, mask, mask_look_ahead):
        x, att1 = self.mhsa1(inputs, training, mask)
        x, att2 = self.mhsa2(inputs, enc_outputs, training, mask_look_ahead)
        x = self.pwff(x, training)
        return x, att1, att2

class Decoder(layers.Layer):
    def __init__(self, n_layers, target_vocab_size, max_seq_len, d_model, n_heads, do_rate):
        super(Decoder, self).__init__()
        self.target_vocab_size = target_vocab_size
        self.max_seq_len = max_seq_len
        self.d_model = d_model
        self.n_heads = n_heads
        self.do_rate = do_rate
        self.do = layers.Dropout(do_rate)
        self.embedding = layers.Embedding(target_vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(max_seq_len, d_model)
        self.mhsa_modules = [DecoderModule(source_vocab_size, max_seq_len, d_model, n_heads, do_rate) for _ in range(n_layers)]

    def call(self, inputs, enc_outputs, training, mask_dec, mask_look_ahead):
        x = self.embedding(inputs)
        x = self.pos_encoding(x)
        x = self.do(x)
        attention_weights = []
        for module in self.mhsa_modules:
            x, att1, att2 = module(x, enc_outputs, training, mask_dec, mask_look_ahead)
            attention_weights += [[att1, att2]]
        return x, attention_weights

model

In [None]:
class Transformer(layers.Layer):
    def __init__(self, n_layers, source_vocab_size, target_vocab_size, max_seq_len, d_model, n_heads, do_rate):
        super(Transformer, self).__init__()
        self.target_vocab_size = target_vocab_size
        self.max_seq_len = max_seq_len
        self.d_model = d_model
        self.n_heads = n_heads
        self.do_rate = do_rate
        self.encoder = Encoder(n_layers, source_vocab_size, max_seq_len, d_model, n_heads, do_rate)
        self.decoder = Decoder(n_layers, target_vocab_size, max_seq_len, d_model, n_heads, do_rate)
        self.classification = layers.Dense(target_vocab_size-1)

    def call(self, enc_inputs, dec_inputs, mask_enc, mask_dec, mask_look_ahead, training=False):

        enc_outputs, att_enc = self.encoder(enc_inputs, training=training, mask_enc=mask_enc)
        dec_outputs, att_dec = self.decoder(dec_inputs, enc_outputs, training=training, mask_dec=mask_dec, mask_look_ahead=mask_look_ahead)
        y = self.classification(dec_outputs)

        return y, att_enc + att_dec

## 0.2 データの準備

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!apt install aptitude
!aptitude install mecab libmecab-dev mecab-ipadic-utf8 git make curl xz-utils file -y
!pip install mecab-python3==0.7

Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following package was automatically installed and is no longer required:
  libnvidia-common-460
Use 'apt autoremove' to remove it.
The following additional packages will be installed:
  aptitude-common libcgi-fast-perl libcgi-pm-perl libclass-accessor-perl
  libcwidget3v5 libencode-locale-perl libfcgi-perl libhtml-parser-perl
  libhtml-tagset-perl libhttp-date-perl libhttp-message-perl libio-html-perl
  libio-string-perl liblwp-mediatypes-perl libparse-debianchangelog-perl
  libsigc++-2.0-0v5 libsub-name-perl libtimedate-perl liburi-perl libxapian30
Suggested packages:
  aptitude-doc-en | aptitude-doc apt-xapian-index debtags tasksel
  libcwidget-dev libdata-dump-perl libhtml-template-perl libxml-simple-perl
  libwww-perl xapian-tools
The following NEW packages will be installed:
  aptitude aptitude-common libcgi-fast-perl libcgi-pm-perl
  libclass-accessor-perl libcwidget3v5 libencode-l

In [None]:
import numpy as np
import MeCab
import re
#from tensorflow.keras.preprocessing.text import Tokenizer
#from tensorflow.keras.utils import to_categorical

num_samples = 10000  # 訓練に使うサンプルの数。この中の１割をvalに使う

text_path = '/content/drive/MyDrive/ml_datasets/jpn-eng/jpn.txt'

with open(text_path, 'r', encoding='utf-8') as f:
    lines = f.read().split('\n')

#input_text_sp = re.split('\s', 'No way!'.replace('!', ''))
#input_text_sp.append('!')

input_texts = []
target_texts = []
input_characters = set()
target_characters = set()

tagger = MeCab.Tagger('-Owakati')

for line in lines[: min(num_samples, len(lines) - 1)]:
    input_text, target_text, _ = line.split('\t')

    last_char = input_text[-1]
    if last_char in {'.', '!', '?'}:
        input_text_replaced = input_text.replace(last_char, '')
        input_text_sp = re.split('\s', input_text_replaced)
        input_text_sp.append(last_char)
    else:
        input_text_sp = re.split('\s', input_text)

    for word in input_text_sp:
        if word not in input_characters:
            input_characters.add(word)
    input_texts.append(input_text_sp)
    
    # '\t'を出力文の開始記号SOS、'\n'を終了記号EOSに使う
    result = tagger.parse(target_text)
    wakachi = result.split()
    wakachi = ['\t'] + wakachi + ['\n']
    for word in wakachi:
        if word not in target_characters:
            target_characters.add(word)
    target_texts.append(wakachi)

input_characters = sorted(list(input_characters))
target_characters = sorted(list(target_characters))

print('num of input characters w/out pad:', len(input_characters))
print('num of output characters w/out pad:', len(target_characters))


# padding記号も含めておく
num_encoder_tokens = len(input_characters) + 1
num_decoder_tokens = len(target_characters) + 1

# ここはハイパーパラメータ化
max_encoder_seq_length = max([len(txt) for txt in input_texts])
max_decoder_seq_length = max([len(txt) for txt in target_texts])

max_seq_len = max(max_encoder_seq_length, max_decoder_seq_length)

# padding記号を0にする
input_token_index = dict(
    [(char, i+1) for i, char in enumerate(input_characters)])
target_token_index = dict(
    [(char, i+1) for i, char in enumerate(target_characters)])
 
encoder_input_data = np.zeros(
    (len(input_texts), max_seq_len),
    dtype='float32')
decoder_input_data = np.zeros(
    (len(input_texts), max_seq_len),
    dtype='float32')

decoder_target_data = np.zeros(
    (len(input_texts), max_seq_len, num_decoder_tokens-1),
    dtype='float32')
 
for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):
    for t, char in enumerate(input_text):
        encoder_input_data[i, t] = input_token_index[char]
    for t, char in enumerate(target_text):
        # decoder_target_data is ahead of decoder_input_data by one timestep
        decoder_input_data[i, t] = target_token_index[char]
        if t > 0:
            # 次の文字予測なので、targetは2文字目から始める
            # targetはpad記号を含まない(num_decoder_tokens-1)文字に対するone-hot
            decoder_target_data[i, t - 1, target_token_index[char]-1] = 1.

num of input characters w/out pad: 3014
num of output characters w/out pad: 3982


Padding記号も含めた後では

In [None]:
print('num_encoder_tokens', num_encoder_tokens)
print('num_decoder_tokens', num_decoder_tokens)

num_encoder_tokens 3015
num_decoder_tokens 3983


In [None]:
encoder_input_data.shape

(10000, 18)

In [None]:
decoder_input_data.shape

(10000, 18)

In [None]:
decoder_target_data.shape

(10000, 18, 3982)

In [None]:
def custom_loss(y_true,y_pred):
    loss = -tf.reduce_sum(y_true * tf.math.log(y_pred+1e-16), axis=-1)
    return tf.reduce_mean(loss)

In [None]:
custom_loss(decoder_target_data[:10,:,:],decoder_target_data[:10,:,:]).numpy()

-0.0

In [None]:
custom_loss(decoder_target_data[:10,:,:],decoder_target_data[8:18,:,:]).numpy()

6.344901

# 1. 訓練プロセスの実装

## 1-1. マスクについて

In [None]:
encoder_mask_data = np.where(encoder_input_data!=0, 0, 1)
encoder_mask_data = encoder_mask_data.astype('float32')
encoder_mask_data = encoder_mask_data[:, np.newaxis, np.newaxis, :]

In [None]:
encoder_mask_data

array([[[[0., 0., 1., ..., 1., 1., 1.]]],


       [[[0., 0., 1., ..., 1., 1., 1.]]],


       [[[0., 0., 1., ..., 1., 1., 1.]]],


       ...,


       [[[0., 0., 0., ..., 1., 1., 1.]]],


       [[[0., 0., 0., ..., 1., 1., 1.]]],


       [[[0., 0., 0., ..., 1., 1., 1.]]]], dtype=float32)

In [None]:
encoder_input_data

array([[ 216.,    6.,    0., ...,    0.,    0.,    0.],
       [ 216.,    6.,    0., ...,    0.,    0.,    0.],
       [ 247.,    6.,    0., ...,    0.,    0.,    0.],
       ...,
       [ 157., 3004., 1194., ...,    0.,    0.,    0.],
       [ 157., 3004., 1194., ...,    0.,    0.,    0.],
       [ 157., 3004., 1532., ...,    0.,    0.,    0.]], dtype=float32)

In [None]:
decoder_mask_data = np.where(decoder_input_data!=0, 0, 1)
decoder_mask_data = decoder_mask_data.astype('float32')
decoder_mask_data = decoder_mask_data[:, np.newaxis, np.newaxis, :]

In [None]:
decoder_mask_data

array([[[[0., 0., 0., ..., 1., 1., 1.]]],


       [[[0., 0., 0., ..., 1., 1., 1.]]],


       [[[0., 0., 0., ..., 1., 1., 1.]]],


       ...,


       [[[0., 0., 0., ..., 1., 1., 1.]]],


       [[[0., 0., 0., ..., 1., 1., 1.]]],


       [[[0., 0., 0., ..., 1., 1., 1.]]]], dtype=float32)

In [None]:
look_ahead = 1.-np.tril(np.ones((max_seq_len,max_seq_len)))
look_ahead = look_ahead[np.newaxis, np.newaxis, :, :]

decoder_mask_look_ahead_data = np.clip(decoder_mask_data + look_ahead, 0, 1)

In [None]:
decoder_mask_look_ahead_data

array([[[[0., 1., 1., ..., 1., 1., 1.],
         [0., 0., 1., ..., 1., 1., 1.],
         [0., 0., 0., ..., 1., 1., 1.],
         ...,
         [0., 0., 0., ..., 1., 1., 1.],
         [0., 0., 0., ..., 1., 1., 1.],
         [0., 0., 0., ..., 1., 1., 1.]]],


       [[[0., 1., 1., ..., 1., 1., 1.],
         [0., 0., 1., ..., 1., 1., 1.],
         [0., 0., 0., ..., 1., 1., 1.],
         ...,
         [0., 0., 0., ..., 1., 1., 1.],
         [0., 0., 0., ..., 1., 1., 1.],
         [0., 0., 0., ..., 1., 1., 1.]]],


       [[[0., 1., 1., ..., 1., 1., 1.],
         [0., 0., 1., ..., 1., 1., 1.],
         [0., 0., 0., ..., 1., 1., 1.],
         ...,
         [0., 0., 0., ..., 1., 1., 1.],
         [0., 0., 0., ..., 1., 1., 1.],
         [0., 0., 0., ..., 1., 1., 1.]]],


       ...,


       [[[0., 1., 1., ..., 1., 1., 1.],
         [0., 0., 1., ..., 1., 1., 1.],
         [0., 0., 0., ..., 1., 1., 1.],
         ...,
         [0., 0., 0., ..., 1., 1., 1.],
         [0., 0., 0., ..., 1., 1., 1.]

In [None]:
decoder_input_data

array([[1.000e+00, 3.422e+03, 3.800e+01, ..., 0.000e+00, 0.000e+00,
        0.000e+00],
       [1.000e+00, 3.419e+03, 6.620e+02, ..., 0.000e+00, 0.000e+00,
        0.000e+00],
       [1.000e+00, 3.660e+02, 3.800e+01, ..., 0.000e+00, 0.000e+00,
        0.000e+00],
       ...,
       [1.000e+00, 1.034e+03, 3.899e+03, ..., 0.000e+00, 0.000e+00,
        0.000e+00],
       [1.000e+00, 1.034e+03, 9.190e+02, ..., 0.000e+00, 0.000e+00,
        0.000e+00],
       [1.000e+00, 2.958e+03, 7.090e+02, ..., 0.000e+00, 0.000e+00,
        0.000e+00]], dtype=float32)

In [None]:
n_heads = 8
n_layers = 2
d_ff = 1024#2048
d_model = 256
do_rate = 0.2

encoder_inputs = layers.Input(shape=(max_seq_len,), name='encoder_input')
decoder_inputs = layers.Input(shape=(max_seq_len,), name='decoder_input')

encoder_mask = layers.Input(shape=(1,1,max_seq_len))
#decoder_mask = layers.Input(shape=(1,1,max_seq_len))
decoder_look_ahead_mask = layers.Input(shape=(1,max_seq_len,max_seq_len))

transformer = Transformer(n_layers, num_encoder_tokens, num_decoder_tokens, max_seq_len, d_model, n_heads, do_rate)
y = transformer(encoder_inputs, decoder_inputs, mask_enc=encoder_mask, mask_dec=encoder_mask, mask_look_ahead=decoder_look_ahead_mask, training=True)

transformer_model = tf.keras.models.Model(inputs=[encoder_inputs, decoder_inputs, encoder_mask, decoder_look_ahead_mask], outputs=y[0], name='trasnformer')

transformer_model.summary()

Model: "trasnformer"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
encoder_input (InputLayer)      [(None, 18)]         0                                            
__________________________________________________________________________________________________
decoder_input (InputLayer)      [(None, 18)]         0                                            
__________________________________________________________________________________________________
input_33 (InputLayer)           [(None, 1, 1, 18)]   0                                            
__________________________________________________________________________________________________
input_34 (InputLayer)           [(None, 1, 18, 18)]  0                                            
________________________________________________________________________________________

In [None]:
transformer_model.compile(optimizer='rmsprop', loss=custom_loss)

上の実装のtargetの3982次元「one-hot」は、padding記号に対しては、全てゼロなので、この無意味なpadding記号にはlossを計算しないようにしたcrossentropy関数である`custom_loss`を使うことにした。

In [None]:
transformer_model.fit(x=[encoder_input_data, decoder_input_data, encoder_mask_data, decoder_mask_look_ahead_data], y=decoder_target_data, epochs=16, batch_size=10)

Epoch 1/16
Epoch 2/16
Epoch 3/16
Epoch 4/16
Epoch 5/16
Epoch 6/16
Epoch 7/16
Epoch 8/16
Epoch 9/16
Epoch 10/16
Epoch 11/16
Epoch 12/16
Epoch 13/16
Epoch 14/16
Epoch 15/16
Epoch 16/16


<tensorflow.python.keras.callbacks.History at 0x7f7c7db6b550>

In [None]:
n_heads = 8
n_layers = 2
d_ff = 512
d_model = 128
do_rate = 0.2

encoder_inputs = layers.Input(shape=(max_seq_len,), name='encoder_input')
decoder_inputs = layers.Input(shape=(max_seq_len,), name='decoder_input')

encoder_mask = layers.Input(shape=(1,1,max_seq_len))
decoder_look_ahead_mask = layers.Input(shape=(1,max_seq_len,max_seq_len))

transformer = Transformer(n_layers, num_encoder_tokens, num_decoder_tokens, max_seq_len, d_model, n_heads, do_rate)
y = transformer(encoder_inputs, decoder_inputs, mask_enc=encoder_mask, mask_dec=encoder_mask, mask_look_ahead=decoder_look_ahead_mask, training=True)

transformer_model = tf.keras.models.Model(inputs=[encoder_inputs, decoder_inputs, encoder_mask, decoder_look_ahead_mask], outputs=y[0], name='trasnformer')

transformer_model.compile(optimizer='rmsprop', loss=custom_loss)

In [None]:
transformer_model.fit(x=[encoder_input_data, decoder_input_data, encoder_mask_data, decoder_mask_look_ahead_data], 
                      y=decoder_target_data, epochs=16, batch_size=10)

Epoch 1/16
Epoch 2/16
Epoch 3/16
Epoch 4/16
Epoch 5/16
Epoch 6/16

In [None]:
class PositionalEncoding(layers.Layer):
    def __init__(self, max_seq_len, d_model):
        super(PositionalEncoding, self).__init__()
        self.d_model = d_model
        self.pos_encoding = positional_encoding(max_seq_len, d_model)

    def call(self, inputs):
        input_dtype = inputs.dtype
        pos_encoding = self.pos_encoding
        pos_encoding = tf.cast(pos_encoding, dtype=input_dtype)
        inputs *= tf.math.sqrt(tf.cast(self.d_model, dtype=input_dtype))

        return inputs + pos_encoding

class MultiHeadSelfAttention(layers.Layer):
    def __init__(self, n_heads, d_model):
        super(MultiHeadSelfAttention, self).__init__()
        self.n_heads = n_heads
        self.d_model = d_model

        assert self.d_model % self.n_heads == 0, 'n_headsはd_modelの因数'

        self.d = self.d_model // self.n_heads

        self.dense_q = layers.Dense(d_model, use_bias=False)
        self.dense_k = layers.Dense(d_model, use_bias=False)
        self.dense_v = layers.Dense(d_model, use_bias=False)
        self.dense_o = layers.Dense(d_model, use_bias=False)

    def call(self, x_q, x_k, x_v, mask):
        
        q = self.dense_q(x_q)
        k = self.dense_k(x_k)
        v = self.dense_v(x_v)

        max_seq_len = tf.shape(k)[1]
        
        q = self.split_to_heads(q, max_seq_len)
        k = self.split_to_heads(k, max_seq_len)
        v = self.split_to_heads(v, max_seq_len)

        logit = tf.matmul(q, k, transpose_b=True)

        d_k = tf.shape(k)[-1]
        k_dtype = k.dtype
        scale = 1/tf.math.sqrt(tf.cast(d_k, k_dtype))
        logit *= scale

        if mask is not None:
            logit += mask*k_dtype.min
        
        attention_weight = tf.nn.softmax(logit, axis=-1)
        attention_output = tf.einsum('nhst,nhtd->nshd', attention_weight, v)
        attention_output = tf.reshape(attention_output, (-1, max_seq_len, self.n_heads*self.d))
        attention_output = self.dense_o(attention_output)

        return attention_output, attention_weight

    def split_to_heads(self, x, max_seq_len):
        x = tf.reshape(x, (-1, max_seq_len, self.n_heads, self.d))
        x = tf.transpose(x, perm=[0, 2, 1, 3])
        return x

class MHSAModule(layers.Layer):
    def __init__(self, source_vocab_size, max_seq_len, d_model, n_heads, do_rate):
        super(MHSAModule, self).__init__()
        self.mhsa = MultiHeadSelfAttention(n_heads, d_model)
        self.dropout = layers.Dropout(do_rate)
        self.add = layers.Add()
        self.norm = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs, mask, training=False):
        x, att = self.mhsa(inputs, inputs, inputs, mask)
        x = self.dropout(x, training=training)
        x = self.add([x,inputs])
        x = self.norm(x)
        return x, att

class PointWiseFeedForwardModule(layers.Layer):
    def __init__(self, d_model, d_ff, do_rate):
        super(PointWiseFeedForwardModule, self).__init__()
        self.pwff_1 = layers.Dense(d_ff, activation='relu')
        self.pwff_2 = layers.Dense(d_model)
        self.dropout = layers.Dropout(do_rate)
        self.add = layers.Add()
        self.norm = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs, training=False):
        x = self.pwff_1(inputs)
        x = self.pwff_2(x)
        x = self.dropout(x, training=training)
        x = self.add([x,inputs])
        x = self.norm(x)
        return x

class EncoderModule(layers.Layer):
    def __init__(self, source_vocab_size, max_seq_len, d_model, n_heads, do_rate):
        super(EncoderModule, self).__init__()
        self.mhsa = MHSAModule(source_vocab_size, max_seq_len, d_model, n_heads, do_rate)
        self.pwff = PointWiseFeedForwardModule(d_model, d_ff, do_rate)

    def call(self, inputs, mask, training=False):
        x, att = self.mhsa(inputs, mask, training)
        x = self.pwff(x, training)
        return x, att

class Encoder(layers.Layer):
    def __init__(self, n_layers, source_vocab_size, max_seq_len, d_model, n_heads, do_rate):
        super(Encoder, self).__init__()
        self.source_vocab_size = source_vocab_size
        self.max_seq_len = max_seq_len
        self.d_model = d_model
        self.n_heads = n_heads
        self.do_rate = do_rate
        self.do = layers.Dropout(do_rate)
        self.embedding = layers.Embedding(source_vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(max_seq_len, d_model)
        self.mhsa_modules = [EncoderModule(source_vocab_size, max_seq_len, d_model, n_heads, do_rate) for _ in range(n_layers)]

    def call(self, inputs, mask_enc, training=False):
        x = self.embedding(inputs)
        x = self.pos_encoding(x)
        x = self.do(x, training=training)
        attention_weights = []
        for module in self.mhsa_modules:
            x, att = module(x, mask_enc, training)
            attention_weights.append(att)
        return x, attention_weights

class MHSAModuleDec(layers.Layer):
    def __init__(self, source_vocab_size, max_seq_len, d_model, n_heads, do_rate):
        super(MHSAModuleDec, self).__init__()
        self.mhsa = MultiHeadSelfAttention(n_heads, d_model)
        self.dropout = layers.Dropout(do_rate)
        self.add = layers.Add()
        self.norm = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs, context, mask, training=False):
        x, att = self.mhsa(inputs, context, context, mask)
        x = self.dropout(x, training=training)
        x = self.add([x,inputs])
        x = self.norm(x)
        return x, att

class DecoderModule(layers.Layer):
    def __init__(self, source_vocab_size, max_seq_len, d_model, n_heads, do_rate):
        super(DecoderModule, self).__init__()
        self.mhsa1 = MHSAModule(source_vocab_size, max_seq_len, d_model, n_heads, do_rate)
        self.mhsa2 = MHSAModuleDec(source_vocab_size, max_seq_len, d_model, n_heads, do_rate)
        self.pwff = PointWiseFeedForwardModule(d_model, d_ff, do_rate)

    def call(self, inputs, enc_outputs, mask, mask_look_ahead, training=False):
        x, att1 = self.mhsa1(inputs, mask, training)
        x, att2 = self.mhsa2(inputs, enc_outputs, mask_look_ahead, training)
        x = self.pwff(x, training)
        return x, att1, att2

class Decoder(layers.Layer):
    def __init__(self, n_layers, target_vocab_size, max_seq_len, d_model, n_heads, do_rate):
        super(Decoder, self).__init__()
        self.target_vocab_size = target_vocab_size
        self.max_seq_len = max_seq_len
        self.d_model = d_model
        self.n_heads = n_heads
        self.do_rate = do_rate
        self.do = layers.Dropout(do_rate)
        self.embedding = layers.Embedding(target_vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(max_seq_len, d_model)
        self.mhsa_modules = [DecoderModule(source_vocab_size, max_seq_len, d_model, n_heads, do_rate) for _ in range(n_layers)]

    def call(self, inputs, enc_outputs, mask_dec, mask_look_ahead, training=False):
        x = self.embedding(inputs)
        x = self.pos_encoding(x)
        x = self.do(x)
        attention_weights = []
        for module in self.mhsa_modules:
            x, att1, att2 = module(x, enc_outputs, mask_dec, mask_look_ahead, training)
            attention_weights += [[att1, att2]]
        return x, attention_weights

class Transformer(layers.Layer):
    def __init__(self, n_layers, source_vocab_size, target_vocab_size, max_seq_len, d_model, n_heads, do_rate):
        super(Transformer, self).__init__()
        self.target_vocab_size = target_vocab_size
        self.max_seq_len = max_seq_len
        self.d_model = d_model
        self.n_heads = n_heads
        self.do_rate = do_rate
        self.encoder = Encoder(n_layers, source_vocab_size, max_seq_len, d_model, n_heads, do_rate)
        self.decoder = Decoder(n_layers, target_vocab_size, max_seq_len, d_model, n_heads, do_rate)
        #self.classification = layers.Dense(target_vocab_size-1, activation='softmax')
        self.classification = layers.Dense(target_vocab_size-1)

    def call(self, enc_inputs, dec_inputs, mask_enc, mask_dec, mask_look_ahead, training=False):

        enc_outputs, att_enc = self.encoder(enc_inputs, mask_enc=mask_enc, training=training)
        dec_outputs, att_dec = self.decoder(dec_inputs, enc_outputs, mask_dec=mask_dec, mask_look_ahead=mask_look_ahead, training=training)
        y = self.classification(dec_outputs)
        y = tf.keras.activations.softmax(y, axis=-1)

        return y, att_enc + att_dec

# 問題：
学習済みの`transformer`層を使って、学習後に予測を行うモデルを作り翻訳させてみましょう。授業のSeq2Seqの予測の実装を理解すればほとんど同じです。ただし予測時のマスクの扱いについては注意しましょう（auto regressiveにやるので、mask_look_aheadは不要なので、ゼロテンソルを渡しましょう）。