In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


In [2]:
# ------------------------------
# Device Configuration
# ------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [3]:
# ------------------------------
# Residual Block Definition
# ------------------------------
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out


In [4]:
# ------------------------------
# Custom ResNet Model for CIFAR-10
# ------------------------------
class CustomResNet(nn.Module):
    def __init__(self, block, layers, num_classes=10):
        super(CustomResNet, self).__init__()
        self.in_channels = 64
        # Initial convolution for CIFAR-10 (3-channel images)
        self.conv = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        # Define layers using residual blocks
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        # If input and output dimensions differ, downsample to match them
        if stride != 1 or self.in_channels != out_channels:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels
        for _ in range(1, blocks):
            layers.append(block(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)

        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)

        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.fc(out)
        return out

def get_resnet18():
    # ResNet-18 uses [2, 2, 2, 2] residual blocks in each layer
    return CustomResNet(ResidualBlock, [2, 2, 2, 2])


In [7]:
# ------------------------------
# Main Training Script
# ------------------------------
def main():
    # Hyperparameters
    num_epochs = 30
    batch_size = 128
    learning_rate = 0.001

    # Data augmentation and normalization for training
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2470, 0.2435, 0.2616))
    ])

    # Normalization for testing
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2470, 0.2435, 0.2616))
    ])

    # Load CIFAR-10 dataset
    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                                 download=True, transform=transform_train)
    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                                download=True, transform=transform_test)
    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                              shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=batch_size,
                             shuffle=False, num_workers=2)

      # Initialize the model, loss function, optimizer, and learning rate scheduler
    model = get_resnet18().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    best_accuracy = 0.0
    # Training loop
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for i, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)

            # ----- Forward Pass -----
            outputs = model(images)
            loss = criterion(outputs, labels)

            # ----- Backward Pass and Optimization -----
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if (i + 1) % 100 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {running_loss/100:.4f}")
                running_loss = 0.0

        # Adjust learning rate
        scheduler.step()

        # ----- Evaluation on Test Set -----
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy = 100 * correct / total
        print(f"Epoch [{epoch+1}/{num_epochs}], Test Accuracy: {accuracy:.2f}%")

        # Save the best model based on test accuracy
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            torch.save(model.state_dict(), "best_resnet18_cifar10.pth")
            print("Saved Best Model with Accuracy: {:.2f}%".format(best_accuracy))

In [8]:
if __name__ == '__main__':
    main()

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:05<00:00, 29.9MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Epoch [1/30], Step [100/391], Loss: 1.7875
Epoch [1/30], Step [200/391], Loss: 1.4672
Epoch [1/30], Step [300/391], Loss: 1.2844
Epoch [1/30], Test Accuracy: 57.77%
Saved Best Model with Accuracy: 57.77%
Epoch [2/30], Step [100/391], Loss: 1.0326
Epoch [2/30], Step [200/391], Loss: 0.9832
Epoch [2/30], Step [300/391], Loss: 0.9265
Epoch [2/30], Test Accuracy: 69.39%
Saved Best Model with Accuracy: 69.39%
Epoch [3/30], Step [100/391], Loss: 0.8052
Epoch [3/30], Step [200/391], Loss: 0.7732
Epoch [3/30], Step [300/391], Loss: 0.7104
Epoch [3/30], Test Accuracy: 70.08%
Saved Best Model with Accuracy: 70.08%
Epoch [4/30], Step [100/391], Loss: 0.6403
Epoch [4/30], Step [200/391], Loss: 0.6347
Epoch [4/30], Step [300/391], Loss: 0.6082
Epoch [4/30], Test Accuracy: 80.10%
Saved Best Model with Accuracy: 80.10%
Epoch [5/30], Step [100/391], Loss: 0.5490
Epoch [5/30], Step [200/391], Loss: 0.5506
Epoch [5/