In [None]:
import tensorflow as tf
from prepare_dataset import create_dataset_from_tfrecord

## Build CNN Model

In [None]:
def build_conv_layers(x, kernels, prefix):
    # conv layers
    for i, num in enumerate(kernels, start=1):
        # conv layer i
        x = tf.keras.layers.Conv2D(num, (3, 3), padding='same', name=f'{prefix}_conv_{i}')(x)
        x = tf.keras.layers.MaxPool2D((2, 2), padding='same', name=f'{prefix}_max_pool_{i}')(x)
        x = tf.keras.layers.BatchNormalization(name=f'{prefix}_BN_{i}')(x)
        x = tf.keras.layers.Activation('relu')(x)

    return x


def CNN_crop_inputs(image_shape, n_class=26, name='captcha', output_label='labels'):
    # input
    image_input = tf.keras.Input(shape=image_shape, name='input_image')
    
    # split into two sets: 
    # - half image in left side: the first label
    # - hale image in right side: the last label
    H, W, C = image_shape
    x0 = tf.keras.layers.Cropping2D(cropping=((0, 0), (0, int(W/2))), name='left_half')(image_input)
    x3 = tf.keras.layers.Cropping2D(cropping=((0, 0), (int(W/2), 0)), name='right_half')(image_input)
    
    # build conv layers
    x0 = build_conv_layers(x0, [16, 32], 'Pre_A0')
    x3 = build_conv_layers(x3, [16, 32], 'Pre_A3')
    x12 = tf.keras.layers.Concatenate(axis=2)([x0, x3]) # (s, h, w, c) concatenate on dimension w
    
    x0 = build_conv_layers(x0, [32, 64, 64], 'A0')
    x12 = build_conv_layers(x12, [32, 64, 64], 'A1_2')
    x3 = build_conv_layers(x3, [32, 64, 64], 'A3')
    
    # dense layer
    x0 = tf.keras.layers.Flatten()(x0) # flatten
    x0 = tf.keras.layers.Dense(256, activation='relu')(x0)
    x0 = tf.keras.layers.Dropout(0.3)(x0)
    
    x12 = tf.keras.layers.Flatten()(x12) # flatten
    x12 = tf.keras.layers.Dense(256, activation='relu')(x12)
    x12 = tf.keras.layers.Dropout(0.3)(x12)
    
    x3 = tf.keras.layers.Flatten()(x3) # flatten
    x3 = tf.keras.layers.Dense(256, activation='relu')(x3)
    x3 = tf.keras.layers.Dropout(0.3)(x3)

    # combine multi-outputs
    labels = [
        tf.keras.layers.Dense(n_class, name=f'{output_label}_0')(x0),
        tf.keras.layers.Dense(n_class, name=f'{output_label}_1')(x12),
        tf.keras.layers.Dense(n_class, name=f'{output_label}_2')(x12),
        tf.keras.layers.Dense(n_class, name=f'{output_label}_3')(x3)
    ] 

    # build model
    model = tf.keras.Model(inputs=image_input, outputs=labels, name=name)
    
    return model


In [None]:
image_shape = (60, 120, 1)
n_class = 26

model = CNN_crop_inputs(image_shape, n_class, name='cnn_crop_inputs', output_label='A')
model.summary()

In [None]:
tf.keras.utils.plot_model(model, f'{model.name}.png', show_shapes=True)

## Train Model

In [None]:
# load dataset
train_ds = create_dataset_from_tfrecord('dataset/qq_captcha_train.tfrecords', batch_size=128, image_size=(60, 120), label_prefix='A')
test_ds = create_dataset_from_tfrecord('dataset/qq_captcha_test.tfrecords', batch_size=128, image_size=(60, 120), label_prefix='A')

In [None]:
# set loss for each output lables, or set a same loss for all labels
# e.g. loss={ f'labels_{i}': 
#   tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) for i in range(n_labels)}
model.compile(optimizer=tf.keras.optimizers.RMSprop(0.005, 0.9),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [None]:
import os

# load model if exists
model_dir = os.path.join('models', model.name)
if os.path.exists(model_dir):
    model = tf.keras.models.load_model(model_dir)

model.fit(train_ds, 
          epochs=2, 
          callbacks=[tf.keras.callbacks.TensorBoard(log_dir=os.path.join('tensorboard'))])