In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers.experimental import preprocessing

In [4]:
data = np.random.randint(0, 100, (100, 5))

In [7]:
class Encoder(tf.keras.layers.Layer):
    def __init__(self, input_vocab_size, embedding_dim, enc_units):
        super(Encoder, self).__init__()
        self.enc_units = enc_units
        self.input_vocab_size = input_vocab_size

        # The embedding layer converts tokens to vectors
        self.embedding = tf.keras.layers.Embedding(self.input_vocab_size,
                                                   embedding_dim)

        # The GRU RNN layer processes those vectors sequentially.
        self.gru = tf.keras.layers.GRU(self.enc_units,
                                       # Return the sequence and state
                                       return_sequences=True,
                                       return_state=True,
                                       recurrent_initializer='glorot_uniform')

    def call(self, tokens, state=None):
#         shape_checker = ShapeChecker()
#         shape_checker(tokens, ('batch', 's'))

        # 2. The embedding layer looks up the embedding for each token.
        vectors = self.embedding(tokens)
#         shape_checker(vectors, ('batch', 's', 'embed_dim'))

        # 3. The GRU processes the embedding sequence.
        #    output shape: (batch, s, enc_units)
        #    state shape: (batch, enc_units)
        output, state = self.gru(vectors, initial_state=state)
#         shape_checker(output, ('batch', 's', 'enc_units'))
#         shape_checker(state, ('batch', 'enc_units'))

        # 4. Returns the new sequence and its state.
        return output, state

In [11]:
embedding_dim = 256
units = 1024

encoder = Encoder(100, 256, 1024)
ex_encode_output, ex_encode_state = encoder(data)

In [13]:
ex_encode_output.shape, ex_encode_state.shape

(TensorShape([100, 5, 1024]), TensorShape([100, 1024]))

In [14]:
class BahdanauAttention(tf.keras.layers.Layer):
    def __init__(self, units):
        super().__init__()
        # For Eqn. (4), the  Bahdanau attention
        self.W1 = tf.keras.layers.Dense(units, use_bias=False)
        self.W2 = tf.keras.layers.Dense(units, use_bias=False)

        self.attention = tf.keras.layers.AdditiveAttention()

    def call(self, query, value, mask):
#         shape_checker = ShapeChecker()
#         shape_checker(query, ('batch', 't', 'query_units'))
#         shape_checker(value, ('batch', 's', 'value_units'))
#         shape_checker(mask, ('batch', 's'))

        # From Eqn. (4), `W1@ht`.
        w1_query = self.W1(query)
#         shape_checker(w1_query, ('batch', 't', 'attn_units'))

        # From Eqn. (4), `W2@hs`.
        w2_key = self.W2(value)
#         shape_checker(w2_key, ('batch', 's', 'attn_units'))

        query_mask = tf.ones(tf.shape(query)[:-1], dtype=bool)
        value_mask = mask

        context_vector, attention_weights = self.attention(
            inputs = [w1_query, value, w2_key],
            mask=[query_mask, value_mask],
            return_attention_scores = True,
        )
#         shape_checker(context_vector, ('batch', 't', 'value_units'))
#         shape_checker(attention_weights, ('batch', 't', 's'))

        return context_vector, attention_weights

In [15]:
attention_layer = BahdanauAttention(units)

In [17]:
(data != 0).shape

(100, 5)

In [19]:
# Later, the decoder will generate this attention query
example_attention_query = tf.random.normal(shape=[len(data), 2, 10])

# Attend to the encoded tokens

context_vector, attention_weights = attention_layer(
    query=example_attention_query,
    value=ex_encode_output,
    mask=(data != 0))

print(f'Attention result shape: (batch_size, query_seq_length, units):           {context_vector.shape}')
print(f'Attention weights shape: (batch_size, query_seq_length, value_seq_length): {attention_weights.shape}')

Attention result shape: (batch_size, query_seq_length, units):           (100, 2, 1024)
Attention weights shape: (batch_size, query_seq_length, value_seq_length): (100, 2, 5)


In [20]:
attention_weights.shape

TensorShape([100, 2, 5])

In [21]:
class Decoder(tf.keras.layers.Layer):
    def __init__(self, output_vocab_size, embedding_dim, dec_units):
        super(Decoder, self).__init__()
        self.dec_units = dec_units
        self.output_vocab_size = output_vocab_size
        self.embedding_dim = embedding_dim

        # For Step 1. The embedding layer convets token IDs to vectors
        self.embedding = tf.keras.layers.Embedding(self.output_vocab_size,
                                                   embedding_dim)

        # For Step 2. The RNN keeps track of what's been generated so far.
        self.gru = tf.keras.layers.GRU(self.dec_units,
                                       return_sequences=True,
                                       return_state=True,
                                       recurrent_initializer='glorot_uniform')

        # For step 3. The RNN output will be the query for the attention layer.
        self.attention = BahdanauAttention(self.dec_units)

        # For step 4. Eqn. (3): converting `ct` to `at`
        self.Wc = tf.keras.layers.Dense(dec_units, activation=tf.math.tanh,
                                        use_bias=False)

        # For step 5. This fully connected layer produces the logits for each
        # output token.
        self.fc = tf.keras.layers.Dense(self.output_vocab_size)

In [24]:
import typing
from typing import Any, Tuple

class DecoderInput(typing.NamedTuple):
    new_tokens: Any
    enc_output: Any
    mask: Any

class DecoderOutput(typing.NamedTuple):
    logits: Any
    attention_weights: Any

In [43]:
def call(self,
         inputs: DecoderInput,
         state=None) -> Tuple[DecoderOutput, tf.Tensor]:
    vectors = self.embedding(inputs.new_tokens)
    
    rnn_output, state = self.gru(vectors, initial_state=state)
    print("Hello world")
    print(rnn_output)
    
    context_vector, attention_weights = self.attention(
      query=rnn_output, value=inputs.enc_output, mask=inputs.mask)
    context_and_rnn_output = context_vector
    attention_vector = self.Wc(context_and_rnn_output)
    logits = self.fc(attention_vector)
    return DecoderOutput(logits, attention_weights), state

In [44]:
Decoder.call = call

In [45]:
decoder = Decoder(100, embedding_dim, units)

In [46]:
first_token = tf.constant([[0]] * data.shape[0])

dec_result, dec_state = decoder(
    inputs = DecoderInput(new_tokens=first_token,
                          enc_output=ex_encode_output,
                          mask=(data != 0)),
    state = ex_encode_state
)

print(f'logits shape: (batch_size, t, output_vocab_size) {dec_result.logits.shape}')
print(f'state shape: (batch_size, dec_units) {dec_state.shape}')

Hello world
tf.Tensor(
[[[-6.67360332e-03  1.58940423e-02  7.53728114e-03 ...  5.91070578e-03
   -6.21128175e-03  1.41745375e-03]]

 [[-1.07842200e-02  1.24648772e-02  6.80857105e-04 ...  4.60236892e-03
    3.35511356e-03  1.41524838e-03]]

 [[-9.22174286e-03  7.77802430e-03  5.61453914e-03 ...  1.14401691e-02
   -4.61135665e-03 -1.67042250e-04]]

 ...

 [[-5.00767538e-03  9.74254124e-03  5.27418358e-03 ...  1.18234633e-02
    9.50882491e-03  2.69417092e-03]]

 [[-2.22809007e-03  6.96638180e-03  2.26637488e-03 ...  7.26535823e-03
   -2.24366202e-03 -5.79853076e-05]]

 [[-6.45158347e-03  1.34356488e-02  4.12107119e-03 ...  1.05279125e-02
   -3.18481121e-04 -2.19650287e-03]]], shape=(100, 1, 1024), dtype=float32)
logits shape: (batch_size, t, output_vocab_size) (100, 1, 100)
state shape: (batch_size, dec_units) (100, 1024)


In [47]:
sampled_token = tf.random.categorical(dec_result.logits[:, 0, :], num_samples=1)

In [50]:
sampled_token[:5]

<tf.Tensor: shape=(5, 1), dtype=int64, numpy=
array([[45],
       [ 6],
       [95],
       [36],
       [57]])>

In [51]:
class MaskedLoss(tf.keras.losses.Loss):
    def __init__(self):
        self.name = 'masked_loss'
        self.loss = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction='none')

    def __call__(self, y_true, y_pred):
#         shape_checker = ShapeChecker()
#         shape_checker(y_true, ('batch', 't'))
#         shape_checker(y_pred, ('batch', 't', 'logits'))

        # Calculate the loss for each item in the batch.
        loss = self.loss(y_true, y_pred)
#         shape_checker(loss, ('batch', 't'))

        # Mask off the losses on padding.
        mask = tf.cast(y_true != 0, tf.float32)
#         shape_checker(mask, ('batch', 't'))
        loss *= mask

        # Return the total.
        return tf.reduce_sum(loss)

In [66]:
class Recommendation(tf.keras.Model):
    def __init__(self, embedding_dim, units,
               input_text_processor,
               output_text_processor, 
               use_tf_function=True):
        super().__init__()
        # Build the encoder and decoder
        encoder = Encoder(5000, embedding_dim, units)
        decoder = Decoder(5000, embedding_dim, units)

        self.encoder = encoder
        self.decoder = decoder
        self.input_text_processor = input_text_processor
        self.output_text_processor = output_text_processor
        self.use_tf_function = use_tf_function
#         self.shape_checker = ShapeChecker()

    def train_step(self, inputs):
#         self.shape_checker = ShapeChecker()
        if self.use_tf_function:
            return self._tf_train_step(inputs)
        else:
            return self._train_step(inputs)

In [69]:
def _preprocess(self, input_tokens, target_tokens):

    # Convert the text to token IDs
#     input_tokens = self.input_text_processor(input_text)
#     target_tokens = self.output_text_processor(target_text)
#     self.shape_checker(input_tokens, ('batch', 's'))
#     self.shape_checker(target_tokens, ('batch', 't'))

    # Convert IDs to masks.
    input_mask = input_tokens != 0
#     self.shape_checker(input_mask, ('batch', 's'))

    target_mask = target_tokens != 0
#     self.shape_checker(target_mask, ('batch', 't'))

    return input_tokens, input_mask, target_tokens, target_mask

In [70]:
Recommendation._preprocess = _preprocess

In [77]:
def _train_step(self, inputs):
    input_text, target_text = inputs  

    (input_tokens, input_mask,
    target_tokens, target_mask) = self._preprocess(input_text, target_text)

    max_target_length = tf.shape(target_tokens)[1]

    with tf.GradientTape() as tape:
        # Encode the input
        enc_output, enc_state = self.encoder(input_tokens)
#         self.shape_checker(enc_output, ('batch', 's', 'enc_units'))
#         self.shape_checker(enc_state, ('batch', 'enc_units'))

        # Initialize the decoder's state to the encoder's final state.
        # This only works if the encoder and decoder have the same number of
        # units.
        dec_state = enc_state
        loss = tf.constant(0.0)

        for t in tf.range(max_target_length-1):
            new_tokens = target_tokens[:, t:t+2]
            step_loss, dec_state = self._loop_step(new_tokens, input_mask,
                                                 enc_output, dec_state)
            loss = loss + step_loss

        # Average the loss over all non padding tokens.
        average_loss = loss / tf.reduce_sum(tf.cast(target_mask, tf.float32))

    # Apply an optimization step
    variables = self.trainable_variables 
    gradients = tape.gradient(average_loss, variables)
    self.optimizer.apply_gradients(zip(gradients, variables))

    # Return a dict mapping metric names to current value
    return {'batch_loss': average_loss}

In [78]:
Recommendation._train_step = _train_step

In [73]:
def _loop_step(self, new_tokens, input_mask, enc_output, dec_state):
    input_token, target_token = new_tokens[:, 0:1], new_tokens[:, 1:2]

    # Run the decoder one step.
    decoder_input = DecoderInput(new_tokens=input_token,
                               enc_output=enc_output,
                               mask=input_mask)

    dec_result, dec_state = self.decoder(decoder_input, state=dec_state)
#     self.shape_checker(dec_result.logits, ('batch', 't1', 'logits'))
#     self.shape_checker(dec_result.attention_weights, ('batch', 't1', 's'))
#     self.shape_checker(dec_state, ('batch', 'dec_units'))

    # `self.loss` returns the total for non-padded tokens
    y = target_token
    y_pred = dec_result.logits
    step_loss = self.loss(y, y_pred)

    return step_loss, dec_state

In [74]:
Recommendation._loop_step = _loop_step

In [75]:
input_data = np.random.randint(0, 5000, size=(10000, 10))
target_data = np.random.randint(0, 5000, size=(10000, 10))


translator = Recommendation(
    embedding_dim, units,
    input_text_processor=None,
    output_text_processor=None,
    use_tf_function=False)

# Configure the loss and optimizer
translator.compile(
    optimizer=tf.optimizers.Adam(),
    loss=MaskedLoss(),
)

In [None]:
%%time
for n in range(10):
    print(translator.train_step([input_data, target_data]))
print()