In [1]:
from typing import List

import tensorflow as tf
from tensorflow.keras.layers import Input, Add, Activation, Concatenate, Dropout, BatchNormalization, Dense, Conv2D, MaxPool2D, AveragePooling2D, GlobalMaxPooling2D
from tensorflow.keras.models import Model

In [2]:
class BottleneckBlock(Model):
    """Bottleneck Block(reference ResNet)
    
    arguments
    ----------
    in_channel(int) : number of channels of input tensor
    out_channle(int) : number of channles of output tensor
    
    Typically, in_channel equal out_channel but can run if not equal
    """
    def __init__(self, in_channel, out_channel):
        """initialize and define layers
        
        arguments
        ----------
        in_channel(int) : input tensor channel size
        out_channel(int) : output tensor channel size
        
        params
        ----------
        hidden_channel : filter size of input and center layer
        bn : tf.keras.layers BatchNormalization
        av : tf.keras.layers Activation
        in_conv : input convolution layer
        hidden_conv : center convolution layer
        out_conv : output convolution layer
        shortcut : part of shortcut(input to output)
        add : tf.keras.layters Add
        """
        
        super().__init__()
        
        self.hidden_channel = in_channel // 4
        self.bn1 = BatchNormalization()
        self.bn2 = BatchNormalization()
        self.bn3 = BatchNormalization()
        self.av1 = Activation(tf.nn.relu)
        self.av2 = Activation(tf.nn.relu)
        self.av3 = Activation(tf.nn.relu)
        
        self.in_conv = Conv2D(filters=self.hidden_channel, kernel_size=1, strides=1, padding='valid')
        self.hidden_conv = Conv2D(filters=self.hidden_channel, kernel_size=3, strides=1, padding='same')
        self.out_conv = Conv2D(filters=out_channel, kernel_size=1, strides=1, padding='valid')
        
        self.shortcut = self._scblock(in_channel, out_channel)
        self.add = Add()
        
        
    def _scblock(self, in_channel, out_channel):
        """making shortcut block
        
        arguments
        ----------
        in_channel(int) : input tensor channel size
        out_channel(int) : output tensor channel size
        
        params(if in_channel not equal out_channel)
        ----------
        bn_sc : tf.keras.layers BatchNormalization
        conv_sc : tf.keras layers Conv2D
        
        return
        ----------
        conv_sc : if in_channel not equal out_channel, fit channes size and return 
        x : if in_channel equal out_channel, return x not processeing
        """
        
        if in_channel != out_channel:
            self.bn_sc = BatchNormalization()
            self.conv_sc = Conv2D(flters=out_channel, kernel_size=1, strides=1, padding='same')
            return self.conv_sc
        
        else:
            return lambda x : x
        
        
    def call(self, x):
        """to construct Bottleneck Block"""
        inputs = self.in_conv(self.av1(self.bn1(x)))
        hidden = self.hidden_conv(self.av2(self.bn2(inputs)))
        outputs = self.out_conv(self.av3(self.bn3(hidden)))
        shortcut = self.shortcut(x)
        outputs = self.add([outputs, shortcut])
        
        return outputs

In [3]:
class DenseBlock(Model):
    """Dense Block(reference Dense Net)
    
    define : inputs_channel=n, growth_rate=k.
             outputs_channel=n+k
             
    example : input_shape(None, 28, 28, 16)
              growth_rate=16
              output_shape(None, 28, 28, 32)
    """
    
    def __init__(self, growth_rate):
        """initialize and define layers
        
        argument
        ----------
        :growth_rate  : Number of output channel equal number of input channel add this
        
        params
        ----------
        :bn : tf.keras.layers BatchNormalization
        :av : tf.keras.layers Activate
        :in_conv : input convolution layer
        :out_conv : output convolution layer
        :concat : tf.keras.layers Concatenate
        """
        
        super().__init__()
        
        self.bn1 = BatchNormalization()
        self.bn2 = BatchNormalization()
        self.av1 = Activation(tf.nn.relu)
        self.av2 = Activation(tf.nn.relu)
        self.in_conv = Conv2D(filters=128, kernel_size=1, strides=1, padding='same')
        self.out_conv = Conv2D(filters=growth_rate, kernel_size=3, strides=1, padding='same')
        self.concat = Concatenate()
        
        
    def call(self, x):
        """to construct Dense Block"""
        inputs = self.in_conv(self.av1(self.bn1(x)))
        outputs = self.out_conv(self.av2(self.bn2(inputs)))
        outputs = self.concat([x, outputs])
        
        return outputs