Paper link: https://arxiv.org/pdf/1512.00567

Problem: Although VGGNet has the compelling feature of
architectural simplicity, this comes at a high cost: evaluating the network requires a lot of computation. 

The original motivation was to push useful gradients to the lower layers to make them immediately useful and improve the convergence during training by combating the vanishing gradient problem in very deep networks.

The sources explain that a standard n×n convolution can be factorized into two successive asymmetric convolutions.

Sliding a 3×1 convolution followed by a 1×3 convolution results in the same receptive field as a single 3×3 convolution

In [5]:
import torch
from torch import nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Subset

class InceptionBlock1(nn.Module):
    def __init__(self, in_channels, ch1, ch2, ch3, ch4):
        """
        ch1, ch2, ch3, ch4 = output channels for each of the 4 paths
        Total output channels = ch1 + ch2 + ch3 + ch4
        """
        super().__init__()

        self.path1 = nn.Sequential(
            nn.Conv2d(in_channels, ch1//2, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.Conv2d(ch1//2, ch1//2, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(ch1//2, ch1, kernel_size=3, padding=1),
            nn.ReLU()            
        )

        self.path2 = nn.Sequential(
            nn.Conv2d(in_channels, ch2, kernel_size=1, padding=0),
            nn.ReLU()            
        )

        self.path3 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, padding=1, stride=1),
            nn.Conv2d(in_channels, ch3, kernel_size=1),
            nn.ReLU()            
        )

        self.path4 = nn.Sequential(
            nn.Conv2d(in_channels, ch4//2, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.Conv2d(ch4//2, ch4, kernel_size=3, padding=1),
            nn.ReLU(),            
        )

    def forward(self, x):
        return torch.cat([self.path1(x), self.path2(x), self.path3(x), self.path4(x)], dim=1)

class InceptionBlock2(nn.Module):
    def __init__(self, in_channels, ch1, ch2, ch3, ch4, kernel_size=7):
        super().__init__()
        pad = kernel_size // 2

        self.path1 = nn.Sequential(
            nn.Conv2d(in_channels, ch1//4, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.Conv2d(ch1//4, ch1//4, kernel_size=(1, kernel_size), padding=(0, pad)),
            nn.ReLU(),
            nn.Conv2d(ch1//4, ch1//4, kernel_size=(kernel_size, 1), padding=(pad, 0)),
            nn.ReLU(),
            nn.Conv2d(ch1//4, ch1//4, kernel_size=(1, kernel_size), padding=(0, pad)),
            nn.ReLU(),
            nn.Conv2d(ch1//4, ch1, kernel_size=(kernel_size, 1), padding=(pad, 0)),
            nn.ReLU()
        )

        self.path2 = nn.Sequential(
            nn.Conv2d(in_channels, ch2, kernel_size=1, padding=0),
            nn.ReLU()            
        )

        self.path3 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, padding=1, stride=1),
            nn.Conv2d(in_channels, ch3, kernel_size=1),
            nn.ReLU()            
        )

        self.path4 = nn.Sequential(
            nn.Conv2d(in_channels, ch4//2, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.Conv2d(ch4//2, ch4//2, kernel_size=(1, kernel_size), padding=(0, pad)),
            nn.ReLU(),
            nn.Conv2d(ch4//2, ch4, kernel_size=(kernel_size, 1), padding=(pad, 0)),
            nn.ReLU(),                
        )

    def forward(self, x):
        return torch.cat([self.path1(x), self.path2(x), self.path3(x), self.path4(x)], dim=1)

class InceptionBlock3(nn.Module):
    def __init__(self, in_channels, ch1, ch2, ch3, ch4, ch5, ch6):
        super().__init__()

        self.path1 = nn.Sequential(
            nn.Conv2d(in_channels, ch1, kernel_size=1, padding=0),
            nn.ReLU()             
        )

        self.path2 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, padding=1, stride=1),
            nn.Conv2d(in_channels, ch2, kernel_size=1),
            nn.ReLU()            
        )

        self.path3 = nn.Sequential(
            nn.Conv2d(in_channels, ch3//2, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.Conv2d(ch3//2, ch3, kernel_size=(1, 3), padding=(0, 1)),
            nn.ReLU(),                        
        )

        self.path4 = nn.Sequential(
            nn.Conv2d(in_channels, ch4//2, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.Conv2d(ch4//2, ch4, kernel_size=(3, 1), padding=(1, 0)),
            nn.ReLU(),                        
        )

        self.path5 = nn.Sequential(
            nn.Conv2d(in_channels, ch5//4, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.Conv2d(ch5//4, ch5//2, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(ch5//2, ch5, kernel_size=(1, 3), padding=(0, 1)),
            nn.ReLU(), 
        )

        self.path6 = nn.Sequential(
            nn.Conv2d(in_channels, ch6//4, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.Conv2d(ch6//4, ch6//2, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(ch6//2, ch6, kernel_size=(3, 1), padding=(1, 0)),
            nn.ReLU(), 
        )   

    def forward(self, x):
        return torch.cat([self.path1(x), self.path2(x), self.path3(x), 
                         self.path4(x), self.path5(x), self.path6(x)], dim=1)

In [8]:
class InceptionV3(nn.Module):
    def __init__(self):
        super().__init__()

        self.layers = nn.Sequential(
            # Stem (same as yours)
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),   
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),  
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.Conv2d(64, 80, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(80),
            nn.ReLU(),
            nn.Conv2d(80, 192, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(192),
            nn.ReLU(),
            # nn.Conv2d(192, 288, kernel_size=3, stride=1, padding=1),
            # nn.BatchNorm2d(288),
            # nn.ReLU(),

            # Inception blocks with exact channel control
            # Example: Want 288 total output? Use ch1=64, ch2=64, ch3=80, ch4=80 → 288
            InceptionBlock1(192, ch1=64, ch2=64, ch3=80, ch4=80),     # Output: 288
            InceptionBlock2(288, ch1=96, ch2=96, ch3=96, ch4=96),     # Output: 384
            InceptionBlock3(384, ch1=64, ch2=64, ch3=64, ch4=64, ch5=64, ch6=64),  # Output: 384
        )

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # Changed to (1,1)
        self.dropout = nn.Dropout(0.4)
        self.fc = nn.Linear(384, 100)  # Match last block output

    def forward(self, x):
        x = self.layers(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.fc(x)
        return x    

In [None]:
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    torch.cuda.empty_cache()

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Mild to avoid over-distortion
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2761]),
    transforms.RandomErasing(p=0.25)  # Apply after normalization for consistency
])

test_transform = transforms.Compose([
    transforms.ToTensor(), # Moved ToTensor before Normalize (good practice)
    transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
])

# Load raw datasets
cifar_train_raw = datasets.CIFAR100(root="./data", train=True, download=True, transform=None)

train_size = int(0.9 * len(cifar_train_raw))  # 48,000

train_indices = list(range(0, train_size))
val_indices = list(range(train_size, len(cifar_train_raw)))

# Create datasets with appropriate transforms
cifar_train = Subset(
    datasets.CIFAR100(root="./data", train=True, transform=train_transform),
    train_indices
)
cifar_val = Subset(
    datasets.CIFAR100(root="./data", train=True, transform=test_transform),
    val_indices
)
# Use original test set (10,000 samples) - close to 10% of 60,000
cifar_test = datasets.CIFAR100(root="./data", train=False, transform=test_transform)

train_loader = DataLoader(
    cifar_train,
    batch_size=512,  # Changed from 512
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=6
)

val_loader = DataLoader(
    cifar_val,  # Use directly
    batch_size=512,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=6
)

test_loader = DataLoader(
    cifar_test,
    batch_size=512,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=6
)

num_classes = 100

model = InceptionV3().to(device)

num_epochs = 40
loss_function = nn.CrossEntropyLoss(label_smoothing=0.1)
base_lr = 4e-3

batch_scale = 1024 / 256  # 4x larger batches
scaled_lr = base_lr * batch_scale**0.5  # Square root scaling

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=3e-3,  # Changed from 1e-3
    weight_decay=1e-4
)

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=1e-2,  # Changed from 3e-3
    epochs=num_epochs,
    steps_per_epoch=len(train_loader),
    pct_start=0.3,
    anneal_strategy='cos',
    div_factor=25.0,
    final_div_factor=1000.0
)

best_val_loss = float('inf')

for epoch in range(num_epochs):
    print(f'Starting Epoch {epoch+1}')
    model.train()

    current_loss = 0.0
    num_batches = 0

    for i, data in enumerate(train_loader):
        inputs, targets = data
        inputs, targets = inputs.to(device), targets.to(device)
            
        outputs = model(inputs)
        loss = loss_function(outputs, targets)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)  # Changed from 1.0
        
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        scheduler.step()

        current_loss += loss.item()
        num_batches += 1

        if i % 50 == 0:
            print(f'Batch {i}/{len(train_loader)}, Loss: {loss.item():.4f}')


    avg_train_loss = current_loss / num_batches
    print(f'Epoch {epoch+1} finished')
    print(f'Training - Loss: {avg_train_loss:.4f}')

    if (epoch + 1) % 2 == 0:
        model.eval()
        val_loss = 0.0
        val_batches = 0

        print(f'Epoch {epoch+1} finished')
        print(f'average training loss is {avg_train_loss:.4f}')

        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_targets = val_data
                val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)  # Convert inputs to FP16

                val_outputs = model(val_inputs)
                val_batch_loss = loss_function(val_outputs, val_targets)

                val_loss += val_batch_loss.item()
                val_batches += 1


        avg_val_loss = val_loss / val_batches

        print(f'Epoch {epoch+1} finished')
        print(f'Training - Loss: {avg_train_loss:.4f}')
        print(f'Validation - Loss: {avg_val_loss:.4f}')

if torch.cuda.is_available():
    torch.cuda.empty_cache()


Using device: cuda
GPU Memory: 15.8 GB
Starting Epoch 1
Batch 0/88, Loss: 4.6075
Batch 50/88, Loss: 4.3889
Epoch 1 finished
Training - Loss: 4.4297
Starting Epoch 2
Batch 0/88, Loss: 4.2225
Batch 50/88, Loss: 3.9996
Epoch 2 finished
Training - Loss: 4.0662
Epoch 2 finished
average training loss is 4.0662
Epoch 2 finished
Training - Loss: 4.0662
Validation - Loss: 3.8655
Starting Epoch 3
Batch 0/88, Loss: 3.8620
Batch 50/88, Loss: 3.8129
Epoch 3 finished
Training - Loss: 3.8815
Starting Epoch 4
Batch 0/88, Loss: 3.7847
Batch 50/88, Loss: 3.7783
Epoch 4 finished
Training - Loss: 3.7461
Epoch 4 finished
average training loss is 3.7461
Epoch 4 finished
Training - Loss: 3.7461
Validation - Loss: 3.6653
Starting Epoch 5
Batch 0/88, Loss: 3.5691
Batch 50/88, Loss: 3.6544
Epoch 5 finished
Training - Loss: 3.6142
Starting Epoch 6
Batch 0/88, Loss: 3.5104
Batch 50/88, Loss: 3.4136
Epoch 6 finished
Training - Loss: 3.4690
Epoch 6 finished
average training loss is 3.4690
Epoch 6 finished
Training 

In [10]:
def evaluate_test_set(model):
    model.eval()
    correct = 0
    total = 0
    
    print("Starting evaluation...")
    
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            
            # Use autocast for consistency if you trained with it
            outputs = model(images)
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy of the network on the test images: {accuracy:.2f}%')
    return accuracy

print("\n=== Running standard evaluation ===")
standard_accuracy = evaluate_test_set(model)   
print(f'Standard Test Accuracy: {standard_accuracy:.4f} ({standard_accuracy:.2f}%)')


=== Running standard evaluation ===
Starting evaluation...
Accuracy of the network on the test images: 49.25%
Standard Test Accuracy: 49.2500 (49.25%)
