In [4]:
from enum import Enum
from sklearn.model_selection import GridSearchCV
from skorch import NeuralNetRegressor
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary
from torchvision.datasets import CIFAR10
import torch
import torchvision.transforms as transforms
from torch.nn import (
    Module,
    Sequential,
    Conv2d,
    ReLU,
    ModuleList,
    BatchNorm2d,
    ReLU,
    Linear,
    Flatten,
    AdaptiveAvgPool2d,
    Softmax,
)

In [5]:
class CIFAR10Dataset(Dataset):
    def __init__(self, is_train: bool):
        all_transforms = [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
        if is_train:
            all_transforms = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4)] + all_transforms

        dataset = CIFAR10(root="data", download=True, train=is_train, transform=transforms.Compose(all_transforms))
        # TODO: unlimit data
        dataset.data = dataset.data[:2222]
        dataset.targets = dataset.targets[:2222]

        dataloader = DataLoader(dataset)

        self.data = torch.cat([X for X, _ in dataloader])
        self.labels = torch.eye(10)[torch.cat([y for _, y in dataloader])]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        x = self.data[index]
        y = self.labels[index]
        return x, y


train_dataset = CIFAR10Dataset(True)
test_dataset = CIFAR10Dataset(False)

train_dataloader = DataLoader(train_dataset)
test_dataloader = DataLoader(test_dataset)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Files already downloaded and verified
Files already downloaded and verified


In [6]:
class ResNetConvSizes:
    def __init__(self, resnet_layers: int, block_size: int, conv2: int, conv3: int, conv4: int, conv5) -> None:
        if block_size != 2 and block_size != 3:
            raise ValueError(f"Possible block sizes are [2, 3]. Provided: {block_size}")
        if resnet_layers < 5:
            raise ValueError(f"Possible lowest layers number: 5. Provided: {resnet_layers}")
        if any(value < 1 for value in (conv2, conv3, conv4)) or conv5 < 0:
            raise ValueError("Wrong layers count")

        self.resnet_layers = resnet_layers
        self.block_size = block_size
        self.conv2 = conv2
        self.conv3 = conv3
        self.conv4 = conv4
        self.conv5 = conv5

        size = self.layers_count()
        if size != resnet_layers:
            raise ValueError(f"Wrong summary ResNet size. Current: {size}, expected: {resnet_layers}")

    def layers_count(self) -> int:
        return ((self.conv2 + self.conv3 + self.conv4 + self.conv5) * self.block_size) + 2

    def __str__(self):
        return f"ConvSize(resnet_layers={self.resnet_layers}, block_size={self.block_size}, conv2={self.conv2}, conv3={self.conv3}, conv4={self.conv4}, conv5={self.conv5})"


class ShortcutTypeEnum(Enum):
    Padding = 1
    Convolution = 2

In [7]:
class PaddingLayer(Module):
    def __init__(self, in_channels: int):
        super(PaddingLayer, self).__init__()
        self.pad = (0, 0, 0, 0, in_channels // 2, in_channels // 2)

    def forward(self, x: Tensor) -> Tensor:
        return torch.nn.functional.pad(x[:, :, ::2, ::2], pad=self.pad, mode="constant", value=0.0)


class ResNetModule(Module):
    def __init__(
        self,
        conv_sizes: ResNetConvSizes,
        shortcut_type: ShortcutTypeEnum = ShortcutTypeEnum.Padding,
        batchnorm_momentum: float = 0.1,
    ):
        super(ResNetModule, self).__init__()

        if batchnorm_momentum <= 0 or batchnorm_momentum >= 1:
            raise ValueError(f"Momentum must be value between (0, 1). Provided: {batchnorm_momentum}")

        self.conv_sizes = conv_sizes
        self.momentum = batchnorm_momentum
        self.shortcut_type = shortcut_type

        self.initial_channels = self.latest_channels = 16

        self.conv1 = Sequential(
            # bias is redundant when using batch normalization
            Conv2d(3, self.latest_channels, kernel_size=3, stride=1, padding=1, bias=False),
            BatchNorm2d(self.latest_channels, momentum=self.momentum),
            ReLU()
            # no pooling there
        ).apply(self.__init_weights)

        self.conv2 = self.__create_blocks(conv_sizes.conv2)
        self.conv3 = self.__create_blocks(conv_sizes.conv3)
        self.conv4 = self.__create_blocks(conv_sizes.conv4)
        self.conv5 = self.__create_blocks(conv_sizes.conv5)

        self.relu = ReLU()
        self.shortcuts = self.__create_shortcuts()

        self.output = Sequential(AdaptiveAvgPool2d((1, 1)), Flatten(), Linear(self.latest_channels, 10), Softmax(dim=1))

    @staticmethod
    def load_model(path: str):
        return torch.load(path)

    def save_model(self, path: str) -> None:
        torch.save(self, path)

    def forward(self, x: Tensor):
        all_convs = [self.conv2, self.conv3, self.conv4]
        if self.conv_sizes.conv5 > 0:
            all_convs.append(self.conv5)

        x = self.conv1(x)

        for index, conv_blocks in enumerate(all_convs):
            previous_x = self.shortcuts[index](x)
            for block in conv_blocks:
                x = self.relu(block(x) + previous_x)
                previous_x = x.clone()

        return self.output(x)

    def __create_blocks(self, conv_size: int) -> ModuleList:
        modules = ModuleList()
        if conv_size == 0:
            return modules

        create_block = self.__create_basic_block if self.conv_sizes.block_size == 2 else self.__create_bottleneck_block

        modules.append(create_block(self.latest_channels, True))
        self.latest_channels *= 2

        for _ in range(1, conv_size):
            modules.append(create_block(self.latest_channels))

        return modules.apply(self.__init_weights)

    def __init_weights(self, module):
        if isinstance(module, (Conv2d, Linear)):
            torch.nn.init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="relu")

    def __create_basic_block(self, in_channels: int, downsample_dimensions: bool = False) -> Sequential:
        if not self.__is_power_of_2(in_channels):
            raise ValueError("Input channels number is not power of 2")

        first_stride = 1
        out_channels = in_channels

        if downsample_dimensions:
            first_stride *= 2
            out_channels *= 2

        return Sequential(
            Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=first_stride, bias=False),
            BatchNorm2d(out_channels, momentum=self.momentum),
            ReLU(),
            Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1, bias=False),
            BatchNorm2d(out_channels, momentum=self.momentum),
        )

    def __create_bottleneck_block(self, in_channels: int, downsample_dimensions: bool = False) -> Sequential:
        first_stride = 1
        internal_channels = in_channels // 4
        out_channels = in_channels

        if downsample_dimensions:
            first_stride *= 2
            internal_channels *= 2
            out_channels *= 2

        if not all(self.__is_power_of_2(num) for num in [in_channels, internal_channels, out_channels]):
            raise ValueError("Channels number is not power of 2")

        return Sequential(
            Conv2d(in_channels, internal_channels, padding=0, kernel_size=1, stride=first_stride, bias=False),
            BatchNorm2d(internal_channels, momentum=self.momentum),
            ReLU(),
            Conv2d(internal_channels, internal_channels, padding=1, kernel_size=3, stride=1, bias=False),
            BatchNorm2d(internal_channels, momentum=self.momentum),
            ReLU(),
            Conv2d(internal_channels, out_channels, padding=0, kernel_size=1, stride=1, bias=False),
            BatchNorm2d(out_channels, momentum=self.momentum),
        )

    def __create_shortcuts(self) -> ModuleList:
        in_channels = self.initial_channels
        shortcuts = ModuleList()

        iters = 4 if self.conv_sizes.conv5 > 0 else 3
        for _ in range(iters):
            match self.shortcut_type:
                case ShortcutTypeEnum.Convolution:
                    out_channels = in_channels * 2
                    seq = Sequential(
                        Conv2d(in_channels, out_channels, kernel_size=1, stride=2, bias=False),
                        BatchNorm2d(out_channels, momentum=self.momentum),
                    ).apply(self.__init_weights)
                    shortcuts = shortcuts.append(seq)

                case ShortcutTypeEnum.Padding:
                    shortcuts = shortcuts.append(Sequential(PaddingLayer(in_channels)))

                case _:
                    raise ValueError("Not supported shortcut type")

            in_channels *= 2

        return shortcuts

    def __is_power_of_2(self, n: int) -> bool:
        return (n & (n - 1) == 0) and n != 0

In [8]:
def find_optimal_parameters(conv_sizes: list[ResNetConvSizes]) -> dict:
    net = NeuralNetRegressor(
        module=ResNetModule, optimizer=torch.optim.SGD, criterion=torch.nn.MSELoss(), device=device, verbose=1
    )
    params = {
        "max_epochs": [2, 3],
        "batch_size": [128],
        "optimizer__momentum": [0.9],
        "optimizer__lr": [0.1, 0.01],
        "optimizer__weight_decay": [10**-5],
        "module__conv_sizes": conv_sizes,
        "module__shortcut_type": [ShortcutTypeEnum.Convolution, ShortcutTypeEnum.Padding],
        "module__batchnorm_momentum": [0.1],
    }

    grid_search = GridSearchCV(net, params, refit=False, scoring="neg_mean_squared_error", cv=3, n_jobs=-1, verbose=1)

    samples = 400
    grid_search.fit(train_dataset.data[:samples], train_dataset.labels[:samples])

    best_params = grid_search.best_params_
    best_mse = -grid_search.best_score_
    print("Best Mean Squared Error:", best_mse)
    print("Best Hyperparameters:", best_params)
    print("Best conv_sizes:", best_params["module__conv_sizes"])

    return best_params


def sort_best_conv_sizes_descending(conv_sizes: list[ResNetConvSizes]) -> list[ResNetConvSizes]:
    net = NeuralNetRegressor(
        module=ResNetModule, optimizer=torch.optim.SGD, criterion=torch.nn.MSELoss(), device=device, verbose=1
    )
    params = {
        "max_epochs": [2],
        "batch_size": [128],
        "optimizer__momentum": [0.9],
        "optimizer__lr": [0.1],
        "optimizer__weight_decay": [10**-5],
        "module__conv_sizes": conv_sizes,
        "module__shortcut_type": [ShortcutTypeEnum.Padding],
        "module__batchnorm_momentum": [0.1],
    }

    grid_search = GridSearchCV(net, params, refit=False, scoring="neg_mean_squared_error", cv=3, n_jobs=-1, verbose=1)

    samples = 200
    grid_search.fit(train_dataset.data[:samples], train_dataset.labels[:samples])

    results = grid_search.cv_results_
    conv_sizes = [obj["module__conv_sizes"] for obj in results["params"]]
    indexes = results["rank_test_score"]
    sorted_desc_conv_sizes = [x for _, x in sorted(zip(indexes, conv_sizes))]

    return sorted_desc_conv_sizes

In [24]:
resnet18_conv_sizes = [ResNetConvSizes(18, 2, 2, 4, 2, 0), ResNetConvSizes(18, 2, 2, 2, 2, 2)]
resnet18_best_params = find_optimal_parameters(resnet18_conv_sizes)

resnet18_conv = resnet18_best_params["module__conv_sizes"]
print(resnet18_best_params)
print(resnet18_conv)

ResNetModule(
  (conv1): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (conv2): ModuleList(
    (0): Sequential(
      (0): Conv2d(16, 8, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm2d(8, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(8, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (7): BatchNorm2d(32, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    )
    (1-2): 2 x Sequential(
      (0): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(8, eps=1e-05, momentum=0.9, affine=True, track_running_sta

In [10]:
resnet56_conv_sizes = [
    ResNetConvSizes(56, 2, 3, 10, 11, 3), 
    ResNetConvSizes(56, 2, 6, 15, 6, 0), 
    ResNetConvSizes(56, 3, 3, 6, 6, 3), 
    ResNetConvSizes(56, 3, 3, 12, 3, 0),
]

resnet56_conv_sizes = sort_best_conv_sizes_descending(resnet56_conv_sizes)
[print(c) for c in resnet56_conv_sizes]
resnet56_best_params = find_optimal_parameters(resnet56_conv_sizes[:2])

resnet56_conv = resnet56_best_params['module__conv_sizes']
print(resnet56_best_params)
print(resnet56_conv)

KeyboardInterrupt: 

In [None]:
def check_accuracy(model: ResNetModule):
    correct = 0
    model.eval()

    with torch.no_grad():
        for data, labels in test_dataloader:
            data = data.to(device=device)
            labels = labels.to(device=device)

            predictions = model(data)
            correct += (predictions == labels).sum()

        accuracy = float(correct)/float(len(test_dataloader)) * 100
        print(f"Model: {print(model)}")
        print(f"Accuracy of the model: {accuracy}%")

In [None]:
resnet18 = ResNetModule(ResNetConvSizes(18, 2, 2, 4, 2, 0), ShortcutTypeEnum.Convolution, 0.85).to(device)
#resnet18.forward(train_dataset.data[:10].to(device))
summary(resnet18, (3, 32, 32))

In [None]:
resnet56 = ResNetModule(resnet18_best_params['module__conv_sizes'], ShortcutTypeEnum.Convolution, 0.85).to(device)
resnet56.forward(train_dataset.data[:10].to(device))
summary(resnet18, (3, 32, 32))