We implement a slight modification of DenseNet-121, following the 2018 [paper](https://arxiv.org/pdf/1608.06993.pdf) by Huang, Liu, van der Maaten, and Weinberger.

Addressing the same problems arising from increasing network depth as ResNet, DenseNet opts instead for concatenation in its equivalent of a shortcut connection, in contrast to ResNet's concatenation. Effectively, as the layers in a dense block append their output onto their input, it is as though a layer feeds its output directly into each subsequent layer, hence the name "dense".

In [1]:
import torch
from torch import nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
import numpy as np

In [2]:
device = (
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)

print(f"Using {device}")

if device == "cuda":
    print(torch.cuda.get_device_name(torch.cuda.current_device()))

Using cuda
Tesla P100-PCIE-16GB


In [63]:
class DenseBlock(nn.Module):
    # in_channels: number of input channels
    # num_layers: number of layers in the dense block
    # growth_rate: denoted k in the paper
    def __init__(
        self,
        in_channels: int,
        num_layers: int,
        growth_rate: int
    ) -> None:
        super().__init__()

        self.layers = []

        for n in range(num_layers):
            # 1x1 conv followed by 3x3 conv, each with batchnorm and relu
            self.layers.append(nn.Sequential(
                # Adding nk is the result of the concatenation
                nn.BatchNorm2d(in_channels + n * growth_rate),
                nn.ReLU(),
                # Each 1x1 produces 4k output channels per the paper
                nn.Conv2d(
                    in_channels + n * growth_rate,
                    4 * growth_rate,
                    kernel_size=1
                ),
                nn.BatchNorm2d(4 * growth_rate),
                nn.ReLU(),
                # Output k channels
                nn.Conv2d(
                    4 * growth_rate, growth_rate,
                    kernel_size=3, padding=1
                )
            ))
            
        self.layers = nn.ModuleList(self.layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for l in self.layers:
            x = torch.cat((x, l(x)), 1)

        return x

In [64]:
class TransitionLayer(nn.Module):
    # in_channels: number of input channels
    # compression: denoted theta in the paper
    def __init__(self, in_channels: int, compression: float) -> None:
        super().__init__()

        self.bn = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU()
        self.conv = nn.Conv2d(
            in_channels,
            int(in_channels * compression),
            kernel_size=1
        )
        self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.bn(x)
        x = self.relu(x)
        x = self.conv(x)
        x = self.avgpool(x)
        return x

In [65]:
class DenseNet(nn.Module):
    # in_channels: number of input channels
    # layers: list of number of layers per dense block
    # growth_rate: denoted k in the paper
    # compression: denoted theta in the paper
    def __init__(
        self,
        in_channels: int,
        layers: list[int],
        growth_rate: int,
        compression: float
    ) -> None:
        super().__init__()

        num_ch = 64
        self.components = []

        # Preprocessing, with some parameters modified to acount for data size
        # In particular, we apply no stride
        self.conv = nn.Conv2d(in_channels, num_ch, kernel_size=5, padding=2)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)

        # Add in blocks and transition layers
        for i, num_layers in enumerate(layers):
            self.components.append(DenseBlock(num_ch, num_layers, growth_rate))
            num_ch += num_layers * growth_rate

            # Transition for all but last block
            if i != len(layers) - 1:
                self.components.append(TransitionLayer(num_ch, compression))
                num_ch = int(num_ch * compression)

        self.components = nn.ModuleList(self.components)
                
        # Postprocessing
        self.avgpool = nn.AvgPool2d(kernel_size=4, stride=1)
        self.fc = nn.Linear(num_ch, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        x = self.maxpool(x)

        for component in self.components:
            x = component(x)

        x = self.avgpool(x)
        x = x.reshape(x.shape[0], (-1))
        x = self.fc(x)
        return x

In [6]:
transform_CIFAR10 = transforms.Compose([
    transforms.ToTensor(),
    # Mean and standard deviation for CIFAR10 dataset
    # Sourced from gist.github.com/weiaicunzai/e623931921efefd4c331622c344d8151
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2470, 0.2435, 0.2616]
    )
])

train_data = datasets.CIFAR10(
    root="data",
    train=True,
    transform=transform_CIFAR10,
    download=True
)

test_data = datasets.CIFAR10(
    root="data",
    train=False,
    transform=transform_CIFAR10,
    download=True
)

BATCH_SIZE = 64
VALID_SIZE = 0.9

train_indices = list(range(len(train_data)))
np.random.shuffle(train_indices)
valid_split = int(len(train_data) * VALID_SIZE)
valid_indices = train_indices[valid_split:]
train_indices = train_indices[:valid_split]
valid_data = Subset(train_data, valid_indices)
train_data = Subset(train_data, train_indices)

print(f"Training data: {len(train_data)}")
print(f"Validation data: {len(valid_data)}")
print(f"Test data: {len(test_data)}")

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 31855805.02it/s]


Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified
Training data: 45000
Validation data: 5000
Test data: 10000


In [92]:
EPOCHS = 20
LR = 0.01
WEIGHT_DECAY = 0.001
MOMENTUM = 0.9

model = DenseNet(3, [6, 12, 24, 16], 12, 0.5)
model = nn.DataParallel(model)
model.to(device)

metric = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=LR,
    weight_decay=WEIGHT_DECAY,
    momentum=MOMENTUM
)

In [93]:
def train(
    loader: DataLoader,
    model: nn.Module,
    metric: nn.Module,
    optimizer: torch.optim.Optimizer
) -> None:
    total = len(loader.dataset)
    model.train()

    for batch, (x, y) in enumerate(loader):
        x = x.to(device)
        y = y.to(device)

        pred = model(x)
        loss = metric(pred, y)

        if batch % 100 == 99:
            progress = (batch + 1) * len(x)
            print(f"\tLoss: {loss.item():>7f} [{progress:>5d} / {total:>5d}]")
            print(f"\t\tLearning rate: {optimizer.param_groups[0]['lr']:>8f}")

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [94]:
def test(
    loader: DataLoader,
    model: nn.Module,
    metric: nn.Module,
) -> None:
    total = len(loader.dataset)
    batch_total = len(loader)
    total_loss = 0
    total_correct = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            pred = model(x)
            total_loss += metric(pred, y).item()
            pred_correct = pred.argmax(1) == y
            total_correct += pred_correct.type(torch.float).sum().item()

        total_loss /= batch_total
        total_correct /= total
        print(f"\tAccuracy: {(100 * total_correct):>0.1f}%")
        print(f"\tAverage loss: {total_loss:>8f}")

In [95]:
for t in range(EPOCHS):
    print(f"Epoch: {t + 1}")
    train(train_loader, model, metric, optimizer)
    test(valid_loader, model, metric)

torch.save(model.state_dict(), "densenet121.pth")

Epoch: 1
	Loss: 1.754052 [ 6400 / 45000]
		Learning rate: 0.010000
	Loss: 1.485655 [12800 / 45000]
		Learning rate: 0.010000
	Loss: 1.486960 [19200 / 45000]
		Learning rate: 0.010000
	Loss: 1.320729 [25600 / 45000]
		Learning rate: 0.010000
	Loss: 1.134369 [32000 / 45000]
		Learning rate: 0.010000
	Loss: 1.221157 [38400 / 45000]
		Learning rate: 0.010000
	Loss: 1.081045 [44800 / 45000]
		Learning rate: 0.010000
	Accuracy: 54.4%
	Average loss: 1.290990
Epoch: 2
	Loss: 1.136029 [ 6400 / 45000]
		Learning rate: 0.010000
	Loss: 1.096946 [12800 / 45000]
		Learning rate: 0.010000
	Loss: 0.888631 [19200 / 45000]
		Learning rate: 0.010000
	Loss: 0.912430 [25600 / 45000]
		Learning rate: 0.010000
	Loss: 1.046178 [32000 / 45000]
		Learning rate: 0.010000
	Loss: 0.918515 [38400 / 45000]
		Learning rate: 0.010000
	Loss: 0.998655 [44800 / 45000]
		Learning rate: 0.010000
	Accuracy: 66.7%
	Average loss: 0.946098
Epoch: 3
	Loss: 0.714837 [ 6400 / 45000]
		Learning rate: 0.010000
	Loss: 0.622070 [1280

In [96]:
test(test_loader, model, metric)

	Accuracy: 84.2%
	Average loss: 0.492691
