# Multi-Branch Networks (GoogLeNet)
- Minh Nguyen
- 1/25/2025
- Idea: The core innovation of GoogLeNet is the Inception module, which allows the network to efficiently capture features at multiple scales. Each module consists of parallel convolutional layers with different filter sizes (1x1, 3x3, 5x5) and a max pooling operation. Their outputs are concatenated, providing a ricker feature representation.
- Key Features:
    - Using repeated NiN, reapeted block and cocktail of convolution kernels.
    - Design pattern: the stem is given by the first 2 or 3 convolutions that operate on the image, they extract low-level features from the underlying images. This is followed by a body of convolutional blocks. Finally, the head maps the features obtained so far to the required classification, segmentation, detection, or tracking problem at handd.

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchsummary import summary

# -----------------------------
# Inception Module (Simplified)
# -----------------------------
class SimpleInception(nn.Module):
    def __init__(self, in_channels):
        super(SimpleInception, self).__init__()
        # 1x1 convolution branch
        self.branch1x1 = nn.Conv2d(in_channels, 16, kernel_size=1)

        # 3x3 convolution branch
        self.branch3x3 = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=3, padding=1)
        )

        # 5x5 convolution branch
        self.branch5x5 = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=5, padding=2)
        )

        # Max pooling branch followed by 1x1 convolution
        self.branch_pool = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels, 16, kernel_size=1)
        )

    def forward(self, x):
        branch1x1 = self.branch1x1(x)
        branch3x3 = self.branch3x3(x)
        branch5x5 = self.branch5x5(x)
        branch_pool = self.branch_pool(x)
        
        # Concatenate all branches along the channel dimension
        outputs = torch.cat([branch1x1, branch3x3, branch5x5, branch_pool], dim=1)
        return outputs

# -----------------------------
# Simple Network with Inception
# -----------------------------
class SimpleInceptionNet(nn.Module):
    def __init__(self):
        super(SimpleInceptionNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)  # Initial Convolution
        self.inception1 = SimpleInception(16)  # First Inception block
        self.inception2 = SimpleInception(64)  # Second Inception block
        self.fc = nn.Linear(64 * 7 * 7, 10)  # Fully connected layer for classification

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.inception1(x)
        x = nn.functional.max_pool2d(x, kernel_size=2, stride=2)  # Downsample
        x = self.inception2(x)
        x = nn.functional.max_pool2d(x, kernel_size=2, stride=2)  # Downsample
        x = x.view(x.size(0), -1)  # Flatten
        x = self.fc(x)
        return x

# -----------------------------
# Data Loading and Preparation
# -----------------------------
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# -----------------------------
# View Model Architecture
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleInceptionNet().to(device)

# Use torchsummary to view architecture
summary(model, input_size=(1, 28, 28))

# -----------------------------
# Training Setup
# -----------------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# -----------------------------
# Training Loop
# -----------------------------
epochs = 5
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader):.4f}")

# -----------------------------
# Testing Loop
# -----------------------------
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 28, 28]             160
            Conv2d-2           [-1, 16, 28, 28]             272
            Conv2d-3           [-1, 16, 28, 28]           2,320
            Conv2d-4           [-1, 16, 28, 28]           6,416
         MaxPool2d-5           [-1, 16, 28, 28]               0
            Conv2d-6           [-1, 16, 28, 28]             272
   SimpleInception-7           [-1, 64, 28, 28]               0
            Conv2d-8           [-1, 16, 14, 14]           1,040
            Conv2d-9           [-1, 16, 14, 14]           9,232
           Conv2d-10           [-1, 16, 14, 14]          25,616
        MaxPool2d-11           [-1, 64, 14, 14]               0
           Conv2d-12           [-1, 16, 14, 14]           1,040
  SimpleInception-13           [-1, 64, 14, 14]               0
           Linear-14                   

In [4]:
from torchviz import make_dot

# Example input for the model
x = torch.randn(1, 1, 28, 28).to(device)

# Get the graph of the model
model_graph = make_dot(model(x), params=dict(model.named_parameters()))

# Save the graph or render it
model_graph.render("images/8_4_inception_net", format="png")


'images/8_4_inception_net.png'