In [100]:
import tensorflow as tf

In [101]:
class Darknet_BN_Leaky(tf.keras.Model):
    def __init__(self, 
                 filter, 
                 kernel, 
                 strides = 1,
                 padding = 'valid'):
        
        
        self.conv = tf.keras.layers.Conv2D(filter,
                                            kernel,
                                            strides = strides,
                                            padding = padding)
                                               
        self.batch_norm = tf.keras.layers.BatchNormalization()
        self.lrelu = tf.keras.layers.LeakyReLU()

    def call(self, input_tensor, training=False):
        x = self.conv(input_tensor)
        x = self.batch_norm(x, training=training)
        x = self.lrelu(x)
        return x

In [102]:
class Res_unit(tf.keras.Model):
    def __init__(self,  
                 filters,
                 strides = 1):

        self.conv1_1 = Darknet_BN_Leaky(filter = filters,
                                        kernel = 1)

        self.conv1_2 = Darknet_BN_Leaky(filter = filters * 2,
                                        kernel = 3,
                                        padding = 'same'
                                        )
        

    def call(self, input_tensor):
        
        x = self.conv1_1(input_tensor)
        x = self.conv1_2(x)

        x += input_tensor

        return x 

In [103]:
class ResBlock_N(tf.keras.Model):

    def __init__(self,  
                 filters, 
                 kernel):
        
        self.DBL = tf.keras.layers.Conv2D(filters[0],
                                          kernel)

        self.res_body = []

        for ind in range(1, len(filters)):
            self.res_body.append(Res_unit(filters[ind]))

    def call(self, input_tensor):

        x = self.DBL(input_tensor)

        for Res_Unit in self.res_body:
            x = Res_Unit(x)

        return x

In [1]:
from tensorflow.keras.layers import Input

def build_model(image_height, image_width):

    input_1 = Input(shape = (image_height, image_width), name = 'Input')

    DBL_1 = Darknet_BN_Leaky(32, 3)(input_1)

    res1 = ResBlock_N([32, 64], 3)(DBL_1)
    res2 = ResBlock_N([64, 64, 128], 3)(res1)
    res8 = ResBlock_N([128, 128, 256, 256, 512, 512, 1024, 1024], 3)(res2)
    res4 = ResBlock_N([1024, 2048, 2048, 2048, 2048 ], 3)(res8)

    
    DBL_2 = Darknet_BN_Leaky(2048, 3)(res4)
    DBL_3 = Darknet_BN_Leaky(2048, 3)(DBL_2)
    DBL_4 = Darknet_BN_Leaky(2048, 3)(DBL_3)
    DBL_5 = Darknet_BN_Leaky(2048, 3)(DBL_4)
    DBL_6 = Darknet_BN_Leaky(2048, 3)(DBL_5)

    model = Model(inputs = input_1, outputs = DBL_6)

    return model
