In [None]:
# https://github.com/bai-shang/crnn_seq2seq_ocr_pytorch
# https://github.com/chenjun2hao/Attention_ocr.pytorch
# https://github.com/alleveenstra/attentionocr
# https://github.com/koibiki/CRNN-ATTENTION
import sys
import tensorflow as tf
sys.path.append('..')
tf.get_logger().setLevel('ERROR')
APPROACH_NAME = 'CNNxSeq2Seq'

# Check GPU working

In [None]:
physical_devices = tf.config.list_physical_devices('GPU') 
tf.config.experimental.set_memory_growth(physical_devices[0], True)
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0': raise SystemError('GPU device not found')
print('Found GPU at:', device_name)
!nvcc -V

# Data input pipeline

In [None]:
DATASET_DIR = r'../../Datasets/Patches'
ALL_TRANSCRIPTS_PATH = f'{DATASET_DIR}/All.txt'
VALID_TRANSCRIPTS_PATH = f'{DATASET_DIR}/Validate.txt'
FONT_PATH = r'../../NomNaTong-Regular.ttf'

## Load and remove records with rare characters

In [None]:
from loader import DataImporter
dataset = DataImporter(DATASET_DIR, ALL_TRANSCRIPTS_PATH, min_length=1)
print(dataset)

## Data constants and input pipeline

In [None]:
HEIGHT, WIDTH = 432, 48
PADDING_CHAR = '[PAD]' 
START_CHAR = '[START]'
END_CHAR = '[END]' 

In [None]:
from loader import DataHandler
data_handler = DataHandler(
    dataset, 
    img_size = (HEIGHT, WIDTH), 
    padding_char = PADDING_CHAR,
    start_char = START_CHAR,
    end_char = END_CHAR
)

In [None]:
NUM_VALIDATE = DataImporter(DATASET_DIR, VALID_TRANSCRIPTS_PATH, min_length=1).size
START_TOKEN = data_handler.start_token
END_TOKEN = data_handler.end_token
VOCAB_SIZE = data_handler.char2num.vocab_size()
BATCH_SIZE = 32

## Visualize the data

In [None]:
from visualizer import visualize_images_labels
visualize_images_labels(
    dataset.img_paths, 
    dataset.labels, 
    figsize = (15, 15),
    subplot_size = (2, 8),
    font_path = FONT_PATH
)

# Define model components

In [None]:
from tensorflow.keras.layers import (
    Input, Embedding, Dense, GRU, 
    Bidirectional, Concatenate, Flatten
)
from layers import custom_cnn, reshape_features, AdditiveAttention
EMBEDDING_DIM = 512
UNITS = 256

## The encoder

In [None]:
def Encoder(imagenet_model=None, imagenet_output_layer=None, name='Encoder'):
    if imagenet_model: # Use Imagenet model as CNN layers
        image_input = imagenet_model.input
        imagenet_model.layers[0]._name = 'image'
        x = imagenet_model.get_layer(imagenet_output_layer).output
    else: 
        image_input = Input(shape=(HEIGHT, WIDTH, 3), dtype='float32', name='image')
        conv_blocks_config = {
            'block1': {'num_conv': 1, 'filters':  64, 'pool_size': (2, 2)}, 
            'block2': {'num_conv': 1, 'filters': 128, 'pool_size': (2, 2)}, 
            'block3': {'num_conv': 2, 'filters': 256, 'pool_size': (2, 2)}, 
            'block4': {'num_conv': 2, 'filters': 512, 'pool_size': (2, 2)}, 
            
            # Last Conv blocks with 2x2 kernel but without no padding and pooling layer
            'block5': {'num_conv': 2, 'filters': 512, 'pool_size': None}, 
        }
        x = custom_cnn(conv_blocks_config, image_input)

    # Reshape accordingly before passing output to RNN
    feature_maps = reshape_features(x, dim_to_keep=1, name='rnn_input')
    
    # RNN layers
    bigru1 = Bidirectional(GRU(UNITS, return_sequences=True), name='bigru1')(feature_maps)
    bigru2, forward_state, backward_state = Bidirectional(GRU(
        units = UNITS, 
        return_sequences = True,
        return_state = True
    ), name = 'bigru2')(bigru1)
    
    # Concat states of 2 directions
    final_state = Concatenate(name='encoder_state')([forward_state, backward_state])
    return tf.keras.Model(inputs=image_input, outputs=[bigru2, final_state], name=name)

## The decoder

In [None]:
def Decoder(enc_seq_shape, name='Decoder'):
    token_input = Input(shape=(1,), name='new_token')
    enc_seq_input = Input(shape=enc_seq_shape, name='encoder_sequence')
    
    # In encoder we used Bidirectional so we have to take double UNITS  
    # for single decoder GRU using encoder's final state as initial state
    dec_units = UNITS * 2
    pre_hidden_input = Input(shape=(dec_units,), name='previous_state')
    
    # Process one step with the RNN
    x = Embedding(VOCAB_SIZE, EMBEDDING_DIM)(token_input)
    rnn_output, state = GRU(
        units = dec_units, 
        return_state = True, 
        return_sequences = True,
        name = 'dec_gru'
    )(x, initial_state=pre_hidden_input)
    
    # Use the RNN output as the query for the attention over the encoder output
    attention = AdditiveAttention(dec_units)
    context_vector, attention_weights = attention(rnn_output, enc_seq_input)
    
    # Form the attention vector
    context_and_rnn_output = tf.concat([context_vector, rnn_output], axis=-1)
    attention_vector = Dense(
        units = dec_units, 
        activation = 'tanh', 
        use_bias = False,
        name = 'attention_vector'
    )(context_and_rnn_output)
    
    # Generate predictions
    x = Flatten()(attention_vector)
    y_pred = Dense(VOCAB_SIZE, name='prediction')(x)

    return tf.keras.Model(
        inputs = [token_input, enc_seq_input, pre_hidden_input], 
        outputs = [y_pred, state, attention_weights],
        name = name
    )

# Build the model

In [None]:
from models import get_imagenet_model, EncoderDecoderModel
imagenet_model, imagenet_output_layer = None, None
# # Pick a model from https://keras.io/api/applications
# imagenet_model = get_imagenet_model('VGG16', (HEIGHT, WIDTH, 3))
# imagenet_output_layer = 'block5_pool'
# imagenet_model.summary(line_length=100)

In [None]:
encoder = Encoder(imagenet_model, imagenet_output_layer)
decoder = Decoder(encoder.outputs[0].shape[1:])
encoder.load_weights(f'./IHR-NomDB/IHR-NomDB_{APPROACH_NAME}_enc.h5')
decoder.load_weights(
    f'./IHR-NomDB/IHR-NomDB_{APPROACH_NAME}_dec.h5', 
    skip_mismatch = True,
    by_name = True,
)

In [None]:
model = EncoderDecoderModel(encoder, decoder, data_handler)
encoder.summary(line_length=120)
print()
decoder.summary(line_length=125)

# Training

In [None]:
train_idxs = list(range(dataset.size - NUM_VALIDATE))
valid_idxs = list(range(train_idxs[-1] + 1, dataset.size))
print('Number of training samples:', len(train_idxs))
print('Number of validate samples:', len(valid_idxs))

In [None]:
import random
random.seed(2022)
random.shuffle(train_idxs)

# When run on a small RAM machine, you can set use_cache=False to 
# not run out of memory but it will slow down the training speed
train_tf_dataset = data_handler.prepare_tf_dataset(train_idxs, BATCH_SIZE, drop_remainder=True)
valid_tf_dataset = data_handler.prepare_tf_dataset(valid_idxs, BATCH_SIZE, drop_remainder=True)

## Callbacks

In [None]:
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau

early_stopping_callback = EarlyStopping(
    monitor = 'val_loss', 
    min_delta = 1e-3, # Change that less than 1e-3, will count as no improvement
    patience = 5, # Stop if no improvement after 5 epochs
    restore_best_weights = True, # Restore weights from the epoch with the best value
    verbose = 1
)

# Reduce the learning rate once learning stagnates
reduce_lr_callback = ReduceLROnPlateau(
    monitor = 'val_loss', 
    patience = 2, # Reduce if no improvement after 2 epochs
    min_lr = 1e-6, # Lower bound on the learning rate 
    factor = 0.5, # => new_lr = lr * factor
    verbose = 1
)

## Fine-tuning the NomNaOCR dataset

In [None]:
from losses import MaskedLoss
from metrics import SequenceAccuracy, CharacterAccuracy, LevenshteinDistance
from tensorflow.keras.optimizers import Adam
LEARNING_RATE = 2e-4
EPOCHS = 100

In [None]:
model.compile(
    optimizer = Adam(LEARNING_RATE), 
    loss = MaskedLoss(), 
    metrics = [
        SequenceAccuracy(),
        CharacterAccuracy(),
        LevenshteinDistance(normalize=True, name='lev_distance')
    ]
)

In [None]:
%%time
history = model.fit(
    train_tf_dataset,
    validation_data = valid_tf_dataset,
    epochs = EPOCHS,
    callbacks = [reduce_lr_callback, early_stopping_callback],
    verbose = 1
).history

## Save the training results

In [None]:
best_epoch = early_stopping_callback.best_epoch
print(f'- Loss on validation\t: {history["val_loss"][best_epoch]}')
print(f'- Sequence accuracy\t: {history["val_seq_acc"][best_epoch]}')
print(f'- Character accuracy\t: {history["val_char_acc"][best_epoch]}')
print(f'- Levenshtein distance\t: {history["val_lev_distance"][best_epoch]}')

In [None]:
from visualizer import plot_training_results
plot_training_results(history, f'finetune_{APPROACH_NAME}.png')
model.encoder.save_weights(f'finetune_{APPROACH_NAME}_enc.h5')
model.decoder.save_weights(f'finetune_{APPROACH_NAME}_dec.h5')

# Inference

In [None]:
encoder = Encoder(imagenet_model, imagenet_output_layer)
decoder = Decoder(encoder.outputs[0].shape[1:])
encoder.load_weights(f'Fine-tuning/finetune_{APPROACH_NAME}_enc.h5')
decoder.load_weights(f'Fine-tuning/finetune_{APPROACH_NAME}_dec.h5')

In [None]:
reset_model = EncoderDecoderModel(encoder, decoder, data_handler)
reset_model.compile(
    optimizer = Adam(LEARNING_RATE), 
    loss = MaskedLoss(), 
    metrics = [
        SequenceAccuracy(),
        CharacterAccuracy(),
        LevenshteinDistance(normalize=True, name='lev_distance')
    ]
)

## On test dataset

In [None]:
batch_results = []
for idx, (batch_images, batch_tokens) in enumerate(valid_tf_dataset.take(1)):
    idxs_in_batch = valid_idxs[idx * BATCH_SIZE: (idx + 1) * BATCH_SIZE]
    labels = data_handler.tokens2texts(batch_tokens)
    pred_tokens, attentions = reset_model.predict(batch_images, return_attention=True)
    pred_labels = data_handler.tokens2texts(pred_tokens)
    
    batch_results.append({'true': labels, 'pred': pred_labels, 'attentions': attentions})
    visualize_images_labels(
        img_paths = dataset.img_paths[idxs_in_batch], 
        labels = labels, 
        pred_labels = pred_labels,
        figsize = (11.6, 30),
        subplot_size = (4, 8),
        legend_loc = (3.8, 4.38),
        annotate_loc = (4, 2.75),
        font_path = FONT_PATH, 
    )
    print(
        f'Batch {idx + 1:02d}:\n'
        f'- True: {dict(enumerate(labels, start=1))}\n'
        f'- Pred: {dict(enumerate(pred_labels, start=1))}\n'
    )

## On random image

In [None]:
random_path = '../囷𦝄苔惮󰞺𧍋𦬑囊.jpg'
random_label = '囷𦝄苔惮󰞺𧍋𦬑囊'
random_image = data_handler.process_image(random_path)
pred_tokens = reset_model.predict(tf.expand_dims(random_image, axis=0))
pred_labels = data_handler.tokens2texts(pred_tokens)

In [None]:
visualize_images_labels(
    img_paths = [random_path], 
    labels = [random_label], 
    pred_labels = pred_labels,
    figsize = (5, 4),
    subplot_size = (1, 1), 
    font_path = FONT_PATH, 
)
print('Predicted text:', ''.join(pred_labels))

# Detail evaluation

In [None]:
import pandas as pd
from evaluator import Evaluator
GT10_TRANSCRIPTS_PATH = f'{DATASET_DIR}/Validate_gt10.txt'
LTE10_TRANSCRIPTS_PATH = f'{DATASET_DIR}/Validate_lte10.txt'

In [None]:
gt10_evaluator = Evaluator(reset_model, DATASET_DIR, GT10_TRANSCRIPTS_PATH)
lte10_evaluator = Evaluator(reset_model, DATASET_DIR, LTE10_TRANSCRIPTS_PATH)
df = pd.DataFrame([
    reset_model.evaluate(valid_tf_dataset, return_dict=True),
    gt10_evaluator.evaluate(data_handler, BATCH_SIZE, drop_remainder=True),
    lte10_evaluator.evaluate(data_handler, BATCH_SIZE, drop_remainder=True),
])
df.index = ['Full', 'Length > 10', 'Length ≤ 10']
df