In [1]:
# https://github.com/MhLiao/DB
# https://github.com/zonasw/DBNet
# https://github.com/xuannianz/DifferentiableBinarization
import tensorflow as tf
tf.get_logger().setLevel('ERROR')
APPROACH_NAME = 'DBNet'

# Check GPU working

In [2]:
physical_devices = tf.config.list_physical_devices('GPU') 
tf.config.experimental.set_memory_growth(physical_devices[0], True)
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0': raise SystemError('GPU device not found')
print('Found GPU at:', device_name)
!nvcc -V

Found GPU at: /device:GPU:0
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Mon_Nov_30_19:15:10_Pacific_Standard_Time_2020
Cuda compilation tools, release 11.2, V11.2.67
Build cuda_11.2.r11.2/compiler.29373293_0


# Data input pipeline

In [3]:
BATCH_SIZE = 4
IMAGE_SIZE = 640
THRESH_MIN = 0.3
THRESH_MAX = 0.7
SHRINK_RATIO = 0.4

In [4]:
from loader import DataImporter, DBNetDataGenerator
dataset = DataImporter('Datasets', pattern='*.txt')
print(dataset)

Samples count (1 image can have multiple bounding boxes):
- Number of images found: 951
- Number of image bounding boxes: 951
- Number of bounding boxes in all images: 15741


In [5]:
train_img_paths, all_train_bboxes, valid_img_paths, all_valid_bboxes = dataset.split(0.9)
print('Number of training samples:', len(train_img_paths))
print('Number of validate samples:', len(valid_img_paths))

Number of training samples: 855
Number of validate samples: 96


In [6]:
train_generator = DBNetDataGenerator(
    train_img_paths, all_train_bboxes, BATCH_SIZE, IMAGE_SIZE, 
    THRESH_MIN, THRESH_MAX, SHRINK_RATIO
)
valid_generator = DBNetDataGenerator(
    valid_img_paths, all_valid_bboxes, BATCH_SIZE, IMAGE_SIZE, 
    THRESH_MIN, THRESH_MAX, SHRINK_RATIO, False
)

# Define the model

In [7]:
from tensorflow.keras.layers import (
    Input, Convolution2D, Conv2DTranspose, UpSampling2D,
    BatchNormalization, Activation, Add, Concatenate, Lambda
)
from layers import ConvBnRelu, DeconvolutionalMap
from keras_resnet.models import ResNet50
from losses import db_loss

In [8]:
def DBNet(k=50, is_training=True, name='DBNet'):
    image_input = Input(shape=(None, None, 3), name='image')
    backbone = ResNet50(inputs=image_input, include_top=False, freeze_bn=True)
    
    C2, C3, C4, C5 = backbone.outputs
    in2 = ConvBnRelu(256, kernel_size=1, name='in2')(C2)
    in3 = ConvBnRelu(256, kernel_size=1, name='in3')(C3)
    in4 = ConvBnRelu(256, kernel_size=1, name='in4')(C4)
    in5 = ConvBnRelu(256, kernel_size=1, name='in5')(C5)
        
    P5 = ConvBnRelu(64, kernel_size=3, name='P5_conv')(in5)
    P5 = UpSampling2D(8, name='P5_up')(P5) # 1 / 32 * 8 = 1 / 4
    
    out4 = Add(name='out4')([in4, UpSampling2D(2, name='in5_up')(in5)])
    P4 = ConvBnRelu(64, kernel_size=3, name='P4_conv')(out4)
    P4 = UpSampling2D(4, name='P4_up')(P4) # 1 / 16 * 4 = 1 / 4
    
    out3 = Add(name='out3')([in3, UpSampling2D(2, name='out4_up')(out4)])
    P3 = ConvBnRelu(64, kernel_size=3, name='P3_conv')(out3)
    P3 = UpSampling2D(2, name='P3_up')(P3) # 1 / 8 * 2 = 1 / 4
    
    out2 = Add(name='out2')([in2, UpSampling2D(2, name='out3_up')(out3)])
    P2 = ConvBnRelu(64, kernel_size=3, name='P2_conv')(out2) # 1 / 4
    
    fuse = Concatenate(name='fuse')([P2, P3, P4, P5]) # (batch_size, /4, /4, 256)
    binarize_map = DeconvolutionalMap(64, name='probability_map')(fuse)
    if not is_training: return tf.keras.Model(inputs=image_input, outputs=binarize_map, name=name)
    threshold_map = DeconvolutionalMap(64, name='threshold_map')(fuse)
    
    gt_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE), name='gt_input')
    mask_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE), name='mask_input')
    thresh_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE), name='thresh_input')
    thresh_mask_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE), name='thresh_mask_input')
    
    # Approximate binary map
    b_hat = Lambda( 
        function = lambda x: 1 / (1 + tf.exp(-k * (x[0] - x[1]))), # 1 / (1 + e^(-k(P - T)))
        name = 'approximate_binary_map'
    )([binarize_map, threshold_map]) 
    
    loss_layer = Lambda(db_loss, name='db_loss')([
        gt_input, mask_input, thresh_input, thresh_mask_input, 
        binarize_map, b_hat, threshold_map
    ])

    model = tf.keras.Model(
        inputs = [image_input, gt_input, mask_input, thresh_input, thresh_mask_input], 
        outputs = [loss_layer],
        name = name
    )
    model.add_loss(model.get_layer('db_loss').output)
    return model

In [9]:
model = DBNet(is_training=True)
model.summary(line_length=120)

Model: "DBNet"
________________________________________________________________________________________________________________________
 Layer (type)                          Output Shape               Param #       Connected to                            
 image (InputLayer)                    [(None, None, None, 3)]    0             []                                      
                                                                                                                        
 conv1 (Conv2D)                        (None, None, None, 64)     9408          ['image[0][0]']                         
                                                                                                                        
 bn_conv1 (BatchNormalization)         (None, None, None, 64)     256           ['conv1[0][0]']                         
                                                                                                                        
 conv1_relu (Acti

                                                                                                                        
 padding2c_branch2b (ZeroPadding2D)    (None, None, None, 64)     0             ['res2c_branch2a_relu[0][0]']           
                                                                                                                        
 res2c_branch2b (Conv2D)               (None, None, None, 64)     36864         ['padding2c_branch2b[0][0]']            
                                                                                                                        
 bn2c_branch2b (BatchNormalization)    (None, None, None, 64)     256           ['res2c_branch2b[0][0]']                
                                                                                                                        
 res2c_branch2b_relu (Activation)      (None, None, None, 64)     0             ['bn2c_branch2b[0][0]']                 
                                

 res3c_branch2a (Conv2D)               (None, None, None, 128)    65536         ['res3b_relu[0][0]']                    
                                                                                                                        
 bn3c_branch2a (BatchNormalization)    (None, None, None, 128)    512           ['res3c_branch2a[0][0]']                
                                                                                                                        
 res3c_branch2a_relu (Activation)      (None, None, None, 128)    0             ['bn3c_branch2a[0][0]']                 
                                                                                                                        
 padding3c_branch2b (ZeroPadding2D)    (None, None, None, 128)    0             ['res3c_branch2a_relu[0][0]']           
                                                                                                                        
 res3c_branch2b (Conv2D)        

 res4a (Add)                           (None, None, None, 1024)   0             ['bn4a_branch2c[0][0]',                 
                                                                                 'bn4a_branch1[0][0]']                  
                                                                                                                        
 res4a_relu (Activation)               (None, None, None, 1024)   0             ['res4a[0][0]']                         
                                                                                                                        
 res4b_branch2a (Conv2D)               (None, None, None, 256)    262144        ['res4a_relu[0][0]']                    
                                                                                                                        
 bn4b_branch2a (BatchNormalization)    (None, None, None, 256)    1024          ['res4b_branch2a[0][0]']                
                                

                                                                                                                        
 res4d (Add)                           (None, None, None, 1024)   0             ['bn4d_branch2c[0][0]',                 
                                                                                 'res4c_relu[0][0]']                    
                                                                                                                        
 res4d_relu (Activation)               (None, None, None, 1024)   0             ['res4d[0][0]']                         
                                                                                                                        
 res4e_branch2a (Conv2D)               (None, None, None, 256)    262144        ['res4d_relu[0][0]']                    
                                                                                                                        
 bn4e_branch2a (BatchNormalizati

 res5a_branch1 (Conv2D)                (None, None, None, 2048)   2097152       ['res4f_relu[0][0]']                    
                                                                                                                        
 bn5a_branch2c (BatchNormalization)    (None, None, None, 2048)   8192          ['res5a_branch2c[0][0]']                
                                                                                                                        
 bn5a_branch1 (BatchNormalization)     (None, None, None, 2048)   8192          ['res5a_branch1[0][0]']                 
                                                                                                                        
 res5a (Add)                           (None, None, None, 2048)   0             ['bn5a_branch2c[0][0]',                 
                                                                                 'bn5a_branch1[0][0]']                  
                                

 out4_up (UpSampling2D)                (None, None, None, 256)    0             ['out4[0][0]']                          
                                                                                                                        
 out3 (Add)                            (None, None, None, 256)    0             ['in3[0][0]',                           
                                                                                 'out4_up[0][0]']                       
                                                                                                                        
 in2 (ConvBnRelu)                      (None, None, None, 256)    66816         ['res2c_relu[0][0]']                    
                                                                                                                        
 out3_up (UpSampling2D)                (None, None, None, 256)    0             ['out3[0][0]']                          
                                

# Callbacks

In [10]:
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau

# Stop if no improvement after 5 epochs
early_stopping_callback = EarlyStopping(patience=5, restore_best_weights=True, verbose=1)

# Reduce the learning rate once learning stagnates
reduce_lr_callback = ReduceLROnPlateau(
    monitor = 'val_loss', 
    patience = 2, # Reduce if no improvement after 2 epochs
    min_lr = 1e-6, # Lower bound on the learning rate 
    factor = 0.5, # => new_lr = lr * factor
    verbose = 1
)

# Training

In [11]:
from tensorflow.keras.optimizers import Adam
LEARNING_RATE = 2e-4
EPOCHS = 100
model.compile(optimizer=Adam(LEARNING_RATE))

In [12]:
%%time
history = model.fit(
    train_generator,
    validation_data = valid_generator,
    validation_steps = len(valid_generator),
    steps_per_epoch = len(train_generator),
    epochs = EPOCHS,
    callbacks = [reduce_lr_callback, early_stopping_callback],
    verbose = 1
).history

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 22: ReduceLROnPlateau reducing learning rate to 9.999999747378752e-05.
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 29: ReduceLROnPlateau reducing learning rate to 4.999999873689376e-05.
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 33: ReduceLROnPlateau reducing learning rate to 2.499999936844688e-05.
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 37: ReduceLROnPlateau reducing learning rate to 1.249999968422344e-05.
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 41: ReduceLROnPlateau reducing learning rate to 6.24999984211172e-06.
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 45: ReduceLROnPlateau reducing lea