In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim



In [3]:
# CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])



# Hyper-parameters
num_epochs = 10
batch_size = 64
learning_rate = 0.001

In [4]:
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)



Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:14<00:00, 11771641.99it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [5]:
device = torch.device("mps")


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    

In [6]:

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        residual = self.downsample(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        out = self.relu(out)
        return out
    
# ResNet
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self.make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self.make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self.make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self.make_layer(block, 512, num_blocks[3], stride=2)
        self.avg_pool = nn.AvgPool2d(4)
        self.fc = nn.Linear(512, num_classes)

    def make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = torch.relu(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out
    

In [7]:

# Model initialization
model_cnn = CNN().to(device)
model_resnet = ResNet(ResidualBlock, [2, 2, 2, 2]).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer_cnn = optim.Adam(model_cnn.parameters(), lr=learning_rate)
optimizer_resnet = optim.Adam(model_resnet.parameters(), lr=learning_rate)


total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs_cnn = model_cnn(images)
        outputs_resnet = model_resnet(images)
        loss_cnn = criterion(outputs_cnn, labels)
        loss_resnet = criterion(outputs_resnet, labels)

        # Backward and optimize
        optimizer_cnn.zero_grad()
        optimizer_resnet.zero_grad()
        loss_cnn.backward()
        loss_resnet.backward()
        optimizer_cnn.step()
        optimizer_resnet.step()

        if (i+1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], CNN Loss: {:.4f}, ResNet Loss: {:.4f}'
                  .format(epoch+1, num_epochs, i+1, total_step, loss_cnn.item(), loss_resnet.item()))


Epoch [1/10], Step [100/782], CNN Loss: 1.7586, ResNet Loss: 1.6547
Epoch [1/10], Step [200/782], CNN Loss: 1.4186, ResNet Loss: 1.4513
Epoch [1/10], Step [300/782], CNN Loss: 1.4383, ResNet Loss: 1.4597
Epoch [1/10], Step [400/782], CNN Loss: 1.0583, ResNet Loss: 1.1438
Epoch [1/10], Step [500/782], CNN Loss: 1.1229, ResNet Loss: 1.0431
Epoch [1/10], Step [600/782], CNN Loss: 1.3079, ResNet Loss: 1.1127
Epoch [1/10], Step [700/782], CNN Loss: 1.1000, ResNet Loss: 0.9219
Epoch [2/10], Step [100/782], CNN Loss: 0.9960, ResNet Loss: 1.0230
Epoch [2/10], Step [200/782], CNN Loss: 0.7567, ResNet Loss: 0.7659
Epoch [2/10], Step [300/782], CNN Loss: 0.8800, ResNet Loss: 0.6950
Epoch [2/10], Step [400/782], CNN Loss: 0.7475, ResNet Loss: 0.6428
Epoch [2/10], Step [500/782], CNN Loss: 0.9395, ResNet Loss: 0.8675
Epoch [2/10], Step [600/782], CNN Loss: 1.2115, ResNet Loss: 1.0971
Epoch [2/10], Step [700/782], CNN Loss: 0.7813, ResNet Loss: 0.5120
Epoch [3/10], Step [100/782], CNN Loss: 0.8452, 

In [8]:

# Test the models

model_cnn.eval()
model_resnet.eval()
with torch.no_grad():
    correct_cnn = 0
    correct_resnet = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs_cnn = model_cnn(images)
        outputs_resnet = model_resnet(images)
        _, predicted_cnn = torch.max(outputs_cnn.data, 1)
        _, predicted_resnet = torch.max(outputs_resnet.data, 1)
        total += labels.size(0)
        correct_cnn += (predicted_cnn == labels).sum().item()
        correct_resnet += (predicted_resnet == labels).sum().item()

    print('Accuracy of the CNN on the 10000 test images: {} %'.format(100 * correct_cnn / total))
    print('Accuracy of the ResNet on the 10000 test images: {} %'.format(100 * correct_resnet / total))

Accuracy of the CNN on the 10000 test images: 72.09 %
Accuracy of the ResNet on the 10000 test images: 83.59 %
