<a href="https://colab.research.google.com/github/laurelkeys/machine-learning/blob/master/assignment-4/Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Transformer implementation
Replicating the code (and explanations) from [Transformer model for language understanding](https://www.tensorflow.org/tutorials/text/transformer) TensorFlow tutorial, to build upon it.

In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals

try:
    %tensorflow_version 2.x
except Exception:
    pass
import tensorflow as tf

import time
import numpy as np
import matplotlib.pyplot as plt

TensorFlow 2.x selected.


In [0]:
import warnings

## Positional encoding

Since this model doesn't contain any recurrence or convolution, positional encoding is added to give the model some information about the relative position of the words in the sentence. 

The positional encoding vector is added to the embedding vector. Embeddings represent a token in a d-dimensional space where tokens with similar meaning will be closer to each other. But the embeddings do not encode the relative position of words in a sentence. So after adding the positional encoding, words will be closer to each other based on the *similarity of their meaning and their position in the sentence*, in the d-dimensional space.

See the notebook on [positional encoding](https://github.com/tensorflow/examples/blob/master/community/en/position_encoding.ipynb) to learn more about it. The formula for calculating the positional encoding is as follows:

${PE_{(pos,\,2i)} = sin(pos / 10000^{2i / d_{model}})}\,,\;\;\;{PE_{(pos,\,2i+1)} = cos(pos / 10000^{2i / d_{model}})} $

In [0]:
def get_angles(pos, i, d_model):
    print("[get_angles] pos", pos.shape)
    print("[get_angles] i", i.shape)
    angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
    return pos * angle_rates

In [0]:
def positional_encoding(position, d_model):
    angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                            np.arange(d_model)[np.newaxis, :],
                            d_model)
    print("[positional_encoding] angle_rads", angle_rads.shape)
    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])

    pos_encoding = angle_rads[np.newaxis, ...]
    return tf.cast(pos_encoding, dtype=tf.float32)

## Masking


In [0]:
#    atari image    ==>   resized image   ==MobileNetV2==>  feature vector  ==GlobalAveragePooling2D==>   token
# (bs, 210, 160, 3) ==> (bs, 160, 160, 3) ==MobileNetV2==> (bs, 5, 5, 1280) ==GlobalAveragePooling2D==> (bs, 1280)

# (bs, input_seq_len, d_model) ==Transformer==> ...
# bs            : batch_size
# input_seq_len : 1000 for training (could be less if we want to use it as a policy network, predicting actions on-the-fly)
# d_model       : 1280 (MobileNetV2 output with GlobalAveragePooling2D)

Mask all the pad tokens in the batch of sequence. It ensures that the model does not treat padding as the input. The mask indicates where pad value `PAD_VALUE` is present: it outputs a `1` at those locations, and a `0` otherwise.

In [0]:
PAD_VALUE = 0 # FIXME

def create_padding_mask(seq):
    seq = tf.cast(tf.math.equal(seq, PAD_VALUE), tf.float32)

    # add extra dimensions to add the padding to the attention logits
    return seq[:, tf.newaxis, tf.newaxis, :] # (batch_size, 1, 1, seq_len)

The look-ahead mask is used to mask the future tokens in a sequence. In other words, the mask indicates which entries should not be used.

This means that to predict the third token, only the first and second tokens will be used. Similarly to predict the fourth token, only the first, second and the third tokens will be used and so on.

In [0]:
def create_look_ahead_mask(size):
    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
    return mask # (seq_len, seq_len)

In [8]:
create_look_ahead_mask(size=4)

<tf.Tensor: id=7, shape=(4, 4), dtype=float32, numpy=
array([[0., 1., 1., 1.],
       [0., 0., 1., 1.],
       [0., 0., 0., 1.],
       [0., 0., 0., 0.]], dtype=float32)>

## Scaled dot product attention
The attention function used by the transformer takes three inputs: Q (query), K (key), V (value). The equation used to calculate the attention weights is: ${Attention(Q, K, V) = softmax_k(\frac{QK^T}{\sqrt{d_k}}) V} $

The mask is multiplied by $-10^9$ and summed with the scaled matrix multiplication of Q and K, and is applied immediately before a softmax. The goal is to zero out these cells, as large negative inputs to softmax are near zero in the output.

<img src="https://www.tensorflow.org/images/tutorials/transformer/scaled_attention.png" width="300" alt="scaled_dot_product_attention">

In [0]:
def scaled_dot_product_attention(q, k, v, mask=None):
    ''' Calculate Attention(q, k, v) and the attention weights.
        Args:
          q: query shape == (..., seq_len_q, depth)
          k: key shape   == (..., seq_len_k, depth)
          v: value shape == (..., seq_len_v, depth_v)
          mask: float tensor with shape broadcastable to (..., seq_len_q, seq_len_k)
        Returns:
          output, attention_weights '''
    print("[scaled_dot_product_attention]")
    print("- q", q.shape)
    print("- k", k.shape)
    print("- v", v.shape)
    
    # assert q.shape[:-2] == k.shape[:-2] == v.shape[:-2], "q, k, v must have matching leading dimensions"
    # assert k.shape[-2] == v.shape[-2], "k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v"

    matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k)

    # scale matmul_qk
    d_k = tf.cast(tf.shape(k)[-1], tf.float32)
    scaled_attention_logits = matmul_qk / tf.math.sqrt(d_k)

    # add the (padding or look ahead) mask to the scaled tensor
    if mask is not None:
        scaled_attention_logits += (mask * -1e9)

    # softmax is normalized on the last axis (seq_len_k) so that the scores add up to 1
    attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k)
    
    output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v)
    
    return output, attention_weights

As the softmax normalization is done on K, its values decide the amount of importance given to Q.

The output represents the multiplication of the attention weights and the V (value) vector. This ensures that the tokens you want to focus on are kept as-is and the irrelevant tokens are flushed out.

### Example

In [0]:
def print_out(q, k, v):
    temp_out, temp_attn = scaled_dot_product_attention(q, k, v, None)
    print('Attention weights are:', temp_attn)
    print('Output is:            ', temp_out)
np.set_printoptions(suppress=True)

In [0]:
# def print_out_shape(q, k, v):
#     temp_out, temp_attn = scaled_dot_product_attention(q, k, v, None)
#     print('Attention weights are:', temp_attn.shape)
#     print('Output is:            ', temp_out.shape)

# print_out_shape(
#     q=tf.random.uniform((64, 10, 24 // 8)), # (batch_size, seq_len, d_model_q)
#     k=tf.random.uniform((64, 10, 1280 // 8)), # (batch_size, seq_len, d_model)
#     v=tf.random.uniform((64, 10, 1280 // 8))) # (batch_size, seq_len, d_model)

In [0]:
temp_k = tf.constant([[10,0,0], [0,10,0], [0,0,10], [0,0,10]], dtype=tf.float32) # (4, 3)

temp_v = tf.constant([[   1,0], [  10,0], [ 100,5], [1000,6]], dtype=tf.float32) # (4, 2)

In [13]:
# This query aligns with the second key, so the second value is returned
temp_q = tf.constant([[0, 10, 0]], dtype=tf.float32) # (1, 3)
print_out(temp_q, temp_k, temp_v)

[scaled_dot_product_attention]
- q (1, 3)
- k (4, 3)
- v (4, 2)
Attention weights are: tf.Tensor([[0. 1. 0. 0.]], shape=(1, 4), dtype=float32)
Output is:             tf.Tensor([[10.  0.]], shape=(1, 2), dtype=float32)


In [14]:
# This query aligns with a repeated key (third and fourth), so all associated values get averaged
temp_q = tf.constant([[0, 0, 10]], dtype=tf.float32) # (1, 3)
print_out(temp_q, temp_k, temp_v)

[scaled_dot_product_attention]
- q (1, 3)
- k (4, 3)
- v (4, 2)
Attention weights are: tf.Tensor([[0.  0.  0.5 0.5]], shape=(1, 4), dtype=float32)
Output is:             tf.Tensor([[550.    5.5]], shape=(1, 2), dtype=float32)


In [15]:
# This query aligns equally with the first and second key, so their values get averaged
temp_q = tf.constant([[10, 10, 0]], dtype=tf.float32) # (1, 3)
print_out(temp_q, temp_k, temp_v)

[scaled_dot_product_attention]
- q (1, 3)
- k (4, 3)
- v (4, 2)
Attention weights are: tf.Tensor([[0.5 0.5 0.  0. ]], shape=(1, 4), dtype=float32)
Output is:             tf.Tensor([[5.5 0. ]], shape=(1, 2), dtype=float32)


In [16]:
# Pass all the queries together
temp_q = tf.constant([[0, 0, 10], [0, 10, 0], [10, 10, 0]], dtype=tf.float32) # (3, 3)
print_out(temp_q, temp_k, temp_v)

[scaled_dot_product_attention]
- q (3, 3)
- k (4, 3)
- v (4, 2)
Attention weights are: tf.Tensor(
[[0.  0.  0.5 0.5]
 [0.  1.  0.  0. ]
 [0.5 0.5 0.  0. ]], shape=(3, 4), dtype=float32)
Output is:             tf.Tensor(
[[550.    5.5]
 [ 10.    0. ]
 [  5.5   0. ]], shape=(3, 2), dtype=float32)


## Multi-head attention
Multi-head attention consists of four parts:
* Linear layers and split into heads
* Scaled dot-product attention
* Concatenation of heads
* Final linear layer

<img src="https://www.tensorflow.org/images/tutorials/transformer/multi_head_attention.png" width="300" alt="multi-head attention">

Each multi-head attention block gets three inputs; Q (query), K (key), and V (value). These are put through linear (Dense) layers and split up into multiple heads. 

The `scaled_dot_product_attention` defined above is applied to each head (broadcasted for efficiency). An appropriate mask must be used in the attention step.  The attention output for each head is then concatenated (using `tf.transpose`, and `tf.reshape`) and put through a final `Dense` layer.

Instead of one single attention head, Q, K, and V are split into multiple heads because it allows the model to jointly attend to information at different positions from different representational spaces. After the split each head has a reduced dimensionality, so the total computation cost is the same as a single head attention with full dimensionality.

In [0]:
class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads):
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.depth = d_model // num_heads

        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)

        # NOTE I think this is d_model_q (and not d_model)
        self.dense = tf.keras.layers.Dense(d_model)
        
    def split_heads(self, x, batch_size):
        ''' Split the last dimension into (num_heads, depth).
            Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth) '''
        print("[split_heads] batch_size", batch_size)
        print("[split_heads] x", x)
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        print("[split_heads] x (reshape)", x)
        return tf.transpose(x, perm=[0, 2, 1, 3])
    
    def call(self, v, k, q, mask=None):
        print("[MultiHeadAttention.call]")
        print("- v", v.shape)
        print("- k", k.shape)
        print("- q", q.shape)
        batch_size = tf.shape(q)[0]
        
        q = self.wq(q) # (batch_size, seq_len, d_model)
        k = self.wk(k) # (batch_size, seq_len, d_model)
        v = self.wv(v) # (batch_size, seq_len, d_model)

        # obs.: seq_len_k == seq_len_v
        q = self.split_heads(q, batch_size)   # (batch_size, num_heads, seq_len_q, depth)
        k = self.split_heads(k, batch_size)   # (batch_size, num_heads, seq_len_k, depth)
        v = self.split_heads(v, batch_size)   # (batch_size, num_heads, seq_len_v, depth)

        # attention.shape         == (batch_size, num_heads, seq_len_q, depth)
        # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
        attention, attention_weights = scaled_dot_product_attention(q, k, v, mask)

        print("- scaled_dot_product_attention", attention.shape, attention_weights.shape)

        attention = tf.transpose(attention, perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth)
        # NOTE I think this is d_model (and not d_model)
        print("attention (transposed)", attention.shape)
        concat_attention = tf.reshape(attention, (batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)
        print("concat_attention", concat_attention.shape)
        output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model)
        print("output", concat_attention.shape)

        return output, attention_weights

In [18]:
sample_multiheadattention_layer = MultiHeadAttention(24, 8)

x_sample = tf.random.uniform((64, 100, 24))
sample_multiheadattention_layer_output, _ = sample_multiheadattention_layer(
    x_sample, x_sample, x_sample, None)

sample_multiheadattention_layer_output.shape # (batch_size, input_seq_len, d_model)

[MultiHeadAttention.call]
- v (64, 100, 24)
- k (64, 100, 24)
- q (64, 100, 24)
[split_heads] batch_size tf.Tensor(64, shape=(), dtype=int32)
[split_heads] x tf.Tensor(
[[[-0.54107285 -0.00728109 -0.18216816 ...  0.92696095  0.5914253
    1.0148381 ]
  [-0.6127286  -0.23946619 -0.5391787  ...  0.855666    0.19114496
    0.5038681 ]
  [-0.90806633 -0.16889724 -0.40189302 ...  1.0692891   0.32141468
    0.86086696]
  ...
  [-0.7096786   0.43065998 -0.5938403  ...  1.4208086   0.7842523
    0.7220441 ]
  [-0.5612319   0.16300179 -0.722717   ...  1.0951154   0.18745959
    0.72376585]
  [-0.10877287  0.30417973 -0.38185036 ...  0.97630656 -0.09506099
    0.96329385]]

 [[-0.14319947  0.07771003 -0.4694299  ...  1.2682343   0.29425114
    1.164026  ]
  [-0.8152554   0.1108598  -0.7929447  ...  0.87890166  0.42927733
    0.10784337]
  [-0.93613017 -0.06925837 -0.8938521  ...  1.0110605   0.82675856
    0.9098522 ]
  ...
  [-0.21211053 -0.26718116 -0.5990231  ...  0.67831725  0.5068508
    0.

TensorShape([64, 100, 24])

In [19]:
sample_multiheadattention_layer = MultiHeadAttention(24, 8)

x_sample = tf.random.uniform((64, 100, 24))
y_sample = tf.random.uniform((64, 100, 1280))
sample_multiheadattention_layer_output, _ = sample_multiheadattention_layer(
    y_sample, y_sample, x_sample, None)

sample_multiheadattention_layer_output.shape # (batch_size, input_seq_len, d_model)

[MultiHeadAttention.call]
- v (64, 100, 1280)
- k (64, 100, 1280)
- q (64, 100, 24)
[split_heads] batch_size tf.Tensor(64, shape=(), dtype=int32)
[split_heads] x tf.Tensor(
[[[ 0.16683272 -0.59813243 -0.35397846 ... -0.578677   -0.13041344
    0.8155793 ]
  [ 0.03901248 -0.33196    -0.21627085 ... -0.6664785  -0.29092458
    1.0396746 ]
  [-0.2546511   0.09528636 -0.05130914 ... -0.5835272  -0.30536366
    0.14557661]
  ...
  [-0.3313253   0.43887335  0.04946292 ... -0.6617574  -0.001282
    0.5420608 ]
  [-0.08868523  0.09859145 -0.14062305 ... -0.7323112  -0.06033888
    0.6241577 ]
  [-0.18615347  0.05659546  0.10600337 ... -0.6285707  -0.4801243
    0.04877013]]

 [[ 0.02350116  0.09001929 -0.08479033 ... -0.7803067  -0.68885875
    0.21111192]
  [-0.04867661  0.13468097 -0.13994189 ... -0.38211107 -0.5340219
    0.63712007]
  [ 0.11962267 -0.00640734 -0.10347633 ... -0.48123837 -0.40989676
    0.2644748 ]
  ...
  [-0.12170689  0.04515232 -0.4495984  ... -0.7578077  -0.303447
    0

TensorShape([64, 100, 24])

## Point wise feed forward network
Point wise feed forward network consists of two fully-connected layers with a ReLU activation in between.

In [0]:
def point_wise_feed_forward_network(d_model, dff):
    return tf.keras.Sequential([
        tf.keras.layers.Dense(dff, activation='relu'), # (batch_size, seq_len, dff)
        tf.keras.layers.Dense(d_model) # (batch_size, seq_len, d_model)
    ])

## Encoder and decoder
* The input sentence is passed through N encoder layers that generates an output for each token in the sequence.
* The decoder attends on the encoder's output and its own input (self-attention) to predict the next token.

<img src="https://www.tensorflow.org/images/tutorials/transformer/transformer.png" width="450" alt="transformer">

### Encoder layer
Each encoder layer consists of sublayers:
1. Multi-head attention 
2. Point wise feed forward networks. 

Each of these sublayers has a residual connection around it followed by a layer normalization. Residual connections help in avoiding the vanishing gradient problem in deep networks.

The output of each sublayer is `LayerNorm(x + Sublayer(x))`. The normalization is done on the `d_model` (last) axis. There are N encoder layers in the transformer.

In [0]:
class EncoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(EncoderLayer, self).__init__()

        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = point_wise_feed_forward_network(d_model, dff)

        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)

    def call(self, x, training, mask):
        # x.shape == (batch_size, input_seq_len, d_model)
        attn_output, _ = self.mha(x, x, x, mask) # (batch_size, input_seq_len, d_model)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(x + attn_output) # (batch_size, input_seq_len, d_model)

        ffn_output = self.ffn(out1) # (batch_size, input_seq_len, d_model)
        ffn_output = self.dropout2(ffn_output, training=training)
        out2 = self.layernorm2(out1 + ffn_output) # (batch_size, input_seq_len, d_model)
        
        return out2 # shape same as x

In [22]:
sample_encoder_layer = EncoderLayer(1280, 8, 256)

sample_encoder_layer_output = sample_encoder_layer(
    tf.random.uniform((32, 100, 1280)), False, None)

sample_encoder_layer_output.shape # (batch_size, input_seq_len, d_model)

[MultiHeadAttention.call]
- v (32, 100, 1280)
- k (32, 100, 1280)
- q (32, 100, 1280)
[split_heads] batch_size tf.Tensor(32, shape=(), dtype=int32)
[split_heads] x tf.Tensor(
[[[-0.42526615 -0.36801198 -0.34374177 ... -0.17914224  0.12970966
   -0.37857375]
  [-0.5500134  -0.6793786   0.34375203 ...  0.2678431  -0.0019578
   -0.2944937 ]
  [-0.3829678  -0.39250705 -0.15039334 ... -0.71087974  0.07928467
    0.07562059]
  ...
  [ 0.0051071  -0.26210847 -0.053799   ... -0.02437475  0.3991024
    0.00148789]
  [-0.01000144 -0.01715843 -0.31609607 ... -0.21229145  0.7269036
   -0.24878344]
  [ 0.3403857  -0.48360142 -0.13784464 ...  0.06015272  0.37704748
   -0.3303515 ]]

 [[-0.17915389 -0.8901028   0.11221346 ... -0.5745378  -0.28020296
   -0.36315474]
  [-0.40225878 -0.748935   -0.13502318 ... -0.1821831  -0.20107655
    0.10009203]
  [ 0.08806285 -0.41945893  0.11224972 ... -0.70161855 -0.4982281
    0.3775476 ]
  ...
  [ 0.04234444 -0.36408156 -0.53306276 ... -0.5379631   0.09675336
 

TensorShape([32, 100, 1280])

### Decoder layer
Each decoder layer consists of sublayers:
1. Masked multi-head attention (with look ahead mask and padding mask)
2. Multi-head attention (with padding mask)  
   V (value) and K (key) receive the *encoder output* as inputs  
   Q (query) receives the *output from the masked multi-head attention sublayer*
3. Point wise feed forward networks

Each of these sublayers has a residual connection around it followed by a layer normalization. The output of each sublayer is `LayerNorm(x + Sublayer(x))`. The normalization is done on the `d_model` (last) axis.

There are N decoder layers in the transformer.

As Q receives the output from decoder's first attention block, and K receives the encoder output, the attention weights represent the importance given to the decoder's input based on the encoder's output. In other words, the decoder predicts the next word by looking at the encoder output and self-attending to its own output.

In [0]:
# d_model == 1280 (comes from the encoder, i.e. images' feature vectors)
# d_model_q == 18 (comes from the previous decoder output, i.e. actions)
class DecoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.1, d_model_q=None):
        super(DecoderLayer, self).__init__()

        warnings.warn(f"d_model_q must be 18 as we have actions from 0 to 17 (not {d_model_q})")

        self.mha1 = MultiHeadAttention(d_model_q, num_heads)
        self.mha2 = MultiHeadAttention(d_model_q, num_heads)

        self.ffn = point_wise_feed_forward_network(d_model_q, dff)

        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)
        self.dropout3 = tf.keras.layers.Dropout(rate)    
    
    def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
        # x == (batch_size, target_seq_len, d_model_q)
        # enc_output.shape == (batch_size, input_seq_len, d_model)
        print("[DecoderLayer.call]")
        print("- x", x.shape)
        print("- enc_output", enc_output.shape)

        attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask) # (batch_size, target_seq_len, d_model_q)
        print("- attn1", attn1.shape)
        attn1 = self.dropout1(attn1, training=training)
        out1 = self.layernorm1(attn1 + x)
        print("- out1", out1.shape)


        attn2, attn_weights_block2 = self.mha2(enc_output, enc_output, out1, padding_mask) # (batch_size, target_seq_len, d_model_q)
        print("- attn2", attn2.shape)
        attn2 = self.dropout2(attn2, training=training)
        out2 = self.layernorm2(attn2 + out1) # (batch_size, target_seq_len, d_model_q)
        print("- out2", out1.shape)

        ffn_output = self.ffn(out2) # (batch_size, target_seq_len, d_model_q)
        ffn_output = self.dropout3(ffn_output, training=training)
        out3 = self.layernorm3(ffn_output + out2) # (batch_size, target_seq_len, d_model_q)
        print("- out3", out1.shape)

        return out3, attn_weights_block1, attn_weights_block2

In [24]:
sample_decoder_layer = DecoderLayer(1280, 8, 256, d_model_q=24) # 24 since 18 isn't divisible by 8

sample_decoder_layer_output, _, _ = sample_decoder_layer(tf.random.uniform((32, 100, 24)), 
                                                         sample_encoder_layer_output, False, None, None)

sample_decoder_layer_output.shape # (batch_size, target_seq_len, d_model)

[DecoderLayer.call]
- x (32, 100, 24)
- enc_output (32, 100, 1280)
[MultiHeadAttention.call]
- v (32, 100, 24)
- k (32, 100, 24)
- q (32, 100, 24)
[split_heads] batch_size tf.Tensor(32, shape=(), dtype=int32)
[split_heads] x tf.Tensor(
[[[ 0.48189515  0.45173067  0.47251472 ...  0.33818507 -0.35975766
   -0.2952594 ]
  [-0.01318001  0.4856878   0.06941309 ... -0.3875816  -0.65034574
   -0.27232563]
  [ 0.24929605  0.87308353  0.11611794 ...  0.5246794  -0.29649028
   -0.6419587 ]
  ...
  [ 0.6006093   1.0702689   0.96579677 ...  0.7169961  -0.4811596
    0.07532466]
  [ 0.6823505   0.5338369   0.512386   ... -0.19077457  0.5588001
   -0.05699671]
  [ 0.45850077  0.7353899   0.3858526  ... -0.40173766 -0.18212937
   -1.0086117 ]]

 [[ 0.27529624  0.6573878   0.44783533 ... -0.37546596 -0.02941979
   -0.61511135]
  [ 0.32717127  1.227311    1.0457554  ...  0.1248014   0.11014441
   -0.59923553]
  [ 0.05065921  1.0830885   0.8613958  ...  0.02423338 -0.6920737
   -0.6535962 ]
  ...
  [ 0.

  """


TensorShape([32, 100, 24])

### Encoder
The `Encoder` consists of:
1. Input Embedding
2. Positional Encoding
3. N encoder layers

The input is put through an embedding which is summed with the positional encoding. The output of this summation is the input to the encoder layers. The output of the encoder is the input to the decoder.

In [0]:
# NOTE replacing Embedding inside the Encoder with MobileNetV2 + GlobalAveragePooling2D
# The output should have shape == (batch_size, input_seq_len, d_model), where d_model will be 1280 as we're using MobileNetV2
class FeatureExtractor(tf.keras.layers.Layer):
    def __init__(self):
        super(FeatureExtractor, self).__init__()

        self.mobile_net_v2 = tf.keras.applications.MobileNetV2(input_shape=(160, 160, 3), 
                                                               include_top=False, 
                                                               weights='imagenet')
        self.mobile_net_v2.trainable = False
        
        self.global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
        
    def call(self, x):
        # reshaping (batch_size, input_seq_len, 160, 160, 3) to (batch_size * input_seq_len, 160, 160, 3)
        print("[FeatureExtractor.call]")
        print("- x", x.shape)
        x_shape = tf.shape(x)
        x = tf.reshape(x, (-1, x_shape[2], x_shape[3], x_shape[4]))
        print("- x (reshaped)", x.shape)
        x = self.mobile_net_v2(x) # (batch_size * input_seq_len, 5, 5, 1280)
        print("- x (mobile_net_v2)", x.shape)

        x = self.global_average_layer(x) # (batch_size * input_seq_len, 1280)
        print("- x (global_average_layer)", x.shape)
        x = tf.reshape(x, (x_shape[0], x_shape[1], -1))
        print("- x (reshaped)", x.shape)
        return x # (batch_size, input_seq_len, d_model), with d_model == 1280

In [26]:
sample_feature_extractor = FeatureExtractor()

temp_input = tf.random.uniform((2, 10, 160, 160, 3), dtype=tf.float32, minval=-1, maxval=1)

sample_feature_extractor_output = sample_feature_extractor(temp_input)

print (sample_feature_extractor_output.shape) # (batch_size, input_seq_len, d_model)

Downloading data from https://github.com/JonathanCMitchell/mobilenet_v2_keras/releases/download/v1.1/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
[FeatureExtractor.call]
- x (2, 10, 160, 160, 3)
- x (reshaped) (20, 160, 160, 3)
- x (mobile_net_v2) (20, 5, 5, 1280)
- x (global_average_layer) (20, 1280)
- x (reshaped) (2, 10, 1280)
(2, 10, 1280)


In [0]:
class Encoder(tf.keras.layers.Layer):
    def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, 
                 maximum_position_encoding, rate=0.1):
        super(Encoder, self).__init__()

        assert d_model == 1280, "d_model must be 1280 as we're using MobileNetV2"
        warnings.warn(f"input_vocab_size is no longer used (its value of {input_vocab_size} will be ignored)")

        self.d_model = d_model
        self.num_layers = num_layers

        # NOTE replacing Embedding with MobileNetV2 + GlobalAveragePooling2D
        # > self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
        self.feature_extractor = FeatureExtractor()
        self.pos_encoding = positional_encoding(maximum_position_encoding, self.d_model)

        self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate) for _ in range(num_layers)]

        self.dropout = tf.keras.layers.Dropout(rate)
        
    def call(self, x, training, mask=None):
        seq_len = tf.shape(x)[1]

        # adding embedding and position encoding        
        # NOTE replacing Embedding with MobileNetV2 + GlobalAveragePooling2D
        # > x = self.embedding(x) # (batch_size, input_seq_len, d_model)
        x = self.feature_extractor(x)
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x += self.pos_encoding[:, :seq_len, :]

        x = self.dropout(x, training=training)

        for i in range(self.num_layers):
            x = self.enc_layers[i](x, training, mask)

        return x # (batch_size, input_seq_len, d_model)

In [28]:
sample_encoder = Encoder(num_layers=2, d_model=1280, num_heads=8, 
                         dff=256, input_vocab_size=None,
                         maximum_position_encoding=10000)

temp_input = tf.random.uniform((8, 10, 160, 160, 3), dtype=tf.float32, minval=-1, maxval=1)

sample_encoder_output = sample_encoder(temp_input, training=False, mask=None)

print (sample_encoder_output.shape) # (batch_size, input_seq_len, d_model)

  import sys


[get_angles] pos (10000, 1)
[get_angles] i (1, 1280)
[positional_encoding] angle_rads (10000, 1280)
[FeatureExtractor.call]
- x (8, 10, 160, 160, 3)
- x (reshaped) (80, 160, 160, 3)
- x (mobile_net_v2) (80, 5, 5, 1280)
- x (global_average_layer) (80, 1280)
- x (reshaped) (8, 10, 1280)
[MultiHeadAttention.call]
- v (8, 10, 1280)
- k (8, 10, 1280)
- q (8, 10, 1280)
[split_heads] batch_size tf.Tensor(8, shape=(), dtype=int32)
[split_heads] x tf.Tensor(
[[[ 40.121887    33.45488    -61.34143    ... -18.814598   -26.589529
   -16.244675  ]
  [ 40.23766     29.425745   -59.885853   ... -23.056517   -35.827995
   -10.0563965 ]
  [ 39.576817    19.640247   -54.732      ... -27.72519    -20.652409
   -21.687325  ]
  ...
  [ 43.3747      23.411621   -53.553288   ... -24.422863   -25.91766
   -10.690371  ]
  [ 45.121532    35.356373   -55.92261    ... -27.39553    -21.244217
   -14.149176  ]
  [ 47.41779     37.425034   -59.02287    ... -24.363176   -25.37165
   -10.680425  ]]

 [[ 38.945335    2

### Decoder
The `Decoder` consists of:
1. Output Embedding
2. Positional Encoding
3. N decoder layers

The target is put through an embedding which is summed with the positional encoding. The output of this summation is the input to the decoder layers. The output of the decoder is the input to the final linear layer.

In [0]:
# d_model == 1280 (comes from the encoder, i.e. images' feature vectors)
# d_model_q == 18 (comes from the previous decoder output, i.e. actions)
class Decoder(tf.keras.layers.Layer):
    def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size,
                 maximum_position_encoding, rate=0.1, d_model_q=None):
        super(Decoder, self).__init__()

        warnings.warn(f"d_model_q must be 18 as we have actions from 0 to 17 (not {d_model_q})")
        self.d_model_q = d_model_q

        self.d_model = d_model
        self.num_layers = num_layers

        self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model_q)
        self.pos_encoding = positional_encoding(maximum_position_encoding, d_model_q)

        self.dec_layers = [DecoderLayer(d_model, num_heads, dff, rate, d_model_q) for _ in range(num_layers)]

        self.dropout = tf.keras.layers.Dropout(rate)
    
    def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
        print("[Decoder.call]")
        seq_len = tf.shape(x)[1]
        attention_weights = {}

        print("- x", x.shape)
        x = self.embedding(x) # (batch_size, target_seq_len, d_model_q)
        print("- embedding", x.shape)
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x += self.pos_encoding[:, :seq_len, :]
        print("- pos_encoding", x.shape)

        x = self.dropout(x, training=training)

        for i in range(self.num_layers):
            x, block1, block2 = self.dec_layers[i](x, enc_output, training, look_ahead_mask, padding_mask)
            attention_weights['decoder_layer{}_block1'.format(i+1)] = block1
            attention_weights['decoder_layer{}_block2'.format(i+1)] = block2

        # x.shape == (batch_size, target_seq_len, d_model)
        return x, attention_weights

In [30]:
sample_decoder = Decoder(num_layers=2, d_model=1280, num_heads=8, 
                         dff=256, target_vocab_size=24, 
                         maximum_position_encoding=1000, 
                         d_model_q=24)
temp_input = tf.random.uniform((8, 10), dtype=tf.int64, minval=0, maxval=18)

output, attn = sample_decoder(temp_input, 
                              enc_output=sample_encoder_output, 
                              training=False,
                              look_ahead_mask=None, 
                              padding_mask=None)

output.shape, attn['decoder_layer2_block2'].shape

  
  """


[get_angles] pos (1000, 1)
[get_angles] i (1, 24)
[positional_encoding] angle_rads (1000, 24)
[Decoder.call]
- x (8, 10)
- embedding (8, 10, 24)
- pos_encoding (8, 10, 24)
[DecoderLayer.call]
- x (8, 10, 24)
- enc_output (8, 10, 1280)
[MultiHeadAttention.call]
- v (8, 10, 24)
- k (8, 10, 24)
- q (8, 10, 24)
[split_heads] batch_size tf.Tensor(8, shape=(), dtype=int32)
[split_heads] x tf.Tensor(
[[[ 0.1846576  -2.3046997   2.3656986  ...  0.00086773 -0.11468428
   -1.7591511 ]
  [ 0.43114245 -1.743362    2.4719007  ... -0.35916042  0.15757172
   -1.5222808 ]
  [ 0.45701894  1.1633894   0.7002652  ... -1.3025328   0.418294
   -0.03856704]
  ...
  [ 0.05091113 -0.47977227  0.160038   ...  0.8587648  -1.5319704
    0.95539826]
  [ 1.3332108   1.7182977   0.46065867 ... -0.7948115  -0.5667691
    0.49901727]
  [ 1.6271     -0.7657595   2.2179337  ...  0.73262084 -0.4431193
   -0.42854157]]

 [[ 0.71568376 -0.24749458  1.3604336  ... -1.1180266  -0.2023731
   -1.9695544 ]
  [ 0.43114245 -1.74

(TensorShape([8, 10, 24]), TensorShape([8, 8, 10, 10]))

## Create the Transformer
Transformer consists of the encoder, decoder and a final linear layer. The output of the decoder is the input to the linear layer and its output is returned.

In [0]:
class Transformer(tf.keras.Model):
    def __init__(self, num_layers, d_model, num_heads, dff, 
                 input_vocab_size, target_vocab_size, pe_input, pe_target, rate=0.1, 
                 d_model_q=None):
        super(Transformer, self).__init__()

        self.encoder = Encoder(num_layers, d_model, num_heads, dff, input_vocab_size, pe_input, rate)
        self.decoder = Decoder(num_layers, d_model, num_heads, dff, target_vocab_size, pe_target, rate, d_model_q)
        self.final_layer = tf.keras.layers.Dense(target_vocab_size)
    
    def call(self, inp, tar, training, enc_padding_mask, look_ahead_mask, dec_padding_mask):
        enc_output = self.encoder(inp, training, enc_padding_mask) # (batch_size, inp_seq_len, d_model)

        # dec_output.shape == (batch_size, tar_seq_len, d_model_q)
        dec_output, attention_weights = self.decoder(tar, enc_output, training, look_ahead_mask, dec_padding_mask)

        final_output = self.final_layer(dec_output) # (batch_size, tar_seq_len, target_vocab_size)

        return final_output, attention_weights

In [32]:
sample_transformer = Transformer(
    num_layers=2, d_model=1280, num_heads=8, dff=256, 
    input_vocab_size=None, target_vocab_size=24, 
    pe_input=100, pe_target=100, # https://www.youtube.com/watch?v=HNtz05bhI1k
    d_model_q=24)

#temp_input = tf.random.uniform((64, 10, 160, 160, 3), dtype=tf.int64, minval=-1, maxval=1)
temp_input = tf.random.uniform((1, 10, 160, 160, 3), dtype=tf.float32, minval=-1, maxval=1)
temp_target = tf.random.uniform((1, 10), dtype=tf.int64, minval=0, maxval=17)

fn_out, _ = sample_transformer(temp_input, temp_target, training=False, 
                               enc_padding_mask=None, 
                               look_ahead_mask=None,
                               dec_padding_mask=None)

fn_out.shape # (batch_size, tar_seq_len, target_vocab_size)

  import sys


[get_angles] pos (100, 1)
[get_angles] i (1, 1280)
[positional_encoding] angle_rads (100, 1280)
[get_angles] pos (100, 1)
[get_angles] i (1, 24)
[positional_encoding] angle_rads (100, 24)
[FeatureExtractor.call]
- x (1, 10, 160, 160, 3)
- x (reshaped) (10, 160, 160, 3)
- x (mobile_net_v2) (10, 5, 5, 1280)
- x (global_average_layer) (10, 1280)
- x (reshaped) (1, 10, 1280)
[MultiHeadAttention.call]
- v (1, 10, 1280)
- k (1, 10, 1280)
- q (1, 10, 1280)
[split_heads] batch_size tf.Tensor(1, shape=(), dtype=int32)
[split_heads] x tf.Tensor(
[[[ 56.921673  -32.532948  -32.937977  ... -29.576687   -5.7768064
   -41.21531  ]
  [ 67.61027   -21.277859  -33.88309   ... -21.678205   10.258391
   -35.942688 ]
  [ 66.01861   -22.102505  -41.08621   ... -19.086481   14.193554
   -37.301594 ]
  ...
  [ 60.18357   -21.55305   -41.708633  ... -22.96283    -1.1476722
   -32.293625 ]
  [ 65.87725   -20.396446  -37.95369   ... -27.934143   -2.4106584
   -36.483242 ]
  [ 58.522194  -34.469723  -35.738914  

  
  """


output (1, 10, 1280)
[MultiHeadAttention.call]
- v (1, 10, 1280)
- k (1, 10, 1280)
- q (1, 10, 1280)
[split_heads] batch_size tf.Tensor(1, shape=(), dtype=int32)
[split_heads] x tf.Tensor(
[[[ 0.3111243   1.4982094   0.57353437 ...  0.60018426 -2.31104
   -0.33911234]
  [ 0.41222572  1.5755854   0.58739734 ...  0.71892196 -2.1569953
   -0.2447093 ]
  [ 0.5085004   1.5644644   0.6715011  ...  0.4178145  -1.8109815
   -0.04600961]
  ...
  [ 0.40403762  1.193186    0.2989804  ...  0.59429234 -2.2488184
   -0.02482776]
  [ 0.39705908  1.3270562   0.54203755 ...  0.71064806 -2.194132
   -0.17821893]
  [ 0.25477326  1.5205566   0.6290507  ...  0.6234362  -2.043972
   -0.18357845]]], shape=(1, 10, 1280), dtype=float32)
[split_heads] x (reshape) tf.Tensor(
[[[[ 0.3111243   1.4982094   0.57353437 ...  0.23499595  0.42426237
    -1.3419589 ]
   [-2.247429   -0.91567826  0.10888778 ...  1.4750679   1.656863
    -1.4544063 ]
   [-0.19408017  1.265449    0.643168   ...  0.56163186 -0.29036653
    -

TensorShape([1, 10, 24])

## Set hyperparameters

In [0]:
num_layers = 4
d_model = 1280
d_model_q = 24
dff = 512
#num_heads = 8
num_heads = 2 # FIXME fix asserts above to use 8 instead of 2

In [0]:
N_OF_INP_IMAGES = 1000
INP_IMAGE_SHAPE = (160, 160, 3) # Atari images are 210x160 RGB.. so you should resize this somewhere mate ;)
INP_IMAGE_SIZE  = np.prod(INP_IMAGE_SHAPE)

A total of 18 actions can be performed with the joystick: doing nothing, pressing the action button, going in one of 8 directions (up, down, left and right as well as the 4 diagonals) and going in any of these directions while pressing the button.

In [0]:
input_vocab_size = None # NOTE this is no longer used
target_vocab_size = 24 # actions to take
dropout_rate = 0.1

pe_size = 100

## Optimizer
Use the Adam optimizer with a custom learning rate scheduler according to the formula in the [paper](https://arxiv.org/abs/1706.03762): ${lrate = d_{model}^{-0.5} \cdot min(step{\_}num^{-0.5},\, step{\_}num \cdot warmup{\_}steps^{-1.5})}$


In [0]:
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model, warmup_steps=4000):
        super(CustomSchedule, self).__init__()
        self.d_model = d_model
        self.d_model = tf.cast(self.d_model, tf.float32)
        self.warmup_steps = warmup_steps

    def __call__(self, step):
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)
        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

In [0]:
learning_rate = CustomSchedule(d_model)

optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

## Loss and metrics
Since the target sequences are padded, it is important to apply a padding mask when calculating the loss.

In [0]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')

In [0]:
def loss_function(real, pred):
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss_ = loss_object(real, pred)

    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask

    return tf.reduce_mean(loss_)

In [0]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

## Training and checkpointing


In [41]:
transformer = Transformer(num_layers, d_model, num_heads, dff,
                          input_vocab_size, target_vocab_size, 
                          pe_input=pe_size, 
                          pe_target=pe_size,
                          rate=dropout_rate,
                          d_model_q=d_model_q)

  import sys


[get_angles] pos (100, 1)
[get_angles] i (1, 1280)
[positional_encoding] angle_rads (100, 1280)
[get_angles] pos (100, 1)
[get_angles] i (1, 24)
[positional_encoding] angle_rads (100, 24)


  
  """


In [0]:
def create_masks(inp, tar):
    # Encoder padding mask
    enc_padding_mask = create_padding_mask(inp)

    # Used in the 2nd attention block in the decoder.
    # This padding mask is used to mask the encoder outputs.
    dec_padding_mask = create_padding_mask(inp)

    # Used in the 1st attention block in the decoder.
    # It is used to pad and mask future tokens in the input received by 
    # the decoder.
    look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
    dec_target_padding_mask = create_padding_mask(tar)
    combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)

    return enc_padding_mask, combined_mask, dec_padding_mask

Create the checkpoint path and the checkpoint manager. This will be used to save checkpoints every `n` epochs.

In [0]:
checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(transformer=transformer,
                           optimizer=optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

### Setup input pipeline
Load data

In [0]:
import os

In [45]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [46]:
!ls drive/My\ Drive/unicamp/MC886/atari/data/

Breakout-v0_1000	      Breakout-v4_1000_actions.npy
Breakout-v0_1000_actions.npy  Pong-v0_1000
Breakout-v4_1000	      Pong-v0_1000_actions.npy


In [0]:
actions_taken = np.load('drive/My Drive/unicamp/MC886/atari/data/Breakout-v4_1000_actions.npy')
actions_taken = np.array(actions_taken, dtype='int8')

actions_data = np.array([actions_taken[:100] for _ in range(10)], dtype='int8')

In [0]:
def get_all_png(path):
    files = []
    files_full = []
    for root, dirs, file_names in os.walk(path):
        files.extend([name for name in file_names if '.png' in name.lower()])
        files_full.extend([os.path.join(root, name) for name in file_names if '.png' in name.lower()])
    return files, files_full

In [0]:
IMG_SIZE = 160 # All images will be resized to 160x160

def load_image(image_path):#, action):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_png(image, channels=3)
    image = tf.cast(image, tf.float32)
    image = (image/127.5) - 1
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    return image#, action
  
def load_all_imgs(paths, actions):
    print(paths[0])
    imgs = tf.map_fn(load_image, paths)
    return imgs, actions

In [50]:
_, img_name_vector = get_all_png('drive/My Drive/unicamp/MC886/atari/data/Breakout-v4_1000/')
print(len(img_name_vector))

loaded_imgs = [load_image(path) for path in img_name_vector[:100]]

img_name_all = [loaded_imgs for _ in range(10)]

1000


In [51]:
# Feel free to change batch_size according to your system configuration
image_dataset = tf.data.Dataset.from_tensor_slices((img_name_all, actions_data))

print(image_dataset.element_spec)

image_dataset = image_dataset.batch(2)
# .map(load_all_imgs, num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(2)

(TensorSpec(shape=(100, 160, 160, 3), dtype=tf.float32, name=None), TensorSpec(shape=(100,), dtype=tf.int8, name=None))


In [52]:
print(actions_taken.shape)
print(len(img_name_vector))

(1000,)
1000


In [53]:
train_dataset = image_dataset

train_dataset.element_spec

(TensorSpec(shape=(None, 100, 160, 160, 3), dtype=tf.float32, name=None),
 TensorSpec(shape=(None, 100), dtype=tf.int8, name=None))

### Training

The target is divided into `tar_inp` and `tar_real`. `tar_inp` is passed as an input to the decoder. `tar_real` is that same input shifted by 1: At each location in `tar_input`, `tar_real` contains the  next token that should be predicted.

The transformer is an auto-regressive model: it makes predictions one part at a time, and uses its output so far to decide what to do next. 

During training this example uses teacher-forcing. Teacher forcing is passing the true output to the next time step regardless of what the model predicts at the current time step.

As the transformer predicts each token, *self-attention* allows it to look at the previous tokens in the input sequence to better predict the next one.

To prevent the model from peaking at the expected output the model uses a look-ahead mask.

In [0]:
EPOCHS = 2

In [0]:
# The @tf.function trace-compiles train_step into a TF graph for faster
# execution. The function specializes to the precise shape of the argument
# tensors. To avoid re-tracing due to the variable sequence lengths or variable
# batch sizes (the last batch is smaller), use input_signature to specify
# more generic shapes.

train_step_signature = [
    tf.TensorSpec(shape=(None, None, 160, 160, 3), dtype=tf.float32),
    tf.TensorSpec(shape=(None, None), dtype=tf.int8),
]

_feature_extractor = FeatureExtractor()

@tf.function(input_signature=train_step_signature)
def train_step(inp, tar):
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]

    # > enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)
    inp_feat_vec = _feature_extractor(inp)
    print("[train_step]")
    print("- inp", inp.shape)
    print("- inp_feat_vec", inp_feat_vec.shape)

    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp_feat_vec, tar_inp)
    print("- enc_padding_mask", enc_padding_mask.shape)
    print("- combined_mask", combined_mask.shape)
    print("- dec_padding_mask", dec_padding_mask.shape)

    with tf.GradientTape() as tape:
        predictions, _ = transformer(inp, tar_inp, 
                                     True, 
                                     enc_padding_mask, 
                                     combined_mask, 
                                     dec_padding_mask)
        loss = loss_function(tar_real, predictions)

    gradients = tape.gradient(loss, transformer.trainable_variables)    
    optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))

    train_loss(loss)
    train_accuracy(tar_real, predictions)

In [56]:
for (batch, (inp, tar)) in enumerate(train_dataset):
    print(inp.shape, tar.shape)
    print(_feature_extractor(inp).shape)
    train_step(inp, tar)
    break

(2, 100, 160, 160, 3) (2, 100)
[FeatureExtractor.call]
- x (2, 100, 160, 160, 3)
- x (reshaped) (200, 160, 160, 3)
- x (mobile_net_v2) (200, 5, 5, 1280)
- x (global_average_layer) (200, 1280)
- x (reshaped) (2, 100, 1280)
(2, 100, 1280)
[FeatureExtractor.call]
- x (None, None, 160, 160, 3)
- x (reshaped) (None, None, None, None)
- x (mobile_net_v2) (None, None, None, 1280)
- x (global_average_layer) (None, 1280)
- x (reshaped) (None, None, None)
[train_step]
- inp (None, None, 160, 160, 3)
- inp_feat_vec (None, None, None)
- enc_padding_mask (None, 1, 1, None, None)
- combined_mask (None, 1, None, None)
- dec_padding_mask (None, 1, 1, None, None)
[FeatureExtractor.call]
- x (None, None, 160, 160, 3)
- x (reshaped) (None, None, None, None)
- x (mobile_net_v2) (None, None, None, 1280)
- x (global_average_layer) (None, 1280)
- x (reshaped) (None, None, None)
[MultiHeadAttention.call]
- v (None, None, 1280)
- k (None, None, 1280)
- q (None, None, 1280)
[split_heads] batch_size Tensor("tran

ValueError: ignored

Atari frame images are used as the input language and Atari actions are the target language.

In [0]:
for epoch in range(EPOCHS):
    start = time.time()

    train_loss.reset_states()
    train_accuracy.reset_states()

    # inp -> observations, tar -> actions
    for (batch, (inp, tar)) in enumerate(train_dataset):
        print(inp.shape, tar.shape)
        train_step(inp, tar)

        if batch % 50 == 0:
            print('Epoch {} Batch {} Loss {:.4f} Accuracy {:.4f}'.format(
                  epoch + 1, batch, train_loss.result(), train_accuracy.result()))
        
    if (epoch + 1) % 5 == 0:
        ckpt_save_path = ckpt_manager.save()
        print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                             ckpt_save_path))

    print ('Epoch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1, 
                                                         train_loss.result(), 
                                                         train_accuracy.result()))

    print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))

## Evaluate
TODO continue from [this section](https://www.tensorflow.org/tutorials/text/transformer#evaluate) onwards.