In [53]:
import numpy as np
import math
import re
import time
import tensorflow as tf
import tensorflow
import tensorflow_datasets as tfds
import h5py

from tensorflow.keras.utils import GeneratorEnqueuer
from tensorflow.keras.models import load_model
from tensorflow.keras import layers
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.optimizers import SGD
from textwrap import indent
from tensorflow.keras.callbacks import BaseLogger
import matplotlib.pyplot as plt
import json
import os
import tensorflow.keras.backend as K
import gc

In [None]:
tf.config.list_physical_devices('GPU')
# if oom error still persists, try this:
# tf.keras.backend.set_floatx('float16')

### GPU Related Configuration

In [3]:
configuration = tf.compat.v1.ConfigProto()
configuration.gpu_options.allow_growth = True
configuration.gpu_options.per_process_gpu_memory_fraction = 0.8
session = tf.compat.v1.Session(config=configuration)
tf.config.run_functions_eagerly(True)

### Tokenization Processs

In [4]:
MAX_VOCAB_SIZE = 2**13
MAX_SEQUENCE_LENGTH = 20
BATCH_SIZE = 128
BUFFER_SIZE = 20000

In [5]:
def text_gen(file_path):
    with open(file_path, mode='r', encoding='utf-8') as f:
        return (row for row in f.readlines()) 

In [6]:
with open("./nonbreaking_prefix.en",
          mode='r',
          encoding='utf-8') as f:
    non_breaking_prefix_en = f.read()
with open("./nonbreaking_prefix.fr",
          mode='r',
          encoding='utf-8') as f:
    non_breaking_prefix_fr = f.read()

non_breaking_prefix_en = non_breaking_prefix_en.split("\n")
non_breaking_prefix_en = [' ' + prefix + '.' for prefix in non_breaking_prefix_en]
non_breaking_prefix_fr = non_breaking_prefix_fr.split("\n")
non_breaking_prefix_fr = [' ' + prefix + '.' for prefix in non_breaking_prefix_fr]

In [7]:
non_breaking_prefix_en

[' a.',
 ' b.',
 ' c.',
 ' d.',
 ' e.',
 ' f.',
 ' g.',
 ' h.',
 ' i.',
 ' j.',
 ' k.',
 ' l.',
 ' m.',
 ' n.',
 ' o.',
 ' p.',
 ' q.',
 ' r.',
 ' s.',
 ' t.',
 ' u.',
 ' v.',
 ' w.',
 ' x.',
 ' y.',
 ' z.',
 ' messrs.',
 ' mlle.',
 ' mme.',
 ' mr.',
 ' mrs.',
 ' ms.',
 ' ph.',
 ' prof.',
 ' sr.',
 ' st.',
 ' a.m.',
 ' p.m.',
 ' vs.',
 ' i.e.',
 ' e.g.']

In [8]:
# prepare generator for extremely large text file

def reset_generators():
    europarl_en = text_gen("./europarl-v7.fr-en.en")
    europarl_fr = text_gen("./europarl-v7.fr-en.fr")

    corpus_en = europarl_en

    for prefix in non_breaking_prefix_en:
        corpus_en = (row.replace(prefix, prefix + "###") for row in corpus_en)

    corpus_en = (re.sub(r"\.(?=[0-9a-zA-Z])", ".###", row) for row in corpus_en)    
    corpus_en = (re.sub(".###", "", row) for row in corpus_en)
    corpus_en = (re.sub(r"\s+", " ", row) for row in corpus_en)


    corpus_fr = europarl_fr

    for prefix in non_breaking_prefix_en:
        corpus_fr = (row.replace(prefix, prefix + "###") for row in corpus_fr)

    corpus_fr = (re.sub(r"\.(?=[0-9a-zA-Z])", ".###", row) for row in corpus_fr)    
    corpus_fr = (re.sub(".###", "", row) for row in corpus_fr)
    corpus_fr = (re.sub(r"\s+", " ", row) for row in corpus_fr)
    #     this mimic the shift right process so that we don't need this shift_right layer as in trax
    corpus_fr = ("<sos> " + row +" <eos>" for row in corpus_fr)
    
    return corpus_en, corpus_fr

In [9]:
corpus_en, corpus_fr = reset_generators()

In [10]:
corpus_en, corpus_fr = reset_generators()
corpus_fr = (re.sub(" <eos>", "", row) for row in corpus_fr)
for _ in range(0,100):
    print(next(corpus_fr))

<sos> Reprise de la session 
<sos> Je déclare reprise la session du Parlement européen qui avait été interrompue le vendredi 17 décembre dernier et je vous renouvelle tous mes vux en espérant que vous avez passé de bonnes vacances. 
<sos> Comme vous avez pu le constater, le grand "bogue de l'an 2000" ne s'est pas produit. En revanche, les citoyens d'un certain nombre de nos pays ont été victimes de catastrophes naturelles qui ont vraiment été terribles. 
<sos> Vous avez souhaité un débat à ce sujet dans les prochains jours, au cours de cette période de session. 
<sos> En attendant, je souhaiterais, comme un certain nombre de collègues me l'ont demandé, que nous observions une minute de silence pour toutes les victimes, des tempêtes notamment, dans les différents pays de l'Union européenne qui ont été touchés. 
<sos> Je vous invite à vous lever pour cette minute de silence. 
<sos> (Le Parlement, debout, observe une minute de silence) 
<sos> Madame la Présidente, c'est une motion de proc

In [11]:
tokenizer_en = Tokenizer(num_words=MAX_VOCAB_SIZE, filters="")
tokenizer_fr = Tokenizer(num_words=MAX_VOCAB_SIZE, filters="")

In [12]:
corpus_en, corpus_fr = reset_generators()
tokenizer_en.fit_on_texts(corpus_en)
tokenizer_fr.fit_on_texts(corpus_fr)

In [13]:
VOCAB_SIZE_EN = min(MAX_VOCAB_SIZE, len(tokenizer_en.word_index))
VOCAB_SIZE_FR = min(MAX_VOCAB_SIZE, len(tokenizer_fr.word_index))
print(VOCAB_SIZE_EN, VOCAB_SIZE_FR)

8192 8192


In [14]:
corpus_en, corpus_fr = reset_generators()

encoder_input_ = tokenizer_en.texts_to_sequences(corpus_en)
decoder_input_ = tokenizer_fr.texts_to_sequences((re.sub(" <eos>", "", row) for row in corpus_fr))

_, corpus_fr = reset_generators()
decoder_input_real_ = tokenizer_fr.texts_to_sequences((re.sub("<sos> ", "", row) for row in corpus_fr))

In [15]:
encoder_input = pad_sequences(
    encoder_input_, 
    maxlen=MAX_SEQUENCE_LENGTH, 
    padding="post"
)

decoder_input = pad_sequences(
    decoder_input_, 
    maxlen=MAX_SEQUENCE_LENGTH, 
    padding="post"
)

decoder_input_real = pad_sequences(
    decoder_input_real_, 
    maxlen=MAX_SEQUENCE_LENGTH, 
    padding="post"
)

In [74]:
print(encoder_input.shape)
print(decoder_input.shape)
print(decoder_input_real.shape)

(2007723, 20)
(2007723, 20)
(2007723, 20)


In [17]:
tokenizer_en.index_word[0] = ""
tokenizer_fr.index_word[0] = ""

In [18]:
decoder_input[:,0] = tokenizer_fr.word_index["<sos>"]

In [19]:
for i in range(0,1000, 10):
    print(str(f"[{i}-en]   \t")+ " " + " ".join([tokenizer_en.index_word[index] for index in encoder_input[i]]))
    print(str(f"[{i}-fr-out]\t")+ " " + " ".join([tokenizer_fr.index_word[index] for index in decoder_input_real[i]]))
    print(str(f"[{i}-fr-fed]\t")+ " " + " ".join([tokenizer_fr.index_word[index] for index in decoder_input[i]]))
    print("---------------------")

[0-en]   	 resumption of the session                
[0-fr-out]	 reprise de la session <eos>               
[0-fr-fed]	 <sos> reprise de la session               
---------------------
[10-en]   	 in sri and urging her to do everything she possibly can to seek a peaceful reconciliation to a very difficult
[10-fr-out]	 qui est en son pouvoir pour chercher une réconciliation pacifique et mettre un terme à cette situation particulièrement difficile. <eos>
[10-fr-fed]	 <sos> qui est en son pouvoir pour chercher une réconciliation pacifique et mettre un terme à cette situation particulièrement difficile.
---------------------
[20-en]   	 may, if you wish, raise this question, ie. on thursday prior to the start of the presentation of the report.
[20-fr-out]	 vous en effet, si vous le soulever cette question, c'est-à-dire jeudi avant le début de la présentation du rapport. <eos>
[20-fr-fed]	 <sos> vous en effet, si vous le soulever cette question, c'est-à-dire jeudi avant le début de la prése

### Remark
Even we have fed in the parameter ```num_words=MAX_VOCAB_SIZE``` in the constructor of ```Tokenizer```, both ```len(tokenizer_en.word_index)``` and ```len(tokenizer_fr.word_index)``` still exceed ```MAX_VOCAB_SIZE```.

However, the number of output indexes in the text to seq transformation still agress with the number ```MAX_VOCAB_SIZE```.

### HDF5 Dataset Writer

In [55]:
class HDF5DatasetWriter:
    def __init__(self, dims, outputPath, dataKey="rows", bufferSize=1000):
        """
        dims = shape of training dataset

        outputPath = destination of our hdf5 db file

        dataKey = name of the data to be stored in hdf5 format, like images, features, etc

        bufferSize = size of our in-memory buffer, default to save 1000 feature vectors/images
        """

        if os.path.exists(outputPath):
            raise ValueError(
                "The supplied outoutPath already exists and cannot be overwritten.")

        # "w": create file, truncate if exists
        self.db = h5py.File(outputPath, "w")
        self.data = self.db.create_dataset(dataKey, dims, dtype="float")

        self.bufferSize = bufferSize
        self.buffer = {"data": []}
        self.idx = 0

    def add(self, rows):
        # array.extend(array), concatenate two arrays
        # both append and extend modify the original array, no return
        # i.e., a = [1], then a.extend([2]), we have a = [1, 2]
        self.buffer["data"].extend(rows)

        if len(self.buffer["data"]) >= self.bufferSize:
            self.flush()

    def flush(self):
        i = self.idx + len(self.buffer["data"])
        self.data[self.idx: i] = self.buffer["data"]
        self.idx = i
        self.buffer = {"data": []}

        
    def close(self):
        if len(self.buffer["data"]) > 0:
            self.flush()

        self.db.close()

In [66]:
class HDF5DatasetGenerator:
    def __init__(self, dbPath, batch_size, dataKey="rows"):
        self.batch_size = batch_size
        self.db = h5py.File(dbPath, "r")
        self.dataKey = dataKey
        self.DATA_SIZE = self.db[dataKey].shape[0]

    def generator(self, passes=np.inf):
        # this function generate a list of batches of data a
        # passes = number of epoch that we want
        epochs = 0

        while epochs < passes:
            for i in np.arange(0, self.DATA_SIZE, self.batch_size):
                rows = self.db[self.dataKey][i:i + self.batch_size]
                yield rows

            epochs = epochs + 1

    def close(self):
        self.db.close()

In [57]:
def save_data_into_hdf5(dims, datas, file_path):
    writer = HDF5DatasetWriter(dims, file_path)
    
    for data in datas:   
        writer.add([data])
        
    writer.close()

In [67]:
def get_gen_from_hdf5(hdf5_file_path, batch_size):
    dataset_gen = HDF5DatasetGenerator(hdf5_file_path, batch_size)
    def dataset_close():
        dataset_gen.close()
    
    return dataset_gen.generator(), dataset_close    

In [58]:
ENCODER_INPUT_HDF5_PATH = "./encoder_input.hdf5"
DECODER_INPUT_HDF5_PATH = "./decoder_input.hdf5"
DECODER_INPUT_REAL_HDF5_PATH = "./decoder_input_real.hdf5"

save_data_into_hdf5(encoder_input.shape, encoder_input, ENCODER_INPUT_HDF5_PATH)
save_data_into_hdf5(decoder_input.shape, decoder_input, DECODER_INPUT_HDF5_PATH)
save_data_into_hdf5(decoder_input_real.shape, decoder_input_real, DECODER_INPUT_REAL_HDF5_PATH)

In [88]:
encoder_input_gen, encoder_input_gen_close = get_gen_from_hdf5(ENCODER_INPUT_HDF5_PATH, BATCH_SIZE)
decoder_input_gen, decoder_input_gen_close = get_gen_from_hdf5(DECODER_INPUT_HDF5_PATH, BATCH_SIZE)
decoder_input_real_gen, encoder_input_real_gen_close = get_gen_from_hdf5(DECODER_INPUT_REAL_HDF5_PATH, BATCH_SIZE)

In [89]:
train_gen = ([x[0], x[1]] for x in zip(encoder_input_gen, decoder_input_gen))
train_target_gen = zip(train_gen, decoder_input_real_gen)

### Dataset Generator (Deprecated, Cause Memory Leak Problem in my Computer)

In [20]:
# dataset_X_1 = tensorflow.data.Dataset.from_tensor_slices((encoder_input))
# dataset_X_2 = tensorflow.data.Dataset.from_tensor_slices((decoder_input))
# dataset_Y = tensorflow.data.Dataset.from_tensor_slices((decoder_input_real))

In [21]:
# # can shuffle dataset, check the API if desired

# dataset_X_1 = dataset_X_1.batch(BATCH_SIZE)
# # dataset_X_1 = dataset_X_1.prefetch(tensorflow.data.experimental.AUTOTUNE)

# dataset_X_2 = dataset_X_2.batch(BATCH_SIZE)
# # dataset_X_2 = dataset_X_2.prefetch(tensorflow.data.experimental.AUTOTUNE)

# dataset_Y = dataset_Y.batch(BATCH_SIZE)
# # dataset_Y = dataset_Y.prefetch(tensorflow.data.experimental.AUTOTUNE)

In [47]:
# train_gen = ([x[0], x[1]] for x in zip(dataset_X_1, dataset_X_2))
# target_gen = dataset_Y

# train_target_gen = (item for item in zip(train_gen, target_gen))

### Define Transfomer's Component

$$
\mathrm{Attention}(Q, K, V)=\mathrm{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}+{M}\right) V
$$

$Q$ - query, 
$K$ - key, 
$V$ - values, 
$M$ - mask, 
${d_k}$ - depth/dimension of the queries and keys (used for scaling down)


In [28]:
#batch_size =3, seq_length= 3, n_head = 2, d_head = 3
# (batch_size, seq_length, n_heads x d_head) = (3, 3, 2 x 3) = (3, 3, 6)
inputs = tf.constant([
    [[1,2,3,4,5,6],[11,12,13,14,15,16],[21,22,23,24,25,26]],
    [[101,102,103,104,105,106],[111,112,113,114,115,116],[121,122,123,124,125,126]],
    [[201,202,203,204,205,206],[211,212,213,214,215,216],[221,222,223,224,225,226]]
])
inputs = tf.cast(inputs, dtype=tf.float32)
batch_size = tf.cast(inputs.shape[0], tf.float32)
seq_length = tf.cast(inputs.shape[1], tf.float32)
n_head=2
d_head = tf.cast(inputs.shape[-1]/n_head, tf.float32)

inputs = tf.reshape(inputs, (batch_size, seq_length, -1, d_head))
print("[intermediate shape]")
print(inputs)
print("\n")
inputs = tf.transpose(inputs, (0, 2, 1 , 3))
inputs = tf.reshape(inputs, (-1, seq_length, d_head))
print("[multihead version]")
print(inputs)

[intermediate shape]
tf.Tensor(
[[[[  1.   2.   3.]
   [  4.   5.   6.]]

  [[ 11.  12.  13.]
   [ 14.  15.  16.]]

  [[ 21.  22.  23.]
   [ 24.  25.  26.]]]


 [[[101. 102. 103.]
   [104. 105. 106.]]

  [[111. 112. 113.]
   [114. 115. 116.]]

  [[121. 122. 123.]
   [124. 125. 126.]]]


 [[[201. 202. 203.]
   [204. 205. 206.]]

  [[211. 212. 213.]
   [214. 215. 216.]]

  [[221. 222. 223.]
   [224. 225. 226.]]]], shape=(3, 3, 2, 3), dtype=float32)


[multihead version]
tf.Tensor(
[[[  1.   2.   3.]
  [ 11.  12.  13.]
  [ 21.  22.  23.]]

 [[  4.   5.   6.]
  [ 14.  15.  16.]
  [ 24.  25.  26.]]

 [[101. 102. 103.]
  [111. 112. 113.]
  [121. 122. 123.]]

 [[104. 105. 106.]
  [114. 115. 116.]
  [124. 125. 126.]]

 [[201. 202. 203.]
  [211. 212. 213.]
  [221. 222. 223.]]

 [[204. 205. 206.]
  [214. 215. 216.]
  [224. 225. 226.]]], shape=(6, 3, 3), dtype=float32)


In [29]:
mask_size=4
mask = 1 - tf.experimental.numpy.tril(tf.ones((1, mask_size, mask_size)), k=0)
mask *= -1e9
print(mask)
# the masked below is then fed into tf.where

tf.Tensor(
[[[-0.e+00 -1.e+09 -1.e+09 -1.e+09]
  [-0.e+00 -0.e+00 -1.e+09 -1.e+09]
  [-0.e+00 -0.e+00 -0.e+00 -1.e+09]
  [-0.e+00 -0.e+00 -0.e+00 -0.e+00]]], shape=(1, 4, 4), dtype=float32)


In [30]:
@tf.function
def dot_product_self_attention(queries, keys, values, masked=False):
    product = tf.linalg.matmul(queries, keys, transpose_b=True)
    keys_dim = tf.cast(tf.shape(keys)[-1], tf.float32)
    scaled_product = product / tf.math.sqrt(keys_dim)
    
    if masked == True:
        mask_size = scaled_product.get_shape().as_list()[-1]
        mask = 1 - tf.experimental.numpy.tril(tf.ones((1, mask_size, mask_size)), k=0)
        mask *= -1e9
        scaled_product +=  mask * -1e9
        
    softmax = tf.nn.softmax(scaled_product, axis=-1)
    attention = tf.matmul(softmax, values)
    
    return attention

In [31]:
a = tf.constant([[1,2],[2,2],[3,2]]) 
b = tf.constant([[4,4],[5,4],[6,4]]) 
c = tf.constant([[4,4],[5,4],[6,4]])
c_ = tf.constant([[4,4],[5,4],[6,4]])

e = tf.concat([a,b],axis=-1)

d=tf.concat([a[tf.newaxis,...], b[tf.newaxis,...], c[tf.newaxis,...],c_[tf.newaxis,...]], axis=0)
print(d)
print("\n")

tf.Tensor(
[[[1 2]
  [2 2]
  [3 2]]

 [[4 4]
  [5 4]
  [6 4]]

 [[4 4]
  [5 4]
  [6 4]]

 [[4 4]
  [5 4]
  [6 4]]], shape=(4, 3, 2), dtype=int32)




In [32]:
class MultiHeadAttention(layers.Layer):
    def __init__(self, 
                 d_feature,
                 n_heads, 
                 masked=False):
        super(MultiHeadAttention, self).__init__()        
        self.d_feature = d_feature
        self.n_heads = n_heads
        self.masked = masked   
        
        assert self.d_feature % self.n_heads == 0 
        
        self.d_head = self.d_feature // self.n_heads
    
    @tf.function
    def compute_heads(self, x):
        # x is of shape (batch_size, seq_length, d_features)
        batch_size = tf.shape(x)[0]
        seq_length = x.get_shape().as_list()[1]

        x = tf.reshape(x, (batch_size, seq_length, self.n_heads, self.d_head))
        x = tf.transpose(x, (0, 2, 1, 3))
        # x is of shape (batch_size, n_heads, seq_length, d_head)
        return x
    
    def build(self, input_shape):
        self.branch_1_dense = layers.Dense(self.d_feature)
        self.branch_2_dense = layers.Dense(self.d_feature)
        self.branch_3_dense = layers.Dense(self.d_feature)
        
        self.output_dense = layers.Dense(self.d_feature)
        
        

    def call(self, inputs_1, inputs_2, inputs_3):
        # inputs are supposed to be embedded results (shifted + word_embedding + position_embedding)  
        batch_size = tf.shape(inputs_1)[0]
        seq_length = inputs_1.get_shape().as_list()[1]
        
        branch_1 = self.branch_1_dense(inputs_1)
        branch_1 = self.compute_heads(branch_1)

        branch_2 = self.branch_2_dense(inputs_2)
        branch_2 = self.compute_heads(branch_2)
 
        
        branch_3 = self.branch_3_dense(inputs_3)
        branch_3 = self.compute_heads(branch_3)
        
        x = dot_product_self_attention(branch_1, branch_2, branch_3, self.masked)

        # shape becomes (batch_size, n_heads, seq_length, d_head)
        x = tf.transpose(x, (0, 2, 1, 3))
        concat = tf.reshape(x, (batch_size, seq_length, self.d_feature))
        outputs = self.output_dense(concat)

        return outputs

In [33]:
class PositionalEncoding(layers.Layer):
    def __init__(self):
        super(PositionalEncoding, self).__init__()
    
    def get_angles(self, pos, i, d_model):
        # pos of size (seg_length, 1) and 
        # i of size (1, d_model)
        # pos * angles of size (seq_length, d_model)
        
        angles = 1/np.power(10000., 2*(i//2)/np.float32(d_model))
        return pos * angles
    
    def build(self, input_shape):
        self.seq_length = input_shape[-2]
        self.d_model = input_shape[-1]
        
    def call(self, inputs):
        seq_length = self.seq_length
        d_model = self.d_model

        angles = self.get_angles(
            np.arange(seq_length)[:, np.newaxis],
            np.arange(d_model)[np.newaxis, :],
            d_model
        )

        angles[:, 0::2] = np.sin(angles[:, 0::2])
        angles[:, 1::2] = np.cos(angles[:, 1::2]) 
        pos_encoding = angles[np.newaxis, ...]


        return inputs + tf.cast(pos_encoding, tf.float32)

In [34]:
class EncoderBlock(layers.Layer):
    def __init__(self, d_model, d_ff, n_heads, dropout, ff_activation=layers.ReLU, name="encoder_block"):
        super(EncoderBlock, self).__init__(name=name)
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_ff = d_ff
        self.dropout = dropout
        self.ff_activation = ff_activation
        
    def build(self, input_shape):
        self.multi_head = MultiHeadAttention(self.d_model, self.n_heads, masked=False)
        self.dropout_1 = layers.Dropout(self.dropout)
        self.layer_norm = layers.LayerNormalization(epsilon = 1e-6)
        
        self.dense_1 = layers.Dense(self.d_ff)
        self.dense_2 = layers.Dense(self.d_model)
        self.dropout_2 = layers.Dropout(self.dropout)
        self.layer_norm_2 = layers.LayerNormalization(epsilon = 1e-6)
        
    def call(self, inputs):
        x = self.multi_head(inputs, inputs, inputs)
        x = self.dropout_1(x)
        out_1 = self.layer_norm(x + inputs)
        
        x = self.dense_1(out_1)
        x = self.ff_activation()(x)
        x = self.dense_2(x)
        x = self.dropout_2(x)
        x = self.layer_norm_2(x + out_1)

        return x

In [35]:
class Encoder(layers.Layer):
    def __init__(self, d_model, d_ff, n_heads, dropout, n_layers, vocab_size, ff_activation=layers.ReLU,name="encoder"):
        super(Encoder, self).__init__(name=name)
        self.d_model = d_model
        self.d_ff = d_ff
        self.n_heads = n_heads
        self.dropout = dropout
        self.n_layers = n_layers
        self.vocab_size = vocab_size
        self.ff_activation = ff_activation
    
    def build(self, input_shape):
        self.embedding_word = layers.Embedding(self.vocab_size, self.d_model)
        self.dropout_1 = layers.Dropout(self.dropout)
        
        self.enc_blocks = [
            EncoderBlock(self.d_model, self.d_ff, self.n_heads, self.dropout, self.ff_activation)
            for _ in range(self.n_layers)
        ]
        
    def call(self, inputs):
        x = self.embedding_word(inputs)
        x = x * tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x = PositionalEncoding()(x)
        x = self.dropout_1(x)
                
        for i in range(self.n_layers):
            x = self.enc_blocks[i](x)
        
        return x

In [36]:
class DecoderBlock(layers.Layer):
    # d_model = MultiheadAttention model's qkv feature dimension
    def __init__(self, d_model, d_ff, n_heads, dropout, ff_activation=layers.ReLU, name="decoder_block"): 
        super(DecoderBlock, self).__init__(name=name)
        """
        Args:
        d_model (int):  depth of embedding.
        d_ff (int): depth of feed-forward layer.
        n_heads (int): number of attention heads.
        dropout (float): dropout rate (how much to drop out).
        mode (str): 'train' or 'eval'.
        ff_activation (function): the non-linearity in feed-forward layer.

        Returns:
            list: list of trax.layers.combinators.Serial that maps an activation tensor to an activation tensor.
        """
        self.d_model = d_model
        self.d_ff = d_ff
        self.n_heads = n_heads
        self.dropout = dropout
        self.ff_activation = ff_activation        
    
    def build(self, input_shape):
        self.multi_head_masked = MultiHeadAttention(self.d_model, self.n_heads, masked=True)
        self.dropout_1 = layers.Dropout(self.dropout)
        self.layer_norm_1 = layers.LayerNormalization(epsilon = 1e-6)
        
        self.multi_head_unmasked = MultiHeadAttention(self.d_model, self.n_heads, masked=False)
        self.dropout_2 = layers.Dropout(self.dropout)
        self.layer_norm_2 = layers.LayerNormalization(epsilon = 1e-6)
    
        self.dense_1 = layers.Dense(self.d_ff)
        self.dense_2 = layers.Dense(self.d_model)
        self.dropout_3 = layers.Dropout(self.dropout)
        self.layer_norm_3 = layers.LayerNormalization(epsilon = 1e-6)
        
    def call(self, inputs, enc_outputs):
        x = self.multi_head_masked(
            inputs,
            inputs,
            inputs
        )
        x = self.dropout_1(x)
        out_1 = self.layer_norm_1(x + inputs)
          
        x = self.multi_head_unmasked(
            enc_outputs,
            enc_outputs,
            out_1
        )
        x = self.dropout_2(x)
        out_2 = self.layer_norm_2(x + out_1)
             
        x = self.dense_1(out_2)
        x = self.dense_2(x)
        x = self.ff_activation()(x)
        x = self.dropout_3(x)
        x = self.layer_norm_3(x + out_2)
        
        return x 

In [37]:
class Decoder(layers.Layer):
    def __init__(self, d_model, d_ff, n_heads, dropout, n_layers, vocab_size, ff_activation=layers.ReLU, name="decoder"):
        super(Decoder, self).__init__(name=name)
        self.d_model = d_model
        self.d_ff = d_ff
        self.n_heads = n_heads
        self.dropout = dropout
        self.n_layers = n_layers
        self.vocab_size = vocab_size
        self.ff_activation = ff_activation
        
    def build(self, input_shape):
        self.dec_word_embedding = layers.Embedding(self.vocab_size, self.d_model)
        
    def call(self, inputs, enc_outputs):
        inputs = self.dec_word_embedding(inputs)
        inputs = inputs * tf.math.sqrt(tf.cast(self.d_model, tf.float32))  
        inputs = PositionalEncoding()(inputs)
        
        self.dec_blocks = [
            DecoderBlock(self.d_model, self.d_ff, self.n_heads, self.dropout, ff_activation=self.ff_activation)
            for _ in range(self.n_layers)
        ]
        
        x = inputs
              
        for i in range(self.n_layers):
            x = self.dec_blocks[i](x, enc_outputs)
            
        return x

In [38]:
class Transformer(tf.keras.Model):
    def __init__(self, d_model, d_ff, n_heads, dropout, n_layers, vocab_size_enc, vocab_size_dec, ff_activation=layers.ReLU, name="transformer"):
        super(Transformer, self).__init__(name=name)  
        self.vocab_size_dec = vocab_size_dec
        
        self.encoder = Encoder(d_model, d_ff, n_heads, dropout, n_layers, vocab_size_enc, ff_activation)
        self.decoder = Decoder(d_model, d_ff, n_heads, dropout, n_layers, vocab_size_dec, ff_activation)
        self.output_dense = layers.Dense(self.vocab_size_dec)


    def call(self, enc_inputs, dec_inputs):
        enc_outputs = self.encoder(enc_inputs)
        dec_outputs = self.decoder(dec_inputs, enc_outputs)
        x = self.output_dense(dec_outputs)
        x = layers.Softmax()(x)
        return x

### Define EpochCheckpoint and Monitor Classes

#### Define Classes

In [39]:
class EpochCheckpoint(Callback):
    def __init__(self, output_dir, every=5, startAt=0, model_title=""):
        # call the parent constructor
        super(Callback, self).__init__()

        # store the base output path for the model, the number of
        # epochs that must pass before the model is serialized to
        # disk and the current epoch value
        self.output_dir = output_dir
        self.every = every
        self.intEpoch = startAt
        self.model_title = model_title

    def on_epoch_end(self, epoch, logs={}):
        # check to see if the model should be serialized to disk
        if (self.intEpoch + 1) % self.every == 0:
            p = os.path.sep.join([self.output_dir, self.model_title + "epoch-{}.hdf5".format(self.intEpoch + 1)])
            self.model.save(p, overwrite=True)

        # increment the internal epoch counter
        self.intEpoch += 1

In [40]:
# import the necessary packages


class TrainingMonitorCallback(BaseLogger):
    def __init__(self, figPath, jsonPath=None, startAt=0):
        # store the output path for the figure, the path to the JSON
        # serialized file, and the starting epoch
        super(TrainingMonitorCallback, self).__init__()
        self.figPath = figPath
        self.jsonPath = jsonPath
        self.startAt = startAt

    def on_train_begin(self, logs={}):
        # initialize the history dictionary
        self.H = {}

        # if the JSON history path exists, load the training history
        if self.jsonPath is not None:
            if os.path.exists(self.jsonPath):
                self.H = json.loads(open(self.jsonPath).read())

                # check to see if a starting epoch was supplied
                if self.startAt > 0:
                    # loop over the entries in the history log and
                    # trim any entries that are past the starting
                    # epoch
                    for k in self.H.keys():
                        self.H[k] = self.H[k][:self.startAt]

    def on_epoch_end(self, epoch, logs={}):
        print("[INFO] learning rate: {}".format(K.get_value(self.model.optimizer.lr)))
        # loop over the logs and update the loss, accuracy, etc.
        # for the entire training process
        for (k, v) in logs.items():
            l = self.H.get(k, [])
            l.append(float(v))
            self.H[k] = l

        # check to see if the training history should be serialized
        # to file
        if self.jsonPath is not None:
            f = open(self.jsonPath, "w")
            f.write(json.dumps(self.H, indent=4))
            f.close()

        # ensure at least two epochs have passed before plotting
        # (epoch starts at zero)
        if len(self.H["loss"]) > 1:
            # plot the training loss and accuracy
            N = np.arange(0, len(self.H["loss"]))
            plt.style.use("ggplot")
            plt.figure()
            plt.plot(N, self.H["loss"], label="train_loss")
            plt.plot(N, self.H["val_loss"], label="val_loss")
            plt.plot(N, self.H["accuracy"], label="train_acc")
            plt.plot(N, self.H["val_accuracy"], label="val_acc")
            plt.title("Training Loss and Accuracy [Epoch {}]".format(
                len(self.H["loss"])))
            plt.xlabel("Epoch #")
            plt.ylabel("Loss/Accuracy")
            plt.legend()

            # save the figure
            plt.savefig(self.figPath)
            plt.close()


#### Config the Model

In [50]:
D_MODEL = 128
D_FF = 512
N_LAYERS = 4
N_HEADS = 8
DROPOUT = 0.1

X_1 = layers.Input((MAX_SEQUENCE_LENGTH,))
X_2 = layers.Input((MAX_SEQUENCE_LENGTH,))
             
Y = Transformer(
    D_MODEL,
    D_FF,
    N_HEADS,
    DROPOUT,
    N_LAYERS,
    VOCAB_SIZE_EN,
    VOCAB_SIZE_FR  
)(X_1, X_2)

model = Model(inputs=[X_1,X_2], outputs=Y)

#### Training Configuration

In [51]:
START_AT = 0
STORE_CHECKPOINT_EVERY_EPOCH = 1
FIGURE_PATH = "./figures/results.png"

PREV_MODEL_PATH=None
LR = 1e-3

checkpoint_callback = EpochCheckpoint(
    "./checkpoints_hdf5/", 
    every=STORE_CHECKPOINT_EVERY_EPOCH,
    model_title="en_to_fr-"+"-", 
    startAt=START_AT
)

traing_callback = TrainingMonitorCallback(
    FIGURE_PATH,
    jsonPath=None,
    startAt=START_AT
)

callbacks = [checkpoint_callback, traing_callback]

#### Train

In [None]:
if PREV_MODEL_PATH is not None:
    print("[INFO] loading model from: {}".format(PREV_MODEL_PATH))
    model = load_model(PREV_MODEL_PATH)
    print("[INFO] start at epoch: {}".format(START_AT))
    print("[INFO] old learning rate: {}".format(K.get_value(model.optimizer.lr)))
    K.set_value(model.optimizer.lr, LR)
    print("[INFO] new learning rate: {}".format(K.get_value(model.optimizer.lr)))

else:    
    opt = SGD(learning_rate=1e-3, momentum=0.9, nesterov=True)

    model.compile(loss="sparse_categorical_crossentropy",
                  optimizer=opt,
                  metrics=["accuracy"])

TRAINING_SIZE = encoder_input.shape[0]
    
model.fit(train_target_gen,
          batch_size=BATCH_SIZE,  
          epochs=10,   
          use_multiprocessing=False,
          workers=1,
          steps_per_epoch=TRAINING_SIZE//BATCH_SIZE,
          max_queue_size=10,
          callbacks=[callbacks])



Epoch 1/10
 1585/15685 [==>...........................] - ETA: 1:34:07 - loss: 8.6930 - accuracy: 0.1157

In [None]:
encoder_input_gen_close()
decoder_input_gen_close()
encoder_input_real_gen_close()