Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add OCR example for Captcha images * format the code * use tfdataset for preprocessing, modify model inputs and refactor * remove seed, rename network i/o * add generated files
- Loading branch information
1 parent
6999246
commit 4ce9f51
Showing
5 changed files
with
1,498 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,340 @@ | ||
""" | ||
Title: OCR model for reading captcha | ||
Author: [A_K_Nain](https://twitter.com/A_K_Nain) | ||
Date created: 2020/06/14 | ||
Last modified: 2020/06/14 | ||
Description: How to implement an OCR model using CNNs, RNNs and CTC loss. | ||
""" | ||
|
||
""" | ||
## Introduction | ||
This example demonstrates a simple OCR model using Functional API. Apart from | ||
combining CNN and RNN, it also illustrates how you can instantiate a new layer | ||
and use it as an `Endpoint` layer for implementing CTC loss. For a detailed | ||
description on layer subclassing, please check out this | ||
[example](https://keras.io/guides/making_new_layers_and_models_via_subclassing/#the-addmetric-method) | ||
in the developer guides. | ||
""" | ||
|
||
""" | ||
## Setup | ||
""" | ||
|
||
import os | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
from pathlib import Path | ||
from collections import Counter | ||
|
||
import tensorflow as tf | ||
from tensorflow import keras | ||
from tensorflow.keras import layers | ||
|
||
|
||
""" | ||
## Load the data: [Captcha Images](https://www.kaggle.com/fournierp/captcha-version-2-images) | ||
Let's download the data. | ||
""" | ||
|
||
|
||
"""shell | ||
curl -LO https://github.com/AakashKumarNain/CaptchaCracker/raw/master/captcha_images_v2.zip | ||
unzip -qq captcha_images_v2.zip | ||
""" | ||
|
||
|
||
""" | ||
The dataset contains 1040 captcha files as png images. The label for each sample is the | ||
name of the file (excluding the '.png' part). The label for each sample is a string. | ||
We will map each character in the string to a number for training the model. Similary, | ||
we would be required to map the predictions of the model back to string. For this purpose | ||
would maintain two dictionary mapping characters to numbers and numbers to characters | ||
respectively. | ||
""" | ||
|
||
|
||
# Path to the data directory | ||
data_dir = Path("./captcha_images_v2/") | ||
|
||
# Get list of all the images | ||
images = sorted(list(map(str, list(data_dir.glob("*.png"))))) | ||
labels = [img.split(os.path.sep)[-1].split(".png")[0] for img in images] | ||
characters = set(char for label in labels for char in label) | ||
|
||
print("Number of images found: ", len(images)) | ||
print("Number of labels found: ", len(labels)) | ||
print("Number of unique characters: ", len(characters)) | ||
print("Characters present: ", characters) | ||
|
||
# Batch size for training and validation | ||
batch_size = 16 | ||
|
||
# Desired image dimensions | ||
img_width = 200 | ||
img_height = 50 | ||
|
||
# Factor by which the image is going to be downsampled | ||
# by the convolutional blocks. We will be using two | ||
# convolution blocks and each convolution block will have | ||
# a pooling layer which downsample the features by a factor of 2. | ||
# Hence total downsampling factor would be 4. | ||
downsample_factor = 4 | ||
|
||
# Maximum length of any captcha in the dataset | ||
max_length = max([len(label) for label in labels]) | ||
|
||
|
||
""" | ||
## Preprocessing | ||
""" | ||
|
||
|
||
# Mapping characters to numbers | ||
char_to_num = layers.experimental.preprocessing.StringLookup( | ||
vocabulary=list(characters), num_oov_indices=0, mask_token=None | ||
) | ||
|
||
# Mapping numbers back to original characters | ||
num_to_char = layers.experimental.preprocessing.StringLookup( | ||
vocabulary=char_to_num.get_vocabulary(), mask_token=None, invert=True | ||
) | ||
|
||
|
||
def split_data(images, labels, train_size=0.9, shuffle=True): | ||
# 1. Get the total size of the dataset | ||
size = len(images) | ||
# 2. Make an indices array and shuffle it, if required | ||
indices = np.arange(size) | ||
if shuffle: | ||
np.random.shuffle(indices) | ||
# 3. Get the size of training samples | ||
train_samples = int(size * train_size) | ||
# 4. Split data into training and validation sets | ||
x_train, y_train = images[indices[:train_samples]], labels[indices[:train_samples]] | ||
x_valid, y_valid = images[indices[train_samples:]], labels[indices[train_samples:]] | ||
return x_train, x_valid, y_train, y_valid | ||
|
||
|
||
# Splitting data into training and validation sets | ||
x_train, x_valid, y_train, y_valid = split_data(np.array(images), np.array(labels)) | ||
|
||
|
||
def encode_single_sample(img_path, label): | ||
# 1. Read image | ||
img = tf.io.read_file(img_path) | ||
# 2. Decode and convert to grayscale | ||
img = tf.io.decode_png(img, channels=1) | ||
# 3. Convert to float32 in [0, 1] range | ||
img = tf.image.convert_image_dtype(img, tf.float32) | ||
# 4. Resize to the desired size | ||
img = tf.image.resize(img, [img_height, img_width]) | ||
# 5. Transpose the image because we want the time | ||
# dimension to correspond to the width of the image. | ||
img = tf.transpose(img, perm=[1, 0, 2]) | ||
# 6. Map the characters in label to numbers | ||
label = char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8")) | ||
# 7. Return a dict as our model is expecting two inputs | ||
return {"image": img, "label": label} | ||
|
||
|
||
""" | ||
## Data Generators | ||
""" | ||
|
||
|
||
train_data_generator = tf.data.Dataset.from_tensor_slices((x_train, y_train)) | ||
train_data_generator = ( | ||
train_data_generator.map( | ||
encode_single_sample, num_parallel_calls=tf.data.experimental.AUTOTUNE | ||
) | ||
.batch(batch_size) | ||
.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) | ||
) | ||
|
||
valid_data_generator = tf.data.Dataset.from_tensor_slices((x_valid, y_valid)) | ||
valid_data_generator = ( | ||
valid_data_generator.map( | ||
encode_single_sample, num_parallel_calls=tf.data.experimental.AUTOTUNE | ||
) | ||
.batch(batch_size) | ||
.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) | ||
) | ||
|
||
""" | ||
## Visualize the data | ||
""" | ||
|
||
|
||
_, ax = plt.subplots(4, 4, figsize=(10, 5)) | ||
for batch in train_data_generator.take(1): | ||
images = batch["image"] | ||
labels = batch["label"] | ||
for i in range(16): | ||
img = (images[i] * 255).numpy().astype("uint8") | ||
label = tf.strings.reduce_join(num_to_char(labels[i])).numpy().decode("utf-8") | ||
ax[i // 4, i % 4].imshow(img[:, :, 0].T, cmap="gray") | ||
ax[i // 4, i % 4].set_title(label) | ||
ax[i // 4, i % 4].axis("off") | ||
plt.show() | ||
|
||
""" | ||
## Model | ||
""" | ||
|
||
|
||
class CTCLayer(layers.Layer): | ||
def __init__(self, name=None): | ||
super().__init__(name=name) | ||
self.loss_fn = keras.backend.ctc_batch_cost | ||
|
||
def call(self, y_true, y_pred): | ||
# Compute the training-time loss value and add it | ||
# to the layer using `self.add_loss()`. | ||
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64") | ||
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64") | ||
label_length = tf.cast(tf.shape(y_true)[1], dtype="int64") | ||
|
||
input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64") | ||
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64") | ||
|
||
loss = self.loss_fn(y_true, y_pred, input_length, label_length) | ||
self.add_loss(loss) | ||
|
||
# On test time, just return the computed loss | ||
return y_pred | ||
|
||
|
||
def build_model(): | ||
# Inputs to the model | ||
input_img = layers.Input( | ||
shape=(img_width, img_height, 1), name="image", dtype="float32" | ||
) | ||
labels = layers.Input(name="label", shape=(None,), dtype="float32") | ||
|
||
# First conv block | ||
x = layers.Conv2D( | ||
32, | ||
(3, 3), | ||
activation="relu", | ||
kernel_initializer="he_normal", | ||
padding="same", | ||
name="Conv1", | ||
)(input_img) | ||
x = layers.MaxPooling2D((2, 2), name="pool1")(x) | ||
|
||
# Second conv block | ||
x = layers.Conv2D( | ||
64, | ||
(3, 3), | ||
activation="relu", | ||
kernel_initializer="he_normal", | ||
padding="same", | ||
name="Conv2", | ||
)(x) | ||
x = layers.MaxPooling2D((2, 2), name="pool2")(x) | ||
|
||
# We have used two max pool with pool size and strides of 2. | ||
# Hence, downsampled feature maps are 4x smaller. The number of | ||
# filters in the last layer is 64. Reshape accordingly before | ||
# passing it to RNNs | ||
new_shape = ((img_width // 4), (img_height // 4) * 64) | ||
x = layers.Reshape(target_shape=new_shape, name="reshape")(x) | ||
x = layers.Dense(64, activation="relu", name="dense1")(x) | ||
x = layers.Dropout(0.2)(x) | ||
|
||
# RNNs | ||
x = layers.Bidirectional(layers.LSTM(128, return_sequences=True, dropout=0.2))(x) | ||
x = layers.Bidirectional(layers.LSTM(64, return_sequences=True, dropout=0.25))(x) | ||
|
||
# Output layer | ||
x = layers.Dense(len(characters) + 1, activation="softmax", name="dense2")(x) | ||
|
||
# Add CTC layer for calculating CTC loss at each step | ||
output = CTCLayer(name="ctc_loss")(labels, x) | ||
|
||
# Define the model | ||
model = keras.models.Model( | ||
inputs=[input_img, labels], outputs=output, name="ocr_model_v1" | ||
) | ||
# Optimizer | ||
opt = keras.optimizers.Adam() | ||
# Compile the model and return | ||
model.compile(optimizer=opt) | ||
return model | ||
|
||
|
||
# Get the model | ||
model = build_model() | ||
model.summary() | ||
|
||
""" | ||
## Training | ||
""" | ||
|
||
|
||
epochs = 100 | ||
es_patience = 10 | ||
# Add early stopping | ||
es = keras.callbacks.EarlyStopping( | ||
monitor="val_loss", patience=es_patience, restore_best_weights=True | ||
) | ||
|
||
# Train the model | ||
history = model.fit( | ||
train_data_generator, | ||
validation_data=valid_data_generator, | ||
epochs=epochs, | ||
callbacks=[es], | ||
) | ||
|
||
|
||
""" | ||
## Let's test-drive it | ||
""" | ||
|
||
|
||
# Get the prediction model by extracting layers till the output layer | ||
prediction_model = keras.models.Model( | ||
model.get_layer(name="image").input, model.get_layer(name="dense2").output | ||
) | ||
prediction_model.summary() | ||
|
||
# A utility function to decode the output of the network | ||
def decode_batch_predictions(pred): | ||
input_len = np.ones(pred.shape[0]) * pred.shape[1] | ||
# Use greedy search. For complex tasks, you can use beam search | ||
results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][ | ||
:, :max_length | ||
] | ||
# Iterate over the results and get back the text | ||
output_text = [] | ||
for res in results: | ||
res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8") | ||
output_text.append(res) | ||
return output_text | ||
|
||
|
||
# Let's check results on some validation samples | ||
for batch in valid_data_generator.take(1): | ||
batch_images = batch["image"] | ||
batch_labels = batch["label"] | ||
|
||
preds = prediction_model.predict(batch_images) | ||
pred_texts = decode_batch_predictions(preds) | ||
|
||
orig_texts = [] | ||
for label in batch_labels: | ||
label = tf.strings.reduce_join(num_to_char(label)).numpy().decode("utf-8") | ||
orig_texts.append(label) | ||
|
||
_, ax = plt.subplots(4, 4, figsize=(15, 5)) | ||
for i in range(len(pred_texts)): | ||
img = (batch_images[i, :, :, 0] * 255).numpy().astype(np.uint8) | ||
img = img.T | ||
title = f"Prediction: {pred_texts[i]}" | ||
ax[i // 4, i % 4].imshow(img, cmap="gray") | ||
ax[i // 4, i % 4].set_title(title) | ||
ax[i // 4, i % 4].axis("off") | ||
plt.show() |
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.