<a href="https://colab.research.google.com/github/mjmousavi97/Deep-Learning-Tehran-uni/blob/main/HomeWorks/03%20HW/src/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class SCNNB(nn.Module):
    """
    Shallow Convolutional Neural Network with Batch Normalization (SCNNB).

    Architecture (for input 28x28x1 like MNIST):
        - Conv(3x3, 32 filters, padding=1) + BN + ReLU
        - MaxPool(2x2)
        - Conv(3x3, 64 filters, padding=1) + BN + ReLU
        - MaxPool(2x2)
        - Flatten
        - FC(3136 -> 1280) + ReLU
        - Dropout(0.5)
        - FC(1280 -> num_classes)

    Notes:
        - Padding=1 keeps spatial size before pooling.
        - For MNIST (1x28x28) → Flatten=3136 (64*7*7).
        - For CIFAR-10 (3x32x32) → Flatten=4096 (64*8*8).
        - Forward returns raw logits; use nn.CrossEntropyLoss for training.
    """

    def __init__(self, in_channels: int = 1, num_classes: int = 10, input_size: int = 28):
        """
        Initialize the SCNNB model.

        Args:
            in_channels (int): Number of input channels (1 for grayscale, 3 for RGB).
            num_classes (int): Number of output classes.
            input_size (int): Height/Width of the input image (assumed square).
        """
        super().__init__()
        self.input_size  = input_size
        self.in_channels = in_channels
        self.num_classes = num_classes

        # ---- Feature extractor ----
        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        # ---- Classifier ----
        feat_dim   = self._calc_feat_dim()  # Dynamically compute flattened feature dimension
        self.fc1   = nn.Linear(feat_dim, 1280)
        self.drop1 = nn.Dropout(p=0.5)
        self.fc_out= nn.Linear(1280, num_classes)

    def _calc_feat_dim(self) -> int:
        """
        Compute the number of features after the convolution + pooling layers.
        This ensures the fully connected layer is correctly sized for any input.

        Returns:
            int: Flattened feature dimension.
        """
        with torch.no_grad():
            device = next(self.parameters()).device  # Match the device of the model
            tmp = torch.zeros((1, self.in_channels, self.input_size, self.input_size), device=device)
            tmp = self.pool1(F.relu(self.bn1(self.conv1(tmp))))
            tmp = self.pool2(F.relu(self.bn2(self.conv2(tmp))))
            return tmp.numel()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of SCNNB.

        Args:
            x (torch.Tensor): Input tensor of shape (N, C, H, W).

        Returns:
            torch.Tensor: Logits of shape (N, num_classes).
        """
        # ---- Feature extraction ----
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))

        # ---- Classification head ----
        x = torch.flatten(x, start_dim=1)   # Flatten to (N, feat_dim)
        x = F.relu(self.fc1(x))             # Fully connected hidden layer
        x = self.drop1(x)                   # Dropout for regularization
        logits = self.fc_out(x)             # Output logits (no softmax)
        return logits
