In [6]:
# https://www.tensorflow.org/text/tutorials/nmt_with_attention
import sys
import tensorflow as tf
import tensorflow_text as tf_text
sys.path.append('..')
tf.get_logger().setLevel('ERROR')

# Check GPU working

In [2]:
physical_devices = tf.config.list_physical_devices('GPU') 
tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [3]:
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0': raise SystemError('GPU device not found')
print('Found GPU at:', device_name)
!nvcc -V

Found GPU at: /device:GPU:0
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Mon_Nov_30_19:15:10_Pacific_Standard_Time_2020
Cuda compilation tools, release 11.2, V11.2.67
Build cuda_11.2.r11.2/compiler.29373293_0


# Import the dataset

In [None]:
from dataset_handler import create_dataset, remove_rare_chars
DATA_PATH = '../../Dataset/trdg'
FONT_PATH = '../../Dataset/NomNaTong-Regular.ttf'
HEIGHT = 148
WIDTH = 32

## Load and remove records with rare characters

In [None]:
img_paths, labels, vocabs = create_dataset(DATA_PATH, sim2tra=True)
img_paths, labels, vocabs = remove_rare_chars(img_paths, labels, vocabs, threshold=3)
print('Number of images found:', len(img_paths))
print('Number of labels found:', len(labels))
print('Number of unique characters:', len(vocabs))
print('Characters present:', vocabs, sep='')

## Visualize the data

In [None]:
from visualizer import visualize_images_labels
visualize_images_labels(img_paths, labels, font_path=FONT_PATH, text_x=WIDTH + 3)

# Text preprocessing

In [8]:
def standardize_text(text):
    text = tf_text.normalize_utf8(text, 'NFKD')
    text = tf.strings.lower(text)
    text = tf.strings.strip(text)
    text = tf.strings.join(['[START]', text, '[END]'])
    return text

In [None]:
text_processor = tf.keras.layers.TextVectorization(
    standardize = standardize_text,
    max_tokens = max_vocab_size
)
text_processor.adapt(targ)
text_processor.get_vocabulary()[:10]

# Define the model

In [None]:
from tensorflow.keras.layers import Embedding, GRU, Dense, AdditiveAttention
from typing import Any, Tuple
import typing
embedding_dim = 256
units = 1024

## Shape checker

In [None]:
class ShapeChecker():
    def __init__(self):
        # Keep a cache of every axis-name seen
        self.shapes = {}

    def __call__(self, tensor, names, broadcast=False):
        if not tf.executing_eagerly(): return
        if isinstance(names, str): names = (names,)

        shape = tf.shape(tensor)
        rank = tf.rank(tensor)

        if rank != len(names): raise ValueError(
            f'Rank mismatch:\n'
            f'\t found {rank}: {shape.numpy()}\n'
            f'\t expected {len(names)}: {names}\n'
        )

        for i, name in enumerate(names):
            if isinstance(name, int): old_dim = name
            else: old_dim = self.shapes.get(name, None)
            new_dim = shape[i]

            if (broadcast and new_dim == 1): continue
            if old_dim is None:
                # If the axis name is new, add its length to the cache.
                self.shapes[name] = new_dim
                continue

            if new_dim != old_dim: raise ValueError(
                f"Shape mismatch for dimension: '{name}'\n"
                f"\t found: {new_dim}\n"
                f"\t expected: {old_dim}\n"
            )

## The Encoder


In [None]:
class Encoder(tf.keras.layers.Layer):
    def __init__(self, input_vocab_size, embedding_dim, enc_units):
        super(Encoder, self).__init__()
        self.enc_units = enc_units
        self.input_vocab_size = input_vocab_size

        # The embedding layer converts tokens to vectors
        self.embedding = Embedding(self.input_vocab_size, embedding_dim)

        # The GRU layer processes those vectors sequentially.
        self.gru = GRU(
            self.enc_units,
            return_sequences = True,
            return_state = True,
            recurrent_initializer = 'glorot_uniform'
        )

    def call(self, tokens, state=None):
        shape_checker = ShapeChecker()
        shape_checker(tokens, ('batch', 's'))

        # 2. The embedding layer looks up the embedding for each token.
        vectors = self.embedding(tokens)
        shape_checker(vectors, ('batch', 's', 'embed_dim'))

        # 3. The GRU processes the embedding sequence.
        #    output shape: (batch, s, enc_units)
        #    state shape: (batch, enc_units)
        output, state = self.gru(vectors, initial_state=state)
        shape_checker(output, ('batch', 's', 'enc_units'))
        shape_checker(state, ('batch', 'enc_units'))

        # 4. Returns the new sequence and its state.
        return output, state

## The Attention head 

![](https://www.tensorflow.org/text/tutorials/images/attention_equation_1.jpg)
![](https://www.tensorflow.org/text/tutorials/images/attention_equation_2.jpg)
![](https://www.tensorflow.org/text/tutorials/images/attention_equation_4.jpg)

In [None]:
class BahdanauAttention(tf.keras.Model):
    def __init__(self, units):
        super(BahdanauAttention, self).__init__()
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V = tf.keras.layers.Dense(1)

    def call(self, features, hidden):
        # features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)

        # hidden shape == (batch_size, hidden_size)
        # hidden_with_time_axis shape == (batch_size, 1, hidden_size)
        hidden_with_time_axis = tf.expand_dims(hidden, 1)

        # attention_hidden_layer shape == (batch_size, 64, units)
        attention_hidden_layer = (tf.nn.tanh(self.W1(features) +
                                             self.W2(hidden_with_time_axis)))

        # score shape == (batch_size, 64, 1)
        # This gives you an unnormalized score for each image feature.
        score = self.V(attention_hidden_layer)

        # attention_weights shape == (batch_size, 64, 1)
        attention_weights = tf.nn.softmax(score, axis=1)

        # context_vector shape after sum == (batch_size, hidden_size)
        context_vector = attention_weights * features
        context_vector = tf.reduce_sum(context_vector, axis=1)

        return context_vector, attention_weights

In [None]:
class BahdanauAttention(tf.keras.layers.Layer):
    def __init__(self, units):
        super().__init__()
        # For Eqn. (4), the  Bahdanau attention
        self.W1 = Dense(units, use_bias=False)
        self.W2 = Dense(units, use_bias=False)
        self.attention = AdditiveAttention()

    def call(self, query, value, mask):
        shape_checker = ShapeChecker()
        shape_checker(query, ('batch', 't', 'query_units'))
        shape_checker(value, ('batch', 's', 'value_units'))
        shape_checker(mask, ('batch', 's'))

        w1_query = self.W1(query) # From Eqn. (4), `W1 @ ht`
        shape_checker(w1_query, ('batch', 't', 'attn_units'))

        w2_key = self.W2(value) # From Eqn. (4), `W2 @ hs`
        shape_checker(w2_key, ('batch', 's', 'attn_units'))

        query_mask = tf.ones(tf.shape(query)[:-1], dtype=bool)
        value_mask = mask

        context_vector, attention_weights = self.attention(
            inputs = [w1_query, value, w2_key],
            mask = [query_mask, value_mask],
            return_attention_scores = True,
        )

        shape_checker(context_vector, ('batch', 't', 'value_units'))
        shape_checker(attention_weights, ('batch', 't', 's'))
        return context_vector, attention_weights

## The Decoder

![](https://www.tensorflow.org/text/tutorials/images/attention_equation_3.jpg)


In [None]:
class DecoderInput(typing.NamedTuple):
    new_tokens: Any
    enc_output: Any
    mask: Any

class DecoderOutput(typing.NamedTuple):
    logits: Any
    attention_weights: Any

In [None]:
class Decoder(tf.keras.layers.Layer):
    def __init__(self, output_vocab_size, embedding_dim, dec_units):
        super(Decoder, self).__init__()
        self.dec_units = dec_units
        self.output_vocab_size = output_vocab_size
        self.embedding_dim = embedding_dim

        # 1. The embedding layer converts token IDs to vectors
        self.embedding = Embedding(self.output_vocab_size, embedding_dim)

        # 2. The RNN keeps track of what's been generated so far.
        self.gru = GRU(
            self.dec_units,
            return_sequences = True,
            return_state = True,
            recurrent_initializer = 'glorot_uniform'
        )

        # 3. The RNN output will be the query for the attention layer.
        self.attention = BahdanauAttention(self.dec_units)

        # 4. RNN output & context vector using Eqn. (3) to generate attention vector.
        self.Wc = Dense(dec_units, activation=tf.math.tanh, use_bias=False)

        # 5. Fully connected layer produces the logits for each output.
        self.fc = Dense(self.output_vocab_size)
    
    def call(self, inputs:DecoderInput, state=None) -> Tuple[DecoderOutput, tf.Tensor]:
        shape_checker = ShapeChecker()
        shape_checker(inputs.new_tokens, ('batch', 't'))
        shape_checker(inputs.enc_output, ('batch', 's', 'enc_units'))
        shape_checker(inputs.mask, ('batch', 's'))
        if state is not None: 
            shape_checker(state, ('batch', 'dec_units'))

        # 1. Lookup the embeddings
        vectors = self.embedding(inputs.new_tokens)
        shape_checker(vectors, ('batch', 't', 'embedding_dim'))

        # 2. Process one step with the RNN
        rnn_output, state = self.gru(vectors, initial_state=state)
        shape_checker(rnn_output, ('batch', 't', 'dec_units'))
        shape_checker(state, ('batch', 'dec_units'))

        # 3. Use the RNN output as the query   
        # for the attention over the encoder output.
        context_vector, attention_weights = self.attention(
            query = rnn_output, 
            value = inputs.enc_output, 
            mask = inputs.mask
        )
        shape_checker(context_vector, ('batch', 't', 'dec_units'))
        shape_checker(attention_weights, ('batch', 't', 's'))

        # 4. Eqn. (3): Join the context_vector and rnn_output
        # [ct; ht] shape: (batch t, value_units + query_units)
        context_and_rnn_output = tf.concat([context_vector, rnn_output], axis=-1)

        # 5. Eqn. (3): `at = tanh(Wc @ [ct; ht])`
        attention_vector = self.Wc(context_and_rnn_output)
        shape_checker(attention_vector, ('batch', 't', 'dec_units'))

        # 6. Generates logit predictions for the next token based on the "attention vector".
        logits = self.fc(attention_vector)
        shape_checker(logits, ('batch', 't', 'output_vocab_size'))
        return DecoderOutput(logits, attention_weights), state

# Training

## Implement the training step

In [None]:
class AttentionOCR(tf.keras.Model):
    def __init__(
        self, embedding_dim, units, 
        input_text_processor, output_text_processor, 
    ):
        super().__init__()
        self.shape_checker = ShapeChecker()
        self.input_text_processor = input_text_processor
        self.output_text_processor = output_text_processor

        # Build the encoder and decoder
        self.encoder = Encoder(
            input_text_processor.vocabulary_size(),
            embedding_dim, units
        )
        self.decoder = Decoder(
            output_text_processor.vocabulary_size(),
            embedding_dim, units
        )
    
    @tf.function
    def train_step(self, inputs):
        with tf.GradientTape() as tape:
            average_loss = self._compute_loss(self, inputs)

        # Apply an optimization step
        variables = self.trainable_variables 
        gradients = tape.gradient(average_loss, variables)
        self.optimizer.apply_gradients(zip(gradients, variables))

        # Return a dict mapping metric names to current value
        return {'loss': average_loss}

    @tf.function
    def test_step(self, inputs):
        average_loss = self._compute_loss(self, inputs)
        return {'val_loss': average_loss}

In [None]:
def _preprocess(self, inputs):
    input_text, target_text = inputs 
    self.shape_checker(input_text, ('batch',))
    self.shape_checker(target_text, ('batch',))

    # Convert the text to token IDs
    input_tokens = self.input_text_processor(input_text)
    target_tokens = self.output_text_processor(target_text)
    self.shape_checker(input_tokens, ('batch', 's'))
    self.shape_checker(target_tokens, ('batch', 't'))

    # Convert IDs to masks.
    input_mask = input_tokens != 0
    self.shape_checker(input_mask, ('batch', 's'))

    target_mask = target_tokens != 0
    self.shape_checker(target_mask, ('batch', 't'))
    return input_tokens, input_mask, target_tokens, target_mask

In [None]:
def _loop_step(self, new_tokens, input_mask, enc_output, dec_state):
    # Run the decoder one step.
    decoder_input = DecoderInput(
        new_tokens = new_tokens[:, 0:1], 
        enc_output = enc_output,
        mask = input_mask
    )

    dec_result, dec_state = self.decoder(decoder_input, state=dec_state)
    self.shape_checker(dec_result.logits, ('batch', 't1', 'logits'))
    self.shape_checker(dec_result.attention_weights, ('batch', 't1', 's'))
    self.shape_checker(dec_state, ('batch', 'dec_units'))

    # `self.loss` returns the total for non-padded tokens
    y_true, y_pred = new_tokens[:, 1:2], dec_result.logits
    step_loss = self.loss(y_true, y_pred)
    return step_loss, dec_state

In [None]:
def _compute_loss(self, inputs):
    input_tokens, input_mask, target_tokens, target_mask = self._preprocess(inputs)
    max_target_length = tf.shape(target_tokens)[1]

    # Encode the input
    enc_output, enc_state = self.encoder(input_tokens)
    self.shape_checker(enc_output, ('batch', 's', 'enc_units'))
    self.shape_checker(enc_state, ('batch', 'enc_units'))

    # Initialize the decoder's state to the encoder's final state.
    # This only works if the encoder and decoder have the same number of units.
    dec_state = enc_state
    loss = tf.constant(0.0)

    for t in tf.range(max_target_length - 1):
        # Pass in 2 tokens from the target sequence:
        # 1. The current input to the decoder.
        # 2. The target for the decoder's next prediction.
        step_loss, dec_state = self._loop_step(
            target_tokens[:, t:t+2], input_mask,
            enc_output, dec_state
        )
        loss += step_loss

    # Average the loss over all non padding tokens.
    return loss / tf.reduce_sum(tf.cast(target_mask, tf.float32))

In [None]:
AttentionOCR._preprocess = _preprocess
AttentionOCR._preprocess = _loop_step
AttentionOCR._train_step = _compute_loss

## Define the loss function

In [None]:
class MaskedLoss(tf.keras.losses.Loss):
    def __init__(self):
        self.name = 'masked_loss'
        self.loss = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits = True, 
            reduction = 'none'
        )

    def __call__(self, y_true, y_pred):
        shape_checker = ShapeChecker()
        shape_checker(y_true, ('batch', 't'))
        shape_checker(y_pred, ('batch', 't', 'logits'))

        # Calculate the loss for each item in the batch.
        loss = self.loss(y_true, y_pred)
        shape_checker(loss, ('batch', 't'))

        # Mask off the losses on padding.
        mask = tf.cast(y_true != 0, tf.float32)
        shape_checker(mask, ('batch', 't'))
        loss *= mask

        # Return the total.
        return tf.reduce_sum(loss)

## Train the model

In [None]:
class BatchLogs(tf.keras.callbacks.Callback):
    def __init__(self, key):
        self.key = key
        self.logs = []

    def on_train_batch_end(self, n, logs):
        self.logs.append(logs[self.key])

In [None]:
attention_ocr = AttentionOCR(
    embedding_dim, units,
    input_text_processor = input_text_processor,
    output_text_processor = output_text_processor
)
attention_ocr.compile(optimizer=tf.optimizers.Adam(), loss=MaskedLoss())
attention_ocr.fit(dataset, epochs=3, callbacks=[BatchLogs('batch_loss')])