In [27]:
from torch import nn
from fft_conv import FFTConv2d

class ResNetBlock(nn.Module):

    def __init__(self, c_in, act_fn, subsample=False, c_out=-1):
        """
        Inputs:
            c_in - Number of input features
            act_fn - Activation class constructor (e.g. nn.ReLU)
            subsample - If True, we want to apply a stride inside the block and reduce the output shape by 2 in height and width
            c_out - Number of output features. Note that this is only relevant if subsample is True, as otherwise, c_out = c_in
        """
        super().__init__()
        if not subsample:
            c_out = c_in

        # Network representing F
        self.net = nn.Sequential(
            FFTConv2d(c_in, c_out, kernel_size=3, padding=1, stride=1 if not subsample else 2, bias=False),  # No bias needed as the Batch Norm handles it
            nn.BatchNorm2d(c_out),
            act_fn(),
            FFTConv2d(c_out, c_out, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(c_out)
        )

        # 1x1 convolution with stride 2 means we take the upper left value, and transform it to new output size
        self.downsample = nn.Conv2d(c_in, c_out, kernel_size=1, stride=2) if subsample else None
        self.act_fn = act_fn()

    def forward(self, x):
        z = self.net(x)
        if self.downsample is not None:
            x = self.downsample(x)
        out = z + x
        out = self.act_fn(out)
        return out


In [28]:
from types import SimpleNamespace

resnet_blocks_by_name = {
    "ResNetBlock": ResNetBlock
}
act_fn_by_name = {
    "tanh": nn.Tanh,
    "relu": nn.ReLU,
    "leakyrelu": nn.LeakyReLU,
    "gelu": nn.GELU
}
class ResNet(nn.Module):

    def __init__(self, num_classes=10, num_blocks=None, c_hidden=None, act_fn_name="relu", block_name="ResNetBlock", **kwargs):
        """
        Inputs:
            num_classes - Number of classification outputs (10 for CIFAR10)
            num_blocks - List with the number of ResNet blocks to use. The first block of each group uses downsampling, except the first.
            c_hidden - List with the hidden dimensionalities in the different blocks. Usually multiplied by 2 the deeper we go.
            act_fn_name - Name of the activation function to use, looked up in "act_fn_by_name"
            block_name - Name of the ResNet block, looked up in "resnet_blocks_by_name"
        """
        super().__init__()
        assert block_name in resnet_blocks_by_name
        self.hparams = SimpleNamespace(num_classes=num_classes,
                                       c_hidden=c_hidden,
                                       num_blocks=num_blocks,
                                       act_fn_name=act_fn_name,
                                       act_fn=act_fn_by_name[act_fn_name],
                                       block_class=resnet_blocks_by_name[block_name])
        self._create_network()
        self._init_params()

    def _create_network(self):
        c_hidden = self.hparams.c_hidden

        self.input_net = nn.Sequential(
                nn.Conv2d(3, c_hidden[0], kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(c_hidden[0]),
                self.hparams.act_fn()
            )

        # Creating the ResNet blocks
        blocks = []
        for block_idx, block_count in enumerate(self.hparams.num_blocks):
            for bc in range(block_count):
                subsample = (bc == 0 and block_idx > 0) # Subsample the first block of each group, except the very first one.
                blocks.append(
                    self.hparams.block_class(c_in=c_hidden[block_idx if not subsample else (block_idx-1)],
                                             act_fn=self.hparams.act_fn,
                                             subsample=subsample,
                                             c_out=c_hidden[block_idx])
                )
        self.blocks = nn.Sequential(*blocks)

        # Mapping to classification output
        self.output_net = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten(),
            nn.Linear(c_hidden[-1], self.hparams.num_classes)
        )

    def _init_params(self):
        # Based on our discussion in Tutorial 4, we should initialize the convolutions according to the activation function
        # Fan-out focuses on the gradient distribution, and is commonly used in ResNets
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity=self.hparams.act_fn_name)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.input_net(x)
        x = self.blocks(x)
        x = self.output_net(x)
        return x


In [29]:
import torch
from ptflops import get_model_complexity_info
import re

first_model = ResNet(num_blocks=[3,3,3], c_hidden=[16,32,64], act_fn_name="relu")

first_checkpoint = torch.load('FResNet-models\\fftconv2d-all.ckpt')

first_model.load_state_dict(first_checkpoint['state_dict'], strict=False)

macs, params = get_model_complexity_info(first_model, (3, 32, 32), as_strings=True,
print_per_layer_stat=True, verbose=True)
# Extract the numerical value
flops = eval(re.findall(r'([\d.]+)', macs)[0])*2
# Extract the unit
flops_unit = re.findall(r'([A-Za-z]+)', macs)[0][0]

print('Computational complexity: {:<8}'.format(macs))
print('Computational complexity: {} {}Flops'.format(flops, flops_unit))
print('Number of parameters: {:<8}'.format(params))

ResNet(
  5.11 k, 1.878% Params, 1.29 MMac, 86.986% MACs, 
  (input_net): Sequential(
    464, 0.170% Params, 491.52 KMac, 33.226% MACs, 
    (0): Conv2d(432, 0.159% Params, 442.37 KMac, 29.904% MACs, 3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, 0.012% Params, 32.77 KMac, 2.215% MACs, 16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(0, 0.000% Params, 16.38 KMac, 1.108% MACs, )
  )
  (blocks): Sequential(
    4.0 k, 1.469% Params, 790.53 KMac, 53.439% MACs, 
    (0): ResNetBlock(
      64, 0.023% Params, 98.3 KMac, 6.645% MACs, 
      (net): Sequential(
        64, 0.023% Params, 81.92 KMac, 5.538% MACs, 
        (0): _FFTConv(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
        (1): BatchNorm2d(32, 0.012% Params, 32.77 KMac, 2.215% MACs, 16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(0, 0.000% Params, 16.38 KMac, 1.108% MACs, )
        (3): _FFTConv(0, 0.000% Params, 0

In [21]:
second_model = ResNet(num_blocks=[3,3,3], c_hidden=[16,32,64], act_fn_name="relu")

second_checkpoint = torch.load('FResNet-models\\fftconv2d-input.ckpt')

second_model.load_state_dict(second_checkpoint['state_dict'], strict=False)

macs, params = get_model_complexity_info(second_model, (3, 32, 32), as_strings=True,
print_per_layer_stat=True, verbose=True)
# Extract the numerical value
flops = eval(re.findall(r'([\d.]+)', macs)[0])*2
# Extract the unit
flops_unit = re.findall(r'([A-Za-z]+)', macs)[0][0]

print('Computational complexity: {:<8}'.format(macs))
print('Computational complexity: {} {}Flops'.format(flops, flops_unit))
print('Number of parameters: {:<8}'.format(params))

ResNet(
  272.38 k, 100.000% Params, 41.39 MMac, 99.537% MACs, 
  (input_net): Sequential(
    464, 0.170% Params, 491.52 KMac, 1.182% MACs, 
    (0): Conv2d(432, 0.159% Params, 442.37 KMac, 1.064% MACs, 3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, 0.012% Params, 32.77 KMac, 0.079% MACs, 16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(0, 0.000% Params, 16.38 KMac, 0.039% MACs, )
  )
  (blocks): Sequential(
    271.26 k, 99.591% Params, 40.9 MMac, 98.344% MACs, 
    (0): ResNetBlock(
      4.67 k, 1.715% Params, 4.82 MMac, 11.583% MACs, 
      (net): Sequential(
        4.67 k, 1.715% Params, 4.8 MMac, 11.543% MACs, 
        (0): Conv2d(2.3 k, 0.846% Params, 2.36 MMac, 5.673% MACs, 16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, 0.012% Params, 32.77 KMac, 0.079% MACs, 16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(

In [26]:
third_model = ResNet(num_blocks=[3,3,3], c_hidden=[16,32,64], act_fn_name="relu")

third_checkpoint = torch.load('FResNet-models\\ResNet.ckpt')

third_model.load_state_dict(third_checkpoint['state_dict'], strict=False)

macs, params = get_model_complexity_info(third_model, (3, 32, 32), as_strings=True,
print_per_layer_stat=True, verbose=True)
# Extract the numerical value
flops = eval(re.findall(r'([\d.]+)', macs)[0])*2
# Extract the unit
flops_unit = re.findall(r'([A-Za-z]+)', macs)[0][0]

print('Computational complexity: {:<8}'.format(macs))
print('Computational complexity: {} {}Flops'.format(flops, flops_unit))
print('Number of parameters: {:<8}'.format(params))


ResNet(
  272.38 k, 100.000% Params, 41.39 MMac, 99.537% MACs, 
  (input_net): Sequential(
    464, 0.170% Params, 491.52 KMac, 1.182% MACs, 
    (0): Conv2d(432, 0.159% Params, 442.37 KMac, 1.064% MACs, 3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, 0.012% Params, 32.77 KMac, 0.079% MACs, 16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(0, 0.000% Params, 16.38 KMac, 0.039% MACs, )
  )
  (blocks): Sequential(
    271.26 k, 99.591% Params, 40.9 MMac, 98.344% MACs, 
    (0): ResNetBlock(
      4.67 k, 1.715% Params, 4.82 MMac, 11.583% MACs, 
      (net): Sequential(
        4.67 k, 1.715% Params, 4.8 MMac, 11.543% MACs, 
        (0): Conv2d(2.3 k, 0.846% Params, 2.36 MMac, 5.673% MACs, 16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, 0.012% Params, 32.77 KMac, 0.079% MACs, 16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(