In [1]:
import tensorflow as tf

In [12]:
class Convolution_layer(tf.keras.layers.Layer):
    def __init__(self, n_conv:int,
                 n_filters:int):
        super(Convolution_layer, self).__init__()
        self.n_conv = n_conv
        self.n_filters = n_filters
        
        self.Conv = [tf.keras.layers.Conv2D(self.n_filters,
                                            3, 
                                            padding = 'same',
                                            activation = 'relu')
                    for _ in range(self.n_conv)]
        self.MP = tf.keras.layers.MaxPool2D(2,
                                            2, 
                                            padding = 'same',)
        
    def call(self, X):
        for conv in self.Conv:
            y = conv(X)
        y = self.MP(y)
        return y

In [13]:
class VGGNet16(tf.keras.models.Model):
    def __init__(self,
                n_labels:int,
                last_activation:str):
        super(VGGNet16, self).__init__()
        self.n_labels = n_labels
        self.last_activation = last_activation
        
        self.Conv = tf.keras.Sequential([
            Convolution_layer(2, 64),
            Convolution_layer(2, 128),
            Convolution_layer(3, 256),
            Convolution_layer(3, 512),
            Convolution_layer(3, 512)
        ])
        self.FC = tf.keras.Sequential([
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(4096, activation = 'relu'),
            tf.keras.layers.Dropout(.5),
            tf.keras.layers.Dense(4096, activation = 'relu'),
            tf.keras.layers.Dropout(.5),
            tf.keras.layers.Dense(self.n_labels, activation = self.last_activation)
        ])
    
    def call(self, X):
        y = self.Conv(X)
        y = self.FC(y)
        return y

In [14]:
vgg = VGGNet16(10, 'softmax')

In [15]:
vgg.build([16,224,224,3])

In [16]:
vgg.summary()

Model: "vgg_net16_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
sequential_4 (Sequential)    (16, 7, 7, 512)           11656704  
_________________________________________________________________
sequential_5 (Sequential)    (16, 10)                  119586826 
Total params: 131,243,530
Trainable params: 131,243,530
Non-trainable params: 0
_________________________________________________________________
