In [1]:
from torchvision.datasets import CIFAR10
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

In [2]:
class CIFAR10Dataset(Dataset):
    def __init__(self, is_train: bool):
        normalization = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        dataset = CIFAR10(root="data", download=True, train=is_train, transform=normalization)
        dataset.data = dataset.data[:100]
        dataset.targets = dataset.targets[:100]

        dataloader = DataLoader(dataset)

        self.data = torch.cat([X for X, _ in dataloader])
        self.labels = torch.eye(10)[torch.cat([y for _, y in dataloader])]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        x = self.data[index]
        y = self.labels[index]
        return x, y


train_dataset = CIFAR10Dataset(True)

Files already downloaded and verified


In [15]:
from typing import NamedTuple


class ResNetConvSizes(NamedTuple):
    conv2_conv5: int
    conv3: int
    conv4: int
    # possible values: 2, 3
    block_size: int

    def layers_count(self) -> int:
        return ((self.conv2_conv5 * 2 + self.conv3 + self.conv4) * self.block_size) + 2

    def check_conv_sizes(self, resnet_layers: int) -> bool:
        size = self.layers_count()
        assert size == resnet_layers, f"Wrong ResNetConvSizes. Current: {size}, expected: {resnet_layers}"

    def find_resnet_conv_sizes(resnet_layers: int) -> list:
        assert resnet_layers >= 18, "Minimum is 18 layers"

        # need to subtract 2 const layers (1 from conv1 and 1 from conv2)
        sub_resnet_layers = resnet_layers - 2
        conv_sizes = []

        for base in [2, 3]:
            if sub_resnet_layers % base != 0:
                continue

            conv2_conv5_max = sub_resnet_layers // (4 * base)
            for conv2_conv5 in range(2, conv2_conv5_max + 1):
                remain_layers = sub_resnet_layers - conv2_conv5 * 2 * base
                # conv3 must be >= conv2_conv5 and conv4 >= conv3
                # conv3 must be power of 2
                conv3 = conv2_conv5
                while ResNetConvSizes.__not_power_of_2(conv3):
                    conv3 += 1

                while conv3 * base <= remain_layers // 2:
                    assert remain_layers % base == 0, f"Conv3 not dividable by {base}"

                    conv4 = remain_layers // base - conv3
                    conv_size = ResNetConvSizes(conv2_conv5, conv3, conv4, base)
                    conv_size.check_conv_sizes(resnet_layers)
                    conv_sizes.append(conv_size)

                    conv3 *= base

        return conv_sizes

    def __not_power_of_2(n: int) -> bool:
        return not ((n & (n - 1) == 0) and n != 0)

In [16]:
from torch.nn import Module, Sequential, Conv2d, ReLU, ModuleList


class ResNetModule(Module):
    def __init__(self, conv_sizes: ResNetConvSizes):
        super(ResNetModule, self).__init__()
        self.conv_sizes = conv_sizes
        self.conv1 = ModuleList(
            Sequential(
            )
        )

    def forward(self, x):
        pass

    # option A
    def __create_basic_block(in_channels: int) -> Sequential:
        return Sequential(
            Conv2d(in_channels, in_channels, kernel_size=3), ReLU(), Conv2d(in_channels, in_channels, kernel_size=3)
        )

    # option B
    def __create_bottleneck_block(in_channels: int) -> Sequential:
        internal_channels = in_channels // 4
        return Sequential(
            Conv2d(in_channels, internal_channels, kernel_size=1),
            ReLU(),
            Conv2d(internal_channels, internal_channels, kernel_size=3),
            ReLU(),
            Conv2d(internal_channels, in_channels, kernel_size=1),
        )

In [17]:
a = ResNetConvSizes.find_resnet_conv_sizes(30)
print(a)

[ResNetConvSizes(conv2_conv5=2, conv3=2, conv4=8, block_size=2), ResNetConvSizes(conv2_conv5=2, conv3=4, conv4=6, block_size=2), ResNetConvSizes(conv2_conv5=3, conv3=4, conv4=4, block_size=2)]
