In [None]:
# alexnet.py

import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import (
    Conv2D,
    MaxPooling2D,
    Flatten,
    Dense,
    Dropout,
    BatchNormalization,
    Activation,
)
from tensorflow.keras.regularizers import l2


class AlexNet(Sequential):
    def __init__(
        self,
        input_shape=(224, 224, 3),
        num_classes=1000,
        first_conv_stride=2,
        weight_decay=5e-4,
        name="AlexNet_modern",
    ):
        """
        Modern AlexNet-like Sequential model.
        Args:
            input_shape: tuple, e.g. (224,224,3)
            num_classes: int, number of output classes
            first_conv_stride: int, stride for the first Conv2D (original AlexNet used 4).
                               Default is 2 to preserve more spatial info.
                               Set to 4 to match original downsampling.
            weight_decay: float, L2 regularization factor applied to conv + dense kernels.
        """
        super().__init__(name=name)

        reg = l2(weight_decay)
        kernel_init = "he_normal"

        # 1) First conv block
        # Note: we set use_bias=False because BatchNormalization follows
        self.add(
            Conv2D(
                filters=96,
                kernel_size=(11, 11),
                strides=first_conv_stride,
                padding="valid",
                kernel_initializer=kernel_init,
                kernel_regularizer=reg,
                use_bias=False,
                input_shape=input_shape,
            )
        )
        self.add(BatchNormalization())
        self.add(Activation("relu"))
        self.add(MaxPooling2D(pool_size=(3, 3), strides=2, padding="same"))

        # 2) Second conv block
        self.add(
            Conv2D(
                filters=256,
                kernel_size=(5, 5),
                strides=1,
                padding="same",
                kernel_initializer=kernel_init,
                kernel_regularizer=reg,
                use_bias=False,
            )
        )
        self.add(BatchNormalization())
        self.add(Activation("relu"))
        self.add(MaxPooling2D(pool_size=(3, 3), strides=2, padding="same"))

        # 3) Third, fourth, fifth conv layers
        self.add(
            Conv2D(
                filters=384,
                kernel_size=(3, 3),
                strides=1,
                padding="same",
                kernel_initializer=kernel_init,
                kernel_regularizer=reg,
                use_bias=False,
            )
        )
        self.add(BatchNormalization())
        self.add(Activation("relu"))

        self.add(
            Conv2D(
                filters=384,
                kernel_size=(3, 3),
                strides=1,
                padding="same",
                kernel_initializer=kernel_init,
                kernel_regularizer=reg,
                use_bias=False,
            )
        )
        self.add(BatchNormalization())
        self.add(Activation("relu"))

        self.add(
            Conv2D(
                filters=256,
                kernel_size=(3, 3),
                strides=1,
                padding="same",
                kernel_initializer=kernel_init,
                kernel_regularizer=reg,
                use_bias=False,
            )
        )
        self.add(BatchNormalization())
        self.add(Activation("relu"))
        self.add(MaxPooling2D(pool_size=(3, 3), strides=2, padding="same"))

        # 4) Classifier
        self.add(Flatten())
        self.add(
            Dense(
                4096,
                activation="relu",
                kernel_initializer=kernel_init,
                kernel_regularizer=reg,
            )
        )
        self.add(Dropout(0.5))
        self.add(
            Dense(
                4096,
                activation="relu",
                kernel_initializer=kernel_init,
                kernel_regularizer=reg,
            )
        )
        self.add(Dropout(0.5))
        self.add(Dense(num_classes, activation="softmax", kernel_initializer=kernel_init))


if __name__ == "__main__":
    # Example usage: prints summary in any environment (Colab/VSCode/Spyder)
    # You can change first_conv_stride to 4 to emulate original AlexNet downsampling.
    model = AlexNet(input_shape=(224, 224, 3), num_classes=1000, first_conv_stride=2)
    model.summary()


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
