Using PyTorch, we implement the architectures for ResNet50, ResNet101, and ResNet152 as presented by He, Zhang, Ren, and Sun in their 2015 paper [*Deep Residual Learning for Image Recognition*](https://arxiv.org/pdf/1512.03385.pdf). We then proceed to train a ResNet50 model on the CIFAR10 dataset. We refer to both [Nouman](https://blog.paperspace.com/writing-resnet-from-scratch-in-pytorch/)'s and [Persson](https://www.youtube.com/watch?v=DkNIBBBvcPs)'s implementations for guidance when necessary.

In theory, an neural network can be expanded to an arbitrary depth by appending layers that learn just the identity function. In practice, the vanishing or exploding gradients arising from networks that are too deep impede the model's convergence. ResNet addresses this issue by adding in shortcut connections that add the input of a block directly to the output of a block. Then, instead of learning something close to identity function, a ResNet layer learns the deviation of the output from the identity.

In [1]:
import torch
from torch import nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
import numpy as np

In [2]:
device = (
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)

print(f"Using {device}")

if device == "cuda":
    print(torch.cuda.get_device_name(torch.cuda.current_device()))

Using cuda
Tesla T4


In [3]:
# Represents a single building block of a residual network
class ResNetBlock(nn.Module):
    # in_channels: number of input channels
    # out_channels: number of output channels for intermediate layers
    # downsample: sequence by which to reduce the identity map if necessary
    # stride: stride of second conv of the three-layer stack forming a block
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        stride: int = 1
    ) -> None:
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        # This is by convention, see (Table 1) of the original paper
        self.last_out_channels = out_channels * 4

        self.conv0 = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=1
        )
        # Per the paper, we normalize after every conv and before activation
        self.bn0 = nn.BatchNorm2d(out_channels)
        # padding=1 adds back the two rows and columns lost by kernel_size=3
        self.conv1 = nn.Conv2d(
            out_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=1
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(
            out_channels,
            self.last_out_channels,
            kernel_size=1
        )
        self.bn2 = nn.BatchNorm2d(self.last_out_channels)
        self.relu = nn.ReLU()

        # In the forward of each block, the residual is added to the output
        # Downsampling may be necessary to match the dimension of the residual
        #     to that of the output
        self.downsample_conv = nn.Conv2d(
            in_channels,
            self.last_out_channels,
            kernel_size=1,
            stride=stride
        )
        self.downsample_bn = nn.BatchNorm2d(self.last_out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x
        x = self.conv0(x)
        x = self.bn0(x)
        x = self.relu(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)

        residual = self.downsample_conv(residual)
        residual = self.downsample_bn(residual)

        x += residual
        x = self.relu(x)
        return x

In [4]:
class ResNet(nn.Module):
    # in_channels: number of channels for an image in the dataset
    # layers: list of length 4 representing the layers of ResNet blocks
    # classes: number of output classes for the dataset
    def __init__(
        self,
        in_channels: int,
        layers: list[int],
        classes: int
    ) -> None:
        super().__init__()

        # Initial two steps of convolution and maxpool for any ResNet
        self.conv = nn.Conv2d(
            in_channels,
            64,
            kernel_size=7,
            stride=2,
            padding=3
        )
        self.bn = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        # Because maxpool already halves the output size, the layer represented
        #     by layers[0] will only have stride=1
        self.maxpool = nn.MaxPool2d(
            kernel_size=3,
            stride=2,
            padding=1
        )

        self.layer0 = self._make_layer(layers[0], 64, 64, stride=1)
        self.layer1 = self._make_layer(layers[1], 256, 128, stride=2)
        self.layer2 = self._make_layer(layers[2], 512, 256, stride=2)
        self.layer3 = self._make_layer(layers[3], 1024, 512, stride=2)

        # Final two steps of avgpool and fully connected layer
        self.avgpool = nn.AvgPool2d(7, stride=1)
        # Number of output channels are scaled up by 4
        self.fc = nn.Linear(2048, classes)

    # Make a single layer of ResNet blocks
    # num_blocks: number of blocks in the layer
    # in_channels: number of input channels to the whole layer
    # out_channels: number of output channels for intermediate convs of blocks
    # stride: stride of first block of layer and of downsample if necessary
    def _make_layer(
        self,
        num_blocks: int,
        in_channels: int,
        out_channels: int,
        stride: int = 1
    ) -> nn.Module:
        blocks = []
        blocks.append(ResNetBlock(in_channels, out_channels, stride=stride))

        # stride=1 for remaining blocks
        for _ in range(num_blocks - 1):
            blocks.append(ResNetBlock(out_channels * 4, out_channels))

        return nn.Sequential(*blocks)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.avgpool(x)
        # Reshape for fully connected layer
        x = x.reshape(x.shape[0], (-1))
        x = self.fc(x)
        return x

In [5]:
transform_CIFAR10 = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # Mean and standard deviation for CIFAR10 dataset
    # Sourced from gist.github.com/weiaicunzai/e623931921efefd4c331622c344d8151
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2470, 0.2435, 0.2616]
    )
])

train_data = datasets.CIFAR10(
    root="data",
    train=True,
    transform=transform_CIFAR10,
    download=True
)

test_data = datasets.CIFAR10(
    root="data",
    train=False,
    transform=transform_CIFAR10,
    download=True
)

BATCH_SIZE = 64
VALID_SIZE = 0.9

train_indices = list(range(len(train_data)))
np.random.shuffle(train_indices)
valid_split = int(len(train_data) * VALID_SIZE)
valid_indices = train_indices[valid_split:]
train_indices = train_indices[:valid_split]
valid_data = Subset(train_data, valid_indices)
train_data = Subset(train_data, train_indices)

print(f"Training data: {len(train_data)}")
print(f"Validation data: {len(valid_data)}")
print(f"Test data: {len(test_data)}")

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:04<00:00, 37764211.31it/s]


Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified
Training data: 45000
Validation data: 5000
Test data: 10000


In [7]:
RESNET50_LAYERS = [3, 4, 6, 3]
RESNET101_LAYERS = [3, 4, 23, 3]
RESNET152_LAYERS = [3, 8, 36, 3]

CLASSES = 10
EPOCHS = 20
LR = 0.1
WEIGHT_DECAY = 0.001
MOMENTUM = 0.9

# ResNet50 for CIFAR100
model = ResNet(3, RESNET50_LAYERS, CLASSES)
model = nn.DataParallel(model)
model.to(device)

metric = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=LR,
    weight_decay=WEIGHT_DECAY,
    momentum=MOMENTUM
)
# Per the paper, "the learning rate starts from 0.1 and is divided by 10
#     when the error plateaus"
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="min",
    factor=0.1,
    patience=8,
    threshold=0.01,
    threshold_mode="abs"
)

In [8]:
def train(
    loader: DataLoader,
    model: nn.Module,
    metric: nn.Module,
    optimizer: torch.optim.Optimizer
) -> None:
    total = len(loader.dataset)
    model.train()

    for batch, (x, y) in enumerate(loader):
        x = x.to(device)
        y = y.to(device)

        pred = model(x)
        loss = metric(pred, y)

        if batch % 100 == 99:
            scheduler.step(loss)
            progress = (batch + 1) * len(x)
            print(f"\tLoss: {loss.item():>7f} [{progress:>5d} / {total:>5d}]")
            print(f"\t\tLearning rate: {optimizer.param_groups[0]['lr']:>8f}")

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [9]:
def test(
    loader: DataLoader,
    model: nn.Module,
    metric: nn.Module,
) -> None:
    total = len(loader.dataset)
    batch_total = len(loader)
    total_loss = 0
    total_correct = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            pred = model(x)
            total_loss += metric(pred, y).item()
            pred_correct = pred.argmax(1) == y
            total_correct += pred_correct.type(torch.float).sum().item()

        total_loss /= batch_total
        total_correct /= total
        print(f"\tAccuracy: {(100 * total_correct):>0.1f}%")
        print(f"\tAverage loss: {total_loss:>8f}")

In [10]:
for t in range(EPOCHS):
    print(f"Epoch: {t + 1}")
    train(train_loader, model, metric, optimizer)
    test(valid_loader, model, metric)

torch.save(model.state_dict(), "resnet50.pth")

Epoch: 1
	Loss: 2.078911 [ 6400 / 45000]
		Learning rate: 0.100000
	Loss: 2.078148 [12800 / 45000]
		Learning rate: 0.100000
	Loss: 1.936066 [19200 / 45000]
		Learning rate: 0.100000
	Loss: 1.939066 [25600 / 45000]
		Learning rate: 0.100000
	Loss: 1.808320 [32000 / 45000]
		Learning rate: 0.100000
	Loss: 1.854816 [38400 / 45000]
		Learning rate: 0.100000
	Loss: 1.799024 [44800 / 45000]
		Learning rate: 0.100000
	Accuracy: 30.2%
	Average loss: 1.787851
Epoch: 2
	Loss: 1.773424 [ 6400 / 45000]
		Learning rate: 0.100000
	Loss: 1.641893 [12800 / 45000]
		Learning rate: 0.100000
	Loss: 1.618807 [19200 / 45000]
		Learning rate: 0.100000
	Loss: 1.633263 [25600 / 45000]
		Learning rate: 0.100000
	Loss: 1.836801 [32000 / 45000]
		Learning rate: 0.100000
	Loss: 1.670297 [38400 / 45000]
		Learning rate: 0.100000
	Loss: 1.529169 [44800 / 45000]
		Learning rate: 0.100000
	Accuracy: 38.9%
	Average loss: 1.636970
Epoch: 3
	Loss: 1.665519 [ 6400 / 45000]
		Learning rate: 0.100000
	Loss: 1.688651 [1280

In [11]:
test(test_loader, model, metric)

	Accuracy: 81.5%
	Average loss: 0.538406


We note that our accuracy is relatively low for the given dataset due to our implementation of an unmodified ResNet50. In particular, there are far too many parameters compared to the size of each image in CIFAR10.