In [3]:
import torch
from torch import nn
from typing import List
from torch import Tensor

In [4]:

from torchvision.ops import StochasticDepth

class LayerScaler(nn.Module):
    def __init__(self, init_value: float, dimensions: int):
        super().__init__()
        self.gamma = nn.Parameter(init_value * torch.ones((dimensions)), requires_grad=True)

    def forward(self, x):
        return self.gamma[None, ..., None, None] * x
    
class BottleNeckBlock(nn.Module): 
    '''Inverted BottleNeck'''
    def __init__(self, in_features: int, out_features: int, 
                expansion: int = 4, drop_p:float = .0, layer_scaler_init_value: float = 1e-6):
        super().__init__()
        expanded_features = out_features * expansion
        self.block = nn.Sequential(
            # narrow -> wide (with depth-wise and bigger kernel)
            nn.Conv2d(
                in_features, in_features, kernel_size=7, padding=3, bias=False, groups=in_features
            ),
            # GroupNorm with num_groups=1 is the same as LayerNorm but works for 2D data
            nn.GroupNorm(num_groups=1, num_channels=in_features),
            # wide -> wide
            nn.Conv2d(in_features, expanded_features, kernel_size=1),
            nn.GELU(), # Use GELU rather than ReLU 
            # wide -> narrow
            nn.Conv2d(expanded_features, out_features, kernel_size=1),
        )
        self.layer_scaler = LayerScaler(layer_scaler_init_value, out_features)
        self.drop_path = StochasticDepth(drop_p, mode="batch")

    def forward(self, x: Tensor) -> Tensor:
        res = x
        x = self.block(x)
        x = self.layer_scaler(x)
        x = self.drop_path(x)
        x += res
        return x


In [5]:
class ConvNextStage(nn.Sequential):
    def __init__(self, in_features: int, out_features: int, depth: int, **kwargs):
        super().__init__(
            # Downsampling
            nn.Sequential(
                nn.GroupNorm(num_groups=1, num_channels=in_features),
                nn.Conv2d(in_features, out_features, kernel_size=2, stride=2)
            ),
            *[
                BottleNeckBlock(out_features, out_features, **kwargs)
                for _ in range(depth)
            ],
        )

In [6]:
class ConvNextStem(nn.Sequential): 
    '''Patchifying stem: a 7x7 convolution with stride 2 followed by a 3x3 max pooling with stride 2 
       -> a 4x4 convolution with stride 4 (from overlapping to non-overlapping).'''
    def __init__(self, in_features: int, out_features: int):
        super().__init__(
            nn.Conv2d(in_features, out_features, kernel_size=4, stride=4),
            nn.BatchNorm2d(out_features)
        )

In [7]:
class ConvNextEncoder(nn.Module):
    def __init__(self, in_channels: int, stem_features: int, 
                depths: List[int], widths: List[int], drop_p: float = .0):
        super().__init__()
        self.stem = ConvNextStem(in_channels, stem_features)

        in_out_widths = list(zip(widths, widths[1:]))

        # drop probability
        drop_probs = [x.item() for x in torch.linspace(0, drop_p, sum(depths))]
        
        self.stages = nn.ModuleList(
            [
                ConvNextStage(stem_features, widths[0], depths[0], drop_p=drop_probs[0]),
                *[
                    ConvNextStage(in_features, out_features, depth, drop_p=drop_p) 
                    for (in_features, out_features), depth, drop_p in zip(in_out_widths, depths[1:], drop_probs[1:])
                ]
            ]
        )
    
    def forward(self, x):
        x = self.stem(x)
        for stage in self.stages:
            x = stage(x)
        return x

In [13]:
class ClassificationHead(nn.Sequential):
    def __init__(self, num_channels: int, num_classes: int = 1000):
        super().__init__(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(1),
            nn.LayerNorm(num_channels),
            nn.Linear(num_channels, num_classes)
        )

class ConvNext(nn.Sequential):
    '''ConvNext for Image Classification'''
    def __init__(self,
                in_channels: int = 3,
                stem_features: int = 64,
                depths: List[int] = [3, 3, 9, 3],
                widths: List[int] = [96, 192, 384, 768],
                drop_p: float = .0,
                num_classes: int = 1000):

        super().__init__()
        self.encoder = ConvNextEncoder(in_channels, stem_features, depths, widths, drop_p)
        self.head = ClassificationHead(widths[-1], num_classes)

In [19]:
convnext_tiny = ConvNext(depths=[3, 3, 9, 3], widths=[96, 192, 384, 768])
convnext_tiny

ConvNext(
  (encoder): ConvNextEncoder(
    (stem): ConvNextStem(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(4, 4))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (stages): ModuleList(
      (0): ConvNextStage(
        (0): Sequential(
          (0): GroupNorm(1, 64, eps=1e-05, affine=True)
          (1): Conv2d(64, 96, kernel_size=(2, 2), stride=(2, 2))
        )
        (1): BottleNeckBlock(
          (block): Sequential(
            (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96, bias=False)
            (1): GroupNorm(1, 96, eps=1e-05, affine=True)
            (2): Conv2d(96, 384, kernel_size=(1, 1), stride=(1, 1))
            (3): GELU(approximate='none')
            (4): Conv2d(384, 96, kernel_size=(1, 1), stride=(1, 1))
          )
          (layer_scaler): LayerScaler()
          (drop_path): StochasticDepth(p=0.0, mode=batch)
        )
        (2): BottleNeckBlock(
          (block

In [18]:
convnext_small = ConvNext(depths=[3, 3, 27, 3], widths=[96, 192, 384, 768])
convnext_small

ConvNext(
  (encoder): ConvNextEncoder(
    (stem): ConvNextStem(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(4, 4))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (stages): ModuleList(
      (0): ConvNextStage(
        (0): Sequential(
          (0): GroupNorm(1, 64, eps=1e-05, affine=True)
          (1): Conv2d(64, 96, kernel_size=(2, 2), stride=(2, 2))
        )
        (1): BottleNeckBlock(
          (block): Sequential(
            (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96, bias=False)
            (1): GroupNorm(1, 96, eps=1e-05, affine=True)
            (2): Conv2d(96, 384, kernel_size=(1, 1), stride=(1, 1))
            (3): GELU(approximate='none')
            (4): Conv2d(384, 96, kernel_size=(1, 1), stride=(1, 1))
          )
          (layer_scaler): LayerScaler()
          (drop_path): StochasticDepth(p=0.0, mode=batch)
        )
        (2): BottleNeckBlock(
          (block

In [17]:
convnext_base = ConvNext(depths=[3, 3, 27, 3], widths=[128, 256, 512, 1024])
convnext_base

ConvNext(
  (encoder): ConvNextEncoder(
    (stem): ConvNextStem(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(4, 4))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (stages): ModuleList(
      (0): ConvNextStage(
        (0): Sequential(
          (0): GroupNorm(1, 64, eps=1e-05, affine=True)
          (1): Conv2d(64, 128, kernel_size=(2, 2), stride=(2, 2))
        )
        (1): BottleNeckBlock(
          (block): Sequential(
            (0): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128, bias=False)
            (1): GroupNorm(1, 128, eps=1e-05, affine=True)
            (2): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))
            (3): GELU(approximate='none')
            (4): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
          )
          (layer_scaler): LayerScaler()
          (drop_path): StochasticDepth(p=0.0, mode=batch)
        )
        (2): BottleNeckBlock(
         

In [16]:
convnext_large = ConvNext(depths=[3, 3, 27, 3], widths=[192, 384, 768, 1536])
convnext_large

ConvNext(
  (encoder): ConvNextEncoder(
    (stem): ConvNextStem(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(4, 4))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (stages): ModuleList(
      (0): ConvNextStage(
        (0): Sequential(
          (0): GroupNorm(1, 64, eps=1e-05, affine=True)
          (1): Conv2d(64, 192, kernel_size=(2, 2), stride=(2, 2))
        )
        (1): BottleNeckBlock(
          (block): Sequential(
            (0): Conv2d(192, 192, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=192, bias=False)
            (1): GroupNorm(1, 192, eps=1e-05, affine=True)
            (2): Conv2d(192, 768, kernel_size=(1, 1), stride=(1, 1))
            (3): GELU(approximate='none')
            (4): Conv2d(768, 192, kernel_size=(1, 1), stride=(1, 1))
          )
          (layer_scaler): LayerScaler()
          (drop_path): StochasticDepth(p=0.0, mode=batch)
        )
        (2): BottleNeckBlock(
         

In [20]:
convnext_xlarge = ConvNext(depths=[3, 3, 27, 3], widths=[256, 512, 1024, 2048])
convnext_xlarge

ConvNext(
  (encoder): ConvNextEncoder(
    (stem): ConvNextStem(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(4, 4))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (stages): ModuleList(
      (0): ConvNextStage(
        (0): Sequential(
          (0): GroupNorm(1, 64, eps=1e-05, affine=True)
          (1): Conv2d(64, 256, kernel_size=(2, 2), stride=(2, 2))
        )
        (1): BottleNeckBlock(
          (block): Sequential(
            (0): Conv2d(256, 256, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=256, bias=False)
            (1): GroupNorm(1, 256, eps=1e-05, affine=True)
            (2): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))
            (3): GELU(approximate='none')
            (4): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
          )
          (layer_scaler): LayerScaler()
          (drop_path): StochasticDepth(p=0.0, mode=batch)
        )
        (2): BottleNeckBlock(
       