In [None]:
import torch
from torch.utils.data import DataLoader
from torch import nn
from collections import OrderedDict
from torchvision import transforms, datasets
from torch.utils.data import Subset

'''
https://arxiv.org/pdf/1608.06993
'''

class DenseLayer(nn.Module):
    def __init__(self, in_channels, growth_rate):
        super().__init__()
        self.batch_norm = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU()
        self.conv2d = nn.Conv2d(in_channels, out_channels=growth_rate, kernel_size=(3, 3), padding=1)
    
    def forward(self, x):
        x = self.batch_norm(x)
        x = self.relu(x)
        x = self.conv2d(x)
        return x

class DenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate=16):
        super().__init__()
        self.growth_rate = growth_rate
        self.layers = nn.ModuleList()

        for i in range(6):
            self.layers.append(DenseLayer(in_channels + (i * growth_rate), growth_rate))
        
    def forward(self, x):
        features = x
        for layer in self.layers:
            new_features = layer(features)
            features = torch.cat([features, new_features], dim=1)
        
        return features
    
class TransitionLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.out_channels = out_channels
        self.layers = nn.Sequential(OrderedDict([
            ('bnorm1_1', nn.BatchNorm2d(in_channels)),
            ('relu', nn.ReLU()),
            ('conv1_1', nn.Conv2d(in_channels, self.out_channels, kernel_size=(1, 1))),
            ('adaptive1_1', nn.AvgPool2d(kernel_size=2, stride=2))
        ]))
    
    def forward(self, x):
        x = self.layers(x)
        return x
'''
Use smaller version like DenseNet-40: [6, 6, 6] blocks with k=12
Initial convolution (7x7 or 3x3 for CIFAR)
Dense Block 1 → Transition 1
Dense Block 2 → Transition 2
Dense Block 3 → Transition 3
Dense Block 4 (no transition after last block)
Global average pooling
Fully connected layer

Start with k₀ channels (e.g., 64)

Shape tracking:
'''    
class DenseNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=(3, 3), stride=2),
            DenseBlock(128),
            TransitionLayer(224, 112),
            DenseBlock(112),
            TransitionLayer(208, 104),
            DenseBlock(104),
            TransitionLayer(200, 100),     
            DenseBlock(100),
            nn.AdaptiveAvgPool2d((1, 1)),
        )
        self.classifier = nn.Linear(196, 100)
    
    def forward(self, x):
        x = self.layers(x) # shape is [batch, 98, 1, 1]
        x = x.view(x.size(0), -1) # shape is [batch, 98]
        x = self.classifier(x) # output is [batch, 100]
        return x

In [None]:
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    torch.cuda.empty_cache()

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Mild to avoid over-distortion
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2761]),
    transforms.RandomErasing(p=0.5)  # Apply after normalization for consistency
])

test_transform = transforms.Compose([
    transforms.ToTensor(), # Moved ToTensor before Normalize (good practice)
    transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
])   

cifar_train_raw = datasets.CIFAR100(root='./data', train=True, download=True, transform=None)

train_size = int(0.9 * len(cifar_train_raw))

train_indices = list(range(train_size))
val_indices = list(range(train_size, len(cifar_train_raw)))

cifar_train = Subset(
    datasets.CIFAR100(root="./data", train=True, transform=train_transform),
    train_indices
)

cifar_val = Subset(
    datasets.CIFAR100(root="./data", train=True, transform=test_transform),
    val_indices
)

cifar_test = datasets.CIFAR100(root="./data", train=True, transform=test_transform)

train_loader = DataLoader(
    cifar_train,  # Use directly
    batch_size=1024,
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=6
)

val_loader = DataLoader(
    cifar_val,  # Use directly
    batch_size=1024,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=6
)

test_loader = DataLoader(
    cifar_test,
    batch_size=1024,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=6
)

num_classes = 100

denseNet = DenseNet().to(device)

num_epochs = 40
loss_function = nn.CrossEntropyLoss(label_smoothing=0.1)
base_lr = 4e-3

batch_scale = 1024 / 256  # 4x larger batches
scaled_lr = base_lr * batch_scale**0.5  # Square root scaling

optimizer = torch.optim.AdamW(
    denseNet.parameters(),
    lr=3e-3,  # Keep this for now, let OneCycleLR handle it
    weight_decay=5e-4
)

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=5e-3,                # Slightly lower peak LR for stability
    epochs=num_epochs,          # Keep 45 epochs
    steps_per_epoch=len(train_loader),
    pct_start=0.4,              # Increase warmup to 40% (18 epochs)
    anneal_strategy='cos',
    div_factor=12.0,            # Start LR = 5e-3 / 12 = 4.2e-4
    final_div_factor=400.0      # Final LR = 5e-3 / 400 = 1.25e-5
)

best_val_loss = float('inf')

for epoch in range(num_epochs):
    print(f'Starting Epoch {epoch+1}')
    denseNet.train()

    current_loss = 0.0
    num_batches = 0

    for i, data in enumerate(train_loader):
        inputs, targets = data
        inputs, targets = inputs.to(device), targets.to(device)
            
        outputs = denseNet(inputs)
        loss = loss_function(outputs, targets)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        scheduler.step()

        current_loss += loss.item()
        num_batches += 1

        if i % 50 == 0:
            print(f'Batch {i}/{len(train_loader)}, Loss: {loss.item():.4f}')

    avg_train_loss = current_loss / num_batches
    print(f'Epoch {epoch+1} finished')
    print(f'Training - Loss: {avg_train_loss:.4f}')

    if (epoch + 1) % 2 == 0:
        denseNet.eval()
        val_loss = 0.0
        val_batches = 0

        print(f'Epoch {epoch+1} finished')
        print(f'average training loss is {avg_train_loss:.4f}')

        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_targets = val_data
                val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)  # Convert inputs to FP16

                val_outputs = denseNet(val_inputs)
                val_batch_loss = loss_function(val_outputs, val_targets)

                val_loss += val_batch_loss.item()
                val_batches += 1

        avg_val_loss = val_loss / val_batches

        print(f'Epoch {epoch+1} finished')
        print(f'Training - Loss: {avg_train_loss:.4f}')
        print(f'Validation - Loss: {avg_val_loss:.4f}')

if torch.cuda.is_available():
    torch.cuda.empty_cache()


Using device: cuda
GPU Memory: 15.8 GB
Starting Epoch 1
Batch 0/44, Loss: 4.6182
Epoch 1 finished
Training - Loss: 4.3794
Starting Epoch 2
Batch 0/44, Loss: 4.1531
Epoch 2 finished
Training - Loss: 4.0497
Epoch 2 finished
average training loss is 4.0497
Epoch 2 finished
Training - Loss: 4.0497
Validation - Loss: 3.8909
Starting Epoch 3
Batch 0/44, Loss: 3.9330
Epoch 3 finished
Training - Loss: 3.8483
Starting Epoch 4
Batch 0/44, Loss: 3.7237
Epoch 4 finished
Training - Loss: 3.6663
Epoch 4 finished
average training loss is 3.6663
Epoch 4 finished
Training - Loss: 3.6663
Validation - Loss: 3.5365
Starting Epoch 5
Batch 0/44, Loss: 3.6231
Epoch 5 finished
Training - Loss: 3.5005
Starting Epoch 6
Batch 0/44, Loss: 3.3913
Epoch 6 finished
Training - Loss: 3.3758
Epoch 6 finished
average training loss is 3.3758
Epoch 6 finished
Training - Loss: 3.3758
Validation - Loss: 3.3541
Starting Epoch 7
Batch 0/44, Loss: 3.3047
Epoch 7 finished
Training - Loss: 3.2584
Starting Epoch 8
Batch 0/44, Los

In [13]:
def evaluate_test_set(model):
    model.eval()
    correct = 0
    total = 0
    
    print("Starting evaluation...")
    
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            
            # Use autocast for consistency if you trained with it
            outputs = model(images)
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy of the network on the test images: {accuracy:.2f}%')
    return accuracy

print("\n=== Running standard evaluation ===")
standard_accuracy = evaluate_test_set(denseNet)   
print(f'Standard Test Accuracy: {standard_accuracy:.4f} ({standard_accuracy*100:.2f}%)')


=== Running standard evaluation ===
Starting evaluation...
Accuracy of the network on the test images: 69.94%
Standard Test Accuracy: 69.9440 (6994.40%)
