In [1]:
import torch as th
from torch import nn
import einops
from einops.layers.torch import Reduce

In [2]:
device = th.device("cuda" if th.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
count_params = lambda model: sum(p.numel() for p in model.parameters() if p.requires_grad)

In [4]:
class ConvNeXtBlock(nn.Module):
    '''
    Implementation of a ConvNeXt block.
    ConvNeXt block is a residual convolutional block with a depthwise spatial convolution and an inverted bottleneck layer.
    It is a modern variant of the ResNet block, which is especially efficient for large receptive fields.
    '''
    def __init__(self, input_channels: int, output_channels: int, kernel_size: int = 7, expansion_factor: int = 4, activation_fn = nn.GELU) -> None:
        '''
        input_channels: Number of input channels
        output_channels: Number of output channels
        kernel_size: Kernel size of the depthwise convolution
        expansion_factor: Expansion factor of the inverted bottleneck layer
        activation_fn: Activation function to use
        '''
        super().__init__()
        dim = input_channels * expansion_factor # Dimension of the inverted bottleneck layer
        # The residual block consists of a depthwise convolution, a group normalization, an inverted bottleneck layer and a projection to the output dimension.
        self.residual = nn.Sequential(
            nn.Conv2d(input_channels, input_channels, kernel_size = kernel_size, groups = input_channels, padding = 'same'), #Process spatial information per channel
            nn.GroupNorm(num_groups= 1, num_channels= input_channels), # Normalize each channel to have mean 0 and std 1, this stabilizes training.
            nn.Conv2d(input_channels, dim, kernel_size = 1), # Expand to higher dim
            activation_fn(),# Non-linearity
            nn.Conv2d(dim, output_channels, kernel_size = 1), # Project back to lower dim
        )
        # Shortcut connection to downsample residual dimension if needed
        self.shortcut = nn.Conv2d(input_channels, output_channels, kernel_size = 1) if input_channels != output_channels else nn.Identity() # Identity if same dim, else 1x1 conv to project to same dim

    def forward(self, x):
        return self.residual(x) + self.shortcut(x)

class Nino_classifier(nn.Module):
    '''
    Implementation of a ConvNeXt classifier for SST data.
    '''
    def __init__(self, 
                input_dim: int = 1, 
                latent_dim: int = 128,
                num_classes: int = 3, 
                num_layers: int = 4,
                downsampling: int = -1,
                expansion_factor: int = 4, 
                kernel_size: int = 7,
                activation_fn = nn.GELU):
        '''
        input_dim: Number of input channels
        latent_dim: Number of channels in the latent feature map
        num_classes: Number of classes to classify
        num_layers: Number of ConvNeXt blocks
        downsample_input: Whether to downsample the input with a strided convolution or not
        expansion_factor: Expansion factor of the inverted bottleneck layer
        kernel_size: Kernel size of the depthwise convolutions
        activation_fn: Activation function to use
        '''
        super().__init__()
        #First we need to project the input to the latent dimension
        if downsampling > 0: # If we want to downsample the input, we use a strided convolution. This reduces the computational cost of the network a lot.
            assert downsampling % 2 == 0, 'Downsampling factor must be even'
            self.input_projection = nn.Conv2d(input_dim, latent_dim, kernel_size= kernel_size, stride = downsampling, padding = kernel_size // 2)
        else: #If we don't want to downsample the input, we use a 1x1 convolution. This is a cheap operation that doesn't change the spatial dimension.
            self.input_projection = nn.Conv2d(input_dim, latent_dim, 1)

        #Then we process the spatial information with a series of Residual blocks defined above.
        self.cnn_blocks = nn.ModuleList([ConvNeXtBlock(latent_dim, latent_dim, kernel_size, expansion_factor, activation_fn) for _ in range(num_layers)]) # List of convolutional blocks

        #Finally, we average the latent feature map and perform classification with an inverted bottleneck MLP.
        self.classifier = nn.Sequential(
            Reduce('b c h w -> b c', 'mean'), # Global average pooling, I think this is the same as nn.AdaptiveAvgPool2d(1) but more explicit.
            nn.Linear(latent_dim, latent_dim * expansion_factor), # Linear layer to expand to higher dim
            activation_fn(), # Non-linearity
            nn.Linear(latent_dim * expansion_factor, num_classes), # Final classification layer
        )

    def forward(self, x: th.Tensor) -> th.Tensor:
        # x.shape = (batch_size, input_dim, height, width)
        x = self.input_projection(x) # (batch_size, latent_dim, height, width)
        for block in self.cnn_blocks:
            x = block(x)
        logits = self.classifier(x) # (batch_size, num_classes)
        return logits

In [5]:
cnn = Nino_classifier().to(device)
x = th.randn((64, 1, 64, 160), device = device)
y = cnn(x)
loss = nn.CrossEntropyLoss()(y, th.randint(0, 3, (64,), device = device))
loss.backward()

In [6]:
count_params(cnn)

627459