In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact

%matplotlib inline
%load_ext autoreload
%autoreload 2

# Define model

In [None]:
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Activation
from tensorflow.python.keras.layers import Conv2D, MaxPooling2D

def text_detection_model(channels=8):
    """ Defines a simple CNN that inputs an image and outputs a binary mask. """

    input_shape = (160, 160, 1)
    batch_size = 1
    activation = 'relu'
    padding = 'same'

    model = Sequential()
    model.add(Conv2D(channels, (5, 5), padding=padding, activation=activation, batch_size=batch_size, input_shape=input_shape))
    model.add(Conv2D(channels, (5, 5), activation=activation, padding=padding))

    model.add(MaxPooling2D(pool_size=(2, 2)))
    
    model.add(Conv2D(channels, (5, 5), activation=activation, padding=padding))
    model.add(Conv2D(channels * 2, (5, 5), activation=activation, padding=padding))
    model.add(Conv2D(channels * 2, (5, 5), activation=activation, padding=padding))
    model.add(Conv2D(channels * 2, (5, 5), activation=activation, padding=padding))
    model.add(Conv2D(channels * 2, (5, 5), activation=activation, padding=padding))
    model.add(Conv2D(256, (1, 1), activation=activation, padding=padding))    

    model.add(Conv2D(1, (1, 1), activation='sigmoid', padding=padding))

    model.compile(loss='binary_crossentropy', optimizer='adam')

    return model


# Generate some data

In [None]:
from data_generator import StringGenerator, StringImageBatchGenerator, StringRenderer

string_generator = StringGenerator()
string_renderer = StringRenderer(image_size=(160, 160),
                                 target_size=(80, 80),
                                 max_background_mixture=0.8,
                                 max_noise_sigma=0.04,
                                 fonts_folder='fonts',
                                 backgrounds_folder='backgrounds')
data_generator = StringImageBatchGenerator(string_generator=string_generator, 
                                           string_renderer=string_renderer)


In [None]:
images, targets = data_generator.get_batch(32)

# Visualized data

In [None]:
def view_image(i):
    image = images[i,:,:,0]
    target = targets[i,:,:,0]
    plt.figure()
    plt.imshow(image, cmap=plt.cm.gray_r, interpolation='bilinear')
    plt.show()
    plt.figure()
    plt.imshow(target, cmap=plt.cm.gray_r, interpolation='bilinear')
    plt.show()
    
interact(view_image, i=(0, len(images) - 1))

# Create model

In [None]:
print("Building model")
model_channels = 32
model = text_detection_model(model_channels)

## Train model

In [None]:
from tensorflow.python.keras.callbacks import ModelCheckpoint

steps_per_epoch = 16
validation_image = 2048
num_epochs = 20000

print("Generating validation set")
val_data = data_generator.get_batch(1024)

In [None]:


check_point_callback = ModelCheckpoint('best_model.h5',
                                       monitor='val_loss',
                                       verbose=1,
                                       save_best_only=True, 
                                       mode='auto')


In [None]:

print("Starting training")

model.fit_generator(generator=data_generator.generate(),
                    steps_per_epoch=steps_per_epoch,
                    validation_data=val_data,
                    epochs=num_epochs,
                    verbose=1,
                    callbacks=[check_point_callback])

## Load pre-trained model

In [None]:
model.load_weights('best_model.h5')

## Export model graph to a .tflite file

In [None]:
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    """
    Freezes the state of a session into a prunned computation graph.

    Creates a new computation graph where variable nodes are replaced by
    constants taking their current value in the session. The new graph will be
    prunned so subgraphs that are not neccesary to compute the requested
    outputs are removed.
    @param session The TensorFlow session to be frozen.
    @param keep_var_names A list of variable names that should not be frozen,
                          or None to freeze all the variables in the graph.
    @param output_names Names of the relevant graph outputs.
    @param clear_devices Remove the device directives from the graph for better portability.
    @return The frozen graph definition.
    """
    from tensorflow.python.framework.graph_util import convert_variables_to_constants
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = convert_variables_to_constants(session, input_graph_def,
                                                      output_names, freeze_var_names)
        return frozen_graph

In [None]:
import tempfile
import subprocess
tf.contrib.lite.tempfile = tempfile
tf.contrib.lite.subprocess = subprocess

frozen_graphdef = freeze_session(K.get_session(), output_names=[model.output.op.name])

tflite_model = tf.contrib.lite.toco_convert(frozen_graphdef, model.inputs, model.outputs)

open("converted_model.tflite", "wb").write(tflite_model)


# Export model graph to a .pb file

In [None]:
from tensorflow.python.keras import backend as K
sess = K.get_session()
FREEZE_DIR = 'frozen'

frozen_graph = freeze_session(K.get_session(), output_names=[model.output.op.name])
tf.train.write_graph(frozen_graph, './', "skcc_model.pb", as_text=False)


# View trained model predictions

In [None]:
def view_image(i):
    image = images[i,:,:,0]
    target = targets[i,:,:,0]
    
    plt.figure()
    plt.imshow(image, cmap=plt.cm.gray_r, interpolation='bilinear')
    plt.title('Input image')
    plt.show()
    
    plt.figure()
    plt.imshow(target, cmap=plt.cm.gray_r, interpolation='bilinear')
    plt.title('Target')
    plt.show()
    
    input_img = np.expand_dims(image, axis=2)
    input_img = np.expand_dims(input_img, axis=0)
    
    prediction = model.predict(input_img)[0,:,:,0]
    plt.figure()
    plt.imshow(prediction, cmap=plt.cm.gray_r, interpolation='bilinear')
    plt.title('Prediction')
    plt.show()

In [None]:
q = interact(view_image, i=(0, len(images) - 1))

In [None]:
from skimage import io, transform
import glob
from image import normalize_pixels
from ipywidgets import interact

test_images_folder = 'test_images'

test_image_files = list(glob.iglob('{}/*.jpg'.format(test_images_folder)))
test_image_files += list(glob.iglob('{}/*.png'.format(test_images_folder)))
test_image_files += list(glob.iglob('{}/*.jpeg'.format(test_images_folder)))

IMAGE_WIDTH = 520


def view_image(i):
    
    test_image_file = test_image_files[i]
    test_image = io.imread(test_image_file, as_gray=True)

    if test_image.shape[1] > IMAGE_WIDTH:
        size = (np.int32(test_image.shape[0] * IMAGE_WIDTH / test_image.shape[1]), IMAGE_WIDTH)
        test_image = np.float32(transform.resize(test_image, size, mode='reflect'))


    image = normalize_pixels(test_image)
    print(image.shape)
    
    plt.figure()
    plt.imshow(image, cmap=plt.cm.gray_r, interpolation='bilinear')
    plt.show()
    
    input_img = np.expand_dims(np.expand_dims(image, axis=2), axis=0)
    prediction = model.predict(input_img)[0,:,:,0]
    
    plt.figure()
    plt.imshow(prediction, cmap=plt.cm.gray_r, interpolation='bilinear')
    plt.title('Prediction raw')
    plt.show()

    
    prediction[prediction < 0.5] = 0
    plt.figure()
    plt.imshow(prediction, cmap=plt.cm.gray_r, interpolation='bilinear')
    plt.title('Prediction cleaned up')
    plt.show()
    
q = interact(view_image, i=(0, len(test_image_files) - 1))

