In [12]:
#importing the required libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np
import sys

In [2]:
#comfigurations
IMG_SIZE = 32 #CIFAR-10 standard resolution for stratch training
NUM_CLASSES = 10
BATCH_SIZE = 64
LEARNING_RATE = 0.01
NUM_EPOCHS = 1
DATA_SAMPLE_SIZE = 5000
TEST_SAMPLE_SIZE = 1000

In [3]:
#setting the device to CPU
device = 'cpu'

Residual block definition

In [25]:
class BasicBlock(nn.Module):
  """
    The basic block for ResNet-18 and ResNet-34.
    It contains two 3x3 convolutional layers.
  """
  expansion = 1 #no channel expansion in BasicBlock

  def __init__(self, in_channels, out_channels, stride=1):
    super(BasicBlock, self).__init__()

    #determine if the shortcut paths need a 1x1 convolution (downsampling/channel change)
    self.shortcut = nn.Sequential()
    if stride!=1 or in_channels!=self.expansion*out_channels:
      self.shortcut = nn.Sequential(
          nn.Conv2d(in_channels, self.expansion*out_channels, kernel_size=1, stride=stride, bias=False),
          nn.BatchNorm2d(self.expansion*out_channels)
      )

    #first convolutional layer (with potential stride for downsampling
    self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
    self.bn1 = nn.BatchNorm2d(out_channels)

    #second convolutional layer
    self.conv2 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(self.expansion*out_channels)

  def forward(self, x):
    #the main F(x) function
    out = nn.ReLU()(self.bn1(self.conv1(x)))
    out = self.bn2(self.conv2(out))

    #adding the shortcut path F(x)+x
    out += self.shortcut(x)
    out = nn.ReLU()(out)
    return out

Resnet model definition

In [26]:
class ResNet(nn.Module):
  def __init__(self, block, num_blocks, num_classes=10):
    super(ResNet, self).__init__()
    self.in_channels = 64

    #initial 3x3 convolution (optimized for CIFAR's 32x32 input)
    self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn1 = nn.BatchNorm2d(64)

    #resnet stages (stacks of residual blocks)
    self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
    self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) #downsampling
    self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) #downsampling
    self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) #downsampling

    #using AdaptiveAvgPooling2D for robust global average pooling
    self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    #final layer
    self.linear = nn.Linear(512 * block.expansion, num_classes)

  def _make_layer(self, block, out_channels, num_blocks, stride):
    """Creates a sequential layer of residual blocks."""

    strides = [stride] + [1]*(num_blocks-1)
    layers = []

    for stride in strides:
      layers.append(block(self.in_channels, out_channels, stride))
      self.in_channels = out_channels * block.expansion

    return nn.Sequential(*layers)

  def forward(self, x):
    out = nn.ReLU()(self.bn1(self.conv1(x)))
    out = self.layer1(out)
    out = self.layer2(out)
    out = self.layer3(out)
    out = self.layer4(out)

    #global average pooling (reducing HxW to 1x1)
    out = self.avgpool(out) #changed to size 4 to handle 32x32 downsampling
    out = out.view(out.size(0), -1)

    out = self.linear(out)
    return out

In [27]:
def ResNet18():
  """Factory function to build a ResNet-18 model."""
  #ResNet18 uses: [2 blocks, 2 blocks, 2 blocks, 2 blocks]
  return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=NUM_CLASSES)

Data loading & preprocessing

In [18]:
def load_cifar10_datasets():
    """
    Loads CIFAR-10, applies transformations, and subsamples the data.
    """
    #Standard normalization for CIFAR-10 when training from scratch
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    #Loading full datasets
    train_dataset_full = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_dataset_full = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

    #Subsampling the data to reduce training time ---
    train_indices = np.arange(DATA_SAMPLE_SIZE)
    train_dataset = Subset(train_dataset_full, train_indices)

    test_indices = np.arange(TEST_SAMPLE_SIZE)
    test_dataset = Subset(test_dataset_full, test_indices)

    #Creating DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    return train_loader, test_loader

Training & evaluation functions

In [19]:
def train_model(model, train_loader, test_loader):
    """
    Trains the model for the configured number of epochs.
    """
    #Optimizing ALL parameters (since we are training from scratch)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(NUM_EPOCHS):
        model.train() # Set model to training mode
        running_loss = 0.0

        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)

            if (i + 1) % 100 == 0:
                 print(f"Batch {i+1}/{len(train_loader)}, Loss: {loss.item():.4f}")


        epoch_loss = running_loss / len(train_loader.dataset)
        print(f"Epoch {epoch+1}/{NUM_EPOCHS} Loss: {epoch_loss:.4f}")

    evaluate_model(model, test_loader)

In [20]:
def evaluate_model(model, test_loader):
    """
    Evaluates the model's performance on the test set.
    """
    model.eval() # Setting model to evaluation mode
    correct = 0
    total = 0

    with torch.no_grad(): # Disabling gradient calculation during evaluation
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    loss = nn.CrossEntropyLoss()(outputs, labels).item()

    print(f"Test Loss: {loss:.4f}")
    print(f"Test Accuracy: {accuracy:.2f}%")

Execution

In [21]:
#Loading and prepare data (using 32x32 resolution)
train_loader, test_loader = load_cifar10_datasets()

In [28]:
#Building the ResNet-18 model
model = ResNet18()

In [29]:
# Print the model structure to the console
print(model, file=sys.stderr)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (shortcut): Sequential()
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (shortcut): Sequential()
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, e

In [30]:
#Training and evaluating the model
train_model(model, train_loader, test_loader)

Epoch 1/1 Loss: 2.4528
Test Loss: 2.0301 (Last Batch)
Test Accuracy: 24.00%
