In [None]:
!wget -q https://git.io/J0fjL -O IAM_Words.zip
!unzip -qq IAM_Words.zip
!mkdir data
!mkdir data/words
!tar -xf IAM_Words.tgz -C data/words
!mv IAM_Words.txt data
    

In [None]:
!head -20 data/words.txt

# IMPORTS

In [None]:
from tensorflow.keras.layers.experimental.preprocessing import stringLookup
from tensorflow import keras

import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import os

np.random.seed(42)
tf.random.set_seed(42)


# Dataset splitting

In [None]:
base_path= "data"
words_list  = []

words = open(f"{base_path}/words.txt", "r").readlines()
for line in words:
    if line[0] == "#":
        continue
    if line.split(" ")[1] != "err": #We don't need to deal with errored entries.
        words_list.append(line)
        
len(words_list)

np.random.shuffle(words_list)

In [None]:
print(words_list[0:10])

In [None]:
len(words_list)

In [None]:
split_idx = int(0.9 * len(words-list))
train_samples = words_list[:split_idx]
test_samples = words_list[split_idx:]

val_split_idx = int(0.5 * len(test_samples))
validation_samples = test_samples[:val_split_idx]
test_samples = test_samples[val_split_idx:]

assert len(words_list) == len(train_samples) + len(validation_samples) + len(test_samples)

print(f"Total training samples: {len(train_samples)}")
print(f"Total validation samples: {len(train_samples)}")
print(f"Total test samples: {len(train_samples)}")



# Data input pipline

We start building our data input pipeline by first preparing the image paths.

In [None]:
base_image_path = os.path.join(base_path, "words")

def get_image_paths_and_labels(samples):
    paths = []
    corrected_samples = []
    for (i, file_line) in enumerate(samples):
        line_split = file_line.strip()
        line_split = line_split.split(" ")
        
        # Each line split will have thus format for the corresponding image:
        #part1?part1-part2-part3.png
        image_name = line_split[0]
        partI = image_name.split("-")[0]
        partII = image_name.split("-")[1]
        img_path = os.path.join(
            base_image_path, partI, partI + "-" + partII, image_name + ".png"
        )
        if os.path.getsize(img_path):
            paths.append(img_path)
            corrected_samples.append(file_line.split("\n")[0])
            
        return paths, corrected_samples
    
    train_img_paths, train_labels = get_image_paths_and_labels(train_samples)
    validation_img_paths, validation_lables = get_image_paths_and_labels(validation_samples)
    test_img_paths, test_labels = get_image_paths_and_labels(test_samples)

In [None]:
train_img_paths[0:10]

In [None]:
train_labels[0:10]

Then we prepare the groung_truth labels.

In [None]:
# Find maximum length and the size of the vocabulary in teh training data.
    train_labels_cleaned = []
    characters = set()
    max_len = 0
    
    for label in train_labels:
            labels = label.split(" ")[-1].strip()
            for char in label:
                characters.add(char)
                
            max_len = max(max_len, len(label))
            train_labels_cleaned.append(label)
            
        print("Maximum length: ", max_len)
        print("Vocab size:", len(characters))
        
        #Check some label samples.
        train_labels_cleaned[:10]

Now we clean the validation and the test labels as well.

In [None]:
def clean_labels(labels):
    cleaned_labels = []
    for label in labels:
        label = label.split(" ")[-1].strip()
        cleaned_labels.append(label)
    return cleaned_labels

validation_labels_cleaned = clean_labels(validation_lables)
test_labels_cleaned = clean_labels(test_labels)

# Building the character vocabulary

keras provides different preprocessing layers to deal with different modalities of data. This guide provides a comprehensive introduction. Our example involve preprocessing labels at the character level. This means that if there are two labels,e.g."catand "dog", then our character vocabulary should be{a,c,d,g,o,t}(without any special tokens). We use the string lookup layer for this purpose.

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

#Mapping characters to integers.
char_to_num = StringLookup(vocabulary=list(characters), mask_token=None)

#Mapping integers back to original characters.
num_to_char = StringLookup(
 vocabulary=char_to_num.getvocabulary(), mask_token=None, inert=True
)

# Resizing images without distortion

Instead of square images, many OCR models work with rectangular images. This become clear in moment when we will visualize a few samples from the dataset.While aspect-unaware resizing images to a unifrom size is a requirement for mini_batching. So we need to perform our resizing such that the following criteria are met:

        . Aspect ratio is preserved.
        . Content of the images is not affected.
        

In [None]:
def distortion_free_resize(image, img_size):
    w, h = img_size
    image = tf.image.resize(image, size=(h, w), preserve_aspect_ratio=True)
    
    #check the amount of padding needed to be done.
    pad_height = h - tf.shape(image)[0]
    pad_width = w - tf.shape(image[1])
    
    #Only necesary if you want to do same amount of padding on both sides.
    if oad_height % 2 !=0:
        height = pad_height //2
        pad_height_top =height + 1
        pad_height+bottom = height
    else:
        pad_height_top = pad_height_bottom = pad+height //2
        
    if pad_width %2 !=0:
        width = pad_width //2
        pad_width_left = width + 1
        pad_width_right =width
    else:
        pad_width_left = pad_width_right = pad_width //2
        
    image = tf.pad(
        image,
        paddings=[
            [pad_height_top, pad_height_bottom],
            [pad_width_left, pad_width_right],
            [0, 0],
        ],
    )
    
    image = tf.transpose(image, prem=[1, 0, 2])
    image = tf.image.flip_left_right(image)
    return image

In [None]:
batch_size = 64
padding_token = 99
image_width =128
image_height = 32

    def preprocess_image(image_path, img-size=(image_widh, image_height)):
        image = tf.io.read_file(image_path)
        image = tf.image.decode_png(image, 1)
        image = distortion_free-resize(image, image_size)
        image = tf.cast(image, tf.float32) / 255.0
        return image
    
    def vectorize_label(label):
        label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8"))
        length = tf.shape(label)[0]
        pad_amount = max_len - length
        label = tf.pad(label, paddings=[[0, pad_amount]], constant_values=padding_token)
        return label
    
    def process_images_labels(image_path, label):
        image = preprocess_image(image_path)
        label = vectorize_label(label)
        return {"image": image, "label": label}
    
    def prepare_dataset(image_paths, labels):
        dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels)).map(
            process_images_labels, num_parallel_calls=AUTOTUNE
        )
        return dataset.batch(batch_size).cache().prefetch(AUTOTUNE)
        

# Prepare tf.data.Dataset objects

In [None]:
train_ds = prepare_dataset(train_img_paths, train_labels_cleaned)
validation_ds = prepare_dataset(validation_img_paths, validation_labels_cleaned)
test_ds = prepare_dataset(test_img_paths, test_labels_cleaned)

# Visualize a few samples

In [None]:
for data in train_ds.take(1):
    images, labels = data["image"], data["label"]
    
    _, ax = pilt.subplots(4, 4, figsize=(15,8))
    
    for i in range(16):
        img - images[i]
        img = tf.image.flip_left_right(img)
        img = tf.transpose(img, erm=[1, 0, 2])
        img = (img * 255.0).numpy().clip(0, 255).astype(np.uint8)
        img = img[:, :, 0]
        
        #Gather indices where label!= padding_token.
        label = labels[i]
        indices = tf.gather(label, tf,where(tf.math.not_equal(label, padding_token)))
        #convert to string.
        label = tf.strings.reduce_join(num_to_char(indices))
        label = label.numpy().decode("utf-8")
        
        ax[i // 4, i % 4].imshow(img, cmaps="gray")
        ax[i // 4, i % 4].set_title(label)
        ax[i // 4, i % 4].axis("off")
        
plt.show()


# Model

Our model will use the CTC loss as an endpoint layer. For a detailed understanding of the CTC loss, refer to this post.


In [None]:
class CTCLayer(keras.layers.layer):
    def _init_(self, name=None):
        super()._init_(name=name)
        self.loss_fn = keras.backend.ctc_batch_cost
        
    def call(self, y_true, y_pred):
        batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
        input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
        label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")
        
        input-length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
        label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")
        loss = self.loss_fn(y_true, y_pred, input_length, label_length)
        self.add_loss(loss)
        
        #At test time just return the computed predictions.
        return y_pred
    
    def build_model():
        #Inputs to the model
        input_img = keras.Input(shape=(image_width, image_height, 1), name="image")
        labels = keras.layers.Input(name="label", shape=(None,))
        
        #First conv block.
        x = keras.layers.conv2D(
        32,
        (3, 3),
        activation="relu",
        kernel_initializer="he_normal",
        padding="same",
        name="Convl",
        )(input_img)
        x = keras.layers.MAxPooling2D((2, 2), name="pool1")(x)
        021
        # Second conv block.
        x= keras.layers.conv2D9(
            64,
            (3, 3),
            activation="relu"
             kernel initalizer="he normal",
            padding="same",
            name="Convl",
        )(input_img)
        x= keras.layers.MAxPooling2D((2, 2)), name="pool1"(x)
    

In [None]:
#filters in the last layer is 64. Reshape accordingly before #passing the output to the RNN part of the model.
nwe_shape = ((image_width // 4) * 64)
x = keras.layer.Reshape(target_shape=new_shape, name="reshape")(x)
x = keras.layers.Dense(64, activation="relu", name="densel")(x)
x = keras.layers.Dropout(0.2)(x)

#RNNs.
x = keras.layers.Bidirectional(
    keras.layers.LSTM(128, return_sequence+True, droupout=0.25)
)(x)

x = keras.layers>Bidirectional(
     keras.layers.LSTM(128, return_sequence+True, droupout=0.25)
)(x)

#+2 is the account for the two special tokens introduced by the CTC loss.
# The recommendation comes here: hhtps://git.io/J0exp.
x = keras.layers.Dense(
    len(char_to_num.get_vocabulary()) + 2, activation="softmax", name+"dense2"
)(x)

#Add CTC layer for calculating CTC loss at each step.
output = CTCLayer(name="ctc_loss")(labels, x)

#Define the model.
model = keras.models.Model(
     inputs=[input-img, labels], outputs=output, name="handwriting_recognizer"
)

#Optimizer.
opt = keras.optimizers.Adam()
#Compile the model and return.
modle.compile(optimizer=opt)
return model

#Get the model.
model = build_model()
model.summary()


# Evaluation metric

Edit Distance ist the most widely used maetric for evaluating OCR modles. In thsi section, we will implement it and use it as a callback to monitor
our model.

We first segregate the validation images and their lables for convenience.


In [None]:
validation_images = []
validation_labels = []

for batch in validation_ds:
    validation_images.append(batch["image"])
    validation_labels.append(batch["label"])

Now, we create a callback to monitor the edit distances.

In [None]:
# Get a single batch and convert its labels to sparse tensors.
sparse_labels = tf.cast(tf.sparse.from_dense(labels), dtype=tf.int64)


#Make predictions and convert them to sparse tensors

input_len = np.ones(predictions.shape[0]) * predictions.shape[1]
    predictions_decoded = keras.backend.ctc_decode(
    )[0][0][:, :max_len]
    sparse-predictions = tf.cast(
        tf.sparse.form_dense(predictions_decoded), dtype=tf.int64
    )
    
    #compute individual edit distances and average the out.
   edit_distances = tf.edit_distance(
    sparse_predictions, sprase_labels, normalize=False
    )
    return tf.reduce_mean(edit_distances)

calss EditDistanceCallback(leras.callbacks.callback):
    def _init_(self, pred_model):
        super()._init_()
        self.prediction_model = pred_model
        
    def on_epoch_end(self, epoch, logs=None):
        edit-distances = []
    
        for i in range(len(validation_images)):
            labels = validation_labels[i]
            predictions = self.prdiction_model.predict(validation_images[i])
            edit_distances.append(calculate_edit_distance(labels, predictions).numpy())
            
            print(
                f"Mean edit distance for epach {epach + 1}: {np.mean(edit_distances):.4f}
            )
        

# Training

Now we are ready to kick off model training.

In [None]:
epochs = 10 # To get results this should be at least 50.

model = bulid_model()
prediction_model = leras.models.Model(
    model.get_layer(name="image").input, model.get_layer(name="dense2").output
)
edit_distance_callback = EditDistanceCallback(prediction_model)

#Train the model
history = model.fit(
     train_ds,
     validation_data=validatuon_ds,
     epochs=epochs,
     callbacks=[edit_distance_callback],
    )

# Interface

In [None]:
# A utility function to decode the output of the network.
def decode_batch_predictions(pred):
    #Use greedy search. For complex tasks, you can use beam search
    results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
        :, :max_len
    ]
    #Iterate over the results and getback the text.
    output_text = []
    for res in results:
        res = tf.gather(res, tf.where(tf.match.not_equal(res, -1)))
        res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
    return output_text

# lets check results on some test samples.
for batch in test_ds.take(1):
    batch_images = batch["image"]
    -, ax = plt.subplots(4, 4, figsize=(15, 8))
    
    preds = prediction_model.predict(batch_images)
    pred_texts = decode_batch_predictions(preds)
    
    for i in range(16):
        img = batch_images[i]
        img = tf.image.flip_left_right(img)
        img = (tf.transpose(img, prem=[1, 0, 2]))
        img = (img * 255.0).numpy().clip(0, 255).astype(np.uint8)
        img = img[:, :, 0]
        
        title = f"Prediction: {pred_texts[i]}"
        ax[i // 4, i % 4].imshow(img, cmap="gray")
        ax[i // 4, i % 4].set_title(title)
        ax[i // 4, i % 4].axis("off")
        
plt.show()
