# Implementing ConvNext in PyTorch

### Starting point: ResNet

### DataSet - CIFAR-10

(Initial dataset: ResNet)

# ResNet with all the stages

In [None]:
# function for model training

from TrainModel import train_model

In [None]:
import torch
from torch import nn
from torch import Tensor
from typing import List
from torch import optim
from torch import nn
from tqdm import tqdm
from torchvision.ops import StochasticDepth
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

In [None]:
class ConvNormAct(nn.Sequential):
  '''This class defines a sequence of operations: convolution, normalization, and activation.'''
    def __init__(
        self,
        in_features: int,
        out_features: int,
        kernel_size: int,
        norm = nn.BatchNorm2d,
        act = nn.ReLU,
        **kwargs
    ):
        super().__init__(
            nn.Conv2d(
                in_features,
                out_features,
                kernel_size=kernel_size,
                padding=kernel_size // 2,
                **kwargs
            ),
            norm(out_features),
            act(),
        )

class BottleNeckBlock(nn.Module):
  '''Purpose:
        This class defines a bottleneck block, a common building block in ResNet architectures.
      -----

      Parameters:
        in_features:
            Number of input channels.
        out_features:
            Number of output channels.
        reduction:
            Reduction factor for the bottleneck block.
        stride:
            Stride for the first convolutional layer.

       ------

      Operation:
        First convolutional layer:
          Reduces input channels to reduced_features using a 1x1 kernel.
        Second convolutional layer:
          Applies 3x3 convolution to reduced_features.
        Third convolutional layer:
          Expands back to out_features using a 1x1 kernel and no activation function (nn.Identity()).
        Shortcut connection:
          Performs a 1x1 convolution on the input if the number of input and output channels differ, otherwise, it is an identity connection.
        ReLU activation is applied.
        '''
    def __init__(
        self,
        in_features: int,
        out_features: int,
        reduction: int = 4,
        stride: int = 1,
    ):
        super().__init__()
        reduced_features = out_features // reduction
        self.block = nn.Sequential(
            ConvNormAct(
                in_features, reduced_features, kernel_size=1, stride=stride, bias=False
            ),
            ConvNormAct(reduced_features, reduced_features, kernel_size=3, bias=False),
            ConvNormAct(reduced_features, out_features, kernel_size=1, bias=False, act=nn.Identity),
        )
        self.shortcut = (
            nn.Sequential(
                ConvNormAct(
                    in_features, out_features, kernel_size=1, stride=stride, bias=False
                )
            )
            if in_features != out_features
            else nn.Identity()
        )

        self.act = nn.ReLU()

    def forward(self, x: Tensor) -> Tensor:
        res = x
        x = self.block(x)
        res = self.shortcut(res)
        x += res
        x = self.act(x)
        return x

class ConvNexStage(nn.Sequential):
    '''Purpose:
            This class defines a stage of the ResNet architecture, composed of multiple bottleneck blocks.

        Parameters:
          in_features:
              Number of input channels.
          out_features:
              Number of output channels.
          depth:
              Number of bottleneck blocks in the stage.
          stride:
              Stride for the first bottleneck block.

        Operation:
          First bottleneck block: Applies with the specified stride.
          Subsequent bottleneck blocks: Each applies with a stride of 1.
          The stage consists of a sequence of bottleneck blocks.
        '''
    def __init__(
        self, in_features: int, out_features: int, depth: int, stride: int = 2, **kwargs
    ):
        super().__init__(
            BottleNeckBlock(in_features, out_features, stride=stride, **kwargs),
            *[
                BottleNeckBlock(out_features, out_features, **kwargs)
                for _ in range(depth - 1)
            ],
        )

class ConvNextStem(nn.Sequential):
    '''Purpose:
          This class defines the stem of the ResNet architecture, responsible for the initial transformation of input features.
      Parameters:
          in_features:
              Number of input channels.
          out_features:
              Number of output channels.
      Operation:
          Applies a convolutional layer followed by normalization and activation.'''

    def __init__(self, in_features: int, out_features: int):
        super().__init__(
            ConvNormAct(
                in_features, out_features, kernel_size=3, stride=1  # Adjust kernel_size and stride
            ),
        )

class ConvNextEncoder(nn.Module):
  '''Purpose:
        This class defines the complete ResNet-like encoder architecture.
      Parameters:
          in_channels: Number of input channels.
          stem_features: Number of features after the stem.
          depths: List specifying the number of bottleneck blocks in each stage.
          widths: List specifying the number of output channels for each stage.
          num_classes: Number of output classes (default is 10 for CIFAR-10).
      Operation:
          Stem: Applies initial transformations to input features.
          Stages: A series of ConvNexStages, each consisting of multiple bottleneck blocks.
          Global Average Pooling (GAP): Reduces spatial dimensions to 1x1.
          Fully Connected Layer: Produces final output predictions.
      '''
    def __init__(
        self,
        in_channels: int,
        stem_features: int,
        depths: List[int],
        widths: List[int],
        num_classes: int = 10  # Adjust for CIFAR-10
    ):
        super().__init__()
        self.stem = ConvNextStem(in_channels, stem_features)

        in_out_widths = list(zip(widths, widths[1:]))

        self.stages = nn.ModuleList(
            [
                ConvNexStage(stem_features, widths[0], depths[0], stride=1),
                *[
                    ConvNexStage(in_features, out_features, depth)
                    for (in_features, out_features), depth in zip(
                        in_out_widths, depths[1:]
                    )
                ],
            ]
        )

        # Global Average Pooling (GAP)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(widths[-1], num_classes)

    def forward(self, x):
        x = self.stem(x)
        for stage in self.stages:
            x = stage(x)

        # Global Average Pooling
        x = self.avg_pool(x)
        x = x.view(x.size(0), -1)

        # Fully Connected Layer
        x = self.fc(x)

        return x

In [None]:
model_1 = ConvNextEncoder(in_channels=3, stem_features=64, depths=[3, 4, 6, 4], widths=[256, 512, 1024, 2048], num_classes=10)
optimizer_1 = optim.SGD(model_1.parameters(), lr=0.01, momentum=0.9)
train_model(model_1, optimizer_1, num_epochs=10, learning_rate=0.01, batch_size=64)


Files already downloaded and verified
Files already downloaded and verified


Epoch 1/10 - Training: 100%|██████████| 782/782 [02:51<00:00,  4.55it/s]
Epoch 1/10 - Validation: 100%|██████████| 157/157 [00:10<00:00, 14.85it/s]


Epoch 1/10, Training Loss: 1.9989, Training Accuracy: 29.87%, Validation Loss: 3.2034, Validation Accuracy: 42.51


Epoch 2/10 - Training: 100%|██████████| 782/782 [02:49<00:00,  4.61it/s]
Epoch 2/10 - Validation: 100%|██████████| 157/157 [00:10<00:00, 15.14it/s]


Epoch 2/10, Training Loss: 1.3789, Training Accuracy: 49.87%, Validation Loss: 1.2811, Validation Accuracy: 54.16


Epoch 3/10 - Training: 100%|██████████| 782/782 [02:49<00:00,  4.61it/s]
Epoch 3/10 - Validation: 100%|██████████| 157/157 [00:10<00:00, 15.21it/s]


Epoch 3/10, Training Loss: 1.0901, Training Accuracy: 61.06%, Validation Loss: 1.2147, Validation Accuracy: 61.31


Epoch 4/10 - Training: 100%|██████████| 782/782 [02:49<00:00,  4.61it/s]
Epoch 4/10 - Validation: 100%|██████████| 157/157 [00:10<00:00, 15.02it/s]


Epoch 4/10, Training Loss: 0.8953, Training Accuracy: 68.35%, Validation Loss: 0.8347, Validation Accuracy: 70.75


Epoch 5/10 - Training: 100%|██████████| 782/782 [02:48<00:00,  4.64it/s]
Epoch 5/10 - Validation: 100%|██████████| 157/157 [00:10<00:00, 15.15it/s]


Epoch 5/10, Training Loss: 0.7537, Training Accuracy: 73.47%, Validation Loss: 0.7181, Validation Accuracy: 75.55


Epoch 6/10 - Training: 100%|██████████| 782/782 [02:47<00:00,  4.66it/s]
Epoch 6/10 - Validation: 100%|██████████| 157/157 [00:10<00:00, 15.18it/s]


Epoch 6/10, Training Loss: 0.6448, Training Accuracy: 77.60%, Validation Loss: 0.7129, Validation Accuracy: 75.57


Epoch 7/10 - Training: 100%|██████████| 782/782 [02:48<00:00,  4.64it/s]
Epoch 7/10 - Validation: 100%|██████████| 157/157 [00:10<00:00, 15.07it/s]


Epoch 7/10, Training Loss: 0.5729, Training Accuracy: 80.16%, Validation Loss: 0.6304, Validation Accuracy: 79.05


Epoch 8/10 - Training: 100%|██████████| 782/782 [02:49<00:00,  4.63it/s]
Epoch 8/10 - Validation: 100%|██████████| 157/157 [00:10<00:00, 15.21it/s]


Epoch 8/10, Training Loss: 0.5168, Training Accuracy: 82.17%, Validation Loss: 0.5838, Validation Accuracy: 81.04


Epoch 9/10 - Training: 100%|██████████| 782/782 [02:48<00:00,  4.65it/s]
Epoch 9/10 - Validation: 100%|██████████| 157/157 [00:10<00:00, 15.36it/s]


Epoch 9/10, Training Loss: 0.4631, Training Accuracy: 84.04%, Validation Loss: 0.4925, Validation Accuracy: 83.02


Epoch 10/10 - Training: 100%|██████████| 782/782 [02:47<00:00,  4.66it/s]
Epoch 10/10 - Validation: 100%|██████████| 157/157 [00:10<00:00, 15.49it/s]

Epoch 10/10, Training Loss: 0.4259, Training Accuracy: 85.32%, Validation Loss: 0.6067, Validation Accuracy: 80.55





# ConvNeXt
-----
## Patchify + AdamW + Changing stage compute ratio
-----
### 1. Parchify

- basically means that we increase kernel size and the kernel size is equal to stride

### 2. AdamW started to be used when Transformers came to CV

### 3. Changing stage compute ratio

- For larger Swin Transformers, the ratio is 1:1:9:1. Following the design, we adjust the number of blocks in each stage from (3, 4, 6, 3) in ResNet-50 to (3, 3, 9, 3), which also aligns the FLOPs with Swin-T.

---
### Expectation:

Accuracy increase by 2.7%

In [None]:
class ConvNextStem(nn.Sequential):
    def __init__(self, in_features: int, out_features: int):
        super().__init__(
            nn.Conv2d(in_features, out_features, kernel_size=4, stride=4),
            nn.BatchNorm2d(out_features)
        )

# from now on
LR = 0.005
EPOCH = 20
BATCH = 64

In [None]:
model_2 = ConvNextEncoder(in_channels=3, stem_features=64, depths=[3,3,9,3], widths=[256, 512, 1024, 2048], num_classes=10)
optimizer = optim.AdamW(model_2.parameters(), lr=LR)
train_model(model_2, optimizer, num_epochs=EPOCH, learning_rate=LR, batch_size=BATCH)

Files already downloaded and verified
Files already downloaded and verified


Epoch 1/20 - Training: 100%|██████████| 782/782 [00:49<00:00, 15.66it/s]
Epoch 1/20 - Validation: 100%|██████████| 157/157 [00:02<00:00, 52.37it/s]


Epoch 1/20, Training Loss: 2.6066, Training Accuracy: 16.72%, Validation Loss: 1.9899, Validation Accuracy: 21.72


Epoch 2/20 - Training: 100%|██████████| 782/782 [00:49<00:00, 15.82it/s]
Epoch 2/20 - Validation: 100%|██████████| 157/157 [00:03<00:00, 50.86it/s]


Epoch 2/20, Training Loss: 1.9268, Training Accuracy: 25.13%, Validation Loss: 1.7953, Validation Accuracy: 30.51


Epoch 3/20 - Training: 100%|██████████| 782/782 [00:49<00:00, 15.75it/s]
Epoch 3/20 - Validation: 100%|██████████| 157/157 [00:03<00:00, 51.56it/s]


Epoch 3/20, Training Loss: 1.7057, Training Accuracy: 35.34%, Validation Loss: 1.6111, Validation Accuracy: 41.21


Epoch 4/20 - Training: 100%|██████████| 782/782 [00:48<00:00, 16.15it/s]
Epoch 4/20 - Validation: 100%|██████████| 157/157 [00:03<00:00, 44.65it/s]


Epoch 4/20, Training Loss: 1.5967, Training Accuracy: 40.44%, Validation Loss: 1.7732, Validation Accuracy: 35.01


Epoch 5/20 - Training: 100%|██████████| 782/782 [00:49<00:00, 15.93it/s]
Epoch 5/20 - Validation: 100%|██████████| 157/157 [00:03<00:00, 40.56it/s]


Epoch 5/20, Training Loss: 1.5606, Training Accuracy: 41.85%, Validation Loss: 1.4346, Validation Accuracy: 47.21


Epoch 6/20 - Training: 100%|██████████| 782/782 [00:49<00:00, 15.94it/s]
Epoch 6/20 - Validation: 100%|██████████| 157/157 [00:03<00:00, 45.85it/s]


Epoch 6/20, Training Loss: 1.4305, Training Accuracy: 47.52%, Validation Loss: 1.4856, Validation Accuracy: 46.07


Epoch 7/20 - Training: 100%|██████████| 782/782 [00:48<00:00, 16.00it/s]
Epoch 7/20 - Validation: 100%|██████████| 157/157 [00:03<00:00, 52.32it/s]


Epoch 7/20, Training Loss: 1.4125, Training Accuracy: 48.42%, Validation Loss: 1.4159, Validation Accuracy: 49.87


Epoch 8/20 - Training: 100%|██████████| 782/782 [00:49<00:00, 15.85it/s]
Epoch 8/20 - Validation: 100%|██████████| 157/157 [00:03<00:00, 49.26it/s]


Epoch 8/20, Training Loss: 1.3154, Training Accuracy: 52.34%, Validation Loss: 1.3946, Validation Accuracy: 50.19


Epoch 9/20 - Training: 100%|██████████| 782/782 [00:49<00:00, 15.77it/s]
Epoch 9/20 - Validation: 100%|██████████| 157/157 [00:03<00:00, 50.98it/s]


Epoch 9/20, Training Loss: 1.2238, Training Accuracy: 55.86%, Validation Loss: 1.2369, Validation Accuracy: 56.85


Epoch 10/20 - Training: 100%|██████████| 782/782 [00:49<00:00, 15.65it/s]
Epoch 10/20 - Validation: 100%|██████████| 157/157 [00:03<00:00, 51.11it/s]


Epoch 10/20, Training Loss: 1.1608, Training Accuracy: 58.31%, Validation Loss: 1.1455, Validation Accuracy: 60.37


Epoch 11/20 - Training: 100%|██████████| 782/782 [00:50<00:00, 15.41it/s]
Epoch 11/20 - Validation: 100%|██████████| 157/157 [00:03<00:00, 51.15it/s]


Epoch 11/20, Training Loss: 1.1067, Training Accuracy: 60.73%, Validation Loss: 1.1786, Validation Accuracy: 58.74


Epoch 12/20 - Training: 100%|██████████| 782/782 [00:49<00:00, 15.77it/s]
Epoch 12/20 - Validation: 100%|██████████| 157/157 [00:03<00:00, 51.21it/s]


Epoch 12/20, Training Loss: 1.0665, Training Accuracy: 62.22%, Validation Loss: 1.0440, Validation Accuracy: 62.82


Epoch 13/20 - Training: 100%|██████████| 782/782 [00:49<00:00, 15.83it/s]
Epoch 13/20 - Validation: 100%|██████████| 157/157 [00:03<00:00, 48.05it/s]


Epoch 13/20, Training Loss: 1.0253, Training Accuracy: 63.79%, Validation Loss: 1.0837, Validation Accuracy: 61.49


Epoch 14/20 - Training: 100%|██████████| 782/782 [00:49<00:00, 15.91it/s]
Epoch 14/20 - Validation: 100%|██████████| 157/157 [00:03<00:00, 44.53it/s]


Epoch 14/20, Training Loss: 0.9823, Training Accuracy: 65.20%, Validation Loss: 0.9372, Validation Accuracy: 66.98


Epoch 15/20 - Training: 100%|██████████| 782/782 [00:48<00:00, 15.97it/s]
Epoch 15/20 - Validation: 100%|██████████| 157/157 [00:03<00:00, 40.58it/s]


Epoch 15/20, Training Loss: 0.9547, Training Accuracy: 66.32%, Validation Loss: 0.9350, Validation Accuracy: 67.77


Epoch 16/20 - Training: 100%|██████████| 782/782 [00:48<00:00, 16.11it/s]
Epoch 16/20 - Validation: 100%|██████████| 157/157 [00:03<00:00, 46.38it/s]


Epoch 16/20, Training Loss: 0.9333, Training Accuracy: 67.01%, Validation Loss: 0.8544, Validation Accuracy: 70.53


Epoch 17/20 - Training: 100%|██████████| 782/782 [00:48<00:00, 16.02it/s]
Epoch 17/20 - Validation: 100%|██████████| 157/157 [00:02<00:00, 52.36it/s]


Epoch 17/20, Training Loss: 0.8722, Training Accuracy: 69.53%, Validation Loss: 0.8036, Validation Accuracy: 72.11


Epoch 18/20 - Training: 100%|██████████| 782/782 [00:49<00:00, 15.88it/s]
Epoch 18/20 - Validation: 100%|██████████| 157/157 [00:03<00:00, 51.60it/s]


Epoch 18/20, Training Loss: 0.8432, Training Accuracy: 70.63%, Validation Loss: 0.8230, Validation Accuracy: 71.82


Epoch 19/20 - Training: 100%|██████████| 782/782 [00:49<00:00, 15.89it/s]
Epoch 19/20 - Validation: 100%|██████████| 157/157 [00:03<00:00, 51.38it/s]


Epoch 19/20, Training Loss: 0.8373, Training Accuracy: 70.98%, Validation Loss: 0.8124, Validation Accuracy: 72.37


Epoch 20/20 - Training: 100%|██████████| 782/782 [00:49<00:00, 15.93it/s]
Epoch 20/20 - Validation: 100%|██████████| 157/157 [00:03<00:00, 51.49it/s]

Epoch 20/20, Training Loss: 0.7885, Training Accuracy: 72.50%, Validation Loss: 0.7729, Validation Accuracy: 73.32





## ResNeXt-ify + Inverted Bottleneck
-----

### * ResNeXt-ify
In ConvNext, they use depth-wise convolution (like in MobileNet and later in EfficientNet). Depth-wise convs are grouped convolutions where the number of groups is equal to the number of input channels.

The authors notice that is very similar to the weighted sum operation in self-attention, which mixes information only in the spatial dimension. Using depth-wise convs reduce the accuracy (since we are not increasing the widths like in ResNetXt), this is expected.

So we change our 3x3 conv inside BottleNeck block to

```ConvNormAct(reduced_features, reduced_features, kernel_size=3, bias=False, groups=reduced_features)```

### * Inverted Bottleneck

Our BottleNeck first reduces the features via a 1x1 conv, then it applies the heavy 3x3 conv and finally expands the features to the original size. An inverted bottleneck block, does the opposite. I have a whole article with nice visualization about them.

So we go from ```wide -> narrow -> wide``` to ```narrow -> wide -> narrow```

This is similar to Transformers, since the MLP layer follows the narrow -> wide -> narrow design, the second dense layer in the MLP expands the input's feature by a factor of four.


In [None]:
class BottleNeckBlock(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        expansion: int = 4,
        stride: int = 1,
    ):
        super().__init__()
        expanded_features = out_features * expansion
        self.block = nn.Sequential(
            # narrow -> wide
            ConvNormAct(
                in_features, expanded_features, kernel_size=1, stride=stride, bias=False
            ),
            # wide -> wide (with depth-wise)
            ConvNormAct(expanded_features, expanded_features, kernel_size=3, bias=False, groups=in_features), # groups refer to  ResNexT
            # wide -> narrow
            ConvNormAct(expanded_features, out_features, kernel_size=1, bias=False, act=nn.Identity),
        )
        self.shortcut = (
            nn.Sequential(
                ConvNormAct(
                    in_features, out_features, kernel_size=1, stride=stride, bias=False
                )
            )
            if in_features != out_features
            else nn.Identity()
        )

        self.act = nn.ReLU()

    def forward(self, x: Tensor) -> Tensor:
        res = x
        x = self.block(x)
        res = self.shortcut(res)
        x += res
        x = self.act(x)
        return x

In [None]:
model_3 = ConvNextEncoder(in_channels=3, stem_features=64, depths=[3,3,9,3], widths=[256, 512, 1024, 2048], num_classes=10)
optimizer = optim.AdamW(model_3.parameters(), lr=LR)
train_model(model_3, optimizer, num_epochs=EPOCH, learning_rate=LR, batch_size=BATCH)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:04<00:00, 36638959.03it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


Epoch 1/20 - Training: 100%|██████████| 782/782 [04:18<00:00,  3.03it/s]
Epoch 1/20 - Validation: 100%|██████████| 157/157 [00:20<00:00,  7.81it/s]


Epoch 1/20, Training Loss: 2.5493, Training Accuracy: 19.54%, Validation Loss: 2.0077, Validation Accuracy: 27.11


Epoch 2/20 - Training: 100%|██████████| 782/782 [04:11<00:00,  3.11it/s]
Epoch 2/20 - Validation: 100%|██████████| 157/157 [00:20<00:00,  7.77it/s]


Epoch 2/20, Training Loss: 1.8892, Training Accuracy: 29.38%, Validation Loss: 1.7289, Validation Accuracy: 36.24


Epoch 3/20 - Training: 100%|██████████| 782/782 [04:12<00:00,  3.10it/s]
Epoch 3/20 - Validation: 100%|██████████| 157/157 [00:20<00:00,  7.70it/s]


Epoch 3/20, Training Loss: 1.8473, Training Accuracy: 30.11%, Validation Loss: 1.6320, Validation Accuracy: 38.07


Epoch 4/20 - Training: 100%|██████████| 782/782 [04:12<00:00,  3.10it/s]
Epoch 4/20 - Validation: 100%|██████████| 157/157 [00:20<00:00,  7.71it/s]


Epoch 4/20, Training Loss: 1.5876, Training Accuracy: 40.70%, Validation Loss: 1.4603, Validation Accuracy: 45.97


Epoch 5/20 - Training: 100%|██████████| 782/782 [04:12<00:00,  3.10it/s]
Epoch 5/20 - Validation: 100%|██████████| 157/157 [00:20<00:00,  7.78it/s]


Epoch 5/20, Training Loss: 1.4047, Training Accuracy: 48.44%, Validation Loss: 1.3657, Validation Accuracy: 50.88


Epoch 6/20 - Training: 100%|██████████| 782/782 [04:11<00:00,  3.11it/s]
Epoch 6/20 - Validation: 100%|██████████| 157/157 [00:20<00:00,  7.77it/s]


Epoch 6/20, Training Loss: 1.2575, Training Accuracy: 54.76%, Validation Loss: 1.5803, Validation Accuracy: 51.23


Epoch 7/20 - Training: 100%|██████████| 782/782 [04:11<00:00,  3.11it/s]
Epoch 7/20 - Validation: 100%|██████████| 157/157 [00:20<00:00,  7.75it/s]


Epoch 7/20, Training Loss: 1.1299, Training Accuracy: 59.70%, Validation Loss: 1.0024, Validation Accuracy: 64.47


Epoch 8/20 - Training: 100%|██████████| 782/782 [04:11<00:00,  3.11it/s]
Epoch 8/20 - Validation: 100%|██████████| 157/157 [00:20<00:00,  7.82it/s]


Epoch 8/20, Training Loss: 1.0309, Training Accuracy: 63.45%, Validation Loss: 1.1143, Validation Accuracy: 61.65


Epoch 9/20 - Training: 100%|██████████| 782/782 [04:11<00:00,  3.11it/s]
Epoch 9/20 - Validation: 100%|██████████| 157/157 [00:20<00:00,  7.79it/s]


Epoch 9/20, Training Loss: 0.9724, Training Accuracy: 65.93%, Validation Loss: 0.9546, Validation Accuracy: 67.87


Epoch 10/20 - Training: 100%|██████████| 782/782 [04:11<00:00,  3.11it/s]
Epoch 10/20 - Validation: 100%|██████████| 157/157 [00:20<00:00,  7.78it/s]


Epoch 10/20, Training Loss: 0.8911, Training Accuracy: 69.00%, Validation Loss: 0.8089, Validation Accuracy: 71.92


Epoch 11/20 - Training: 100%|██████████| 782/782 [04:11<00:00,  3.11it/s]
Epoch 11/20 - Validation: 100%|██████████| 157/157 [00:20<00:00,  7.79it/s]


Epoch 11/20, Training Loss: 0.8593, Training Accuracy: 70.03%, Validation Loss: 0.7814, Validation Accuracy: 73.16


Epoch 12/20 - Training: 100%|██████████| 782/782 [04:10<00:00,  3.12it/s]
Epoch 12/20 - Validation: 100%|██████████| 157/157 [00:20<00:00,  7.77it/s]


Epoch 12/20, Training Loss: 0.7625, Training Accuracy: 73.72%, Validation Loss: 0.7584, Validation Accuracy: 74.08


Epoch 13/20 - Training: 100%|██████████| 782/782 [04:10<00:00,  3.12it/s]
Epoch 13/20 - Validation: 100%|██████████| 157/157 [00:20<00:00,  7.80it/s]


Epoch 13/20, Training Loss: 0.7211, Training Accuracy: 75.17%, Validation Loss: 0.6958, Validation Accuracy: 76.68


Epoch 14/20 - Training: 100%|██████████| 782/782 [04:10<00:00,  3.12it/s]
Epoch 14/20 - Validation: 100%|██████████| 157/157 [00:20<00:00,  7.80it/s]


Epoch 14/20, Training Loss: 0.6946, Training Accuracy: 76.10%, Validation Loss: 0.6861, Validation Accuracy: 76.28


Epoch 15/20 - Training: 100%|██████████| 782/782 [04:10<00:00,  3.12it/s]
Epoch 15/20 - Validation: 100%|██████████| 157/157 [00:20<00:00,  7.82it/s]


Epoch 15/20, Training Loss: 0.6608, Training Accuracy: 77.33%, Validation Loss: 0.6710, Validation Accuracy: 77.01


Epoch 16/20 - Training: 100%|██████████| 782/782 [04:10<00:00,  3.13it/s]
Epoch 16/20 - Validation: 100%|██████████| 157/157 [00:19<00:00,  7.87it/s]


Epoch 16/20, Training Loss: 0.6606, Training Accuracy: 77.49%, Validation Loss: 0.6319, Validation Accuracy: 78.34


Epoch 17/20 - Training: 100%|██████████| 782/782 [04:10<00:00,  3.13it/s]
Epoch 17/20 - Validation: 100%|██████████| 157/157 [00:20<00:00,  7.83it/s]


Epoch 17/20, Training Loss: 0.6335, Training Accuracy: 78.25%, Validation Loss: 0.6663, Validation Accuracy: 77.83


Epoch 18/20 - Training: 100%|██████████| 782/782 [04:09<00:00,  3.13it/s]
Epoch 18/20 - Validation: 100%|██████████| 157/157 [00:20<00:00,  7.82it/s]


Epoch 18/20, Training Loss: 0.5949, Training Accuracy: 79.90%, Validation Loss: 3.0171, Validation Accuracy: 69.24


Epoch 19/20 - Training: 100%|██████████| 782/782 [04:09<00:00,  3.14it/s]
Epoch 19/20 - Validation: 100%|██████████| 157/157 [00:20<00:00,  7.85it/s]


Epoch 19/20, Training Loss: 0.5943, Training Accuracy: 79.66%, Validation Loss: 0.6413, Validation Accuracy: 78.24


Epoch 20/20 - Training: 100%|██████████| 782/782 [04:09<00:00,  3.13it/s]
Epoch 20/20 - Validation: 100%|██████████| 157/157 [00:20<00:00,  7.78it/s]

Epoch 20/20, Training Loss: 0.5432, Training Accuracy: 81.32%, Validation Loss: 0.5725, Validation Accuracy: 80.36





## Large Kernel Sizes
-------

Modern Vision Transfomer, like Swin, uses a bigger kernel size (7x7). Increasing the kernel size will make the computation more expensive, so we move up the big depth-wise conv, by doing so we will have fewer channels. The authors note this is similar to Transformers model where the Multihead Self Attention (MSA) is done before the MLP layers.

In [None]:
# on VM
# best score is
# colab crahes
# Epoch 19/20, Training Loss: 0.5628, Training Accuracy: 80.58%, Validation Loss: 0.6100, Validation Accuracy: 78.99

class BottleNeckBlock(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        expansion: int = 4,
        stride: int = 1,
    ):
        super().__init__()
        expanded_features = out_features * expansion
        self.block = nn.Sequential(
            # narrow -> wide (with depth-wise and bigger kernel)
            ConvNormAct(
                in_features, in_features, kernel_size=7, stride=stride, bias=False, groups=in_features
            ),
            # wide -> wide
            ConvNormAct(in_features, expanded_features, kernel_size=1),
            # wide -> narrow
            ConvNormAct(expanded_features, out_features, kernel_size=1, bias=False, act=nn.Identity),
        )
        self.shortcut = (
            nn.Sequential(
                ConvNormAct(
                    in_features, out_features, kernel_size=1, stride=stride, bias=False
                )
            )
            if in_features != out_features
            else nn.Identity()
        )

        self.act = nn.ReLU()

    def forward(self, x: Tensor) -> Tensor:
        res = x
        x = self.block(x)
        res = self.shortcut(res)
        x += res
        x = self.act(x)
        return x

In [None]:
model_4 = ConvNextEncoder(in_channels=3, stem_features=64, depths=[3,3,9,3], widths=[256, 512, 1024, 2048], num_classes=10)
optimizer = optim.AdamW(model_3.parameters(), lr=LR)
train_model(model_4, optimizer, num_epochs=EPOCH, learning_rate=LR, batch_size=BATCH)

In [None]:
print('Epoch 19/20, Training Loss: 0.5628, Training Accuracy: 80.58%, Validation Loss: 0.6100, Validation Accuracy: 78.99')

Epoch 19/20, Training Loss: 0.5628, Training Accuracy: 80.58%, Validation Loss: 0.6100, Validation Accuracy: 78.99


# Micro Design
----
##Replacing ReLU with GELU

- Since GELU is used by the most advanced transformers, why not use it in our model? The authors report the accuracy stays unchanged. In PyTorch GELU in nn.GELU.

## Fewer activation functions

- Our block has three activation functions. While, in Transformer block, there is only one activation function, the one inside the MLP block. The authors removed all the activations except for the one after the middle conv layer. This improves accuracy to 81.3% matching Swin-T!

## Fewer normalization layers

- Similar to activations, Transformers blocks have fewer normalization layers. The authors decide the remove all the BatchNorm and kept only the one before the middle conv.

## Substituting BN with LN

- Well, they substitute the BatchNorm layers with LinearyNorm. They note that doing so in the original ResNet hurts performance, but after all our changes, the performance increases to 81.5%

## Separate downsampling layers.

- In ResNet the downsampling is done by the stride=2 conv.

- Transformers (and other conv nets too) have a separate downsampling block.

- The authors removed the stride=2 and add a downsampling block before the three convs using a 2x2 stride=2 conv. Normalization is needed before the downsampling operation to maintain stability during training.

We can add this module to our ConvNexStage.


## Stochastic Depth, also known as Drop Path, and Layer Scale.

In [None]:
BATCH = 64

class ConvNormAct(nn.Sequential):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        kernel_size: int,
        norm = nn.BatchNorm2d,
        act = nn.ReLU,
        **kwargs
    ):
        super().__init__(
            nn.Conv2d(
                in_features,
                out_features,
                kernel_size=kernel_size,
                padding=kernel_size // 2,
                **kwargs
            ),
            norm(out_features),
            act(),
        )


#####################

class BottleNeckBlock(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        expansion: int = 4,
        drop_p: float = .0,
        layer_scaler_init_value: float = 1e-6,
    ):
        super().__init__()
        expanded_features = out_features * expansion
        self.block = nn.Sequential(
            # narrow -> wide (with depth-wise and bigger kernel)
            nn.Conv2d(
                in_features, in_features, kernel_size=7, padding=3, bias=False, groups=in_features
            ),
            # GroupNorm with num_groups=1 is the same as LayerNorm but works for 2D data
            nn.GroupNorm(num_groups=1, num_channels=in_features),
            # wide -> wide
            nn.Conv2d(in_features, expanded_features, kernel_size=1),
            nn.GELU(),
            # wide -> narrow
            nn.Conv2d(expanded_features, out_features, kernel_size=1),
        )
        self.layer_scaler = LayerScaler(layer_scaler_init_value, out_features)
        self.drop_path = StochasticDepth(drop_p, mode="batch")


    def forward(self, x: Tensor) -> Tensor:
        res = x
        x = self.block(x)
        x = self.layer_scaler(x)
        x = self.drop_path(x)
        x += res
        return x

#################

class ConvNexStage(nn.Sequential):
    def __init__(
        self, in_features: int, out_features: int, depth: int, **kwargs
    ):
        super().__init__(
            # add the downsampler
            nn.Sequential(
                nn.GroupNorm(num_groups=1, num_channels=in_features),
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1)
            ),
            *[
                BottleNeckBlock(out_features, out_features, **kwargs)
                for _ in range(depth)
            ],
        )

class ConvNextStem(nn.Sequential):
    def __init__(self, in_features: int, out_features: int):
        super().__init__(
            nn.Conv2d(in_features, out_features, kernel_size=4, stride=4),
            nn.BatchNorm2d(out_features)
        )


class LayerScaler(nn.Module):
    def __init__(self, init_value: float, dimensions: int):
        super().__init__()
        self.gamma = nn.Parameter(init_value * torch.ones((dimensions)),
                                    requires_grad=True)

    def forward(self, x):
        return self.gamma[None,...,None,None] * x

class ConvNextEncoder(nn.Module):
    def __init__(
        self,
        in_channels: int,
        stem_features: int,
        depths: List[int],
        widths: List[int],
        drop_p: float = .0,
    ):
        super().__init__()
        self.stem = ConvNextStem(in_channels, stem_features)

        in_out_widths = list(zip(widths, widths[1:]))
        # create drop paths probabilities (one for each stage)
        drop_probs = [x.item() for x in torch.linspace(0, drop_p, sum(depths))]

        self.stages = nn.ModuleList(
            [
                ConvNexStage(stem_features, widths[0], depths[0], drop_p=drop_probs[0]),
                *[
                    ConvNexStage(in_features, out_features, depth, drop_p=drop_p)
                    for (in_features, out_features), depth, drop_p in zip(
                        in_out_widths, depths[1:], drop_probs[1:]
                    )
                ],
            ]
        )


    def forward(self, x):
        x = self.stem(x)
        for stage in self.stages:
            x = stage(x)
        return x

class ClassificationHead(nn.Sequential):
    def __init__(self, num_channels: int, num_classes: int = 10):
        super().__init__(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(1),
            nn.LayerNorm(num_channels),
            nn.Linear(num_channels, num_classes)
        )

class ConvNextForImageClassification(nn.Sequential):
    def __init__(self,
                 in_channels: int,
                 stem_features: int,
                 depths: List[int],
                 widths: List[int],
                 drop_p: float = .0,
                 num_classes: int = 10):
        super().__init__()
        self.encoder = ConvNextEncoder(in_channels, stem_features, depths, widths, drop_p)
        self.head = ClassificationHead(widths[-1], num_classes)


# CIFAR-10 Dataset
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
])

train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH, shuffle=True)

test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transforms.ToTensor(), download=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH, shuffle=False)

# Create the model
model = ConvNextForImageClassification(in_channels=3, stem_features=64, depths=[3, 3, 9, 3], widths=[256, 512, 1024, 2048], num_classes=10)
optimizer = optim.AdamW(model.parameters(), lr=0.005)
criterion = nn.CrossEntropyLoss()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criterion = criterion.to(device)

# Training loop
num_epochs = 50
for epoch in range(num_epochs):
    model.train()
    for inputs, targets in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}'):
        # Move inputs and targets to the GPU
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

    # Validation
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for inputs, targets in test_loader:
            # Move inputs and targets to the GPU
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

        accuracy = correct / total
        print(f'Epoch {epoch + 1}/{num_epochs}, Validation Accuracy: {accuracy * 100:.2f}%')



Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 45206921.47it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


Epoch 1/50: 100%|██████████| 782/782 [02:26<00:00,  5.36it/s]


Epoch 1/50, Validation Accuracy: 35.77%


Epoch 2/50: 100%|██████████| 782/782 [02:22<00:00,  5.49it/s]


Epoch 2/50, Validation Accuracy: 42.50%


Epoch 3/50: 100%|██████████| 782/782 [02:22<00:00,  5.48it/s]


Epoch 3/50, Validation Accuracy: 43.96%


Epoch 4/50: 100%|██████████| 782/782 [02:22<00:00,  5.48it/s]


Epoch 4/50, Validation Accuracy: 48.58%


Epoch 5/50: 100%|██████████| 782/782 [02:22<00:00,  5.48it/s]


Epoch 5/50, Validation Accuracy: 50.19%


Epoch 6/50: 100%|██████████| 782/782 [02:22<00:00,  5.48it/s]


Epoch 6/50, Validation Accuracy: 52.96%


Epoch 7/50: 100%|██████████| 782/782 [02:22<00:00,  5.48it/s]


Epoch 7/50, Validation Accuracy: 53.77%


Epoch 8/50: 100%|██████████| 782/782 [02:22<00:00,  5.48it/s]


Epoch 8/50, Validation Accuracy: 53.33%


Epoch 9/50: 100%|██████████| 782/782 [02:22<00:00,  5.48it/s]


Epoch 9/50, Validation Accuracy: 55.14%


Epoch 10/50: 100%|██████████| 782/782 [02:22<00:00,  5.48it/s]


Epoch 10/50, Validation Accuracy: 54.54%


Epoch 11/50: 100%|██████████| 782/782 [02:22<00:00,  5.49it/s]


Epoch 11/50, Validation Accuracy: 57.45%


Epoch 12/50: 100%|██████████| 782/782 [02:22<00:00,  5.49it/s]


Epoch 12/50, Validation Accuracy: 58.25%


Epoch 13/50: 100%|██████████| 782/782 [02:22<00:00,  5.48it/s]


Epoch 13/50, Validation Accuracy: 57.85%


Epoch 14/50: 100%|██████████| 782/782 [02:22<00:00,  5.49it/s]


Epoch 14/50, Validation Accuracy: 59.21%


Epoch 15/50: 100%|██████████| 782/782 [02:22<00:00,  5.48it/s]


Epoch 15/50, Validation Accuracy: 60.41%


Epoch 16/50: 100%|██████████| 782/782 [02:22<00:00,  5.48it/s]


Epoch 16/50, Validation Accuracy: 60.13%


Epoch 17/50: 100%|██████████| 782/782 [02:22<00:00,  5.48it/s]


Epoch 17/50, Validation Accuracy: 60.69%


Epoch 18/50: 100%|██████████| 782/782 [02:22<00:00,  5.49it/s]


Epoch 18/50, Validation Accuracy: 60.55%


Epoch 19/50: 100%|██████████| 782/782 [02:22<00:00,  5.48it/s]


Epoch 19/50, Validation Accuracy: 60.97%


Epoch 20/50: 100%|██████████| 782/782 [02:22<00:00,  5.48it/s]


Epoch 20/50, Validation Accuracy: 61.57%


Epoch 21/50: 100%|██████████| 782/782 [02:22<00:00,  5.49it/s]


Epoch 21/50, Validation Accuracy: 63.04%


Epoch 22/50: 100%|██████████| 782/782 [02:22<00:00,  5.49it/s]


Epoch 22/50, Validation Accuracy: 62.33%


Epoch 23/50: 100%|██████████| 782/782 [02:22<00:00,  5.49it/s]


Epoch 23/50, Validation Accuracy: 63.14%


Epoch 24/50: 100%|██████████| 782/782 [02:22<00:00,  5.49it/s]


Epoch 24/50, Validation Accuracy: 63.70%


Epoch 25/50: 100%|██████████| 782/782 [02:22<00:00,  5.49it/s]


Epoch 25/50, Validation Accuracy: 63.61%


Epoch 26/50: 100%|██████████| 782/782 [02:22<00:00,  5.48it/s]


Epoch 26/50, Validation Accuracy: 64.35%


Epoch 27/50: 100%|██████████| 782/782 [02:22<00:00,  5.48it/s]


Epoch 27/50, Validation Accuracy: 63.61%


Epoch 28/50: 100%|██████████| 782/782 [02:22<00:00,  5.49it/s]


Epoch 28/50, Validation Accuracy: 63.61%


Epoch 29/50: 100%|██████████| 782/782 [02:22<00:00,  5.48it/s]


Epoch 29/50, Validation Accuracy: 64.94%


Epoch 30/50: 100%|██████████| 782/782 [02:22<00:00,  5.48it/s]


Epoch 30/50, Validation Accuracy: 65.52%


Epoch 31/50: 100%|██████████| 782/782 [02:22<00:00,  5.48it/s]


Epoch 31/50, Validation Accuracy: 64.38%


Epoch 32/50:   3%|▎         | 21/782 [00:03<02:18,  5.48it/s]

## SWIN transformer


In [7]:
import torch
import torch.nn as nn

class SwinBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
        super(SwinBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=kernel_size // 2)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, 1, padding=kernel_size // 2)
        self.norm = nn.BatchNorm2d(out_channels)
        self.act = nn.ReLU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.act(x)
        x = self.conv2(x)
        x = self.norm(x)
        x = self.act(x)
        return x

class SwinTransformer(nn.Module):
    def __init__(self, in_channels=3, num_classes=10, embed_dim=96, num_blocks=[2, 2, 6, 2], window_size=7):
        super(SwinTransformer, self).__init__()

        self.stem = nn.Conv2d(in_channels, embed_dim, kernel_size=window_size, stride=4, padding=window_size // 2)

        layers = []
        for i in range(len(num_blocks)):
            num_blocks_i = num_blocks[i]
            stride = 2 if i != 0 else 1
            for _ in range(num_blocks_i):
                layers.append(SwinBlock(embed_dim, embed_dim, stride=stride))
                stride = 1

        self.blocks = nn.Sequential(*layers)
        self.global_avg_pooling = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.stem(x)
        x = self.blocks(x)
        x = self.global_avg_pooling(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Example usage:
model = SwinTransformer()
print(model)


SwinTransformer(
  (stem): Conv2d(3, 96, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
  (blocks): Sequential(
    (0): SwinBlock(
      (conv1): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU()
    )
    (1): SwinBlock(
      (conv1): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU()
    )
    (2): SwinBlock(
      (conv1): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (conv2): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    