In [0]:
from __future__ import absolute_import, division, print_function, unicode_literals

try:
    %tensorflow_version 2.x
except Exception:
    pass
import tensorflow_datasets as tfds
import tensorflow as tf

import re, string
import time
import numpy as np
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

In [0]:
import sys, os
from google.colab import drive

drive.mount('/content/gdrive')
root_path = 'gdrive/My Drive/projet-alexandrin/'  # your new root path

sys.path.append(os.path.join(root_path, 'notebooks')) # for importing from utils.py

## Setup input pipeline

In [0]:
lang1_symbol = "fr"
lang2_symbol = "mm"

corpus_lang1 = []
with open(root_path+'data/corneille_racine_fr.txt', encoding = 'utf-8') as fin:
    for line in fin:
        corpus_lang1.append(line[:-1].lower().strip())

corpus_lang2 = []
with open(root_path+'data/corneille_racine_mm.txt', encoding = 'utf-8') as fin:
    for line in fin:
        corpus_lang2.append(line[:-1].lower().strip())

# verifying the corpus
print(len(corpus_lang1), len(corpus_lang2))
print(corpus_lang1[50000])
print(corpus_lang2[50000])   

# train_lang1, test_lang1, train_lang2, test_lang2 = train_test_split(np.array(corpus_lang1), np.array(corpus_lang2), test_size=0)
train_lang1 = np.copy(corpus_lang1)
train_lang2 = np.copy(corpus_lang2)

train_examples = tf.data.Dataset.from_tensor_slices((train_lang1, train_lang2))
# val_examples = tf.data.Dataset.from_tensor_slices((test_lang1, test_lang2))

for lang1, lang2 in train_examples:
    print(lang1.numpy().decode("utf-8") )
    print(lang2.numpy().decode("utf-8") )
    break

# Create a custom subwords tokenizer from the training dataset.
tokenizer_lang1 = tfds.features.text.SubwordTextEncoder.build_from_corpus(
    train_lang1, target_vocab_size=2**13)

tokenizer_lang2 = tfds.features.text.SubwordTextEncoder.build_from_corpus(
    train_lang2, target_vocab_size=2**13)

print(tokenizer_lang1.vocab_size, tokenizer_lang2.vocab_size)

In [0]:
BUFFER_SIZE = 80000
BATCH_SIZE = 64

In [0]:
# Add a start and end token to the input and target. 

def encode(lang1, lang2):
    encoded1 = [tokenizer_lang1.vocab_size] + tokenizer_lang1.encode(
      lang1.numpy()) + [tokenizer_lang1.vocab_size+1]

    encoded2 = [tokenizer_lang2.vocab_size] + tokenizer_lang2.encode(
      lang2.numpy()) + [tokenizer_lang2.vocab_size+1]

    return encoded1, encoded2

In [0]:
MAX_LENGTH = 100

In [0]:
def filter_max_length(x, y, max_length=MAX_LENGTH):
    return tf.logical_and(tf.size(x) <= max_length,
                        tf.size(y) <= max_length)

In [0]:
def tf_encode(lang1, lang2):
    return tf.py_function(encode, [lang1, lang2], [tf.int64, tf.int64])

In [0]:
train_dataset = train_examples.map(tf_encode)
train_dataset = train_dataset.filter(filter_max_length)
# cache the dataset to memory to get a speedup while reading from it.
train_dataset = train_dataset.cache()
train_dataset = train_dataset.shuffle(BUFFER_SIZE).padded_batch(
    BATCH_SIZE, padded_shapes=([-1], [-1]))
train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)

# val_dataset = val_examples.map(tf_encode)
# val_dataset = val_dataset.filter(filter_max_length).padded_batch(
#     BATCH_SIZE, padded_shapes=([-1], [-1]))

In [0]:
def get_angles(pos, i, d_model):
    angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
    return pos * angle_rates

In [0]:
def positional_encoding(position, d_model):
    angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                          np.arange(d_model)[np.newaxis, :],
                          d_model)

    # apply sin to even indices in the array; 2i
    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])

    # apply cos to odd indices in the array; 2i+1
    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])

    pos_encoding = angle_rads[np.newaxis, ...]

    return tf.cast(pos_encoding, dtype=tf.float32)

In [0]:
def create_padding_mask(seq):
    seq = tf.cast(tf.math.equal(seq, 0), tf.float32)

    # add extra dimensions to add the padding
    # to the attention logits.
    return seq[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)

In [0]:
def create_look_ahead_mask(size):
    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
    return mask  # (seq_len, seq_len)

In [0]:
def scaled_dot_product_attention(q, k, v, mask):
    """Calculate the attention weights.
    q, k, v must have matching leading dimensions.
    k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
    The mask has different shapes depending on its type(padding or look ahead) 
    but it must be broadcastable for addition.

    Args:
    q: query shape == (..., seq_len_q, depth)
    k: key shape == (..., seq_len_k, depth)
    v: value shape == (..., seq_len_v, depth_v)
    mask: Float tensor with shape broadcastable 
          to (..., seq_len_q, seq_len_k). Defaults to None.

    Returns:
    output, attention_weights
    """

    matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)

    # scale matmul_qk
    dk = tf.cast(tf.shape(k)[-1], tf.float32)
    scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

    # add the mask to the scaled tensor.
    if mask is not None:
        scaled_attention_logits += (mask * -1e9)  

    # softmax is normalized on the last axis (seq_len_k) so that the scores
    # add up to 1.
    attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)

    output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

    return output, attention_weights

In [0]:
class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model

        assert d_model % self.num_heads == 0

        self.depth = d_model // self.num_heads

        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)

        self.dense = tf.keras.layers.Dense(d_model)

    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, depth).
        Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
        """
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, v, k, q, mask):
        batch_size = tf.shape(q)[0]

        q = self.wq(q)  # (batch_size, seq_len, d_model)
        k = self.wk(k)  # (batch_size, seq_len, d_model)
        v = self.wv(v)  # (batch_size, seq_len, d_model)

        q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
        k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
        v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)

        # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
        # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
        scaled_attention, attention_weights = scaled_dot_product_attention(
            q, k, v, mask)

        scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)

        concat_attention = tf.reshape(scaled_attention, 
                                      (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)

        output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)

        return output, attention_weights

In [0]:
def point_wise_feed_forward_network(d_model, dff):
    return tf.keras.Sequential([
      tf.keras.layers.Dense(dff, activation='relu'),  # (batch_size, seq_len, dff)
      tf.keras.layers.Dense(d_model)  # (batch_size, seq_len, d_model)
    ])

In [0]:
class EncoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(EncoderLayer, self).__init__()

        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = point_wise_feed_forward_network(d_model, dff)

        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)

    def call(self, x, training, mask):

        attn_output, _ = self.mha(x, x, x, mask)  # (batch_size, input_seq_len, d_model)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(x + attn_output)  # (batch_size, input_seq_len, d_model)

        ffn_output = self.ffn(out1)  # (batch_size, input_seq_len, d_model)
        ffn_output = self.dropout2(ffn_output, training=training)
        out2 = self.layernorm2(out1 + ffn_output)  # (batch_size, input_seq_len, d_model)

        return out2

In [0]:
class DecoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(DecoderLayer, self).__init__()

        self.mha1 = MultiHeadAttention(d_model, num_heads)
        self.mha2 = MultiHeadAttention(d_model, num_heads)

        self.ffn = point_wise_feed_forward_network(d_model, dff)

        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)
        self.dropout3 = tf.keras.layers.Dropout(rate)


    def call(self, x, enc_output, training, 
           look_ahead_mask, padding_mask):
    # enc_output.shape == (batch_size, input_seq_len, d_model)

        attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask)  # (batch_size, target_seq_len, d_model)
        attn1 = self.dropout1(attn1, training=training)
        out1 = self.layernorm1(attn1 + x)

        attn2, attn_weights_block2 = self.mha2(
            enc_output, enc_output, out1, padding_mask)  # (batch_size, target_seq_len, d_model)
        attn2 = self.dropout2(attn2, training=training)
        out2 = self.layernorm2(attn2 + out1)  # (batch_size, target_seq_len, d_model)

        ffn_output = self.ffn(out2)  # (batch_size, target_seq_len, d_model)
        ffn_output = self.dropout3(ffn_output, training=training)
        out3 = self.layernorm3(ffn_output + out2)  # (batch_size, target_seq_len, d_model)

        return out3, attn_weights_block1, attn_weights_block2

In [0]:
class Encoder(tf.keras.layers.Layer):
    def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
                   maximum_position_encoding, rate=0.1):
        super(Encoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
        self.pos_encoding = positional_encoding(maximum_position_encoding, 
                                                self.d_model)


        self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate) 
                           for _ in range(num_layers)]

        self.dropout = tf.keras.layers.Dropout(rate)

    def call(self, x, training, mask):

        seq_len = tf.shape(x)[1]

        # adding embedding and position encoding.
        x = self.embedding(x)  # (batch_size, input_seq_len, d_model)
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x += self.pos_encoding[:, :seq_len, :]

        x = self.dropout(x, training=training)

        for i in range(self.num_layers):
          x = self.enc_layers[i](x, training, mask)

        return x  # (batch_size, input_seq_len, d_model)

In [0]:
class Decoder(tf.keras.layers.Layer):
    def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size,
                   maximum_position_encoding, rate=0.1):
        super(Decoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
        self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)

        self.dec_layers = [DecoderLayer(d_model, num_heads, dff, rate) 
                           for _ in range(num_layers)]
        self.dropout = tf.keras.layers.Dropout(rate)

    def call(self, x, enc_output, training, 
               look_ahead_mask, padding_mask):

        seq_len = tf.shape(x)[1]
        attention_weights = {}

        x = self.embedding(x)  # (batch_size, target_seq_len, d_model)
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x += self.pos_encoding[:, :seq_len, :]

        x = self.dropout(x, training=training)

        for i in range(self.num_layers):
            x, block1, block2 = self.dec_layers[i](x, enc_output, training,
                                                     look_ahead_mask, padding_mask)

            attention_weights['decoder_layer{}_block1'.format(i+1)] = block1
            attention_weights['decoder_layer{}_block2'.format(i+1)] = block2

        # x.shape == (batch_size, target_seq_len, d_model)
        return x, attention_weights

In [0]:
class Transformer(tf.keras.Model):
    def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, 
               target_vocab_size, pe_input, pe_target, rate=0.1):
        super(Transformer, self).__init__()

        self.encoder = Encoder(num_layers, d_model, num_heads, dff, 
                               input_vocab_size, pe_input, rate)

        self.decoder = Decoder(num_layers, d_model, num_heads, dff, 
                               target_vocab_size, pe_target, rate)

        self.final_layer = tf.keras.layers.Dense(target_vocab_size)

    def call(self, inp, tar, training, enc_padding_mask, 
               look_ahead_mask, dec_padding_mask):

        enc_output = self.encoder(inp, training, enc_padding_mask)  # (batch_size, inp_seq_len, d_model)

        # dec_output.shape == (batch_size, tar_seq_len, d_model)
        dec_output, attention_weights = self.decoder(
            tar, enc_output, training, look_ahead_mask, dec_padding_mask)

        final_output = self.final_layer(dec_output)  # (batch_size, tar_seq_len, target_vocab_size)

        return final_output, attention_weights

## Set hyperparameters

To keep this example small and relatively fast, the values for *num_layers, d_model, and dff* have been reduced. 

The values used in the base model of transformer were; *num_layers=6*, *d_model = 512*, *dff = 2048*. See the [paper](https://arxiv.org/abs/1706.03762) for all the other versions of the transformer.

Note: By changing the values below, you can get the model that achieved state of the art on many tasks.

In [0]:
# num_layers = 4
# d_model = 128
# dff = 512
# num_heads = 8

num_layers = 4
d_model = 256
dff = 1024
num_heads = 8

input_vocab_size = tokenizer_lang1.vocab_size + 2
target_vocab_size = tokenizer_lang2.vocab_size + 2
dropout_rate = 0.1

In [0]:
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model, warmup_steps=4000):
        super(CustomSchedule, self).__init__()

        self.d_model = d_model
        self.d_model = tf.cast(self.d_model, tf.float32)

        self.warmup_steps = warmup_steps

    def __call__(self, step):
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)

        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

In [0]:
learning_rate = CustomSchedule(d_model)

optimizer_1to2 = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, 
                                     epsilon=1e-9)
optimizer_2to1 = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, 
                                     epsilon=1e-9)

In [0]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')

In [0]:
def loss_function(real, pred):
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss_ = loss_object(real, pred)

    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask

    return tf.reduce_mean(loss_)

In [0]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
    name='train_accuracy')

## Training and checkpointing

In [0]:
transformer_1to2 = Transformer(num_layers, d_model, num_heads, dff,
                          input_vocab_size, target_vocab_size, 
                          pe_input=input_vocab_size, 
                          pe_target=target_vocab_size,
                          rate=dropout_rate)

transformer_2to1 = Transformer(num_layers, d_model, num_heads, dff,
                          target_vocab_size, input_vocab_size, 
                          pe_input=target_vocab_size, 
                          pe_target=input_vocab_size,
                          rate=dropout_rate)

In [0]:
def create_masks(inp, tar):
    # Encoder padding mask
    enc_padding_mask = create_padding_mask(inp)

    # Used in the 2nd attention block in the decoder.
    # This padding mask is used to mask the encoder outputs.
    dec_padding_mask = create_padding_mask(inp)

    # Used in the 1st attention block in the decoder.
    # It is used to pad and mask future tokens in the input received by 
    # the decoder.
    look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
    dec_target_padding_mask = create_padding_mask(tar)
    combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)

    return enc_padding_mask, combined_mask, dec_padding_mask

Create the checkpoint path and the checkpoint manager. This will be used to save checkpoints every `n` epochs.

In [0]:
checkpoint_path_1to2 = root_path + "transformer_" + lang1_symbol + "_to_" + lang2_symbol + "_" + str(num_layers) + "_" + str(d_model) + "_" + str(dff) + "_checkpoints/train"

ckpt_1to2 = tf.train.Checkpoint(transformer=transformer_1to2,
                           optimizer=optimizer_1to2)

ckpt_manager_1to2 = tf.train.CheckpointManager(ckpt_1to2, checkpoint_path_1to2, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager_1to2.latest_checkpoint:
    ckpt_1to2.restore(ckpt_manager_1to2.latest_checkpoint)
    print ('1to2: Latest checkpoint restored!!')

In [0]:
checkpoint_path_2to1 = root_path + "transformer_" + lang2_symbol + "_to_" + lang1_symbol + "_" + str(num_layers) + "_" + str(d_model) + "_" + str(dff) + "_checkpoints/train"

ckpt_2to1 = tf.train.Checkpoint(transformer=transformer_2to1,
                           optimizer=optimizer_2to1)

ckpt_manager_2to1 = tf.train.CheckpointManager(ckpt_2to1, checkpoint_path_2to1, max_to_keep=2)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager_2to1.latest_checkpoint:
    ckpt_2to1.restore(ckpt_manager_2to1.latest_checkpoint)
    print ('2to1: Latest checkpoint restored!!')

The target is divided into tar_inp and tar_real. tar_inp is passed as an input to the decoder. `tar_real` is that same input shifted by 1: At each location in `tar_input`, `tar_real` contains the  next token that should be predicted.

For example, `sentence` = "SOS A lion in the jungle is sleeping EOS"

`tar_inp` =  "SOS A lion in the jungle is sleeping"

`tar_real` = "A lion in the jungle is sleeping EOS"

The transformer is an auto-regressive model: it makes predictions one part at a time, and uses its output so far to decide what to do next. 

During training this example uses teacher-forcing (like in the [text generation tutorial](./text_generation.ipynb)). Teacher forcing is passing the true output to the next time step regardless of what the model predicts at the current time step.

As the transformer predicts each word, *self-attention* allows it to look at the previous words in the input sequence to better predict the next word.

To prevent the model from peaking at the expected output the model uses a look-ahead mask.

In [0]:
def top_p_logits(logits, p=0.9, temperature=0.7):
    logits = logits / temperature
    # with tf.variable_scope('top_p_logits'):
    logits_sort = tf.sort(logits, direction='DESCENDING')
    probs_sort = tf.nn.softmax(logits_sort)
    probs_sort_sums = tf.cumsum(probs_sort, axis=1, exclusive=True)
    logits_sort_masked = tf.where(
        probs_sort_sums < p, logits_sort,
        tf.ones_like(logits, dtype=logits.dtype) * 1e10) # [batchsize, vocab]
    min_logits = tf.reduce_min(logits_sort_masked, axis=1, keepdims=True) # [batchsize, 1]
    return tf.where(
        logits < min_logits,
        tf.ones_like(logits, dtype=logits.dtype) * -1e10,
        logits,
    )

In [0]:
# The @tf.function trace-compiles train_step into a TF graph for faster
# execution. The function specializes to the precise shape of the argument
# tensors. To avoid re-tracing due to the variable sequence lengths or variable
# batch sizes (the last batch is smaller), use input_signature to specify
# more generic shapes.

train_step_signature = [
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]

@tf.function(input_signature=train_step_signature)
def train_step_1to2(inp, tar):
            
    transformer = transformer_1to2
    optimizer = optimizer_1to2
        
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]

    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)

    with tf.GradientTape() as tape:
        predictions, _ = transformer(inp, tar_inp, 
                                     True, 
                                     enc_padding_mask, 
                                     combined_mask, 
                                     dec_padding_mask)
        # predictions = top_p_logits(predictions)
        loss = loss_function(tar_real, predictions)

        gradients = tape.gradient(loss, transformer.trainable_variables)    
        optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))

        train_loss(loss)
        train_accuracy(tar_real, predictions)

@tf.function(input_signature=train_step_signature)
def train_step_2to1(inp, tar):
            
    transformer = transformer_2to1
    optimizer = optimizer_2to1
        
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]

    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)

    with tf.GradientTape() as tape:
        predictions, _ = transformer(inp, tar_inp, 
                                     True, 
                                     enc_padding_mask, 
                                     combined_mask, 
                                     dec_padding_mask)
        # predictions = top_p_logits(predictions)
        loss = loss_function(tar_real, predictions)

        gradients = tape.gradient(loss, transformer.trainable_variables)    
        optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))

        train_loss(loss)
        train_accuracy(tar_real, predictions)

In [0]:
def training(epochs = 10, direction_1to2=True):
    
    if direction_1to2:
        ckpt = ckpt_1to2
        ckpt_manager = ckpt_manager_1to2
        checkpoint_path = checkpoint_path_1to2
    else:
        ckpt = ckpt_2to1
        ckpt_manager = ckpt_manager_2to1
        checkpoint_path = checkpoint_path_2to1
        
    for epoch in range(epochs):
        start = time.time()

        train_loss.reset_states()
        train_accuracy.reset_states()

        # reshuffle dataset
        train_dataset = train_examples.map(tf_encode)
        train_dataset = train_dataset.filter(filter_max_length)
        train_dataset = train_dataset.cache()
        train_dataset = train_dataset.shuffle(BUFFER_SIZE).padded_batch(
            BATCH_SIZE, padded_shapes=([-1], [-1]))
        train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)

        # inp -> lang1, tar -> lang2
        for (batch, (inp, tar)) in enumerate(train_dataset):
            if direction_1to2:
                train_step_1to2(inp, tar)
            else:
                train_step_2to1(tar, inp)

            if batch % 50 == 0:
                  print ('Epoch {} Batch {} Loss {:.4f} Accuracy {:.4f}'.format(
              epoch + 1, batch, train_loss.result(), train_accuracy.result()))

        if (epoch + 1) % 1 == 0:
            ckpt_save_path = ckpt_manager.save()
            print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                                 ckpt_save_path))

            print ('Epoch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1, 
                                                        train_loss.result(), 
                                                        train_accuracy.result()))

            print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))

## Evaluate

The following steps are used for evaluation:

* Encode the input sentence using the Portuguese tokenizer (`tokenizer_pt`). Moreover, add the start and end token so the input is equivalent to what the model is trained with. This is the encoder input.
* The decoder input is the `start token == tokenizer_en.vocab_size`.
* Calculate the padding masks and the look ahead masks.
* The `decoder` then outputs the predictions by looking at the `encoder output` and its own output (self-attention).
* Select the last word and calculate the argmax of that.
* Concatentate the predicted word to the decoder input as pass it to the decoder.
* In this approach, the decoder predicts the next word based on the previous words it predicted.

Note: The model used here has less capacity to keep the example relatively faster so the predictions maybe less right. To reproduce the results in the paper, use the entire dataset and base transformer model or transformer XL, by changing the hyperparameters above.

In [0]:
def evaluate(inp_sentence, direction_1to2=True, top_p=0.9, temp=0.7):

    if direction_1to2:
        tok_lang1 = tokenizer_lang1
        tok_lang2 = tokenizer_lang2
        transformer = transformer_1to2
        optimizer = optimizer_1to2
    else:
        tok_lang1 = tokenizer_lang2
        tok_lang2 = tokenizer_lang1
        transformer = transformer_2to1
        optimizer = optimizer_2to1
        
    start_token = [tok_lang1.vocab_size]
    end_token = [tok_lang1.vocab_size + 1]

    # inp sentence is portuguese, hence adding the start and end token
    inp_sentence = start_token + tok_lang1.encode(inp_sentence) + end_token
    encoder_input = tf.expand_dims(inp_sentence, 0)

    # as the target is english, the first word to the transformer should be the
    # english start token.
    decoder_input = [tok_lang2.vocab_size]
    output = tf.expand_dims(decoder_input, 0)

    for i in range(MAX_LENGTH):
        enc_padding_mask, combined_mask, dec_padding_mask = create_masks(
            encoder_input, output)

        # predictions.shape == (batch_size, seq_len, vocab_size)
        predictions, attention_weights = transformer(encoder_input, 
                                                     output,
                                                     False,
                                                     enc_padding_mask,
                                                     combined_mask,
                                                     dec_padding_mask)

        # select the last word from the seq_len dimension
        predictions = predictions[: ,-1:, :]  # (batch_size, 1, vocab_size)

        predictions = top_p_logits(predictions[0], p=top_p, temperature=temp)
        # predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)
        predicted_id = tf.cast(tf.random.categorical(predictions, 1), tf.int32)
        # _, predicted_indices = tf.nn.top_k(predictions, 2)

        # return the result if the predicted_id is equal to the end token
        if predicted_id == tok_lang2.vocab_size+1:
            if i >= 12 or not direction_1to2:
                return tf.squeeze(output, axis=0), attention_weights
            else:  # try again
                predicted_id = tf.cast(tf.random.categorical(predictions, 1), tf.int32)

        # concatentate the predicted_id to the output which is given to the decoder
        # as its input.
        output = tf.concat([output, predicted_id], axis=-1)

    return tf.squeeze(output, axis=0), attention_weights

In [0]:
def plot_attention_weights(attention, sentence, result, layer, direction_1to2=True):
    
    if direction_1to2:
        tok_lang1 = tokenizer_lang1
        tok_lang2 = tokenizer_lang2
        transformer = transformer_1to2
        optimizer = optimizer_1to2
    else:
        tok_lang1 = tokenizer_lang2
        tok_lang2 = tokenizer_lang1
        transformer = transformer_2to1
        optimizer = optimizer_2to1
        
    fig = plt.figure(figsize=(16, 8))

    sentence = tok_lang1.encode(sentence)

    attention = tf.squeeze(attention[layer], axis=0)

    for head in range(attention.shape[0]):
        ax = fig.add_subplot(2, 4, head+1)

        # plot the attention weights
        ax.matshow(attention[head][:-1, :], cmap='viridis')

        fontdict = {'fontsize': 10}

        ax.set_xticks(range(len(sentence)+2))
        ax.set_yticks(range(len(result)))

        ax.set_ylim(len(result)-1.5, -0.5)

        ax.set_xticklabels(
            ['<start>']+[tok_lang1.decode([i]) for i in sentence]+['<end>'], 
            fontdict=fontdict, rotation=90)

        ax.set_yticklabels([tok_lang2.decode([i]) for i in result 
                            if i < tok_lang2.vocab_size], 
                           fontdict=fontdict)

        ax.set_xlabel('Head {}'.format(head+1))

    plt.tight_layout()
    plt.show()

In [0]:
def translate(sentence, plot='', top_p=0.9, temp=0.7, verbose=True, direction_1to2=True, real_translation=''):   
    
    if direction_1to2:
        tok_lang1 = tokenizer_lang1
        tok_lang2 = tokenizer_lang2
        transformer = transformer_1to2
        optimizer = optimizer_1to2
    else:
        tok_lang1 = tokenizer_lang2
        tok_lang2 = tokenizer_lang1
        transformer = transformer_2to1
        optimizer = optimizer_2to1
        
    result, attention_weights = evaluate(sentence, direction_1to2=direction_1to2, top_p=top_p, temp=temp)

    # predicted_sentence = tok_lang2.decode([i for i in result 
    #                                         if i < tokenizer_lang2.vocab_size])  
    
    predicted_sentence = ''.join([tok_lang2.decode([i]) for i in result 
                                            if i < tokenizer_lang2.vocab_size])

    if verbose:
        print('Input: {}'.format(sentence))
        print('Predicted translation: {}'.format(predicted_sentence))
        if len(real_translation) > 0:
            print('Real translation: {}'.format(real_translation))

    if plot:
        plot_attention_weights(attention_weights, sentence, result, plot, direction_1to2=direction_1to2)
    
    return predicted_sentence

You can pass different layers and attention blocks of the decoder to the `plot` parameter.

### Testing with random verses from Corneille / Racine

In [0]:
training(epochs = 10, direction_1to2=False)

In [0]:
for i in range(5):
    rand_idx = np.random.randint(len(corpus_lang1))
    translate(corpus_lang2[rand_idx], top_p=0.8, temp=0.6, verbose=True, direction_1to2=False, real_translation=corpus_lang1[rand_idx])

### Testing with text