Wide Residual Networks (WRN)

Paper link: https://arxiv.org/pdf/1605.07146

-  decrease depth and increase width of residual networks

Problem: As gradient flows through the network there is nothing to force it to go through residual block weights and it can avoid learning anything during training, so it is possible that there is either only a few
blocks that learn useful representations, or many blocks share very little information with
small contribution to the final goal.

- Instead, we argue here that dropout should be inserted between convolutional layers.



In [1]:
import torch
from torch.utils.data import DataLoader
from torch import nn
from collections import OrderedDict
from torchvision import transforms, datasets
from torch.utils.data import Subset

'''
3 groups of 4 blocks
'''

class Block(nn.Module):
    def __init__(self, in_channel, num_filters, stride=1):
        super().__init__()

        self.layers = nn.Sequential(
            nn.BatchNorm2d(in_channel),
            nn.ReLU(),
            nn.Conv2d(in_channel, num_filters, kernel_size=3, padding=1, stride=stride),
            nn.Dropout(0.3),

            nn.BatchNorm2d(num_filters),
            nn.ReLU(),
            nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1, stride=1),
        )

        if stride != 1 or in_channel != num_filters:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channel, num_filters, kernel_size=(1, 1), stride=stride),
                nn.BatchNorm2d(num_filters)
            )
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        x = self.layers(x) + self.shortcut(x)
        return x

class WideResNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.layers = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),

            Block(16, 80),
            Block(80, 80),
            Block(80, 80),
            Block(80, 80),

            Block(80, 160, stride=2),
            Block(160, 160),
            Block(160, 160),
            Block(160, 160),       

            Block(160, 320, stride=2),
            Block(320, 320),
            Block(320, 320),
            Block(320, 320),     
        )

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.linear = nn.Linear(320, 100)

    def forward(self, x):
        x = self.layers(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x                     

In [None]:
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    torch.cuda.empty_cache()

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2761]),
    transforms.RandomErasing(p=0.25)
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
])

# Load raw datasets
cifar_train_raw = datasets.CIFAR100(root="./data", train=True, download=True, transform=None)

train_size = int(0.9 * len(cifar_train_raw))  # 48,000

train_indices = list(range(0, train_size))
val_indices = list(range(train_size, len(cifar_train_raw)))

# Create datasets with appropriate transforms
cifar_train = Subset(
    datasets.CIFAR100(root="./data", train=True, transform=train_transform),
    train_indices
)
cifar_val = Subset(
    datasets.CIFAR100(root="./data", train=True, transform=test_transform),
    val_indices
)
# Use original test set (10,000 samples) - close to 10% of 60,000
cifar_test = datasets.CIFAR100(root="./data", train=False, transform=test_transform)

train_loader = DataLoader(
    cifar_train,
    batch_size=128,   # was 1024
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    persistent_workers=False,  # free worker memory
    prefetch_factor=2
)
val_loader = DataLoader(
    cifar_val,
    batch_size=128,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    persistent_workers=False,
    prefetch_factor=2
)
test_loader = DataLoader(
    cifar_test,
    batch_size=128,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    persistent_workers=False,
    prefetch_factor=2
)

num_classes = 100

def init_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
        nn.init.normal_(m.weight, 0, 0.01)
        nn.init.constant_(m.bias, 0)

model = WideResNet().to(device)
model.apply(init_weights)

num_epochs = 40
loss_function = nn.CrossEntropyLoss(label_smoothing=0.1)
base_lr = 4e-3

batch_scale = 1024 / 256  # 4x larger batches
scaled_lr = base_lr * batch_scale**0.5  # Square root scaling

optimizer = torch.optim.SGD(model.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-4)

scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer,
    milestones=[60, 120, 160],  # Decay at these epochs
    gamma=0.2  # Multiply LR by 0.2
)

best_val_loss = float('inf')

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

for epoch in range(num_epochs):
    print(f'Starting Epoch {epoch+1}')
    model.train()

    current_loss = 0.0
    num_batches = 0

    for i, data in enumerate(train_loader):
        inputs, targets = data
        inputs, targets = inputs.to(device), targets.to(device)
            
        outputs = model(inputs)
        loss = loss_function(outputs, targets)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

        current_loss += loss.item()
        num_batches += 1

        if i % 50 == 0:
            print(f'Batch {i}/{len(train_loader)}, Loss: {loss.item():.4f}')

    scheduler.step()

    avg_train_loss = current_loss / num_batches
    print(f'Epoch {epoch+1} finished')
    print(f'Training - Loss: {avg_train_loss:.4f}')

    if (epoch + 1) % 2 == 0:
        model.eval()
        val_loss = 0.0
        val_batches = 0

        print(f'Epoch {epoch+1} finished')
        print(f'average training loss is {avg_train_loss:.4f}')

        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_targets = val_data
                val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)  # Convert inputs to FP16

                val_outputs = model(val_inputs)
                val_batch_loss = loss_function(val_outputs, val_targets)

                val_loss += val_batch_loss.item()
                val_batches += 1


        avg_val_loss = val_loss / val_batches

        print(f'Epoch {epoch+1} finished')
        print(f'Training - Loss: {avg_train_loss:.4f}')
        print(f'Validation - Loss: {avg_val_loss:.4f}')

if torch.cuda.is_available():
    torch.cuda.empty_cache()


  self.setter(val)


Using device: cuda
GPU Memory: 15.8 GB
Model parameters: 9,167,060
Starting Epoch 1
Batch 0/704, Loss: 4.6001
Batch 50/704, Loss: 4.5886
Batch 100/704, Loss: 4.7027
Batch 150/704, Loss: 4.2960
Batch 200/704, Loss: 4.4945
Batch 250/704, Loss: 4.4332
Batch 300/704, Loss: 4.0695
Batch 350/704, Loss: 4.0572
Batch 400/704, Loss: 4.1386
Batch 450/704, Loss: 4.1587
Batch 500/704, Loss: 4.0140


KeyboardInterrupt: 

In [None]:
def evaluate_test_set(model):
    model.eval()
    correct = 0
    total = 0
    
    print("Starting evaluation...")
    
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            
            # Use autocast for consistency if you trained with it
            outputs = model(images)
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy of the network on the test images: {accuracy:.2f}%')
    return accuracy

print("\n=== Running standard evaluation ===")
standard_accuracy = evaluate_test_set(resnet)   
print(f'Standard Test Accuracy: {standard_accuracy:.4f} ({standard_accuracy:.2f}%)')