In [30]:
import torch
import torch.nn as nn

In [31]:
# ConvBlock
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups=1, bias=False):
        super().__init__()
        self.c = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_channels)
        
    def forward(self, x):
        return self.bn(self.c(x))

In [32]:
# Bottleneck ResidualBlock
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride, first=False, cardinality=32):
        super().__init__()
        C = cardinality
        res_channels = out_channels // 2
        self.downsample = stride == 2 or first
        self.c1 = ConvBlock(in_channels, res_channels, 1, 1 ,0)
        self.c2 = ConvBlock(res_channels, res_channels, 3, stride, 1, C)
        self.c3 = ConvBlock(res_channels, out_channels, 1, 1, 0)
        
        if self.downsample:
            self.p = ConvBlock(in_channels, out_channels, 1, stride, 0)
            
        self.relu = nn.ReLU()
        
    def forward(self, x):
        f = self.relu(self.c1(x))
        f = self.relu(self.c2(f))
        f = self.c3(f)
        
        if self.downsample:
            x = self.p(x)
        
        h = self.relu(torch.add(f, x))
        return h

In [33]:
# ResNext
class ResNeXt(nn.Module):
    def __init__(self, no_blocks, in_channels=3, classes=1000):
        super().__init__()
        out_features = [256, 512, 1024, 2048]
        self.blocks = nn.ModuleList([ResidualBlock(64, 256, 1, True)])
        
        for i in range(len(out_features)):
            if i > 0:
                self.blocks.append(ResidualBlock(out_features[i-1], out_features[i], 2))
            for _ in range(no_blocks[i]-1):
                self.blocks.append(ResidualBlock(out_features[i], out_features[i], 1))
        
        self.conv1 = ConvBlock(in_channels, 64, 7, 2, 3)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(2048, classes)
        
        self.relu = nn.ReLU()
        
        self.init_weight()
        
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.maxpool(x)
        for block in self.blocks:
            x = block(x)
            
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x
    
    def init_weight(self):
        for layer in self.modules():
            if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
                nn.init.kaiming_normal_(layer.weight)

In [34]:
no_blocks = [3, 4, 23, 3]
my_resnet = ResNeXt(no_blocks,1, 10)
print(my_resnet(torch.rand(16,1,224,224)).shape)

torch.Size([16, 10])


In [35]:
from ptflops import get_model_complexity_info

with torch.cuda.device(0):
  net = ResNeXt(no_blocks,1, 10)
  macs, params = get_model_complexity_info(net, (1, 224, 224), as_strings=True,
                                           print_per_layer_stat=True, verbose=True)
  print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
  print('{:<30}  {:<8}'.format('Number of parameters: ', params))

ResNeXt(
  42.143 M, 100.000% Params, 7.952 GMac, 100.000% MACs, 
  (blocks): ModuleList(
    42.119 M, 99.944% Params, 7.91 GMac, 99.463% MACs, 
    (0): ResidualBlock(
      0.063 M, 0.151% Params, 0.201 GMac, 2.524% MACs, 
      (c1): ConvBlock(
        0.008 M, 0.020% Params, 0.026 GMac, 0.333% MACs, 
        (c): Conv2d(0.008 M, 0.019% Params, 0.026 GMac, 0.323% MACs, 64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(0.0 M, 0.001% Params, 0.001 GMac, 0.010% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (c2): ConvBlock(
        0.005 M, 0.012% Params, 0.015 GMac, 0.192% MACs, 
        (c): Conv2d(0.005 M, 0.011% Params, 0.014 GMac, 0.182% MACs, 128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (bn): BatchNorm2d(0.0 M, 0.001% Params, 0.001 GMac, 0.010% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (c3): ConvBlock(
        0.0

In [36]:
import numpy as np

In [51]:
np.random.randn(1)

array([1.04394446])