In [1]:
# https://github.com/MhLiao/DB
# https://github.com/zonasw/DBNet
# https://github.com/xuannianz/DifferentiableBinarization
import tensorflow as tf
tf.get_logger().setLevel('ERROR')
APPROACH_NAME = 'DBNet'

# Check GPU working

In [2]:
physical_devices = tf.config.list_physical_devices('GPU') 
tf.config.experimental.set_memory_growth(physical_devices[0], True)
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0': raise SystemError('GPU device not found')
print('Found GPU at:', device_name)
!nvcc -V

Found GPU at: /device:GPU:0
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Mon_Nov_30_19:15:10_Pacific_Standard_Time_2020
Cuda compilation tools, release 11.2, V11.2.67
Build cuda_11.2.r11.2/compiler.29373293_0


# Data input pipeline

In [3]:
BATCH_SIZE = 4
IMAGE_SIZE = 640
THRESH_MIN = 0.3
THRESH_MAX = 0.7
SHRINK_RATIO = 0.4

In [4]:
from loader import DataImporter, DBNetDataGenerator
dataset = DataImporter('Datasets', pattern='*.cach')
print(dataset)

Samples count (1 image can have multiple bounding boxes):
- Number of images found: 100
- Number of image bounding boxes: 100
- Number of bounding boxes in all images: 2506


In [5]:
train_img_paths, all_train_bboxes, valid_img_paths, all_valid_bboxes = dataset.split(0.8)
print('Number of training samples:', len(train_img_paths))
print('Number of validate samples:', len(valid_img_paths))

Number of training samples: 80
Number of validate samples: 20


In [6]:
train_generator = DBNetDataGenerator(
    train_img_paths, all_train_bboxes, BATCH_SIZE, IMAGE_SIZE, 
    THRESH_MIN, THRESH_MAX, SHRINK_RATIO
)
valid_generator = DBNetDataGenerator(
    valid_img_paths, all_valid_bboxes, BATCH_SIZE, IMAGE_SIZE, 
    THRESH_MIN, THRESH_MAX, SHRINK_RATIO, False
)

# Define the model

In [7]:
from tensorflow.keras.layers import (
    Conv2D, Conv2DTranspose, UpSampling2D, Concatenate,
    Input, BatchNormalization, Activation, Add, Lambda
)
from keras_resnet.models import ResNet50
from losses import db_loss

In [8]:
def DBNet(k=50, is_training=True, name='DBNet'):
    image_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3), name='image')
    backbone = ResNet50(inputs=image_input, include_top=False, freeze_bn=True)
    num_backbone_outputs = len(backbone.outputs)

    ins = list(map(lambda output_idx: (
        prefix :=  f'in{output_idx + 2}', # in2, in3, in4, in5
        x := backbone.outputs[output_idx],
        x := Conv2D(256, 1, padding='same', kernel_initializer='he_normal', name=f'{prefix}_conv')(x), 
        x := BatchNormalization(name=f'{prefix}_bn')(x), 
        Activation('relu', name=f'{prefix}_relu')(x)
    )[-1], range(num_backbone_outputs)))

    P5, P4, P3, P2 = list(map(lambda i: (
        name_idx := num_backbone_outputs - i + 1, # P5, P4, P3, P2,
        up_size := 8 // (2 ** i),
        x := Add(name=f'out{name_idx}')([
            ins[num_backbone_outputs - i - 1], 
            UpSampling2D(2, name=f'in{name_idx + 1}_up')(ins[num_backbone_outputs - i])
        ]) if i > 0 else ins[-1],
        x := Conv2D(64, 3, padding='same', kernel_initializer='he_normal', name=f'P{name_idx}_conv')(x),
        x := BatchNormalization(name=f'P{name_idx}_bn')(x),
        x := Activation('relu', name=f'P{name_idx}_relu')(x),
        UpSampling2D(up_size, name=f'P{name_idx}_up')(x) if up_size > 1 else x,
    )[-1], range(num_backbone_outputs)))

    fuse = Concatenate(name='fuse')([P2, P3, P4, P5])
    binarize_map, threshold_map = list(map(lambda name: (
        x := Conv2D(64, 3, padding='same', kernel_initializer='he_normal', use_bias=False)(fuse),
        x := BatchNormalization()(x),
        x := Activation('relu')(x),
        x := Conv2DTranspose(64, 2, strides=2, kernel_initializer='he_normal', use_bias=False)(x),
        x := BatchNormalization()(x),
        x := Activation('relu')(x),
        Conv2DTranspose(1, 2, strides=2, kernel_initializer='he_normal', activation='sigmoid', name=name)(x)
    )[-1], ['probability_map', 'threshold_map']))
    if not is_training: return tf.keras.Model(inputs=image_input, outputs=binarize_map, name=name)
    
    gt_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE), name='gt_input')
    mask_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE), name='mask_input')
    thresh_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE), name='thresh_input')
    thresh_mask_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE), name='thresh_mask_input')
    
    # Approximate binary map
    b_hat = Lambda(
        function = lambda x: 1 / (1 + tf.exp(-k * (x[0] - x[1]))), 
        name = 'approximate_binary_map'
    )([binarize_map, threshold_map]) 
    
    loss_layer = Lambda(db_loss, name='db_loss')([
        gt_input, mask_input, thresh_input, thresh_mask_input, 
        binarize_map, b_hat, threshold_map
    ])

    model = tf.keras.Model(
        inputs = [image_input, gt_input, mask_input, thresh_input, thresh_mask_input], 
        outputs = [loss_layer],
        name = name
    )
    model.add_loss(model.get_layer('db_loss').output)
    return model

In [9]:
model = DBNet()
model.summary(line_length=120)

Model: "DBNet"
________________________________________________________________________________________________________________________
Layer (type)                           Output Shape               Param #       Connected to                            
image (InputLayer)                     [(None, 640, 640, 3)]      0                                                     
________________________________________________________________________________________________________________________
conv1 (Conv2D)                         (None, 320, 320, 64)       9408          image[0][0]                             
________________________________________________________________________________________________________________________
bn_conv1 (BatchNormalization)          (None, 320, 320, 64)       256           conv1[0][0]                             
________________________________________________________________________________________________________________________
conv1_relu (Activ

# Callbacks

In [10]:
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau

# Stop if no improvement after 5 epochs
early_stopping_callback = EarlyStopping(patience=5, restore_best_weights=True)

# Reduce the learning rate once learning stagnates
reduce_lr_callback = ReduceLROnPlateau(
    monitor = 'val_loss', 
    patience = 2, # Reduce if no improvement after 2 epochs
    min_lr = 1e-6, # Lower bound on the learning rate 
    factor = 0.5, # => new_lr = lr * factor
    verbose = 1
)

# Training

In [11]:
from tensorflow.keras.optimizers import Adam
LEARNING_RATE = 2e-4
EPOCHS = 100
model.compile(optimizer=Adam(LEARNING_RATE), loss=[None] * len(model.output.shape))

In [None]:
%%time
history = model.fit(
    train_generator,
    validation_data = valid_generator,
    validation_steps = len(valid_generator),
    steps_per_epoch = len(train_generator),
    epochs = EPOCHS,
    callbacks = [reduce_lr_callback, early_stopping_callback],
    verbose = 1
).history

Epoch 1/100
Epoch 2/100
Epoch 3/100
 1/20 [>.............................] - ETA: 13s - loss: 3.8469