In [1]:
import tensorflow as tf
import tensorflow.keras.layers as layers
import tensorflow.keras.models as models
import tensorflow.keras.optimizers as optim

In [None]:
class AlexNet(models.Model):
    """AlexNet class for image task"""
    def __init__(self, output_dim):
        """Constructor"""
        super().__init__()
        
        PAD_TB = [pad_top, pad_bottom]
        PAD_LR = [pad_left, pad_right]
        padding0 = [[0, 0], PAD_TB, [0, 0], PAD_LR]
        padding1 = [[1, 1], PAD_TB, [1, 1], PAD_LR]
        padding2 = [[2, 2], PAD_TB, [2, 2], PAD_LR]
        
        self.inputs = layers.Input(shape=(None, 227, 227, 3))
        self.conv1 = layers.Conv2D(filters=96, kernel_size=11, strides=4, padding=padding0)  # (BS, 55, 55, 96)
        self.pool1 = layers.MaxPooling2D(pool_size=3, strides=2)  # (BS, 27, 27, 96)
        self.conv2 = layers.Conv2D(filters=256, kernel_size=5, strides=1, padding=padding2)  # (BS, 27, 27, 256)
        self.pool2 = layers.MaxPooling2D(pool_size=3, strides=2)  # (BS, 13, 13, 256)
        self.conv3 = layers.Conv2D(filters=384, kernel_size=3, strides=1, padding=padding1)  # (BS, 13, 13, 384)
        self.conv4 = layers.Conv2D(filters=384, kernel_size=3, strides=1, padding=padding1)  # (BS, 13, 13, 384)
        self.conv5 = layers.Conv2D(filters=256, kernel_size=3, strides=1, padding=padding1)  # (BS, 13, 13, 256)
        self.pool3 = layers.MaxPooling2D(pool_size=3, strides=2)  # (BS, 6, 6, 256)
        self.flatten = layers.Flatten()
        self.linear1 = layers.Dense(units=4096)
        self.dr1 = layers.Dropout(rate=0.5)
        self.linear2 = layers.Dense(units=4096)
        self.dr2 = layers.Dropout(rate=0.5)
        self.outputs = layers.Dense(units=output_dim)
        
        
    def call(self, x):
        inputs = self.inputs(x)
        hidden = layers.Activation('relu')(self.conv1(inputs))
        hidden = self.pool1(hidden)
        hidden = layers.Activation('relu')(self.conv2(hidden))
        hidden = self.pool2(hidden)
        hidden = layers.Activation('relu')(self.conv3(hidden))
        hidden = layers.Activation('relu')(self.conv4(hidden))
        hidden = layers.Activation('relu')(self.conv5(hidden))
        hidden = self.pool3(hidden)
        hidden = self.flatten(hidden)
        hidden = layers.Activation('relu')(self.linear1(hidden))
        hidden = self.dr1(hidden)
        hidden = layers.Activation('relu')(self.linear2(hidden))
        hidden = self.dr2(hidden)
        outputs = layers.Activation('softmax')(self.outputs(hidden))
        
        