In [1]:
import numpy  as np

import typing
from typing import Any, Tuple

import tensorflow as tf
import tensorflow_text as tf_text

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

# Retrieve data

In [2]:
import pathlib

# Data from http://www.manythings.org/anki/.
path_to_file = pathlib.Path("/tf/code/piplinjen/notebooks/deu-eng/deu.txt")

def load_data(path):
    text = path.read_text(encoding='utf-8')

    lines = text.splitlines()
    pairs = [line.split('\t') for line in lines]
    inp = [inp for targ, inp, _ in pairs]
    targ = [targ for targ, inp, _ in pairs]

    return targ, inp

targ, inp = load_data(path_to_file)

print("Example data:")
print("Deutsch: ", inp[-1])
print("English: ", targ[-1])



Example data:
Deutsch:  Ohne Zweifel findet sich auf dieser Welt zu jedem Mann genau die richtige Ehefrau und umgekehrt; wenn man jedoch in Betracht zieht, dass ein Mensch nur Gelegenheit hat, mit ein paar hundert anderen bekannt zu sein, von denen ihm nur ein Dutzend oder weniger nahesteht, darunter höchstens ein oder zwei Freunde, dann erahnt man eingedenk der Millionen Einwohner dieser Welt leicht, dass seit Erschaffung ebenderselben wohl noch nie der richtige Mann der richtigen Frau begegnet ist.
English:  Doubtless there exists in this world precisely the right woman for any given man to marry and vice versa; but when you consider that a human being has the opportunity of being acquainted with only a few hundred people, and out of the few hundred that there are but a dozen or less whom he knows intimately, and out of the dozen, one or two friends at most, it will easily be seen, when we remember the number of millions who inhabit this world, that probably, since the earth was crea

In [3]:
# Create a tensorflow dataset.
BUFFER_SIZE = len(inp)
BATCH_SIZE = 64

dataset = tf.data.Dataset.from_tensor_slices((inp, targ)).shuffle(BUFFER_SIZE)
dataset = dataset.batch(BATCH_SIZE)

for inp_batch, targ_batch in dataset.take(1):
    print("Input: ", inp_batch[:5])
    print("Target: ", targ_batch[:5])
    break

Input:  tf.Tensor(
[b'Lesen Sie diese Anweisungen.'
 b'Tom hat immer seine Unschuld bekr\xc3\xa4ftigt und behauptet bis heute, dass er verleumdet wurde.'
 b'Tom ist fettleibig.'
 b'Tom nahm mir das Versprechen ab, dort nicht wieder hinzugehen.'
 b'Ich lernte eifrig, als ich noch zur Schule ging.'], shape=(5,), dtype=string)
Target:  tf.Tensor(
[b'Read these instructions.'
 b'Tom has always maintained that he is innocent, and claims to this day that he was framed.'
 b'Tom is obese.' b'Tom made me promise not to go there again.'
 b'I studied hard when I was in school.'], shape=(5,), dtype=string)


# Text preprocessing

In [4]:
example_text = tf.constant("Verkrümele dich! Mach ’ne Fliege! Mir ist heiß.")

print(example_text.numpy())
print(tf_text.normalize_utf8(example_text, 'NFKD').numpy())

def tf_lower_and_split_punct(text):
    # Split accecented characters,
    text = tf_text.normalize_utf8(text, "NFKD")
    text = tf.strings.lower(text)
    # Replace special characters.
    text = tf.strings.regex_replace(text, 'ß', 'ss')
    # Keep space, a to z, and select punctuation.
    text = tf.strings.regex_replace(text, '[^ a-z.?!,]', '')
    # Add spaces around punctuation.
    text = tf.strings.regex_replace(text, '[.?!,]', r' \0 ')
    # Strip whitespace.
    text = tf.strings.strip(text)

    text = tf.strings.join(['[START]', text, '[END]'], separator=' ')
    return text

print(example_text.numpy().decode())
print(tf_lower_and_split_punct(example_text).numpy().decode())

b'Verkr\xc3\xbcmele dich! Mach \xe2\x80\x99ne Fliege! Mir ist hei\xc3\x9f.'
b'Verkru\xcc\x88mele dich! Mach \xe2\x80\x99ne Fliege! Mir ist hei\xc3\x9f.'
Verkrümele dich! Mach ’ne Fliege! Mir ist heiß.
[START] verkrumele dich !  mach ne fliege !  mir ist heiss . [END]


In [5]:
# Text vectorization.

max_vocab_size = 5000

input_text_processor = tf.keras.layers.TextVectorization(standardize=tf_lower_and_split_punct,
max_tokens=max_vocab_size)

In [6]:
input_text_processor.adapt(inp)

# First 10 words from the vocabulary.
input_text_processor.get_vocabulary()[:10]

['', '[UNK]', '[START]', '[END]', '.', ',', 'ich', 'tom', '?', 'nicht']

In [7]:
output_text_processor = tf.keras.layers.TextVectorization(standardize=tf_lower_and_split_punct, max_tokens=max_vocab_size)

output_text_processor.adapt(targ)
output_text_processor.get_vocabulary()[:10]

['', '[UNK]', '[START]', '[END]', '.', 'tom', 'to', 'you', 'the', 'i']

In [8]:
# These layers can convert a batch of strings into a batch of token IDs that are zero-padded.
example_tokens = input_text_processor(inp_batch)
print(example_tokens[0])
input_vocab = np.array(input_text_processor.get_vocabulary())
tokens = input_vocab[example_tokens[0].numpy()]
print(' '.join(tokens))

tf.Tensor(
[   2  310   13   97 2464    4    3    0    0    0    0    0    0    0
    0    0    0    0], shape=(18,), dtype=int64)
[START] lesen sie diese anweisungen . [END]           


# Debug utils

In [9]:
# Useful class that enforces the right tensor dimensions.
class ShapeChecker():
    def __init__(self):
        # Keep a cache of every axis-name seen.
        self.shapes = {}

    def __call__(self, tensor, names, broadcast=False):
        if not tf.executing_eagerly():
            return
        
        if isinstance(names, str):
            names = (names,)

        shape = tf.shape(tensor)
        rank = tf.rank(tensor)

        if rank != len(names):
            raise ValueError(f"Rank mismatch:\n"
                             f"     found {rank}: {shape.numpy()}\n"
                             f"     expected {len(names)}: {names}\n")

        for i, name in enumerate(names):
            if isinstance(name, int):
                old_dim = name
            else:
                old_dim = self.shapes.get(name, None)
            new_dim = shape[i]

            if (broadcast and new_dim == 1):
                continue

            if old_dim is None:
                # If the axis name is new, add its length to the cache.
                self.shapes[name] = new_dim
                continue

            if new_dim != old_dim:
                raise ValueError(f"Shape mismatch for dimension: '{name}'\n"
                                f"     found: {new_dim}\n"
                                f"     expected: {old_dim}\n")

# NMT Model

In [12]:
class Encoder(tf.keras.layers.Layer):
    def __init__(self, input_vocab_size, embedding_dim, enc_units):
        super(Encoder, self).__init__()
        self.enc_units = enc_units
        self.input_vocab_size = input_vocab_size

        # The embedding layer converts tokens to vectors.
        self.embedding = tf.keras.layers.Embedding(self.input_vocab_size, embedding_dim)

        # The GRU RNN layer processes those vectors sequentially.
        self.gru = tf.keras.layers.GRU(self.enc_units,
        return_sequences=True,
        return_state=True,
        recurrent_initializer='glorot_uniform')

    def call(self, tokens, state=None):
        shape_checker = ShapeChecker() 
        shape_checker(tokens, ('batch', 's'))

        # Look up the embedding for each token.
        vectors = self.embedding(tokens)
        shape_checker(vectors, ('batch', 's', 'embed_dim'))

        # The GRU processes the embedding sequence.
        #   ouput shape: (batch, s, enc_units)
        #   state shape: (batch, enc_units)
        output, state = self.gru(vectors, initial_state=state)
        shape_checker(output, ('batch', 's', 'enc_units'))
        shape_checker(state, ('batch', 'enc_units'))

        return output, state


In [13]:
# Example usage of the encoder.

embedding_dim = 256
units = 1024

example_tokens = input_text_processor(inp_batch)

encoder = Encoder(input_text_processor.vocabulary_size(),
embedding_dim, units)
example_enc_output, example_enc_state = encoder(example_tokens)

print(f'Input batch shape: (batch): {inp_batch.shape}')
print(f'Input batch tokens shape: (batch, s): {example_tokens.shape}')
print(f'Encoder output shape: (batch, s, units): {example_enc_output.shape}')
print(f'Encoder state shape: (batch, units): {example_enc_state.shape}')


Input batch shape: (batch): (64,)
Input batch tokens shape: (batch, s): (64, 18)
Encoder output shape: (batch, s, units): (64, 18, 1024)
Encoder state shape: (batch, units): (64, 1024)


The attention layer first calculates the attention weights, 
$$\alpha_{ts} = \frac{\exp(\text{score}(\bf{h}_t, \bar{h}_s))}{\sum_{s'}\exp(\text{score}(\bf{h}_t, \bar{h}_{s'}))},$$
where the score is $\bold{v}_a^T \text{tanh}(\bold{W}_1 \bold{h}_t + \bold{W}_2\bold{\bar{h}}_s)$.
And then the context vector,
$$\bold{c}_t = \sum_s \alpha_{ts}\bar{\bold{h}}_s.$$

In [14]:
class BahdanauAttention(tf.keras.layers.Layer):
    def __init__(self, units):
        super().__init__()
        self.W1 = tf.keras.layers.Dense(units, use_bias=False)
        self.W2 = tf.keras.layers.Dense(units, use_bias=False)

        self.attention = tf.keras.layers.AdditiveAttention()

    def call(self, query, value, mask):
        """
        The query is generated by the decoder.
        The value is the output of the encoder.
        The mask is to exclude padding.
        """
        shape_checker = ShapeChecker()
        shape_checker(query, ('batch', 't', 'query_units'))
        shape_checker(value, ('batch', 's', 'value_units'))
        shape_checker(mask, ('batch', 's'))


        w1_query = self.W1(query)
        shape_checker(w1_query, ('batch', 't', 'attn_units'))
        w2_key = self.W2(value)
        shape_checker(w2_key, ('batch', 's', 'attn_units'))

        query_mask = tf.ones(tf.shape(query)[:-1], dtype=bool)
        value_mask = mask

        context_vector, attention_weights = self.attention(inputs = [w1_query, value, w2_key], mask=[query_mask, value_mask], return_attention_scores = True)
        shape_checker(context_vector, ('batch', 't', 'value_units'))
        shape_checker(attention_weights, ('batch', 't', 's'))

        return context_vector, attention_weights

In [15]:
# Test of attention layer.
attention_layer = BahdanauAttention(units)
(example_tokens != 0).shape

TensorShape([64, 18])

In [16]:
# The decoder will generate this.
example_attention_query = tf.random.normal(shape=[len(example_tokens), 2, 10])

context_vector, attention_weights = attention_layer(
    query=example_attention_query,
    value = example_enc_output,
    mask=(example_tokens != 0)
)

print(f"Attention result shape: (batch_size, query_seq_length, units):  {context_vector.shape}")
print(f"Attention weights shape: (batch_size, query_seq_length, value_seq_length):  {attention_weights.shape}")

Attention result shape: (batch_size, query_seq_length, units):  (64, 2, 1024)
Attention weights shape: (batch_size, query_seq_length, value_seq_length):  (64, 2, 18)


In [28]:
# The decoder will generate predictions for the next output token.
#   - The output of the encoder is used as input for the decoder.
#   - The RNN layer keeps track of what has been generated so far.
#   - The RNN output serves as the query to the attention over the encoder's output to produce the context vector.
#   - It combines the RNN output and the context vector to generate the attention vector.
#       attention vector, a_t = f(c_t, h_t) = tanh(W_c[c_t;h_t])
#   - It generates logit predictions for the next token based on the attention vector.

class DecoderInput(typing.NamedTuple):
    new_tokens: Any
    enc_output: Any
    mask: Any

class DecoderOutput(typing.NamedTuple):
    logits: Any
    attention_weights: Any 

class Decoder(tf.keras.layers.Layer):
    def __init__(self, output_vocab_size, embedding_dim, dec_units):
        super(Decoder, self).__init__()
        self.dec_units = dec_units
        self.output_vocab_size = output_vocab_size
        self.embedding_dim = embedding_dim

        self.embedding = tf.keras.layers.Embedding(self.output_vocab_size, embedding_dim)
        self.gru = tf.keras.layers.GRU(self.dec_units, 
                    return_sequences=True, 
                    return_state=True, 
                    recurrent_initializer='glorot_uniform')

        self.attention = BahdanauAttention(self.dec_units)
        # To create the attention vector.
        self.Wc = tf.keras.layers.Dense(dec_units, activation=tf.math.tanh, use_bias=False)
        # To create logit predictions.
        self.fc = tf.keras.layers.Dense(self.output_vocab_size)

    def call(self, inputs: DecoderInput, state=None) -> Tuple[DecoderOutput, tf.Tensor]:
        shape_checker = ShapeChecker()
        shape_checker(inputs.new_tokens, ('batch', 't'))
        shape_checker(inputs.enc_output, ('batch', 's', 'enc_units'))
        shape_checker(inputs.mask, ('batch', 's'))

        if state is not None:
            shape_checker(state, ('batch', 'dec_units'))

        vectors = self.embedding(inputs.new_tokens)
        shape_checker(vectors, ('batch', 't', 'embedding_dim'))

        # Process one step with the RNN.
        rnn_output, state = self.gru(vectors, initial_state=state)
        shape_checker(rnn_output, ('batch', 't', 'dec_units'))
        shape_checker(state, ('batch', 'dec_units'))

        # Use the RNN output as the query for the attention over the encoder output.
        context_vector, attention_weights = self.attention(
            query=rnn_output, 
            value=inputs.enc_output, 
            mask=inputs.mask
        )
        shape_checker(context_vector, ('batch', 't', 'dec_units'))
        shape_checker(attention_weights, ('batch', 't', 's'))

        context_and_rnn_output = tf.concat([context_vector, rnn_output], axis=-1)
        attention_vector = self.Wc(context_and_rnn_output)
        shape_checker(attention_vector, ('batch', 't', 'dec_units'))

        logits = self.fc(attention_vector)
        shape_checker(logits, ('batch', 't', 'output_vocab_size'))

        return DecoderOutput(logits, attention_weights), state


In [29]:
decoder = Decoder(output_text_processor.vocabulary_size(), embedding_dim, units)

example_output_tokens = output_text_processor(targ_batch)

start_index = output_text_processor.get_vocabulary().index('[START]')
first_token = tf.constant([[start_index]] * example_output_tokens.shape[0])

dec_result, dec_state = decoder(
    inputs = DecoderInput(
        new_tokens=first_token,
        enc_output=example_enc_output,
        mask=(example_tokens != 0)
        ),
    state = example_enc_state
)

print(f"logits shape: (batch_size, t, output_vocab_size) {dec_result.logits.shape}")
print(f"state shape: (batch_size, dec_units) {dec_state.shape}")

logits shape: (batch_size, t, output_vocab_size) (64, 1, 5000)
state shape: (batch_size, dec_units) (64, 1024)
