# GoogLeNet

This notebook is for the blog: [GoogLeNet: Going Deeper with Inception](https://derekzhouai.github.io/posts/googlenet/)

In [2]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt

## Model Implementation

In [None]:
class Inception(nn.Module):
    def __init__(self, in_channels, out_1_1x1, out_2_1x1, out_2_3x3, out_3_1x1, out_3_5x5, out_4_pool):
        super().__init__()
        # 1x1 conv branch
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_channels, out_1_1x1, kernel_size=1),
            nn.ReLU(inplace=True)
        )
        # 1x1 conv -> 3x3 conv branch
        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels, out_2_1x1, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_2_1x1, out_2_3x3, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        # 1x1 conv -> 5x5 conv branch
        self.branch3 = nn.Sequential(
            nn.Conv2d(in_channels, out_3_1x1, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_3_1x1, out_3_5x5, kernel_size=5, padding=2),
            nn.ReLU(inplace=True)
        )
        # 3x3 max pool -> 1x1 conv branch
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels, out_4_pool, kernel_size=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        branch1_out = self.branch1(x)
        branch2_out = self.branch2(x)
        branch3_out = self.branch3(x)
        branch4_out = self.branch4(x)
        outputs = [branch1_out, branch2_out, branch3_out, branch4_out]
        return torch.cat(outputs, 1)

In [None]:
class GoogLeNet(nn.Module):
    def __init__(self, input_channels=3, num_classes=1000):
        super().__init__()
        self.net = nn.Sequential(
            # Stage 1
            nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3), # input_channelsx224x224 → 64x112x112
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, stride=2, padding=1),                 # 64x112x112 → 64x56x56

            # Stage 2
            nn.Conv2d(64, 64, kernel_size=1),                     # 64x56x56 → 64x56x56
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 192, kernel_size=3, padding=1),         # 64x56x56 → 192x56x56
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, stride=2, padding=1),                 # 192x56x56 → 192x28x28

            # Stage 3
            Inception(192, 64, 96, 128, 16, 32, 32),              # 192x28x28 → 256x28x28
            Inception(256, 128, 128, 192, 32, 96, 64),            # 256x28x28 → 480x28x28
            nn.MaxPool2d(3, stride=2, padding=1),                 # 480x28x28 → 480x14x14

            # Stage 4
            Inception(480, 192, 96, 208, 16, 48, 64),             # 480x14x14 → 512x14x14
            Inception(512, 160, 112, 224, 24, 64, 64),            # 512x14x14 → 512x14x14
            Inception(512, 128, 128, 256, 24, 64, 64),            # 512x14x14 → 512x14x14
            Inception(512, 112, 144, 288, 32, 64, 64),            # 512x14x14 → 528x14x14
            Inception(528, 256, 160, 320, 32, 128, 128),          # 528x14x14 → 832x14x14
            nn.MaxPool2d(3, stride=2, padding=1),                 # 832x14x14 → 832x7x7

            # Stage 5
            Inception(832, 256, 160, 320, 32, 128, 128),          # 832x7x7 → 832x7x7
            Inception(832, 384, 192, 384, 48, 128, 128),          # 832x7x7 → 1024x7x7
            nn.AdaptiveAvgPool2d((1,1)),                          # 1024x7x7 → 1024x1x1

            nn.Flatten(),                                         # 1024x1x1 → 1024
            nn.Linear(1024, num_classes)                          # 1024 → num_classes
        )

    def forward(self, x):
        return self.net(x)

In [None]:
X = torch.randn(1, 3, 224, 224)
for layer in GoogLeNet().net:
    X = layer(X)
    print(layer.__class__.__name__, 'output shape:\t', X.shape)

## Example: Model Training on FashionMNIST

### Preparing the Dataset

In [None]:
def get_data_loaders(batch_size=128):
    transform = transforms.Compose([
        transforms.Resize(224),    # Upscale 28x28 → 224x224
        transforms.ToTensor()
    ])

    train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader

In [None]:
train_loader, test_loader = get_data_loaders(batch_size=128)
print(f"Number of training samples: {len(train_loader.dataset)}")
print(f"Number of test samples: {len(test_loader.dataset)}")

In [None]:
X, y = next(iter(train_loader))
print(f"X.shape: {X.shape}")
print(f"y.shape: {y.shape}")

### Training the Model

In [None]:
def evaluate(model, loader, loss, device):
    model.eval()
    total_loss, total_correct, total_num = 0.0, 0, 0

    with torch.no_grad():
        for X, y in loader:
            X, y = X.to(device), y.to(device)
            
            y_hat = model(X)
            l = loss(y_hat, y)

            total_loss += l.item() * X.size(0)
            total_correct += (y_hat.argmax(dim=1) == y).sum().item()
            total_num += X.size(0)
            
    return total_loss / total_num, total_correct / total_num

In [None]:
def train(model, num_epochs, batch_size, lr, device):
    model.to(device)
    
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    loss = nn.CrossEntropyLoss()

    train_loader, test_loader = get_data_loaders(batch_size)

    for epoch in range(num_epochs):
        model.train()
        total_loss, total_correct, total_num = 0.0, 0, 0

        for X, y in train_loader:
            X, y = X.to(device), y.to(device)
        
            optimizer.zero_grad()
            y_hat = model(X)
            l = loss(y_hat, y)
            l.backward()
            optimizer.step()

            total_loss += l.item() * X.size(0)
            total_correct += (y_hat.argmax(dim=1) == y).sum().item()
            total_num += X.size(0)

        train_loss = total_loss / total_num
        train_acc = total_correct / total_num

        test_loss, test_acc = evaluate(model, test_loader, loss, device)
        print(f"Epoch {epoch+1}/{num_epochs}: "
              f"Train => Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | "
              f"Test => Loss: {test_loss:.4f}, Acc: {test_acc:.4f}")

In [None]:
device = (
    torch.device("cuda") if torch.cuda.is_available()
    else torch.device("mps") if torch.backends.mps.is_available()
    else torch.device("cpu")
)
# model = GoogLeNet(input_channels=1, num_classes=10)  # FashionMNIST has 1 channel and 10 classes
model = models.GoogLeNet(init_weights=None)
model.fc = nn.Linear(1024, 10)
num_epochs = 10
batch_size = 128
lr = 0.1

train(model, num_epochs, batch_size, lr, device)

### Evaluating the Model

In [None]:
# FashionMNIST class names
FASHION_CLASSES = [
    "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
    "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"
]

def show_fashionmnist_preds(model, test_loader, device, n=8):
    """
    Show n FashionMNIST test images with predicted and true labels.
    """
    model.to(device)
    model.eval()
    with torch.no_grad():
        X, y = next(iter(test_loader))
        X, y = X[:n], y[:n]
        pred = model(X.to(device)).argmax(1).cpu()

    plt.figure(figsize=(2*n, 2.6))
    for i in range(n):
        plt.subplot(1, n, i + 1)
        img = X[i].squeeze().cpu()
        plt.imshow(img, cmap="gray")
        p_idx, t_idx = pred[i].item(), y[i].item()
        plt.title(f"P:{FASHION_CLASSES[p_idx]}\nT:{FASHION_CLASSES[t_idx]}", fontsize=9)
        plt.axis("off")
    plt.tight_layout()
    plt.show()

show_fashionmnist_preds(model, test_loader, device, n=8)