In [1]:
from torchvision.datasets import CIFAR10
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

In [2]:
class CIFAR10Dataset(Dataset):
    def __init__(self, is_train: bool):
        normalization = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        dataset = CIFAR10(root="data", download=True, train=is_train, transform=normalization)
        dataset.data = dataset.data[:100]
        dataset.targets = dataset.targets[:100]

        dataloader = DataLoader(dataset)

        self.data = torch.cat([X for X, _ in dataloader])
        self.labels = torch.eye(10)[torch.cat([y for _, y in dataloader])]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        x = self.data[index]
        y = self.labels[index]
        return x, y


train_dataset = CIFAR10Dataset(True)

Files already downloaded and verified


In [15]:
class ResNetConvSizes:
    def __init__(self, resnet_layers: int, block_size: int, conv2: int, conv3: int, conv4: int, conv5: int = 0) -> None:
        if block_size != 2 and block_size != 3:
            raise ValueError(f"Possible block sizes are [2, 3]. Provided: {block_size}")
        if resnet_layers < 5:
            raise ValueError(f"Possible lowest layers number: 5. Provided: {resnet_layers}")
        if any(value < 1 for value in (conv2, conv3, conv4)) or conv5 < 0:
            raise ValueError("Wrong layers count")

        self.block_size = block_size
        self.conv2 = conv2
        self.conv3 = conv3
        self.conv4 = conv4
        self.conv5 = conv5

        size = self.layers_count()
        if size != resnet_layers:
            raise ValueError(f"Wrong summary ResNet size. Current: {size}, expected: {resnet_layers}")

    def layers_count(self) -> int:
        return ((self.conv2 + self.conv3 + self.conv4 + self.conv5) * self.block_size) + 2

In [1]:
from torch.nn import (
    Module,
    Sequential,
    Conv2d,
    ReLU,
    ModuleList,
    BatchNorm2d,
    ReLU,
    Linear,
    Flatten,
    AdaptiveAvgPool2d,
    Softmax,
)


class ResNetModule(Module):
    def __init__(self, conv_sizes: ResNetConvSizes):
        super(ResNetModule, self).__init__()

        self.conv_sizes = conv_sizes
        self.latest_channels = 16

        self.conv1 = Sequential(
            # bias is redundant when using batch normalization
            Conv2d(3, self.latest_channels, kernel_size=3, stride=1, padding=1, bias=False),
            BatchNorm2d(self.latest_channels),
            ReLU()
            # no pooling there
        )

        self.conv2 = self.__create_blocks(conv_sizes.conv2)
        self.conv3 = self.__create_blocks(conv_sizes.conv3)
        self.conv4 = self.__create_blocks(conv_sizes.conv4)
        self.conv5 = self.__create_blocks(conv_sizes.conv5)
        self.output = Sequential(AdaptiveAvgPool2d((1, 1)), Flatten(), Linear(self.latest_channels, 10), Softmax(dim=1))

    def forward(self, x):
        pass

    def __create_blocks(self, conv_size: int) -> ModuleList:
        modules = ModuleList()
        if conv_size == 0:
            return modules

        create_block = self.__create_basic_block if self.conv_sizes.block_size == 2 else self.__create_bottleneck_block

        modules.append(create_block(self.latest_channels, True))
        self.latest_channels *= 2

        for _ in range(1, conv_size):
            modules.append(create_block(self.latest_channels))

        return modules

    # option A
    def __create_basic_block(self, in_channels: int, downsample_dimensions: bool = False) -> Sequential:
        if not self.__is_power_of_2(in_channels):
            raise ValueError("Input channels number is not power of 2")

        first_stride = 1
        out_channels = in_channels
        if downsample_dimensions:
            first_stride = 2
            out_channels = out_channels * 2

        return Sequential(
            Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=first_stride),
            BatchNorm2d(out_channels),
            ReLU(),
            Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1),
            BatchNorm2d(out_channels),
        )

    # option B
    def __create_bottleneck_block(self, in_channels: int, downsample_dimensions: bool = False) -> Sequential:
        first_stride = 1
        internal_channels = in_channels // 4
        out_channels = in_channels
        if downsample_dimensions:
            first_stride = 2
            internal_channels = internal_channels * 2
            out_channels = out_channels * 2

        if not all(self.__is_power_of_2(num) for num in [in_channels, internal_channels, out_channels]):
            raise ValueError("Channels number is not power of 2")

        return Sequential(
            Conv2d(in_channels, internal_channels, padding=1, kernel_size=1, stride=first_stride),
            BatchNorm2d(internal_channels),
            ReLU(),
            Conv2d(internal_channels, internal_channels, padding=1, kernel_size=3, stride=1),
            BatchNorm2d(internal_channels),
            ReLU(),
            Conv2d(internal_channels, out_channels, padding=1, kernel_size=1, stride=1),
            BatchNorm2d(out_channels),
        )

    def __is_power_of_2(self, n: int) -> bool:
        return (n & (n - 1) == 0) and n != 0
    

IndentationError: expected an indented block after 'if' statement on line 37 (891259096.py, line 40)

[ResNetConvSizes(conv2_conv5=2, conv3=2, conv4=8, block_size=2), ResNetConvSizes(conv2_conv5=2, conv3=4, conv4=6, block_size=2), ResNetConvSizes(conv2_conv5=3, conv3=4, conv4=4, block_size=2)]
