In [None]:
import torch
from torch.utils.data import DataLoader
from torch import nn
from collections import OrderedDict
from torchvision import transforms, datasets
from torch.utils.data import Subset

class Block(nn.Module):
    def __init__(self, in_channels, num_filters, stride=1):
        super().__init__()
        self.layers = nn.Sequential(OrderedDict([
            ('bn1_1', nn.BatchNorm2d(num_filters)),
            ('relu1_1', nn.ReLU()),
            ('conv1_1', nn.Conv2d(in_channels, num_filters, kernel_size=(3, 3), stride=stride, padding=1)),
            ('bn1_2', nn.BatchNorm2d(num_filters))
            ('conv1_2',  nn.Conv2d(num_filters, num_filters, kernel_size=(3, 3), padding=1)),
        ]))
        self.last_relu = nn.ReLU()
        if stride != 1 or in_channels != num_filters:
            self.shortcut = nn.Sequential(OrderedDict([
                ('sho_conv_1', nn.Conv2d(in_channels, num_filters, kernel_size=(1, 1), stride=stride)),
                ('sho_bn1', nn.BatchNorm2d(num_filters))
            ]))
        else:
            self.shortcut = nn.Identity()
    
    def forward(self, x):
        x = self.layers(x) + self.shortcut(x)
        x = self.last_relu(x)
        return x

class ResNet18(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(OrderedDict([
            ('conv1_1', nn.Conv2d(3, 64, kernel_size=(3, 3), stride=2)),
            ('bn1_1', nn.BatchNorm2d(64)),
            ('relu1_1', nn.ReLU()),
            ('block1', Block(in_channels=64, num_filters=64)),
            ('block2', Block(in_channels=64, num_filters=64)),
            ('block3', Block(in_channels=64, num_filters=128, stride=2)),
            ('block4', Block(in_channels=128, num_filters=128)),
            ('block5', Block(in_channels=128, num_filters=256, stride=2)),
            ('block6', Block(in_channels=256, num_filters=256)),
            ('block7', Block(in_channels=256, num_filters=512, stride=2)),
            ('block8', Block(in_channels=512, num_filters=512)),
        ]))
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.linear = nn.Linear(512, 100)
    
    def forward(self, x):
        x = self.layers(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x

In [None]:
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    torch.cuda.empty_cache()

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Mild to avoid over-distortion
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2761]),
    transforms.RandomErasing(p=0.5)  # Apply after normalization for consistency
])

test_transform = transforms.Compose([
    transforms.ToTensor(), # Moved ToTensor before Normalize (good practice)
    transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
])

# Load raw datasets
cifar_train_raw = datasets.CIFAR100(root="./data", train=True, download=True, transform=None)

train_size = int(0.9 * len(cifar_train_raw))  # 48,000

train_indices = list(range(0, train_size))
val_indices = list(range(train_size, len(cifar_train_raw)))

# Create datasets with appropriate transforms
cifar_train = Subset(
    datasets.CIFAR100(root="./data", train=True, transform=train_transform),
    train_indices
)
cifar_val = Subset(
    datasets.CIFAR100(root="./data", train=True, transform=test_transform),
    val_indices
)
# Use original test set (10,000 samples) - close to 10% of 60,000
cifar_test = datasets.CIFAR100(root="./data", train=False, transform=test_transform)

train_loader = DataLoader(
    cifar_train,  # Use directly
    batch_size=1024,
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=6
)

val_loader = DataLoader(
    cifar_val,  # Use directly
    batch_size=1024,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=6
)

test_loader = DataLoader(
    cifar_test,
    batch_size=1024,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=6
)

num_classes = 100

resnet = ResNet18().to(device)

num_epochs = 40
loss_function = nn.CrossEntropyLoss(label_smoothing=0.1)
base_lr = 4e-3

batch_scale = 1024 / 256  # 4x larger batches
scaled_lr = base_lr * batch_scale**0.5  # Square root scaling

optimizer = torch.optim.AdamW(
    resnet.parameters(),
    lr=3e-3,  # Keep this for now, let OneCycleLR handle it
    weight_decay=5e-4
)

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=5e-3,                # Slightly lower peak LR for stability
    epochs=num_epochs,          # Keep 45 epochs
    steps_per_epoch=len(train_loader),
    pct_start=0.4,              # Increase warmup to 40% (18 epochs)
    anneal_strategy='cos',
    div_factor=12.0,            # Start LR = 5e-3 / 12 = 4.2e-4
    final_div_factor=400.0      # Final LR = 5e-3 / 400 = 1.25e-5
)

best_val_loss = float('inf')

for epoch in range(num_epochs):
    print(f'Starting Epoch {epoch+1}')
    resnet.train()

    current_loss = 0.0
    num_batches = 0

    for i, data in enumerate(train_loader):
        inputs, targets = data
        inputs, targets = inputs.to(device), targets.to(device)
            
        outputs = resnet(inputs)
        loss = loss_function(outputs, targets)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        scheduler.step()

        current_loss += loss.item()
        num_batches += 1

        if i % 50 == 0:
            print(f'Batch {i}/{len(train_loader)}, Loss: {loss.item():.4f}')


    avg_train_loss = current_loss / num_batches
    print(f'Epoch {epoch+1} finished')
    print(f'Training - Loss: {avg_train_loss:.4f}')

    if (epoch + 1) % 2 == 0:
        resnet.eval()
        val_loss = 0.0
        val_batches = 0

        print(f'Epoch {epoch+1} finished')
        print(f'average training loss is {avg_train_loss:.4f}')

        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_targets = val_data
                val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)  # Convert inputs to FP16

                val_outputs = resnet(val_inputs)
                val_batch_loss = loss_function(val_outputs, val_targets)

                val_loss += val_batch_loss.item()
                val_batches += 1


        avg_val_loss = val_loss / val_batches

        print(f'Epoch {epoch+1} finished')
        print(f'Training - Loss: {avg_train_loss:.4f}')
        print(f'Validation - Loss: {avg_val_loss:.4f}')

if torch.cuda.is_available():
    torch.cuda.empty_cache()


Using device: cuda
GPU Memory: 15.8 GB
Starting Epoch 1
Epoch 1 finished
Training - Loss: 4.2126
Starting Epoch 2
Epoch 2 finished
Training - Loss: 3.8563
Epoch 2 finished
average training loss is 3.8563
Epoch 2 finished
Training - Loss: 3.8563
Validation - Loss: 3.6542
Starting Epoch 3
Epoch 3 finished
Training - Loss: 3.6387
Starting Epoch 4
Epoch 4 finished
Training - Loss: 3.4621
Epoch 4 finished
average training loss is 3.4621
Epoch 4 finished
Training - Loss: 3.4621
Validation - Loss: 3.2964
Starting Epoch 5
Epoch 5 finished
Training - Loss: 3.2996
Starting Epoch 6
Epoch 6 finished
Training - Loss: 3.1720
Epoch 6 finished
average training loss is 3.1720
Epoch 6 finished
Training - Loss: 3.1720
Validation - Loss: 3.3164
Starting Epoch 7
Epoch 7 finished
Training - Loss: 3.0592
Starting Epoch 8
Epoch 8 finished
Training - Loss: 2.9488
Epoch 8 finished
average training loss is 2.9488
Epoch 8 finished
Training - Loss: 2.9488
Validation - Loss: 2.9748
Starting Epoch 9
Epoch 9 finished

In [None]:
def evaluate_test_set(model):
    model.eval()
    correct = 0
    total = 0
    
    print("Starting evaluation...")
    
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            
            # Use autocast for consistency if you trained with it
            outputs = model(images)
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy of the network on the test images: {accuracy:.2f}%')
    return accuracy

print("\n=== Running standard evaluation ===")
standard_accuracy = evaluate_test_set(resnet)   
print(f'Standard Test Accuracy: {standard_accuracy:.4f} ({standard_accuracy*100:.2f}%)')


=== Running standard evaluation ===
Starting evaluation...
Accuracy of the network on the test images: 65.22%
Standard Test Accuracy: 65.2200 (6522.00%)
