In [1]:
# importing the necessary libraries

import tensorflow as tf
try: 
    [tf.config.experimental.set_memory_growth(gpu, True) for gpu in tf.config.experimental.list_physical_devices("GPU")]
except: 
    pass

import os
import typing
import requests
import tarfile
from tqdm import tqdm
from datetime import datetime
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard
from keras import layers
from keras.models import Model
from mltu.preprocessors import ImageReader
from mltu.transformers import ImageResizer, LabelIndexer, LabelPadding, ImageShowCV2
from mltu.augmentors import RandomBrightness, RandomRotate, RandomErodeDilate, RandomSharpen
from mltu.annotations.images import CVImage
from mltu.tensorflow.dataProvider import DataProvider
from mltu.tensorflow.losses import CTCloss
from mltu.tensorflow.callbacks import Model2onnx, TrainLogger
from mltu.tensorflow.metrics import CERMetric, WERMetric
from mltu.configs import BaseModelConfigs




In [2]:
# authentication details for downloading the dataset

USER = "daniel.byiringiro@ashesi.edu.gh"
PASSWORD = "December@2020"

In [3]:
# download the ASCII dataset and the sentences dataset

# URL of the files to download
ascii_url = "https://fki.tic.heia-fr.ch/DBs/iamDB/data/ascii.tgz"
sentences_url = "https://fki.tic.heia-fr.ch/DBs/iamDB/data/sentences.tgz"

# Session to handle cookies
session = requests.Session()

# Perform login
login_url = "https://fki.tic.heia-fr.ch/login"
login_payload = {"email": USER, "password": PASSWORD}
login_response = session.post(login_url, data=login_payload)

if "error" in login_response.text:
    print("Login failed.")
else:
    # Download the ASCII dataset
    ascii_response = session.get(ascii_url)
    if ascii_response.status_code == 200:
        with open("ascii.tgz", "wb") as file:
            file.write(ascii_response.content)
        print("ASCII dataset downloaded successfully.")
    else:
        print(f"Failed to download the ASCII dataset. Status code: {ascii_response.status_code}")

    # Download the sentences dataset
    sentences_response = session.get(sentences_url)
    if sentences_response.status_code == 200:
        with open("sentences.tgz", "wb") as file:
            file.write(sentences_response.content)
        print("Sentences dataset downloaded successfully.")
    else:
        print(f"Failed to download the sentences dataset. Status code: {sentences_response.status_code}")


ASCII dataset downloaded successfully.
Sentences dataset downloaded successfully.


In [4]:
# extract the downloaded files

ascii_file = "ascii.tgz"
sentences_file = "sentences.tgz"

ascii_extracted_folder = "ascii"
sentences_extracted_folder = "sentences"

if not os.path.exists(ascii_extracted_folder):
    print("Extracting the ASCII dataset...")
    with tarfile.open(ascii_file, "r:gz") as tar:
        tar.extractall(path = ascii_extracted_folder)
    print("ASCII dataset extracted successfully.")
    
if not os.path.exists(sentences_extracted_folder):
    print("Extracting the sentences dataset...")
    with tarfile.open(sentences_file, "r:gz") as tar:
        tar.extractall(path = sentences_extracted_folder)
    print("Sentences dataset extracted successfully.")

In [5]:
# remove the downloaded files

if os.path.exists(ascii_file):
    os.remove(ascii_file)
    print("ASCII dataset file removed.")

if os.path.exists(sentences_file):
    os.remove(sentences_file)
    print("Sentences dataset file removed.")

ASCII dataset file removed.
Sentences dataset file removed.


In [6]:
# prepare the dataset

ascii_sentences_file = os.path.join(ascii_extracted_folder, "sentences.txt")
sentences_folder = sentences_extracted_folder   

# read the sentences file

dataset, vocab, maximum_length = [], set(), 0   

words = open(ascii_sentences_file).readlines()

for line in tqdm(words):
    
    if line.startswith("#"):
        continue
    
    line = line.split(" ")
    
    subfolder = line[0].split("-")[0]
    sub_subfolder = line[0].split("-")[:2]
    sub_subfolder = "-".join(sub_subfolder)
    filename = line[0] + ".png"
    
    image_path = os.path.join(sentences_folder, subfolder, sub_subfolder, filename)
    
    text = " ".join(line[9:])
    text = text.replace("|", " ")
    text = text.replace("\n", "")
    
    # check if the image exists
    
    if not os.path.exists(image_path):
        print(f"Image {image_path} does not exist.")
        continue
    
    # check if the text is empty
    
    if len(text) == 0:
        print(f"Text is empty for image {image_path}.")
        continue
    
    dataset.append([image_path, text])
    vocab.update(list(text))
    maximum_length = max(maximum_length, len(text))
    
    

  0%|          | 0/16777 [00:00<?, ?it/s]

100%|██████████| 16777/16777 [00:02<00:00, 6226.69it/s]


In [7]:
# create a configuration class

class ModelConfigs(BaseModelConfigs):
    def __init__(self):
        super().__init__()
        self.model_path = os.path.join("models", datetime.strftime(datetime.now(), "%Y%m%d%H%M"))
        self.vocab = ""
        self.height = 96
        self.width = 1408
        self.max_text_length = 0
        self.batch_size = 32
        self.learning_rate = 0.0005
        self.train_epochs = 1000
        self.train_workers = 20
        self.channels = 3

In [8]:
# create a ModelConfigs object to store the configurations and update the values

configs = ModelConfigs()

# update the vocab, maximum text length

configs.vocab = "".join(vocab)
configs.max_text_length = maximum_length
configs.save()

In [9]:
# Create a data provider for the dataset

data_provider = DataProvider(
    dataset = dataset,
    skip_validation = True,
    batch_size = configs.batch_size,
    data_preprocessors = [ImageReader(CVImage)],
    transformers = [
        ImageResizer(height = configs.height, width = configs.width, keep_aspect_ratio=True),
        LabelIndexer(vocab = configs.vocab),
        LabelPadding(max_word_length = configs.max_text_length, padding_value=len(configs.vocab)),
    ],
)

In [10]:
# split the dataset into training and validation sets

train_dataset, validation_dataset = data_provider.split(0.9)

In [11]:
# Augment the training dataset

train_dataset.augmentors = [
    RandomBrightness(),
    RandomErodeDilate(),
    RandomSharpen(),
]   

In [12]:
# define an activation layer

def activation_layer(layer, activation: str="relu", alpha: float=0.1) -> tf.Tensor:
    """ Activation layer wrapper for LeakyReLU and ReLU activation functions
    Args:
        layer: tf.Tensor
        activation: str, activation function name (default: 'relu')
        alpha: float (LeakyReLU activation function parameter)
    Returns:
        tf.Tensor
    """
    if activation == "relu":
        layer = layers.ReLU()(layer)
    elif activation == "leaky_relu":
        layer = layers.LeakyReLU(alpha=alpha)(layer)

    return layer

In [13]:
# define a residual block

def residual_block(
        x: tf.Tensor,
        filter_num: int,
        strides: typing.Union[int, list] = 2,
        kernel_size: typing.Union[int, list] = 3,
        skip_conv: bool = True,
        padding: str = "same",
        kernel_initializer: str = "he_uniform",
        activation: str = "relu",
        dropout: float = 0.2):
    # Create skip connection tensor
    x_skip = x

    # Perform 1-st convolution
    x = layers.Conv2D(filter_num, kernel_size, padding = padding, strides = strides, kernel_initializer=kernel_initializer)(x)
    x = layers.BatchNormalization()(x)
    x = activation_layer(x, activation=activation)

    # Perform 2-nd convolution
    x = layers.Conv2D(filter_num, kernel_size, padding = padding, kernel_initializer=kernel_initializer)(x)
    x = layers.BatchNormalization()(x)

    # Perform 3-rd convolution if skip_conv is True, matchin the number of filters and the shape of the skip connection tensor
    if skip_conv:
        x_skip = layers.Conv2D(filter_num, 1, padding = padding, strides = strides, kernel_initializer=kernel_initializer)(x_skip)

    # Add x and skip connection and apply activation function
    x = layers.Add()([x, x_skip])     
    x = activation_layer(x, activation=activation)

    # Apply dropout
    if dropout:
        x = layers.Dropout(dropout)(x)

    return x

In [14]:
# define model architecture

def build_model(input_dim, output_dim, activation="leaky_relu", dropout=0.2):
    
    inputs = layers.Input(shape=input_dim, name="input")

    # normalize images here instead in preprocessing step
    input = layers.Lambda(lambda x: x / 255)(inputs)

    x1 = residual_block(input, 32, activation=activation, skip_conv=True, strides=1, dropout=dropout)

    x2 = residual_block(x1, 32, activation=activation, skip_conv=True, strides=2, dropout=dropout)
    x3 = residual_block(x2, 32, activation=activation, skip_conv=False, strides=1, dropout=dropout)

    x4 = residual_block(x3, 64, activation=activation, skip_conv=True, strides=2, dropout=dropout)
    x5 = residual_block(x4, 64, activation=activation, skip_conv=False, strides=1, dropout=dropout)

    x6 = residual_block(x5, 128, activation=activation, skip_conv=True, strides=2, dropout=dropout)
    x7 = residual_block(x6, 128, activation=activation, skip_conv=True, strides=1, dropout=dropout)

    x8 = residual_block(x7, 128, activation=activation, skip_conv=True, strides=2, dropout=dropout)
    x9 = residual_block(x8, 128, activation=activation, skip_conv=False, strides=1, dropout=dropout)

    squeezed = layers.Reshape((x9.shape[-3] * x9.shape[-2], x9.shape[-1]))(x9)

    blstm = layers.Bidirectional(layers.LSTM(256, return_sequences=True))(squeezed)
    blstm = layers.Dropout(dropout)(blstm)

    blstm = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(blstm)
    blstm = layers.Dropout(dropout)(blstm)

    output = layers.Dense(output_dim + 1, activation="softmax", name="output")(blstm)

    model = Model(inputs=inputs, outputs=output)
    return model

In [15]:
# create a model

model = build_model(
    input_dim = (configs.height, configs.width, configs.channels),
    output_dim = len(configs.vocab),
    activation = "leaky_relu",
)





In [16]:
# compile the model and print the summary

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=configs.learning_rate), 
    loss=CTCloss(), 
    metrics=[
        CERMetric(vocabulary=configs.vocab),
        WERMetric(vocabulary=configs.vocab)
        ],
    run_eagerly=False
)
model.summary(line_length=110)

Model: "model"
______________________________________________________________________________________________________________
 Layer (type)                    Output Shape                     Param #    Connected to                     
 input (InputLayer)              [(None, 96, 1408, 3)]            0          []                               


                                                                                                              
 lambda (Lambda)                 (None, 96, 1408, 3)              0          ['input[0][0]']                  
                                                                                                              
 conv2d (Conv2D)                 (None, 96, 1408, 32)             896        ['lambda[0][0]']                 
                                                                                                              
 batch_normalization (BatchNorm  (None, 96, 1408, 32)             128        ['conv2d[0][0]']                 
 alization)                                                                                                   
                                                                                                              
 leaky_re_lu (LeakyReLU)         (None, 96, 1408, 32)             0          ['batch_normalization[0][0]']    
 

In [17]:
# Define callbacks

earlystopper = EarlyStopping(monitor="val_CER", patience=20, verbose=1, mode="min")
checkpoint = ModelCheckpoint(f"{configs.model_path}/model.h5", monitor="val_CER", verbose=1, save_best_only=True, mode="min")
trainLogger = TrainLogger(configs.model_path)
tb_callback = TensorBoard(f"{configs.model_path}/logs", update_freq=1)
reduceLROnPlat = ReduceLROnPlateau(monitor="val_CER", factor=0.9, min_delta=1e-10, patience=5, verbose=1, mode="auto")
model2onnx = Model2onnx(f"{configs.model_path}/model.h5")





In [18]:
# train the model

model.fit(
    train_dataset,
    validation_data=validation_dataset,
    epochs=configs.train_epochs,
    callbacks=[earlystopper, checkpoint, trainLogger, tb_callback, reduceLROnPlat, model2onnx],
    verbose=1,
    workers=configs.train_workers,
)

Epoch 1/1000

  4/472 [..............................] - ETA: 8:56:07 - loss: 1356.7601 - CER: 4.1739 - WER: 1.0267