In [2]:
import tensorflow
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact

%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Define model

In [3]:
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Activation
from tensorflow.python.keras.layers import Conv2D, MaxPooling2D

def text_detection_model(channels=8):
    """ Defines a simple CNN that inputs an image and outputs a binary mask. """

    input_shape = (None, None, 1)
    activation = 'relu'
    padding = 'same'

    model = Sequential()
    model.add(Conv2D(channels, (5, 5), padding=padding, activation=activation, input_shape=input_shape))
    model.add(Conv2D(channels, (5, 5), activation=activation, padding=padding))
   
    model.add(MaxPooling2D(pool_size=(2, 2)))
    
    model.add(Conv2D(channels, (5, 5), activation=activation, padding=padding))
    model.add(Conv2D(channels * 2, (5, 5), activation=activation, padding=padding))
    model.add(Conv2D(channels * 2, (5, 5), activation=activation, padding=padding))
    model.add(Conv2D(channels * 2, (5, 5), activation=activation, padding=padding))
    model.add(Conv2D(256, (1, 1), activation=activation, padding=padding))    

    model.add(Conv2D(1, (1, 1), activation='sigmoid', padding=padding))

    model.compile(loss='binary_crossentropy', optimizer='adam')

    return model


In [4]:
text_detection_model()

<tensorflow.python.keras._impl.keras.models.Sequential at 0x7fbfde692e80>

# Generate some data

In [5]:
from data_generator import StringGenerator, StringImageBatchGenerator, StringRenderer

string_generator = StringGenerator()
string_renderer = StringRenderer(image_size=(160, 160),
                                 target_size=(80, 80),
                                 fonts_folder='fonts',
                                 backgrounds_folder='backgrounds')
data_generator = StringImageBatchGenerator(string_generator=string_generator, 
                                           string_renderer=string_renderer)


In [6]:
images, targets = data_generator.get_batch(32)

# Visualized data

In [7]:
def view_image(i):
    image = images[i,:,:,0]
    
    print(image.shape)
    target = targets[i,:,:,0]
    plt.figure()
    plt.imshow(image, cmap=plt.cm.gray_r, interpolation='bilinear')
    plt.show()
    plt.figure()
    plt.imshow(target, cmap=plt.cm.gray_r, interpolation='bilinear')
    plt.show()
    
q = interact(view_image, i=(0, len(images) - 1))

# Create model

In [8]:
print("Building model")
model_channels = 16
model = text_detection_model(model_channels)

Building model


## Train model

In [28]:
from tensorflow.python.keras.callbacks import ModelCheckpoint

steps_per_epoch = 8
validation_image = 256
num_epochs = 1

print("Generating validation set")
val_data = data_generator.get_batch(256)

check_point_callback = ModelCheckpoint('best_model.h5',
                                       monitor='val_loss',
                                       verbose=1,
                                       save_best_only=True, 
                                       mode='auto')

print("Starting training")
model.fit_generator(generator=data_generator.generate(),
                    steps_per_epoch=steps_per_epoch,
                    validation_data=val_data,
                    epochs=num_epochs,
                    verbose=1,
                    callbacks=[check_point_callback])

Building model
Generating validation set


## Load pre-trained model

In [9]:
model.load_weights('best_model.h5')

# View trained model predictions

In [10]:
def view_image(i):
    image = images[i,:,:,0]
    target = targets[i,:,:,0]
    
    plt.figure()
    plt.imshow(image, cmap=plt.cm.gray_r, interpolation='bilinear')
    plt.title('Input image')
    plt.show()
    
    plt.figure()
    plt.imshow(target, cmap=plt.cm.gray_r, interpolation='bilinear')
    plt.title('Target')
    plt.show()
    
    input_img = np.expand_dims(image, axis=2)
    input_img = np.expand_dims(input_img, axis=0)
    
    prediction = model.predict(input_img)[0,:,:,0]
    plt.figure()
    plt.imshow(prediction, cmap=plt.cm.gray_r, interpolation='bilinear')
    plt.title('Prediction')
    plt.show()

In [11]:
q = interact(view_image, i=(0, len(images) - 1))

In [12]:
from skimage import io, transform
import glob
from image import normalize_pixels
from ipywidgets import interact

test_images_folder = 'test_images'

test_image_files = list(glob.iglob('{}/*.jpg'.format(test_images_folder)))
test_image_files += list(glob.iglob('{}/*.png'.format(test_images_folder)))
test_image_files += list(glob.iglob('{}/*.jpeg'.format(test_images_folder)))

IMAGE_WIDTH = 520


def view_image(i):
    
    test_image_file = test_image_files[i]
    test_image = io.imread(test_image_file, as_gray=True)

    if test_image.shape[1] > IMAGE_WIDTH:
        size = (np.int32(test_image.shape[0] * IMAGE_WIDTH / test_image.shape[1]), IMAGE_WIDTH)
        test_image = np.float32(transform.resize(test_image, size, mode='reflect'))


    image = normalize_pixels(test_image)
    print(image.shape)
    
    plt.figure()
    plt.imshow(image, cmap=plt.cm.gray_r, interpolation='bilinear')
    plt.show()
    
    input_img = np.expand_dims(np.expand_dims(image, axis=2), axis=0)
    prediction = model.predict(input_img)[0,:,:,0]
    prediction[prediction < 0.2] = 0


    plt.figure()
    plt.imshow(prediction, cmap=plt.cm.gray_r, interpolation='bilinear')
    plt.title('Prediction')
    plt.show()
    
q = interact(view_image, i=(0, len(test_image_files) - 1))

