Skip to content

Commit

Permalink
Add ocr example (#91)
Browse files Browse the repository at this point in the history
* add OCR example for Captcha images

* format the code

* use tfdataset for preprocessing, modify model inputs and refactor

* remove seed, rename network i/o

* add generated files
  • Loading branch information
AakashKumarNain committed Jun 26, 2020
1 parent 6999246 commit 4ce9f51
Show file tree
Hide file tree
Showing 5 changed files with 1,498 additions and 0 deletions.
340 changes: 340 additions & 0 deletions examples/vision/captcha_ocr.py
@@ -0,0 +1,340 @@
"""
Title: OCR model for reading captcha
Author: [A_K_Nain](https://twitter.com/A_K_Nain)
Date created: 2020/06/14
Last modified: 2020/06/14
Description: How to implement an OCR model using CNNs, RNNs and CTC loss.
"""

"""
## Introduction
This example demonstrates a simple OCR model using Functional API. Apart from
combining CNN and RNN, it also illustrates how you can instantiate a new layer
and use it as an `Endpoint` layer for implementing CTC loss. For a detailed
description on layer subclassing, please check out this
[example](https://keras.io/guides/making_new_layers_and_models_via_subclassing/#the-addmetric-method)
in the developer guides.
"""

"""
## Setup
"""

import os
import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path
from collections import Counter

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers


"""
## Load the data: [Captcha Images](https://www.kaggle.com/fournierp/captcha-version-2-images)
Let's download the data.
"""


"""shell
curl -LO https://github.com/AakashKumarNain/CaptchaCracker/raw/master/captcha_images_v2.zip
unzip -qq captcha_images_v2.zip
"""


"""
The dataset contains 1040 captcha files as png images. The label for each sample is the
name of the file (excluding the '.png' part). The label for each sample is a string.
We will map each character in the string to a number for training the model. Similary,
we would be required to map the predictions of the model back to string. For this purpose
would maintain two dictionary mapping characters to numbers and numbers to characters
respectively.
"""


# Path to the data directory
data_dir = Path("./captcha_images_v2/")

# Get list of all the images
images = sorted(list(map(str, list(data_dir.glob("*.png")))))
labels = [img.split(os.path.sep)[-1].split(".png")[0] for img in images]
characters = set(char for label in labels for char in label)

print("Number of images found: ", len(images))
print("Number of labels found: ", len(labels))
print("Number of unique characters: ", len(characters))
print("Characters present: ", characters)

# Batch size for training and validation
batch_size = 16

# Desired image dimensions
img_width = 200
img_height = 50

# Factor by which the image is going to be downsampled
# by the convolutional blocks. We will be using two
# convolution blocks and each convolution block will have
# a pooling layer which downsample the features by a factor of 2.
# Hence total downsampling factor would be 4.
downsample_factor = 4

# Maximum length of any captcha in the dataset
max_length = max([len(label) for label in labels])


"""
## Preprocessing
"""


# Mapping characters to numbers
char_to_num = layers.experimental.preprocessing.StringLookup(
vocabulary=list(characters), num_oov_indices=0, mask_token=None
)

# Mapping numbers back to original characters
num_to_char = layers.experimental.preprocessing.StringLookup(
vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True
)


def split_data(images, labels, train_size=0.9, shuffle=True):
# 1. Get the total size of the dataset
size = len(images)
# 2. Make an indices array and shuffle it, if required
indices = np.arange(size)
if shuffle:
np.random.shuffle(indices)
# 3. Get the size of training samples
train_samples = int(size * train_size)
# 4. Split data into training and validation sets
x_train, y_train = images[indices[:train_samples]], labels[indices[:train_samples]]
x_valid, y_valid = images[indices[train_samples:]], labels[indices[train_samples:]]
return x_train, x_valid, y_train, y_valid


# Splitting data into training and validation sets
x_train, x_valid, y_train, y_valid = split_data(np.array(images), np.array(labels))


def encode_single_sample(img_path, label):
# 1. Read image
img = tf.io.read_file(img_path)
# 2. Decode and convert to grayscale
img = tf.io.decode_png(img, channels=1)
# 3. Convert to float32 in [0, 1] range
img = tf.image.convert_image_dtype(img, tf.float32)
# 4. Resize to the desired size
img = tf.image.resize(img, [img_height, img_width])
# 5. Transpose the image because we want the time
# dimension to correspond to the width of the image.
img = tf.transpose(img, perm=[1, 0, 2])
# 6. Map the characters in label to numbers
label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8"))
# 7. Return a dict as our model is expecting two inputs
return {"image": img, "label": label}


"""
## Data Generators
"""


train_data_generator = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_data_generator = (
train_data_generator.map(
encode_single_sample, num_parallel_calls=tf.data.experimental.AUTOTUNE
)
.batch(batch_size)
.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
)

valid_data_generator = tf.data.Dataset.from_tensor_slices((x_valid, y_valid))
valid_data_generator = (
valid_data_generator.map(
encode_single_sample, num_parallel_calls=tf.data.experimental.AUTOTUNE
)
.batch(batch_size)
.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
)

"""
## Visualize the data
"""


_, ax = plt.subplots(4, 4, figsize=(10, 5))
for batch in train_data_generator.take(1):
images = batch["image"]
labels = batch["label"]
for i in range(16):
img = (images[i] * 255).numpy().astype("uint8")
label = tf.strings.reduce_join(num_to_char(labels[i])).numpy().decode("utf-8")
ax[i // 4, i % 4].imshow(img[:, :, 0].T, cmap="gray")
ax[i // 4, i % 4].set_title(label)
ax[i // 4, i % 4].axis("off")
plt.show()

"""
## Model
"""


class CTCLayer(layers.Layer):
def __init__(self, name=None):
super().__init__(name=name)
self.loss_fn = keras.backend.ctc_batch_cost

def call(self, y_true, y_pred):
# Compute the training-time loss value and add it
# to the layer using `self.add_loss()`.
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")

input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")

loss = self.loss_fn(y_true, y_pred, input_length, label_length)
self.add_loss(loss)

# On test time, just return the computed loss
return y_pred


def build_model():
# Inputs to the model
input_img = layers.Input(
shape=(img_width, img_height, 1), name="image", dtype="float32"
)
labels = layers.Input(name="label", shape=(None,), dtype="float32")

# First conv block
x = layers.Conv2D(
32,
(3, 3),
activation="relu",
kernel_initializer="he_normal",
padding="same",
name="Conv1",
)(input_img)
x = layers.MaxPooling2D((2, 2), name="pool1")(x)

# Second conv block
x = layers.Conv2D(
64,
(3, 3),
activation="relu",
kernel_initializer="he_normal",
padding="same",
name="Conv2",
)(x)
x = layers.MaxPooling2D((2, 2), name="pool2")(x)

# We have used two max pool with pool size and strides of 2.
# Hence, downsampled feature maps are 4x smaller. The number of
# filters in the last layer is 64. Reshape accordingly before
# passing it to RNNs
new_shape = ((img_width // 4), (img_height // 4) * 64)
x = layers.Reshape(target_shape=new_shape, name="reshape")(x)
x = layers.Dense(64, activation="relu", name="dense1")(x)
x = layers.Dropout(0.2)(x)

# RNNs
x = layers.Bidirectional(layers.LSTM(128, return_sequences=True, dropout=0.2))(x)
x = layers.Bidirectional(layers.LSTM(64, return_sequences=True, dropout=0.25))(x)

# Output layer
x = layers.Dense(len(characters) + 1, activation="softmax", name="dense2")(x)

# Add CTC layer for calculating CTC loss at each step
output = CTCLayer(name="ctc_loss")(labels, x)

# Define the model
model = keras.models.Model(
inputs=[input_img, labels], outputs=output, name="ocr_model_v1"
)
# Optimizer
opt = keras.optimizers.Adam()
# Compile the model and return
model.compile(optimizer=opt)
return model


# Get the model
model = build_model()
model.summary()

"""
## Training
"""


epochs = 100
es_patience = 10
# Add early stopping
es = keras.callbacks.EarlyStopping(
monitor="val_loss", patience=es_patience, restore_best_weights=True
)

# Train the model
history = model.fit(
train_data_generator,
validation_data=valid_data_generator,
epochs=epochs,
callbacks=[es],
)


"""
## Let's test-drive it
"""


# Get the prediction model by extracting layers till the output layer
prediction_model = keras.models.Model(
model.get_layer(name="image").input, model.get_layer(name="dense2").output
)
prediction_model.summary()

# A utility function to decode the output of the network
def decode_batch_predictions(pred):
input_len = np.ones(pred.shape[0]) * pred.shape[1]
# Use greedy search. For complex tasks, you can use beam search
results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
:, :max_length
]
# Iterate over the results and get back the text
output_text = []
for res in results:
res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
output_text.append(res)
return output_text


# Let's check results on some validation samples
for batch in valid_data_generator.take(1):
batch_images = batch["image"]
batch_labels = batch["label"]

preds = prediction_model.predict(batch_images)
pred_texts = decode_batch_predictions(preds)

orig_texts = []
for label in batch_labels:
label = tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8")
orig_texts.append(label)

_, ax = plt.subplots(4, 4, figsize=(15, 5))
for i in range(len(pred_texts)):
img = (batch_images[i, :, :, 0] * 255).numpy().astype(np.uint8)
img = img.T
title = f"Prediction: {pred_texts[i]}"
ax[i // 4, i % 4].imshow(img, cmap="gray")
ax[i // 4, i % 4].set_title(title)
ax[i // 4, i % 4].axis("off")
plt.show()
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 4ce9f51

Please sign in to comment.