In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.torchvision import *
from typing import List

# DenseNet

Dense Convolutional Network (DenseNet) connects each layer to every other layer in a feed-forward fashion. Whereas traditional convolutional networks with $L$ layers have $L$ connections—one between each layer and its subsequent layer, our network has $\frac{L(L+1)}{2}$ direct connections. For each layer, the feature-maps of all preceding layers are used as inputs, and its own feature-maps are used as inputs into all subsequent layers. DenseNets have several compelling advantages: they alleviate the vanishing-gradient problem, strengthen feature propagation, encourage feature reuse, and substantially reduce the number of parameters. We evaluate our proposed architecture on four highly competitive object recognition benchmark tasks (CIFAR-10, CIFAR-100, SVHN, and ImageNet). DenseNets obtain significant improvements over the state-of-the-art on most of them, whilst requiring less computation to achieve high performance. [Paper](https://arxiv.org/pdf/1608.06993)

<center>
<img width="800" src="https://i.ibb.co/7GsdcsZ/image-2024-06-13-141410503.png" alt="image-2024-06-13-141410503" border="0">
</center>

In [2]:
class FeatureExtraction(nn.Module):
  def __init__(self, num_init_channels):
    super().__init__()
    self.conv = nn.Conv2d(3, num_init_channels, kernel_size=7, stride=2, padding=3, bias=False)
    self.norm = nn.BatchNorm2d(num_init_channels)
    self.relu = nn.ReLU(inplace=True)
    self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

  def forward(self, x):
    return self.pool(self.relu(self.norm(self.conv(x))))

class DenseLayer(nn.Module):
  def __init__(self, input_channels, growth_rate, bn_size, drop_rate):
    super().__init__()
    fix_channel = bn_size * growth_rate
    self.norm1 = nn.BatchNorm2d(input_channels)
    self.relu1 = nn.ReLU(inplace=True)
    self.conv1 = nn.Conv2d(input_channels, fix_channel, kernel_size=1, stride=1, bias=False)

    self.norm2 = nn.BatchNorm2d(fix_channel)
    self.relu2 = nn.ReLU(inplace=True)
    self.conv2 = nn.Conv2d(fix_channel, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)

    self.drop_rate = float(drop_rate)

  def bn_function(self, inputs: List[Tensor]) -> Tensor:
        concated_features = torch.cat(inputs, 1)
        bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features)))
        return bottleneck_output

  def forward(self, input):
      if isinstance(input, Tensor): prev_features = [input]
      else: prev_features = input
      bottleneck_output = self.bn_function(prev_features)
      new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
      return F.dropout(new_features, p=self.drop_rate)

class DenseBlock(nn.ModuleDict):
  def __init__(self, num_layers, input_channels, bn_size, growth_rate, drop_rate=0.1):
    super().__init__()
    for i in range(num_layers):
      layer = DenseLayer(input_channels + i * growth_rate, growth_rate, bn_size, drop_rate)
      self.add_module("denselayer%d" % (i + 1), layer)

  def forward(self, init_features):
    features = [init_features]
    for name, layer in self.items():
        new_features = layer(features)
        features.append(new_features)
    return torch.cat(features, 1)

class Transition(nn.Sequential):
  def __init__(self, input_features, output_features):
    super().__init__()
    self.norm = nn.BatchNorm2d(input_features)
    self.relu = nn.ReLU(inplace=True)
    self.conv = nn.Conv2d(input_features, output_features, kernel_size=1, stride=1, bias=False)
    self.pool = nn.AvgPool2d(kernel_size=2, stride=2)

  def forward(self, x):
    return self.pool(self.conv(self.relu(self.norm(x))))

class SimplifiedDenseNet(nn.Module):
  def __init__(self, num_classes, growth_rate=32, num_init_channels=64,
               layer_each_block=(6, 12, 24, 16), bn_size=4, drop_rate=0.1):
    super().__init__()

    # Feature extraction
    self.feature_extraction = FeatureExtraction(num_init_channels)

    # DenseBlocks
    num_channels = num_init_channels
    self.denseblock1 = DenseBlock(layer_each_block[0], num_channels, bn_size, growth_rate)
    concated_channels = num_channels + layer_each_block[0] * growth_rate
    self.transition1 = Transition(concated_channels, concated_channels//2)

    num_channels *= 2
    self.denseblock2 = DenseBlock(layer_each_block[1], num_channels, bn_size, growth_rate)
    concated_channels = num_channels + layer_each_block[1] * growth_rate
    self.transition2 = Transition(concated_channels, concated_channels//2)

    num_channels *= 2
    self.denseblock3 = DenseBlock(layer_each_block[2], num_channels, bn_size, growth_rate)
    concated_channels = num_channels + layer_each_block[2] * growth_rate
    self.transition3 = Transition(concated_channels, concated_channels//2)

    num_channels *= 2
    self.denseblock4 = DenseBlock(layer_each_block[3], num_channels, bn_size, growth_rate)

    # Output
    self.final_norm = nn.BatchNorm2d(num_channels*2)
    self.final_pool = nn.AdaptiveAvgPool2d((1,1))
    self.classifier = nn.Linear(num_channels*2, num_classes)

  def forward(self, x):
    x = self.feature_extraction(x)

    x = self.denseblock1(x)
    x = self.transition1(x)

    x = self.denseblock2(x)
    x = self.transition2(x)

    x = self.denseblock3(x)
    x = self.transition3(x)

    x = self.denseblock4(x)
    x = self.final_pool(self.final_norm(x))

    x = torch.flatten(x, 1)
    return self.classifier(x)

In [3]:
x = torch.randn(12, 3, 224, 224)
model = SimplifiedDenseNet(num_classes=1000)
y = model(x)

y.shape

torch.Size([12, 1000])

The DenseNet architecture is highly computationally efficient as a result of
feature reuse. However, a naïve DenseNet implementation can require a significant amount of GPU memory: If not properly managed, pre-activation batch normalization and contiguous convolution operations can produce feature maps that grow quadratically with network depth. This implementation follows the strategy of shared memory allocations to reduce the memory cost for storing feature maps from quadratic to linear.

<center>
<img width="800" src="https://i.ibb.co/PjxygZD/image.png" alt="image" border="0">
<img width="800" src="https://i.ibb.co/G207C92/image.png" alt="image" border="0">
</center>


In [4]:
with open("/workspace/dataset/vision-reg/imagenet-1k/labels.txt", 'r') as file:
    lines = file.readlines()
    _IMAGENET_CATEGORIES = [line.strip().strip('"')[:-2] for line in lines]
    
print(_IMAGENET_CATEGORIES[:10])

['tench', 'goldfish', 'great white shark', 'tiger shark', 'hammerhead', 'electric ray', 'stingray', 'cock', 'hen', 'ostrich']


## Torchvision version

In [7]:
class _DenseLayer(nn.Module):
    def __init__(self, num_input_features: int, growth_rate: int, bn_size: int,
                 drop_rate: float, memory_efficient: bool = False) -> None:
        super().__init__()
        self.norm1 = nn.BatchNorm2d(num_input_features)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)

        self.norm2 = nn.BatchNorm2d(bn_size * growth_rate)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)

        self.drop_rate = float(drop_rate)
        self.memory_efficient = memory_efficient

    def bn_function(self, inputs: List[Tensor]) -> Tensor:
        concated_features = torch.cat(inputs, 1)
        bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features)))  # noqa: T484
        return bottleneck_output
    
    def any_requires_grad(self, input: List[Tensor]) -> bool:
        for tensor in input:
            if tensor.requires_grad:
                return True
        return False

    @torch.jit.unused
    def call_checkpoint_bottleneck(self, input: List[Tensor]) -> Tensor:
        def closure(*inputs):
            return self.bn_function(inputs)

        return cp.checkpoint(closure, *input, use_reentrant=False)

    @torch.jit._overload_method
    def forward(self, input: List[Tensor]) -> Tensor:
        pass

    @torch.jit._overload_method
    def forward(self, input: Tensor) -> Tensor:
        pass

    def forward(self, input: Tensor) -> Tensor: 
        if isinstance(input, Tensor):
            prev_features = [input]
        else:
            prev_features = input

        if self.memory_efficient and self.any_requires_grad(prev_features):
            if torch.jit.is_scripting():
                raise Exception("Memory Efficient not supported in JIT")

            bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
        else:
            bottleneck_output = self.bn_function(prev_features)

        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
        if self.drop_rate > 0:
            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
        return new_features


class _DenseBlock(nn.ModuleDict):
    _version = 2

    def __init__(
        self,
        num_layers: int,
        num_input_features: int,
        bn_size: int,
        growth_rate: int,
        drop_rate: float,
        memory_efficient: bool = False,
    ) -> None:
        super().__init__()
        for i in range(num_layers):
            layer = _DenseLayer(
                num_input_features + i * growth_rate,
                growth_rate=growth_rate,
                bn_size=bn_size,
                drop_rate=drop_rate,
                memory_efficient=memory_efficient,
            )
            self.add_module("denselayer%d" % (i + 1), layer)

    def forward(self, init_features: Tensor) -> Tensor:
        print("New Block ===")
        features = [init_features]
        for name, layer in self.items():
            new_features = layer(features)
            print()
            print(f"Feature: {new_features.shape}")
            print()
            features.append(new_features)
        return torch.cat(features, 1)


class _Transition(nn.Sequential):
    def __init__(self, num_input_features: int, num_output_features: int) -> None:
        super().__init__()
        self.norm = nn.BatchNorm2d(num_input_features)
        self.relu = nn.ReLU(inplace=True)
        self.conv = nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)

class DenseNet(nn.Module):
    """
    Densely Connected Convolutional Networks

    Args:
        growth_rate (int) - how many filters to add each layer (`k` in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
            (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
        memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
        but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
    """

    def __init__(
        self,
        growth_rate: int = 32,
        block_config: Tuple[int, int, int, int] = (6, 12, 24, 16),
        num_init_features: int = 64,
        bn_size: int = 4,
        drop_rate: float = 0,
        num_classes: int = 1000,
        memory_efficient: bool = False,
    ) -> None:

        super().__init__()
        _log_api_usage_once(self)

        # First convolution
        self.features = nn.Sequential(
            OrderedDict(
                [
                    ("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
                    ("norm0", nn.BatchNorm2d(num_init_features)),
                    ("relu0", nn.ReLU(inplace=True)),
                    ("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
                ]
            )
        )

        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(
                num_layers=num_layers,
                num_input_features=num_features,
                bn_size=bn_size,
                growth_rate=growth_rate,
                drop_rate=drop_rate,
                memory_efficient=memory_efficient,
            )
            self.features.add_module("denseblock%d" % (i + 1), block)
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
                trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
                self.features.add_module("transition%d" % (i + 1), trans)
                num_features = num_features // 2

        # Final batch norm
        self.features.add_module("norm5", nn.BatchNorm2d(num_features))

        # Linear layer
        self.classifier = nn.Linear(num_features, num_classes)

        # Official init from torch repo.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x: Tensor) -> Tensor:
        features = self.features(x)
        out = F.relu(features, inplace=True)
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)
        out = self.classifier(out)
        return out

### Torchvision User Usage

In [31]:
import torch
from torchvision import models, transforms
from PIL import Image
import numpy as np

# Load the pre-trained DenseNet121 model
model = models.densenet121(weights="DenseNet121_Weights.IMAGENET1K_V1", progress=True)
model.eval()

# Define the image preprocessing steps
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load an image
img_path = '/workspace/dataset/samples/fish.jpeg'
img = Image.open(img_path)
img_t = preprocess(img)
input_tensor = img_t.unsqueeze(0)

# Perform inference
with torch.no_grad():
    output = model(input_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)

# Print the top 5 most probable classes
_, indices = torch.topk(probabilities, 5)
print(f'Top 5 classes: {np.array(_IMAGENET_CATEGORIES)[indices].tolist()}')
print(f'Probabilities: {probabilities[indices]}')

Top 5 classes: ['goldfish', 'tench', 'axolotl', 'anemone fish', 'coho']
Probabilities: tensor([1.0000e+00, 2.1803e-06, 9.4948e-07, 1.8117e-07, 1.7783e-07])
