# ResNet

## Theoretical Introduction

The **ResNet** (Residual Network) architecture, introduced in 2015 by **Kaiming He** and his collaborators at Microsoft Research, represents a decisive turning point in the design of deep neural networks for computer vision. While LeNet-5 established the systematic use of convolutional layers and architectures such as VGG demonstrated that increasing depth can significantly improve performance, ResNet introduces a fundamentally new structural mechanism that enables the effective and stable training of extremely deep networks, including models with more than one hundred layers.

The importance of ResNet lies in its ability to overcome the optimization challenges that arise as network depth increases. Prior to its introduction, very deep architectures often suffered from severe training difficulties, including slow convergence, numerical instability, and degradation of performance as additional layers were added. These issues made it impractical to exploit the full representational power theoretically offered by deep models. ResNet addresses these limitations by reformulating how layers learn transformations, allowing information and gradients to propagate more effectively through the network.

The practical impact of this innovation was clearly demonstrated when ResNet won the **ImageNet 2015** Large Scale Visual Recognition Challenge with a **152-layer** variant. This depth was previously considered unattainable from a training perspective, given the known difficulties associated with optimizing such deep architectures. The success of ResNet not only established a new state of the art in image recognition performance, but also reshaped prevailing assumptions about the feasible depth of neural networks, paving the way for subsequent generations of very deep models in computer vision and beyond.

## The Degradation Problem in Deep Networks

Prior to the introduction of ResNet, it was widely assumed that increasing the number of layers in a neural network should, at least in principle, enhance its representational capacity and improve performance. Deeper networks can theoretically capture increasingly complex hierarchical features, enabling more sophisticated modeling of input data. However, empirical studies revealed a surprising and counterintuitive phenomenon: beyond a certain depth—typically around twenty to thirty layers—adding additional layers often **degrades performance**, even on the training set itself. This effect is distinct from overfitting, as it occurs during training rather than being a consequence of limited generalization.

This phenomenon is referred to as **degradation** and arises from structural optimization difficulties inherent to very deep networks. During backpropagation, the error signal—used to update network parameters—tends to either vanish or become numerically unstable as it propagates backward through many layers. Consequently, layers that are close to the input receive gradients that are extremely small or dominated by noise. As a result, these early layers fail to update their parameters effectively, preventing the network from fully exploiting its representational capacity.

## Residual Connections and Shortcut Blocks

The central innovation of ResNet lies in the introduction of **shortcut connections**, also referred to as **skip connections**, which form the fundamental building block known as the **residual block**. The underlying idea is conceptually straightforward yet profoundly impactful: rather than requiring each group of layers to learn a complete input-to-output mapping, the network is allowed to learn only the **residual**—that is, the difference between the input to the block and the desired output.

Mathematically, if the desired mapping is denoted by $H(x)$ and the input to a residual block is $x$, the block is designed to learn a function $F(x) = H(x) - x$. The output of the block is then expressed as:

$$
y = F(x) + x
$$

This formulation offers two key advantages. First, it simplifies the learning problem: it is often easier for a block to learn small deviations from the identity mapping than to approximate a completely new transformation. Second, the shortcut connection provides a direct pathway for both the forward signal and the backward gradients to propagate through the network. This mechanism mitigates the vanishing gradient problem and enhances numerical stability, allowing extremely deep networks to be trained effectively.

By enabling layers to focus on learning residual functions instead of full transformations, ResNet overcomes the degradation problem observed in very deep networks and establishes a structural principle that has become foundational in modern deep learning architectures. Residual blocks can be stacked to create networks with hundreds of layers, achieving remarkable performance while maintaining stable and efficient training dynamics.

### Advantages of Residual Learning

Residual learning provides several fundamental advantages. First, it facilitates optimization. If the optimal transformation in a certain block is close to the identity, it is easier for the network to learn a residual function with $F(x) \approx 0$ than to learn a complete transformation $H(x) \approx x$ from scratch. The function space of residuals tends to be closer to the origin and is therefore more accessible to gradient-based optimization methods.

Second, gradient propagation improves significantly. During backpropagation, the gradient of the loss $L$ with respect to the input $x$ of a residual block satisfies, in simplified form,

$$
\frac{\partial L}{\partial x} = \frac{\partial L}{\partial y} \left( \frac{\partial F}{\partial x} + I \right),
$$

where $I$ is the identity matrix. This expression guarantees that, even in the limiting case where $\frac{\partial F}{\partial x}$ tends to zero, there is always a direct gradient path through the identity term $I$. In practice, this prevents the signal from vanishing completely and contributes to stabilizing the training of very deep networks.

Third, residual connections introduce a form of adaptive depth. If a block is not necessary for the task, it can approximate $F(x) \approx 0$ and effectively behave as an identity transformation, so that $y \approx x$. The network thus retains the ability to neutralize blocks that do not provide improvements, without compromising the information flow. Taken together, these mechanisms make it possible to train networks with more than one hundred layers without suffering the severe degradation that affected previous architectures. The input signal can traverse the network without being distorted, and the gradient has alternative paths that mitigate vanishing effects.

## Architectural Variants of ResNet

The ResNet family includes several configurations that differ mainly in depth and in the type of residual block used. Summarizing:

|      Model | Layers | Parameters | Blocks per stage | Block type |
| ---------: | -----: | ---------: | ---------------- | ---------- |
|  ResNet-18 |     18 |      ~11 M | [2, 2, 2, 2]     | Basic      |
|  ResNet-34 |     34 |      ~21 M | [3, 4, 6, 3]     | Basic      |
|  ResNet-50 |     50 |      ~25 M | [3, 4, 6, 3]     | Bottleneck |
| ResNet-101 |    101 |      ~44 M | [3, 4, 23, 3]    | Bottleneck |
| ResNet-152 |    152 |      ~60 M | [3, 8, 36, 3]    | Bottleneck |

### Basic Block versus Bottleneck Block

In the shallower versions, such as ResNet-18 and ResNet-34, the basic block is used. This block consists of two $3 \times 3$ convolutions followed by batch normalization and ReLU activation, and a residual sum with the identity branch. The number of output channels matches that of the input, with an expansion factor of 1.

In the deeper variants, such as ResNet-50, ResNet-101, and ResNet-152, the bottleneck block is used, whose purpose is to reduce computational cost while preserving representational capacity. This block combines three consecutive convolutions. The first, of size $1 \times 1$, reduces the channel dimensionality, for example from 256 to 64 channels. The second, of size $3 \times 3$, performs the main processing on a reduced number of channels. The third, again of size $1 \times 1$, restores the original dimensionality, for example from 64 back to 256 channels. The typical expansion factor is 4: the number of output channels is four times the number of intermediate channels.

## Current Impact and Importance of ResNet

ResNet is currently regarded as a reference architecture in both academic and industrial contexts. The balance between depth, training stability, and efficiency makes it the backbone of numerous systems for face recognition, autonomous driving, medical imaging diagnosis, and large-scale visual analysis in multiple domains.

Its advantages include efficient parameter usage, for example ResNet-50 uses approximately five times fewer parameters than VGG-16, the ability to train networks with more than one hundred layers without severe gradient degradation, its suitability as a base structure for transfer learning, numerical stability during training, and versatility, which has inspired variants in vision, natural language processing, and other modalities. Beyond solving a specific technical problem, ResNet redefines deep architecture design by explicitly incorporating identity paths that facilitate the flow of information and gradients through the network.

## Practical Implementation of ResNet for CIFAR-10

The following provides a complete and fully functional implementation of **ResNet** (ResNet-18, ResNet-34, and ResNet-50) for the **CIFAR-10** dataset using PyTorch. The code is structured for easy conversion into a Jupyter Notebook and can be executed sequentially—from data loading to training, evaluation, and result visualization.

### Importing Libraries

The first step is to import the necessary modules for defining the network architecture, handling data, performing training, and visualizing results.

In [None]:
# Standard libraries
import time
from typing import Any, List, Type, Union

# 3pps
# Third-party libraries
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torch.nn.functional as F
from sklearn.manifold import TSNE
from sklearn.metrics import classification_report, confusion_matrix
from torch import nn
from torch.utils.data import DataLoader
from torchinfo import summary
from torchvision import datasets, transforms
from tqdm import tqdm


print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

### Global Hyperparameter Configuration

The constants and hyperparameters that will be used throughout the experiment are defined next.

In [None]:
# Global configuration
BATCH_SIZE: int = 128
NUM_EPOCHS: int = 1
LEARNING_RATE: float = 0.1
WEIGHT_DECAY: float = 1e-4
MOMENTUM: float = 0.9
NUM_CLASSES: int = 10
INPUT_SIZE: int = 32

# CIFAR-10 class names
CIFAR10_CLASSES = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]

print("Configuration:")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Initial learning rate: {LEARNING_RATE}")
print(f"  Momentum: {MOMENTUM}")
print(f"  Weight decay: {WEIGHT_DECAY}")
print(f"  Number of classes: {NUM_CLASSES}")

### Visualization Helper Function

The `show_images` function allows visualization of CIFAR-10 images together with theirground-truth labels and, optionally, model predictions.

In [None]:
def show_images(images, labels, predictions=None, classes=CIFAR10_CLASSES):
    """
    Visualize a set of images with their labels and predictions.

    Args:
        images: Image tensor [N, C, H, W].
        labels: Label tensor [N].
        predictions: Optional tensor of predictions [N].
        classes: List of class names.
    """
    n_images = min(len(images), 8)
    fig, axes = plt.subplots(1, n_images, figsize=(2 * n_images, 3))
    if n_images == 1:
        axes = [axes]

    for idx in range(n_images):
        img = images[idx]
        label = labels[idx]
        ax = axes[idx]

        # Denormalize image (assuming standard normalization)
        img = img / 2 + 0.5
        img = img.numpy().transpose((1, 2, 0))
        ax.imshow(img)

        title = f"True: {classes[label]}"
        if predictions is not None:
            pred = predictions[idx]
            color = "green" if pred == label else "red"
            title += f"\nPred: {classes[pred]}"
            ax.set_title(title, fontsize=9, color=color, fontweight="bold")
        else:
            ax.set_title(title, fontsize=9, fontweight="bold")

        ax.axis("off")

    plt.tight_layout()
    plt.show()


print("Visualization function defined correctly")

### Preparing the CIFAR-10 Dataset

CIFAR-10 is then loaded and the preprocessing and data augmentation transformations for training and validation are defined.

In [None]:
# CIFAR-10 normalization statistics
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD = (0.2470, 0.2435, 0.2616)

# Training transformations with data augmentation
transform_train = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
    ]
)

# Validation/test transformations
transform_test = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD)]
)

print("Downloading CIFAR-10 dataset...")
train_dataset = datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform_train
)

test_dataset = datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform_test
)

print("\nDataset statistics:")
print(f"  Training samples: {len(train_dataset):,}")
print(f"  Test samples: {len(test_dataset):,}")
print(f"  Number of classes: {len(train_dataset.classes)}")
print("  Image size: 32×32 pixels (RGB)")

Data augmentation introduces random cropping and horizontal flipping to enhance generalization. Normalizing each channel using the CIFAR-10 mean and standard deviation centers and scales the data, promoting faster and more stable convergence.

### Creating DataLoaders

The training and test `DataLoader` objects are then defined, configuring the number ofworker processes and other performance-oriented options.

In [None]:
train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True,
)

test_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True,
)

print("DataLoaders configured:")
print(f"  Training batches: {len(train_dataloader)}")
print(f"  Test batches: {len(test_dataloader)}")

### Visual Exploration of the Dataset

Before training the model, it is helpful to inspect a batch of images to ensure that data loading and preprocessing are correctly configured.

In [None]:
# Get one training batch
data_iter = iter(train_dataloader)
train_images, train_labels = next(data_iter)

print("\nBatch dimensions:")
print(f"  Images: {train_images.shape}")
print(f"  Labels: {train_labels.shape}")

print("\nDisplaying first 8 samples...")
show_images(train_images[:8], train_labels[:8])

## ResNet Implementation

The BasicBlock class implements the basic residual block used in ResNet-18 and ResNet-34. This block consists of two $3 \times 3$ convolutions with batch normalization and a shortcut connection that adds the input to the output.

In [None]:
class BasicBlock(nn.Module):
    """
    Basic residual block for ResNet-18 and ResNet-34.
    """

    expansion: int = 1

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        stride: int = 1,
        downsample: nn.Module = None,
    ) -> None:
        super().__init__()

        # First 3×3 convolution
        self.conv1 = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False,
        )
        self.bn1 = nn.BatchNorm2d(out_channels)

        # Second 3×3 convolution
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)

        # Shortcut branch (dimensionality adjustment if needed)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = F.relu(out)

        return out

The Bottleneck class implements the bottleneck block used in deeper ResNet variants such as ResNet-50, ResNet-101, and ResNet-152. This block uses three convolutions with an expansion factor of 4 to reduce computational cost while maintaining representational capacity.

In [None]:
class Bottleneck(nn.Module):
    """
    Bottleneck block for ResNet-50, ResNet-101, and ResNet-152.
    """

    expansion: int = 4

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        stride: int = 1,
        downsample: nn.Module = None,
    ) -> None:
        super().__init__()

        # 1×1 conv to reduce dimensionality
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)

        # 3×3 conv for main processing
        self.conv2 = nn.Conv2d(
            out_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False,
        )
        self.bn2 = nn.BatchNorm2d(out_channels)

        # 1×1 conv to restore dimensionality
        self.conv3 = nn.Conv2d(
            out_channels, out_channels * self.expansion, kernel_size=1, bias=False
        )
        self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)

        self.downsample = downsample
        self.stride = stride

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = F.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = F.relu(out)

        return out

The ResNet class implements the complete ResNet architecture adapted for CIFAR-10. The main differences from the original ImageNet version include a smaller initial convolution, no initial max pooling, and a final classification layer adapted to 10 classes.

In [None]:
class ResNet(nn.Module):
    """
    ResNet implementation adapted for CIFAR-10.

    Differences with respect to the original ImageNet version:
      - First layer: Conv 3×3 instead of Conv 7×7.
      - No initial MaxPooling (images are 32×32).
      - Final classification layer adapted to CIFAR-10.

    Args:
        block: Block type (BasicBlock or Bottleneck).
        layers: List with the number of blocks per stage.
        num_classes: Number of output classes.
    """

    def __init__(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        layers: List[int],
        num_classes: int = 10,
    ) -> None:
        super().__init__()
        self.in_channels = 64

        # Initial layer adapted to CIFAR-10 (32×32)
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)

        # Four stages of residual blocks
        self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        # Global pooling and final classifier
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        # Weight initialization
        self._initialize_weights()

    def _make_layer(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        out_channels: int,
        num_blocks: int,
        stride: int = 1,
    ) -> nn.Sequential:
        """
        Build one stage of residual blocks.

        Args:
            block: Residual block type.
            out_channels: Number of output channels.
            num_blocks: Number of blocks in the stage.
            stride: Stride of the first block (downsampling).

        Returns:
            nn.Sequential containing the stage blocks.
        """
        downsample = None

        # Dimensionality adjustment in the shortcut branch
        if stride != 1 or self.in_channels != out_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(
                    self.in_channels,
                    out_channels * block.expansion,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(out_channels * block.expansion),
            )

        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels * block.expansion

        for _ in range(1, num_blocks):
            layers.append(block(self.in_channels, out_channels))

        return nn.Sequential(*layers)

    def _initialize_weights(self) -> None:
        """
        Initialize weights using He (Kaiming) initialization.
        """
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        ResNet forward pass.

        Args:
            x: Input tensor [B, 3, 32, 32].

        Returns:
            Classification logits [B, num_classes].
        """
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)  # 32×32 → 32×32
        x = self.layer2(x)  # 32×32 → 16×16
        x = self.layer3(x)  # 16×16 →  8×8
        x = self.layer4(x)  #  8×8 →  4×4

        x = self.avgpool(x)  # 4×4 → 1×1
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

    def get_features(self, x: torch.Tensor) -> torch.Tensor:
        """
        Extract features before the classification layer.
        Useful for visualization of embeddings and transfer learning.
        """
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)

        return x

Factory functions are defined to instantiate the different ResNet variants with the appropriate block types and layer configurations:

In [None]:
def resnet18(num_classes: int = 10) -> ResNet:
    """Construct a ResNet-18 for the given number of classes."""
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)


def resnet34(num_classes: int = 10) -> ResNet:
    """Construct a ResNet-34 for the given number of classes."""
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)


def resnet50(num_classes: int = 10) -> ResNet:
    """Construct a ResNet-50 for the given number of classes."""
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)


print("ResNet architecture defined correctly")

This implementation clearly separates the fundamental components: `BasicBlock` forResNet-18/34, `Bottleneck` for ResNet-50/101/152, and the `ResNet` class, which assemblesthe stages, manages downsampling, and applies global average pooling before the finalclassification.

### Model Instantiation and Analysis

A ResNet-18 instance adapted to CIFAR-10 is created and its structure and parameter countare examined using `torchinfo.summary`.

In [None]:
# Create ResNet-18
model = resnet18(num_classes=NUM_CLASSES)

# Select device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

print(f"Device used: {device}")
print(f"\n{'='*70}")
print("RESNET-18 ARCHITECTURE SUMMARY")
print(f"{'='*70}\n")

summary(model, input_size=(BATCH_SIZE, 3, 32, 32), device=str(device))


def count_parameters(module: nn.Module) -> int:
    return sum(p.numel() for p in module.parameters())


print(f"\n{'='*70}")
print("PARAMETER ANALYSIS BY COMPONENT")
print(f"{'='*70}")
print(f"  Initial conv:     {count_parameters(model.conv1):>12,} parameters")
print(f"  Layer 1 (64 ch.): {count_parameters(model.layer1):>12,} parameters")
print(f"  Layer 2 (128 ch.):{count_parameters(model.layer2):>12,} parameters")
print(f"  Layer 3 (256 ch.):{count_parameters(model.layer3):>12,} parameters")
print(f"  Layer 4 (512 ch.):{count_parameters(model.layer4):>12,} parameters")
print(f"  FC classifier:    {count_parameters(model.fc):>12,} parameters")
print(f"  {'-'*66}")

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"  TOTAL:            {total_params:>12,} parameters")
print(f"  Trainable:        {trainable_params:>12,} parameters")
print(f"  Memory (float32): {total_params * 4 / (1024**2):>10.2f} MB")

### Training Configuration

The optimizer, learning rate scheduler, and loss function are set up. Stochastic Gradient Descent (SGD) with Nesterov momentum is used, along with a MultiStepLR scheduler that decreases the learning rate at predefined epochs.

In [None]:
print("TRAINING CONFIGURATION")
print(f"{'='*70}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Initial learning rate: {LEARNING_RATE}")
print(f"  Momentum: {MOMENTUM}")
print(f"  Weight decay (L2): {WEIGHT_DECAY}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"{'='*70}\n")

# Optimizer: SGD with Nesterov momentum
optimizer = torch.optim.SGD(
    params=model.parameters(),
    lr=LEARNING_RATE,
    momentum=MOMENTUM,
    weight_decay=WEIGHT_DECAY,
    nesterov=True,
)

# Scheduler: MultiStepLR
scheduler = torch.optim.lr_scheduler.MultiStepLR(
    optimizer, milestones=[60, 80], gamma=0.1
)

# Loss function
loss_function = nn.CrossEntropyLoss()

print("Optimizer: SGD with Nesterov momentum")
print("  Nesterov momentum adds a 'look-ahead' in the descent direction")
print("\nScheduler: MultiStepLR")
print("  Reduces learning rate ×0.1 at epochs 60 and 80")
print("  Classic strategy for training ResNet on CIFAR-10")
print("\nLoss function: CrossEntropyLoss")

### Training and Validation Loop

The training loop is implemented with metric tracking and includes saving the model that achieves the highest test accuracy.

In [None]:
# Metric storage
train_losses, train_accuracies = [], []
test_losses, test_accuracies = [], []
learning_rates = []

# Variables for saving the best model
best_test_acc = 0.0
best_epoch = 0


def calculate_accuracy(outputs: torch.Tensor, labels: torch.Tensor):
    _, predicted = torch.max(outputs, 1)
    correct = (predicted == labels).sum().item()
    total = labels.size(0)
    return correct, total


print("STARTING TRAINING\n")
print(f"{'='*70}\n")

start_time = time.time()

for epoch in range(NUM_EPOCHS):
    epoch_start_time = time.time()

    # Training phase
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    train_loop = tqdm(
        train_dataloader, desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} [TRAIN]", leave=False
    )

    for batch_image, batch_label in train_loop:
        batch_image = batch_image.to(device)
        batch_label = batch_label.to(device)

        optimizer.zero_grad()
        outputs = model(batch_image)
        loss = loss_function(outputs, batch_label)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        batch_correct, batch_total = calculate_accuracy(outputs, batch_label)
        correct += batch_correct
        total += batch_total

        train_loop.set_postfix(
            {"loss": f"{loss.item():.4f}", "acc": f"{100 * correct / total:.2f}%"}
        )

    epoch_train_loss = running_loss / len(train_dataloader)
    epoch_train_acc = 100 * correct / total
    train_losses.append(epoch_train_loss)
    train_accuracies.append(epoch_train_acc)

    # Validation phase
    model.eval()
    test_loss, correct_test, total_test = 0.0, 0, 0

    test_loop = tqdm(
        test_dataloader, desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} [TEST]", leave=False
    )

    with torch.no_grad():
        for images, labels in test_loop:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = loss_function(outputs, labels)

            test_loss += loss.item()
            batch_correct, batch_total = calculate_accuracy(outputs, labels)
            correct_test += batch_correct
            total_test += batch_total

            test_loop.set_postfix(
                {
                    "loss": f"{loss.item():.4f}",
                    "acc": f"{100 * correct_test / total_test:.2f}%",
                }
            )

    epoch_test_loss = test_loss / len(test_dataloader)
    epoch_test_acc = 100 * correct_test / total_test
    test_losses.append(epoch_test_loss)
    test_accuracies.append(epoch_test_acc)

    # Update scheduler
    scheduler.step()
    current_lr = optimizer.param_groups[0]["lr"]
    learning_rates.append(current_lr)

    # Save best model according to test accuracy
    if epoch_test_acc > best_test_acc:
        best_test_acc = epoch_test_acc
        best_epoch = epoch + 1
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "test_acc": best_test_acc,
            },
            "resnet18_cifar10_best.pth",
        )

    epoch_time = time.time() - epoch_start_time
    print(f"Epoch [{epoch + 1}/{NUM_EPOCHS}] - Time: {epoch_time:.2f}s")
    print(f"  Train → Loss: {epoch_train_loss:.4f} | Acc: {epoch_train_acc:.2f}%")
    print(f"  Test  → Loss: {epoch_test_loss:.4f} | Acc: {epoch_test_acc:.2f}%")
    print(
        f"  LR: {current_lr:.6f} | Best test acc: {best_test_acc:.2f}% (epoch {best_epoch})"
    )
    print(f"  {'─'*66}\n")

total_time = time.time() - start_time

print(f"\n{'='*70}")
print("TRAINING COMPLETED")
print(f"{'='*70}")
print(f"  Total time: {total_time / 60:.2f} minutes")
print(f"  Average time per epoch: {total_time / NUM_EPOCHS:.2f} seconds")
print(f"  Final test accuracy: {test_accuracies[-1]:.2f}%")
print(f"  Best test accuracy: {best_test_acc:.2f}% at epoch {best_epoch}")

# Save final model and metrics
torch.save(
    {
        "epoch": NUM_EPOCHS,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "train_losses": train_losses,
        "train_accuracies": train_accuracies,
        "test_losses": test_losses,
        "test_accuracies": test_accuracies,
        "best_test_acc": best_test_acc,
        "best_epoch": best_epoch,
    },
    "resnet18_cifar10_final.pth",
)

print("\nSaved models:")
print("  - resnet18_cifar10_best.pth (best model)")
print("  - resnet18_cifar10_final.pth (final model + metrics)")

Compared to architectures like VGG, ResNet training on CIFAR-10 is generally more stable and efficient at similar depths, thanks to residual connections and a more moderate number of parameters.

### Visualization of Training Metrics

Training metrics are visualized using four plots—loss progression, accuracy progression, learning rate schedule, and the train-test accuracy gap—to analyze model performance and identify potential overfitting.

In [None]:
epochs_range = range(1, NUM_EPOCHS + 1)

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))

# Loss
ax1.plot(
    epochs_range,
    train_losses,
    "o-",
    label="Train Loss",
    linewidth=2,
    markersize=3,
    alpha=0.7,
)
ax1.plot(
    epochs_range,
    test_losses,
    "s-",
    label="Test Loss",
    linewidth=2,
    markersize=3,
    alpha=0.7,
)
ax1.axvline(
    x=best_epoch,
    color="green",
    linestyle="--",
    alpha=0.5,
    label=f"Best epoch ({best_epoch})",
)
ax1.set_xlabel("Epoch", fontsize=12, fontweight="bold")
ax1.set_ylabel("Loss", fontsize=12, fontweight="bold")
ax1.set_title("Loss Evolution", fontsize=14, fontweight="bold")
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# Accuracy
ax2.plot(
    epochs_range,
    train_accuracies,
    "o-",
    label="Train Accuracy",
    linewidth=2,
    markersize=3,
    alpha=0.7,
)
ax2.plot(
    epochs_range,
    test_accuracies,
    "s-",
    label="Test Accuracy",
    linewidth=2,
    markersize=3,
    alpha=0.7,
)
ax2.axvline(
    x=best_epoch,
    color="green",
    linestyle="--",
    alpha=0.5,
    label=f"Best epoch ({best_epoch})",
)
ax2.axhline(
    y=best_test_acc,
    color="red",
    linestyle="--",
    alpha=0.5,
    label=f"Best acc: {best_test_acc:.2f}%",
)
ax2.set_xlabel("Epoch", fontsize=12, fontweight="bold")
ax2.set_ylabel("Accuracy (%)", fontsize=12, fontweight="bold")
ax2.set_title("Accuracy Evolution", fontsize=14, fontweight="bold")
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)

# Learning rate
ax3.plot(
    epochs_range,
    learning_rates,
    "o-",
    color="red",
    linewidth=2,
    markersize=3,
    alpha=0.7,
)
ax3.set_xlabel("Epoch", fontsize=12, fontweight="bold")
ax3.set_ylabel("Learning Rate", fontsize=12, fontweight="bold")
ax3.set_title("Learning Rate Schedule", fontsize=14, fontweight="bold")
ax3.set_yscale("log")
ax3.grid(True, alpha=0.3)
ax3.axvline(x=60, color="orange", linestyle="--", alpha=0.5, label="LR decay")
ax3.axvline(x=80, color="orange", linestyle="--", alpha=0.5)
ax3.legend(fontsize=10)

# Train–test gap
gap = np.array(train_accuracies) - np.array(test_accuracies)
ax4.plot(epochs_range, gap, "o-", color="purple", linewidth=2, markersize=3, alpha=0.7)
ax4.axhline(y=0, color="black", linestyle="-", linewidth=0.5)
ax4.axhline(
    y=5, color="red", linestyle="--", alpha=0.5, label="Overfitting threshold (5%)"
)
ax4.set_xlabel("Epoch", fontsize=12, fontweight="bold")
ax4.set_ylabel("Train–Test Gap (%)", fontsize=12, fontweight="bold")
ax4.set_title("Train–Test Accuracy Difference", fontsize=14, fontweight="bold")
ax4.legend(fontsize=10)
ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("resnet18_training_history.png", dpi=300, bbox_inches="tight")
plt.show()

print("\nResult analysis:")
final_gap = train_accuracies[-1] - test_accuracies[-1]
print(f"  Overfitting detected: {'YES' if final_gap > 10 else 'NO'}")
print(f"  Final train–test gap: {final_gap:.2f}%")
print(f"  Best epoch: {best_epoch}")
print(f"  Improvement from epoch 1: {test_accuracies[-1] - test_accuracies[0]:.2f}%")

Moderate train–test accuracy gaps indicate a healthy balance between fitting the data and generalizing to new samples. Learning rate reductions at epochs 60 and 80 often coincide with shifts in network behavior and corresponding improvements in test accuracy.

### Visualization of Model Predictions

Finally, model predictions on the test set are visualized, including both correctlyclassified examples and some errors, allowing qualitative inspection of model behavior.

In [None]:
print("\nVisualizing predictions of the best model...")

# Get one test batch
data_iter = iter(test_dataloader)
test_images, test_labels = next(data_iter)

model.eval()
with torch.no_grad():
    test_images_device = test_images.to(device)
    outputs = model(test_images_device)
    _, predictions = torch.max(outputs, 1)
    predictions = predictions.cpu()

print("\nFirst 8 predictions:")
show_images(test_images[:8], test_labels[:8], predictions[:8])

# Examples of misclassifications
incorrect_indices = (predictions != test_labels).nonzero(as_tuple=True)[0]

if len(incorrect_indices) >= 8:
    print("\nExamples of incorrect predictions:")
    error_indices = incorrect_indices[:8]
    show_images(
        test_images[error_indices],
        test_labels[error_indices],
        predictions[error_indices],
    )
else:
    print(f"\nOnly {len(incorrect_indices)} misclassifications in this batch")