In [1]:
import tensorflow as tf

In [161]:
class BottleneckLayer(tf.keras.layers.Layer):
    def __init__(self, n_filters:int):
        super(BottleneckLayer, self).__init__()
        self.n_filters = n_filters
        
        self.BN1 = tf.keras.layers.BatchNormalization()
        self.Conv1 = tf.keras.layers.Conv2D(filters = self.n_filters * 4,
                                            kernel_size = (1,1),
                                            padding = 'valid',
                                            use_bias = False
                                           )
        self.BN2 = tf.keras.layers.BatchNormalization()
        self.Conv2 = tf.keras.layers.Conv2D(filters = self.n_filters,
                                            kernel_size = (3,3),
                                            padding = 'same',
                                            use_bias = False
                                           )
    
    def call(self, X):
        y = self.BN1(X)
        y = tf.nn.relu(y)
        y = self.Conv1(y)
        
        y = self.BN2(y)
        y = tf.nn.relu(y)
        y = self.Conv2(y)
        y = tf.concat([X, y], axis = -1)
        return y
    
    
class DenseBlock(tf.keras.layers.Layer):
    def __init__(self, n_filters:int, n_layers:int):
        super(DenseBlock, self).__init__()
        self.n_filters = n_filters
        self.n_layers = n_layers
        
        self.Blocks = tf.keras.Sequential([
            BottleneckLayer(self.n_filters) for _ in range(self.n_layers)
        ])
        
    def call(self, X):
        y = self.Blocks(X)
        return y
    

class TransitionLayer(tf.keras.layers.Layer):
    def __init__(self, out_channel:int):
        super(TransitionLayer, self).__init__()
        self.out_channel = out_channel
        
        self.BN1 = tf.keras.layers.BatchNormalization()
        self.Conv1 = tf.keras.layers.Conv2D(self.out_channel,
                                            kernel_size = (1,1),
                                            padding = 'valid',
                                            use_bias = False
                                           )
        self.AP = tf.keras.layers.AveragePooling2D(pool_size = (2,2),
                                                   strides = (2,2)
                                                  )
        
    def call(self, X):
        y = self.BN1(X)
        y = tf.nn.relu(y)
        y = self.Conv1(y)
        y = self.AP(y)
        return y

In [162]:
class DenseNet(tf.keras.models.Model):
    def __init__(self, n_layers, n_labels:int, k:int = 12 , output_activation:str='softmax'):
        super(DenseNet, self).__init__()
        self.k = k
        self.n_layers = n_layers
        self.output_channels = 64
        self.n_labels = n_labels
        self.output_activation = output_activation
        
        self.Conv1 = tf.keras.layers.Conv2D(64,
                                            kernel_size = (7,7),
                                            strides = (2,2),
                                            padding = 'same'
                                           )
        self.MP1 = tf.keras.layers.MaxPool2D(pool_size = (3,3),
                                             strides = (2,2),
                                             padding = 'same'
                                            )
        self.Blocks = tf.keras.Sequential()
        for i in range(4):
            self.output_channels = int((self.output_channels + self.n_layers[i] * (64 + (self.k * i)))/2)
            self.Blocks.add(DenseBlock(64 + self.k * i, self.n_layers[i]))
            if i != 3:
                self.Blocks.add(TransitionLayer(self.output_channels))
                
                
        self.classifier = tf.keras.Sequential([
            tf.keras.layers.GlobalAveragePooling2D(),
            tf.keras.layers.Dense(self.n_labels, activation = self.output_activation)
        ])
        
    def call(self, X):
        y = self.Conv1(X)
        y = self.MP1(y)
        y = self.Blocks(y)
        y = self.classifier(y)
        return y

In [163]:
densenet = DenseNet([6, 12, 24, 16], 1000)

In [164]:
densenet.build([16,224,224,3])

In [165]:
densenet.summary()

Model: "dense_net_60"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_3666 (Conv2D)         multiple                  9472      
_________________________________________________________________
max_pooling2d_59 (MaxPooling multiple                  0         
_________________________________________________________________
sequential_186 (Sequential)  (16, 7, 7, 2940)          49996320  
_________________________________________________________________
sequential_191 (Sequential)  (16, 1000)                2941000   
Total params: 52,946,792
Trainable params: 52,737,384
Non-trainable params: 209,408
_________________________________________________________________
