In [None]:
#-------------------
# fixed params
#------------------
dim =512
mean = [103.939, 116.779, 123.68]
thresh_min=0.3
thresh_max=0.7
#----------------
# imports
#---------------
import tensorflow as tf
import random
import json
import os
import numpy as np
import matplotlib.pyplot as plt

#from kaggle_datasets import KaggleDatasets
from glob import glob
from tqdm.auto import tqdm

%matplotlib inline
#-------------

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
print(gpus)
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

In [None]:
#----------------------------------------------------------
# Detect hardware, return appropriate distribution strategy
#----------------------------------------------------------
# TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
    tf.config.optimizer.set_jit(True)
else:
    strategy = tf.distribute.get_strategy() 
    # default distribution strategy in Tensorflow. Works on CPU and single GPU.

print("REPLICAS: ", strategy.num_replicas_in_sync)

In [None]:
#-------------
# GCS and files and synth
#-------------
def get_tfrecs(_path):
    gcs_pattern=os.path.join(_path,'*.tfrecord')
    file_paths = tf.io.gfile.glob(gcs_pattern)
    random.shuffle(file_paths)
    return file_paths

# synth
#GCS_PATH = KaggleDatasets().get_gcs_path('dbnet-dataset')  
GCS_PATH="/backup/RAW/DET/DATA/synth/temp/tfrecords/synthtext/"
_recs =get_tfrecs(GCS_PATH)

train_recs=_recs[1:]
eval_recs =_recs[:1]
print(len(eval_recs),len(train_recs))

In [None]:
def data_input_fn(recs): 
    '''
      This Function generates data from gcs
    '''
    
    def _parser(example):
        feature ={  'image'      : tf.io.FixedLenFeature([],tf.string),
                    'gt'         : tf.io.FixedLenFeature([],tf.string),
                    'mask'       : tf.io.FixedLenFeature([],tf.string),
                    'thresh_map' : tf.io.FixedLenFeature([],tf.string),
                    'thresh_mask': tf.io.FixedLenFeature([],tf.string)
        }     
        ret={}
        parsed_example=tf.io.parse_single_example(example,feature)
        # image
        image_raw=parsed_example['image']
        image=tf.image.decode_png(image_raw,channels=3)
        image=tf.cast(image,tf.float32)
        
        r=image[..., 0] 
        g=image[..., 1]
        b=image[..., 2]
        
        r=tf.subtract(r,mean[0])
        g=tf.subtract(r,mean[1])
        b=tf.subtract(r,mean[2])
        
        r=tf.reshape(r,(dim,dim,1))
        g=tf.reshape(g,(dim,dim,1))
        b=tf.reshape(b,(dim,dim,1))
        
        image=tf.concat([r,g,b], -1)
        image=image/255
        image=tf.reshape(image,[dim,dim,3])
        ret["image"]=image
        # thresh_map
        thresh_map=parsed_example['thresh_map']
        thresh_map=tf.image.decode_png(thresh_map,channels=1)
        thresh_map=tf.cast(thresh_map,tf.float32)/255.0
        thresh_map=tf.reshape(thresh_map,(dim,dim))
        thresh_map= thresh_map * (thresh_max - thresh_min) + thresh_min
        ret["thresh_map"]=thresh_map
        # thresh_mask
        thresh_mask=parsed_example['thresh_mask']
        thresh_mask=tf.image.decode_png(thresh_mask,channels=1)
        thresh_mask=tf.cast(thresh_mask,tf.float32)/255.0
        thresh_mask=tf.reshape(thresh_mask,(dim,dim))
        ret["thresh_mask"]=thresh_mask
        # gt
        gt=parsed_example['gt']
        gt=tf.image.decode_png(gt,channels=1)
        gt=tf.cast(gt,tf.float32)/255.0
        gt=tf.reshape(gt,(dim,dim))
        ret["gt"]=gt
        # mask
        mask=parsed_example['mask']
        mask=tf.image.decode_png(mask,channels=1)
        mask=tf.cast(mask,tf.float32)/255.0
        mask=tf.reshape(mask,(dim,dim))
        ret["mask"]=mask
        return ret

    dataset = tf.data.TFRecordDataset(recs)
    dataset = dataset.map(_parser)
    dataset = dataset.shuffle(2048,reshuffle_each_iteration=True)
    dataset = dataset.repeat()
    dataset = dataset.batch(BATCH_SIZE,drop_remainder=True)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    return dataset

In [None]:
#-------------------------
# train and data paras
#-----------------------
EPOCHS          = 100
if strategy.num_replicas_in_sync==1:
    BATCH_SIZE = 16
else:
    BATCH_SIZE = 128 * strategy.num_replicas_in_sync

STEPS_PER_EPOCH = (len(train_recs))*1024//BATCH_SIZE
EVAL_STEPS      = (len(eval_recs)*1024)//BATCH_SIZE
print("Steps:",STEPS_PER_EPOCH)
print("Eval Steps:",EVAL_STEPS)

In [None]:
#-----------------------
# visual
#-----------------------
train_ds  =   data_input_fn(train_recs)
eval_ds   =   data_input_fn(eval_recs)
for ret in eval_ds.take(1):
    for k,v in ret.items():
        print(k)
        data=np.squeeze(v[0])
        plt.imshow(data)
        plt.show()
        print(f'{k} Batch Shape:',v.shape)
        

# Modeling

In [None]:
import tensorflow as tf
from tensorflow import keras as K
from tensorflow.keras import layers as KL

def balanced_crossentropy_loss(pred, gt, mask, negative_ratio=3.):
    pred = pred[..., 0]
    positive_mask = (gt * mask)
    negative_mask = ((1 - gt) * mask)
    positive_count = tf.reduce_sum(positive_mask)
    negative_count = tf.reduce_min([tf.reduce_sum(negative_mask), positive_count * negative_ratio])
    # loss_fun = tf.losses.BinaryCrossentropy()
    # loss = loss_fun(gt, pred)
    # loss = K.losses.binary_crossentropy(gt, pred)
    loss = K.backend.binary_crossentropy(gt, pred)
    positive_loss = loss * positive_mask
    negative_loss = loss * negative_mask
    negative_loss, _ = tf.nn.top_k(tf.reshape(negative_loss, (-1,)), tf.cast(negative_count, tf.int32))

    balanced_loss = (tf.reduce_sum(positive_loss) + tf.reduce_sum(negative_loss)) / (positive_count + negative_count + 1e-6)
    return balanced_loss, loss


def dice_loss(pred, gt, mask, weights):
    """
    Args:
        pred: (b, h, w, 1)
        gt: (b, h, w)
        mask: (b, h, w)
        weights: (b, h, w)
    Returns:
    """
    pred = pred[..., 0]
    weights = (weights - tf.reduce_min(weights)) / (tf.reduce_max(weights) - tf.reduce_min(weights) + 1e-6) + 1.
    mask = mask * weights
    intersection = tf.reduce_sum(pred * gt * mask)
    union = tf.reduce_sum(pred * mask) + tf.reduce_sum(gt * mask) + 1e-6
    loss = 1 - 2.0 * intersection / union
    return loss


def l1_loss(pred, gt, mask):
    pred = pred[..., 0]
    mask_sum = tf.reduce_sum(mask)
    loss = K.backend.switch(mask_sum > 0, tf.reduce_sum(tf.abs(pred - gt) * mask) / (mask_sum + 1e-6), tf.constant(0.))
    return loss


def compute_cls_acc(pred, gt, mask):

    zero = tf.zeros_like(pred, tf.float32)
    one = tf.ones_like(pred, tf.float32)

    pred = tf.where(pred < 0.3, x=zero, y=one)
    acc = tf.reduce_mean(tf.cast(tf.equal(pred * mask, gt * mask), tf.float32))

    return acc


def db_loss(args, alpha=5.0, beta=10.0, ohem_ratio=3.0):
    input_gt, input_mask, input_thresh, input_thresh_mask, binarize_map, thresh_binary, threshold_map = args

    threshold_loss = l1_loss(threshold_map, input_thresh, input_thresh_mask)
    binarize_loss, dice_loss_weights = balanced_crossentropy_loss(binarize_map, input_gt, input_mask, negative_ratio=ohem_ratio)
    thresh_binary_loss = dice_loss(thresh_binary, input_gt, input_mask, dice_loss_weights)

    model_loss = alpha * binarize_loss + beta * threshold_loss + thresh_binary_loss
    return model_loss


def db_acc(args):
    input_gt, input_mask, binarize_map, thresh_binary = args
    binarize_acc = compute_cls_acc(binarize_map, input_gt, input_mask)
    thresh_binary_acc = compute_cls_acc(thresh_binary, input_gt, input_mask)
    return binarize_acc, thresh_binary_acc

In [None]:
def DBNet(k=50,
          dim=512,
          outs=["conv2_block3_out",
                "conv3_block4_out",
                "conv4_block6_out",
                "conv5_block3_out"]):
    # input layer
    input_image = KL.Input(shape=[dim,dim, 3], name='image')

    backbone = K.applications.resnet50.ResNet50(input_tensor=input_image,weights='imagenet',include_top=False)
    C2, C3, C4, C5 = [backbone.get_layer(out).output for out in outs]

    # in2
    in2 = KL.Conv2D(256, (1, 1), padding='same', kernel_initializer='he_normal', name='in2')(C2)
    in2 = KL.BatchNormalization()(in2)
    in2 = KL.ReLU()(in2)
    # in3
    in3 = KL.Conv2D(256, (1, 1), padding='same', kernel_initializer='he_normal', name='in3')(C3)
    in3 = KL.BatchNormalization()(in3)
    in3 = KL.ReLU()(in3)
    # in4
    in4 = KL.Conv2D(256, (1, 1), padding='same', kernel_initializer='he_normal', name='in4')(C4)
    in4 = KL.BatchNormalization()(in4)
    in4 = KL.ReLU()(in4)
    # in5
    in5 = KL.Conv2D(256, (1, 1), padding='same', kernel_initializer='he_normal', name='in5')(C5)
    in5 = KL.BatchNormalization()(in5)
    in5 = KL.ReLU()(in5)

    # P5
    P5 = KL.Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(in5)
    P5 = KL.BatchNormalization()(P5)
    P5 = KL.ReLU()(P5)
    P5 = KL.UpSampling2D(size=(8, 8))(P5)
    # P4
    out4 = KL.Add()([in4, KL.UpSampling2D(size=(2, 2))(in5)])
    P4 = KL.Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(out4)
    P4 = KL.BatchNormalization()(P4)
    P4 = KL.ReLU()(P4)
    P4 = KL.UpSampling2D(size=(4, 4))(P4)
    # P3
    out3 = KL.Add()([in3, KL.UpSampling2D(size=(2, 2))(out4)])
    P3 = KL.Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(out3)
    P3 = KL.BatchNormalization()(P3)
    P3 = KL.ReLU()(P3)
    P3 = KL.UpSampling2D(size=(2, 2))(P3)
    # P2
    out2 = KL.Add()([in2, KL.UpSampling2D(size=(2, 2))(out3)])
    P2 = KL.Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(out2)
    P2 = KL.BatchNormalization()(P2)
    P2 = KL.ReLU()(P2)

    fuse = KL.Concatenate()([P2, P3, P4, P5])

    # binarize map
    p = KL.Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal', use_bias=False)(fuse)
    p = KL.BatchNormalization()(p)
    p = KL.ReLU()(p)
    p = KL.Conv2DTranspose(64, (2, 2), strides=(2, 2), kernel_initializer='he_normal', use_bias=False)(p)
    p = KL.BatchNormalization()(p)
    p = KL.ReLU()(p)
    binarize_map  = KL.Conv2DTranspose(1, (2, 2), strides=(2, 2), kernel_initializer='he_normal',activation='sigmoid', name='binarize_map')(p)

    # threshold map
    t = KL.Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal', use_bias=False)(fuse)
    t = KL.BatchNormalization()(t)
    t = KL.ReLU()(t)
    t = KL.Conv2DTranspose(64, (2, 2), strides=(2, 2), kernel_initializer='he_normal', use_bias=False)(t)
    t = KL.BatchNormalization()(t)
    t = KL.ReLU()(t)
    threshold_map  = KL.Conv2DTranspose(1, (2, 2), strides=(2, 2), kernel_initializer='he_normal',activation='sigmoid', name='threshold_map')(t)

    # thresh binary map
    thresh_binary = KL.Lambda(lambda x: 1 / (1 + tf.exp(-k * (x[0] - x[1]))))([binarize_map, threshold_map])

    input_gt = KL.Input(shape=[dim,dim], name='gt')
    input_mask = KL.Input(shape=[dim,dim], name='mask')
    input_thresh = KL.Input(shape=[dim,dim], name='thresh_map')
    input_thresh_mask = KL.Input(shape=[dim,dim], name='thresh_mask')

    loss_layer = KL.Lambda(db_loss, name='db_loss')([input_gt, input_mask, input_thresh, input_thresh_mask, binarize_map, thresh_binary, threshold_map])

    db_model = K.Model(inputs=[input_image, input_gt, input_mask, input_thresh, input_thresh_mask],outputs=[loss_layer])

    loss_names = ["db_loss"]
    for layer_name in loss_names:
        layer = db_model.get_layer(layer_name)
        db_model.add_loss(layer.output)
    return db_model


In [None]:
with strategy.scope():
    model=DBNet()
    model.compile(optimizer=K.optimizers.Adam(),loss=[None] * len(model.output.shape))
model.summary()

In [None]:
# reduces learning rate on plateau
lr_reducer = tf.keras.callbacks.ReduceLROnPlateau(factor=0.1,
                                                  cooldown= 10,
                                                  patience=3,
                                                  verbose =1,
                                                  min_lr=0.1e-7)
# early stopping
early_stopping = tf.keras.callbacks.EarlyStopping(patience=1, 
                                                  verbose=1, 
                                                  mode = 'auto') 


class SaveBestModel(tf.keras.callbacks.Callback):
    def __init__(self):
        self.best = float('inf')

    def on_epoch_end(self, epoch, logs=None):
        metric_value = logs['val_loss']
        if metric_value < self.best:
            print(f"Loss Improved epoch:{epoch} from {self.best} to {metric_value}")
            self.best = metric_value
            inp=self.model.get_layer("image").input
            out=self.model.get_layer('binarize_map').output
            net=tf.keras.Model(inputs=inp,outputs=out)
            net.save_weights(f"dbnet.h5")
            print("Saved Best Weights")
    def set_model(self, model):
        self.model = model
            
model_save=SaveBestModel()
model_save.set_model(model)
callbacks= [lr_reducer,early_stopping,model_save]

In [None]:
EPOCHS=5
history=model.fit(train_ds,
                  epochs=EPOCHS,
                  steps_per_epoch=STEPS_PER_EPOCH,
                  verbose=1,
                  validation_data=eval_ds,
                  validation_steps=EVAL_STEPS, 
                  callbacks=callbacks)