# Install packages

In [None]:
!pip install torchsummary commentjson torcheval python-box[all]~=7.0

# Training configs

In [None]:
from box import Box

options ={
    "EXPERIMENTS":
    {
        "PROJECT_NAME": "face_attribute"
    },

    "DATA":
    {
        "DATASET_NAME": "celeb_A",
        "INPUT_SHAPE": [3, 224, 224],
        "TRAIN_SIZE": 0.9,
        "BATCH_SIZE": 360,
        "NUM_WORKERS": 4,
        "TRANSFORM":
        {
            "NAME_LIST": ["PILToTensor", "ToDtype", "RandomRotation"],
            "ARGS":
            {
                "0": {},

                "1":
                {
                    "dtype": "float32",
                    # normalize image from range [0, 255] into range [0, 1]
                    "scale": True
                },

                "2":
                {
                    "degrees": [-10, 10],
                    "interpolation": "NEAREST"
                }
            }
        }
    },

    "CHECKPOINT":
    {
        "SAVE": True,
        "LOAD": True,
        "SAVE_ALL": False,
        "RESUME_NAME": "epoch_30.pt"
    },

    "EPOCH":
    {
        "START": 1,
        "EPOCHS": 1
    },

    "METRICS":
    {
        "NAME_LIST": ["BinaryAccuracy", "BinaryF1Score"],
        "ARGS":
        {
            "0":
            {
               "threshold": 0.5
            },

            "1":
            {
               "threshold": 0.5
            }
        }
    },

    "SOLVER":
    {
        "MODEL":
        {
            "BASE": "vgg",
            "NAME": "vgg13",
            "PRETRAINED": False,
            "ARGS":
            {
                "num_classes": 1
            }
        },

        "OPTIMIZER":
        {
            "NAME": "Adam",
            "ARGS":
            {
                "lr": 1e-7,
                "amsgrad": True
            }
        },

        "LR_SCHEDULER":
        {
            "NAME": "CosineAnnealingWarmRestarts",
            "ARGS":
            {
                "T_0": 100,
                "T_mult": 3,
            }
        },

        "LOSS":
        {
            "NAME": "BCELoss",
            "ARGS":
            {
                "reduction": "mean"
            }
        },

        "EARLY_STOPPING":
        {
            "PATIENCE": 5,
            "MIN_DELTA": 0
        }
    },

    "MISC":
    {
        "SEED": 12345,
        "APPLY_EARLY_STOPPING": True,
        "CUDA": True
    }
}
options = Box(options)

# Copy existing checkpoint for continuous training

In [None]:
if options.CHECKPOINT.LOAD:
    !cp -r /kaggle/input/checkpoints /kaggle/working

# Vgg base

In [None]:
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo

__all__ = [
    'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
    'vgg19_bn', 'vgg19', "get_vgg_model"
]


model_urls = {
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
    'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
    'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
    'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
    'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}

cfg = {
    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


class VGG(nn.Module):
    def __init__(self, features, dropout=.5, num_classes=1000, init_weights=True):
        super(VGG, self).__init__()
        self.features = features
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(dropout),

            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(dropout),

            nn.Linear(4096, 2048),
            nn.ReLU(True),
            nn.Dropout(dropout),

            nn.Linear(2048, 1024),
            nn.ReLU(True),
            nn.Dropout(dropout),

            nn.Linear(1024, 512),
            nn.ReLU(True),
            nn.Dropout(dropout),

            nn.Linear(512, num_classes),
            nn.Softmax(dim=1)
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(
                    m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)


def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)




def vgg11(pretrained=False, **kwargs):
    """VGG 11-layer model (configuration "A")

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    if pretrained:
        kwargs['init_weights'] = False
    model = VGG(make_layers(cfg['A']), **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['vgg11']))
    return model


def vgg11_bn(pretrained=False, **kwargs):
    """VGG 11-layer model (configuration "A") with batch normalization

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    if pretrained:
        kwargs['init_weights'] = False
    model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn']))
    return model


def vgg13(pretrained=False, **kwargs):
    """VGG 13-layer model (configuration "B")

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    if pretrained:
        kwargs['init_weights'] = False
    model = VGG(make_layers(cfg['B']), **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['vgg13']))
    return model


def vgg13_bn(pretrained=False, **kwargs):
    """VGG 13-layer model (configuration "B") with batch normalization

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    if pretrained:
        kwargs['init_weights'] = False
    model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn']))
    return model


def vgg16(pretrained=False, **kwargs):
    """VGG 16-layer model (configuration "D")

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    if pretrained:
        kwargs['init_weights'] = False
    model = VGG(make_layers(cfg['D']), **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['vgg16']))
    return model


def vgg16_bn(pretrained=False, **kwargs):
    """VGG 16-layer model (configuration "D") with batch normalization

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    if pretrained:
        kwargs['init_weights'] = False
    model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn']))
    return model


def vgg19(pretrained=False, **kwargs):
    """VGG 19-layer model (configuration "E")

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    if pretrained:
        kwargs['init_weights'] = False
    model = VGG(make_layers(cfg['E']), **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
    return model


def vgg19_bn(pretrained=False, **kwargs):
    """VGG 19-layer model (configuration 'E') with batch normalization

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    if pretrained:
        kwargs['init_weights'] = False
    model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn']))
    return model


def get_vgg_model(cuda: bool, name: str, pretrained: bool = True, model_state_dict: dict = None, **kwargs):
    models = {
        "vgg13": vgg13,
        "vgg16": vgg16,
        "vgg19": vgg19,
        "vgg13_bn": vgg13_bn,
        "vgg16_bn": vgg16_bn,
        "vgg19_bn": vgg19_bn,
    }
    assert name in models.keys(), "Your selected vgg model derivative is unavailable"

    model = models[name](pretrained=pretrained, **kwargs)

    if model_state_dict:
        print("Loading pretrained model...")
        model.load_state_dict(model_state_dict)
        print("Finished.")
    else:
        print("Initializing parameters...")
        for para in model.parameters():
            if para.dim() > 1:
                nn.init.xavier_uniform_(para)
        print("Finished.")

    if cuda:
        model = model.to("cuda")
    return model


# Restnet base

In [None]:
from functools import partial
from typing import Any, Callable, List, Optional, Type, Union

import torch
import torch.nn as nn
from torch import Tensor

from torchvision.transforms._presets import ImageClassification

from torchvision.utils import _log_api_usage_once
from torchvision.models._api import register_model, Weights, WeightsEnum
from torchvision.models._utils import _ovewrite_named_param, handle_legacy_interface
from torchvision.models._meta import _IMAGENET_CATEGORIES


__all__ = [
    "ResNet",
    "ResNet18_Weights",
    "ResNet34_Weights",
    "ResNet50_Weights",
    "ResNet101_Weights",
    "ResNet152_Weights",
    "ResNeXt50_32X4D_Weights",
    "ResNeXt101_32X8D_Weights",
    "ResNeXt101_64X4D_Weights",
    "Wide_ResNet50_2_Weights",
    "Wide_ResNet101_2_Weights",
    "resnet18",
    "resnet34",
    "resnet50",
    "resnet101",
    "resnet152",
    "resnext50_32x4d",
    "resnext101_32x8d",
    "resnext101_64x4d",
    "wide_resnet50_2",
    "wide_resnet101_2",
]


def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
    """3x3 convolution with padding"""
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        groups=groups,
        bias=False,
        dilation=dilation,
    )


def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion: int = 1

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
    # according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385.
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

    expansion: int = 4

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.0)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):
    def __init__(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        layers: List[int],
        num_classes: int = 1000,
        zero_init_residual: bool = False,
        groups: int = 1,
        width_per_group: int = 64,
        replace_stride_with_dilation: Optional[List[bool]] = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super().__init__()
        _log_api_usage_once(self)
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError(
                "replace_stride_with_dilation should be None "
                f"or a 3-element tuple, got {replace_stride_with_dilation}"
            )
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck) and m.bn3.weight is not None:
                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
                elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]

    def _make_layer(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        planes: int,
        blocks: int,
        stride: int = 1,
        dilate: bool = False,
    ) -> nn.Sequential:
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(
            block(
                self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
            )
        )
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(
                    self.inplanes,
                    planes,
                    groups=self.groups,
                    base_width=self.base_width,
                    dilation=self.dilation,
                    norm_layer=norm_layer,
                )
            )

        return nn.Sequential(*layers)

    def _forward_impl(self, x: Tensor) -> Tensor:
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)


def _resnet(
    block: Type[Union[BasicBlock, Bottleneck]],
    layers: List[int],
    weights: Optional[WeightsEnum],
    progress: bool,
    **kwargs: Any,
) -> ResNet:
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

    model = ResNet(block, layers, **kwargs)

    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

    return model


_COMMON_META = {
    "min_size": (1, 1),
    "categories": _IMAGENET_CATEGORIES,
}


class ResNet18_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/resnet18-f37072fd.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 11689512,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 69.758,
                    "acc@5": 89.078,
                }
            },
            "_ops": 1.814,
            "_file_size": 44.661,
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
        },
    )
    DEFAULT = IMAGENET1K_V1


class ResNet34_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/resnet34-b627a593.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 21797672,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 73.314,
                    "acc@5": 91.420,
                }
            },
            "_ops": 3.664,
            "_file_size": 83.275,
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
        },
    )
    DEFAULT = IMAGENET1K_V1


class ResNet50_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/resnet50-0676ba61.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 25557032,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 76.130,
                    "acc@5": 92.862,
                }
            },
            "_ops": 4.089,
            "_file_size": 97.781,
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/resnet50-11ad3fa6.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 25557032,
            "recipe": "https://github.com/pytorch/vision/issues/3995#issuecomment-1013906621",
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 80.858,
                    "acc@5": 95.434,
                }
            },
            "_ops": 4.089,
            "_file_size": 97.79,
            "_docs": """
                These weights improve upon the results of the original paper by using TorchVision's `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
        },
    )
    DEFAULT = IMAGENET1K_V2


class ResNet101_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/resnet101-63fe2227.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 44549160,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 77.374,
                    "acc@5": 93.546,
                }
            },
            "_ops": 7.801,
            "_file_size": 170.511,
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/resnet101-cd907fc2.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 44549160,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 81.886,
                    "acc@5": 95.780,
                }
            },
            "_ops": 7.801,
            "_file_size": 170.53,
            "_docs": """
                These weights improve upon the results of the original paper by using TorchVision's `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
        },
    )
    DEFAULT = IMAGENET1K_V2


class ResNet152_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/resnet152-394f9c45.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 60192808,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 78.312,
                    "acc@5": 94.046,
                }
            },
            "_ops": 11.514,
            "_file_size": 230.434,
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/resnet152-f82ba261.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 60192808,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 82.284,
                    "acc@5": 96.002,
                }
            },
            "_ops": 11.514,
            "_file_size": 230.474,
            "_docs": """
                These weights improve upon the results of the original paper by using TorchVision's `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
        },
    )
    DEFAULT = IMAGENET1K_V2


class ResNeXt50_32X4D_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 25028904,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 77.618,
                    "acc@5": 93.698,
                }
            },
            "_ops": 4.23,
            "_file_size": 95.789,
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 25028904,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 81.198,
                    "acc@5": 95.340,
                }
            },
            "_ops": 4.23,
            "_file_size": 95.833,
            "_docs": """
                These weights improve upon the results of the original paper by using TorchVision's `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
        },
    )
    DEFAULT = IMAGENET1K_V2


class ResNeXt101_32X8D_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 88791336,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 79.312,
                    "acc@5": 94.526,
                }
            },
            "_ops": 16.414,
            "_file_size": 339.586,
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 88791336,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 82.834,
                    "acc@5": 96.228,
                }
            },
            "_ops": 16.414,
            "_file_size": 339.673,
            "_docs": """
                These weights improve upon the results of the original paper by using TorchVision's `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
        },
    )
    DEFAULT = IMAGENET1K_V2


class ResNeXt101_64X4D_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/resnext101_64x4d-173b62eb.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 83455272,
            "recipe": "https://github.com/pytorch/vision/pull/5935",
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 83.246,
                    "acc@5": 96.454,
                }
            },
            "_ops": 15.46,
            "_file_size": 319.318,
            "_docs": """
                These weights were trained from scratch by using TorchVision's `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
        },
    )
    DEFAULT = IMAGENET1K_V1


class Wide_ResNet50_2_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 68883240,
            "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 78.468,
                    "acc@5": 94.086,
                }
            },
            "_ops": 11.398,
            "_file_size": 131.82,
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 68883240,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 81.602,
                    "acc@5": 95.758,
                }
            },
            "_ops": 11.398,
            "_file_size": 263.124,
            "_docs": """
                These weights improve upon the results of the original paper by using TorchVision's `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
        },
    )
    DEFAULT = IMAGENET1K_V2


class Wide_ResNet101_2_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 126886696,
            "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 78.848,
                    "acc@5": 94.284,
                }
            },
            "_ops": 22.753,
            "_file_size": 242.896,
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 126886696,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 82.510,
                    "acc@5": 96.020,
                }
            },
            "_ops": 22.753,
            "_file_size": 484.747,
            "_docs": """
                These weights improve upon the results of the original paper by using TorchVision's `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
        },
    )
    DEFAULT = IMAGENET1K_V2


# @register_model()
@handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1))
def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
    """ResNet-18 from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__.

    Args:
        weights (:class:`~torchvision.models.ResNet18_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.ResNet18_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.ResNet18_Weights
        :members:
    """
    weights = ResNet18_Weights.verify(weights)

    return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)


# @register_model()
@handle_legacy_interface(weights=("pretrained", ResNet34_Weights.IMAGENET1K_V1))
def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
    """ResNet-34 from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__.

    Args:
        weights (:class:`~torchvision.models.ResNet34_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.ResNet34_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.ResNet34_Weights
        :members:
    """
    weights = ResNet34_Weights.verify(weights)

    return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs)


# @register_model()
@handle_legacy_interface(weights=("pretrained", ResNet50_Weights.IMAGENET1K_V1))
def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
    """ResNet-50 from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__.

    .. note::
       The bottleneck of TorchVision places the stride for downsampling to the second 3x3
       convolution while the original paper places it to the first 1x1 convolution.
       This variant improves the accuracy and is known as `ResNet V1.5
       <https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch>`_.

    Args:
        weights (:class:`~torchvision.models.ResNet50_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.ResNet50_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.ResNet50_Weights
        :members:
    """
    weights = ResNet50_Weights.verify(weights)

    return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)


# @register_model()
@handle_legacy_interface(weights=("pretrained", ResNet101_Weights.IMAGENET1K_V1))
def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
    """ResNet-101 from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__.

    .. note::
       The bottleneck of TorchVision places the stride for downsampling to the second 3x3
       convolution while the original paper places it to the first 1x1 convolution.
       This variant improves the accuracy and is known as `ResNet V1.5
       <https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch>`_.

    Args:
        weights (:class:`~torchvision.models.ResNet101_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.ResNet101_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.ResNet101_Weights
        :members:
    """
    weights = ResNet101_Weights.verify(weights)

    return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)


# @register_model()
@handle_legacy_interface(weights=("pretrained", ResNet152_Weights.IMAGENET1K_V1))
def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
    """ResNet-152 from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__.

    .. note::
       The bottleneck of TorchVision places the stride for downsampling to the second 3x3
       convolution while the original paper places it to the first 1x1 convolution.
       This variant improves the accuracy and is known as `ResNet V1.5
       <https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch>`_.

    Args:
        weights (:class:`~torchvision.models.ResNet152_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.ResNet152_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.ResNet152_Weights
        :members:
    """
    weights = ResNet152_Weights.verify(weights)

    return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs)


# @register_model()
@handle_legacy_interface(weights=("pretrained", ResNeXt50_32X4D_Weights.IMAGENET1K_V1))
def resnext50_32x4d(
    *, weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
    """ResNeXt-50 32x4d model from
    `Aggregated Residual Transformation for Deep Neural Networks <https://arxiv.org/abs/1611.05431>`_.

    Args:
        weights (:class:`~torchvision.models.ResNeXt50_32X4D_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.ResNext50_32X4D_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
            for more details about this class.
    .. autoclass:: torchvision.models.ResNeXt50_32X4D_Weights
        :members:
    """
    weights = ResNeXt50_32X4D_Weights.verify(weights)

    _ovewrite_named_param(kwargs, "groups", 32)
    _ovewrite_named_param(kwargs, "width_per_group", 4)
    return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)


# @register_model()
@handle_legacy_interface(weights=("pretrained", ResNeXt101_32X8D_Weights.IMAGENET1K_V1))
def resnext101_32x8d(
    *, weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
    """ResNeXt-101 32x8d model from
    `Aggregated Residual Transformation for Deep Neural Networks <https://arxiv.org/abs/1611.05431>`_.

    Args:
        weights (:class:`~torchvision.models.ResNeXt101_32X8D_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.ResNeXt101_32X8D_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
            for more details about this class.
    .. autoclass:: torchvision.models.ResNeXt101_32X8D_Weights
        :members:
    """
    weights = ResNeXt101_32X8D_Weights.verify(weights)

    _ovewrite_named_param(kwargs, "groups", 32)
    _ovewrite_named_param(kwargs, "width_per_group", 8)
    return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)


# @register_model()
@handle_legacy_interface(weights=("pretrained", ResNeXt101_64X4D_Weights.IMAGENET1K_V1))
def resnext101_64x4d(
    *, weights: Optional[ResNeXt101_64X4D_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
    """ResNeXt-101 64x4d model from
    `Aggregated Residual Transformation for Deep Neural Networks <https://arxiv.org/abs/1611.05431>`_.

    Args:
        weights (:class:`~torchvision.models.ResNeXt101_64X4D_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.ResNeXt101_64X4D_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
            for more details about this class.
    .. autoclass:: torchvision.models.ResNeXt101_64X4D_Weights
        :members:
    """
    weights = ResNeXt101_64X4D_Weights.verify(weights)

    _ovewrite_named_param(kwargs, "groups", 64)
    _ovewrite_named_param(kwargs, "width_per_group", 4)
    return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)


# @register_model()
@handle_legacy_interface(weights=("pretrained", Wide_ResNet50_2_Weights.IMAGENET1K_V1))
def wide_resnet50_2(
    *, weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
    """Wide ResNet-50-2 model from
    `Wide Residual Networks <https://arxiv.org/abs/1605.07146>`_.

    The model is the same as ResNet except for the bottleneck number of channels
    which is twice larger in every block. The number of channels in outer 1x1
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.

    Args:
        weights (:class:`~torchvision.models.Wide_ResNet50_2_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.Wide_ResNet50_2_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
            for more details about this class.
    .. autoclass:: torchvision.models.Wide_ResNet50_2_Weights
        :members:
    """
    weights = Wide_ResNet50_2_Weights.verify(weights)

    _ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
    return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)


# @register_model()
@handle_legacy_interface(weights=("pretrained", Wide_ResNet101_2_Weights.IMAGENET1K_V1))
def wide_resnet101_2(
    *, weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
    """Wide ResNet-101-2 model from
    `Wide Residual Networks <https://arxiv.org/abs/1605.07146>`_.

    The model is the same as ResNet except for the bottleneck number of channels
    which is twice larger in every block. The number of channels in outer 1x1
    convolutions is the same, e.g. last block in ResNet-101 has 2048-512-2048
    channels, and in Wide ResNet-101-2 has 2048-1024-2048.

    Args:
        weights (:class:`~torchvision.models.Wide_ResNet101_2_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.Wide_ResNet101_2_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
            for more details about this class.
    .. autoclass:: torchvision.models.Wide_ResNet101_2_Weights
        :members:
    """
    weights = Wide_ResNet101_2_Weights.verify(weights)

    _ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
    return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)


def get_resnet_model(cuda: bool, model_derivative: str, pretrained: bool = True, model_state_dict: dict = None, **kwargs):
    models = {
        "resnet18": resnet18,
        "resnet34": resnet34,
        "resnet50": resnet50,
        "resnet101": resnet101,
        "resnet152": resnet152
    }
    assert model_derivative in models.keys(), "Your selected resnet model derivative is unavailable"

    model = models[model_derivative](weights='ResNet50_Weights.DEFAULT', **kwargs) if pretrained else models[model_derivative](**kwargs)

    if model_state_dict:
        print("Loading pretrained model...")
        model.load_state_dict(model_state_dict)
        print("Finished.")
    else:
        print("Initializing parameters...")
        for para in model.parameters():
            if para.dim() > 1:
                nn.init.xavier_uniform_(para)
        print("Finished.")

    if cuda:
        model = model.to("cuda")
    return model

# Utils

In [None]:
### Early stopping
class EarlyStopper:
    def __init__(self, PATIENCE=1, MIN_DELTA=0):
        self.counter = 0
        self.min_delta = MIN_DELTA
        self.patience = PATIENCE  # Num of epoch that val loss is allowed to increase
        self.min_val_loss = float('inf')

    def check(self, val_loss: float):
        if val_loss < self.min_val_loss:
            self.min_val_loss = val_loss
            self.counter = 0
        elif val_loss > (self.min_val_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

### Logger
import json

from datetime import datetime
from multimethod import multimethod


class Logger:
    __time: datetime

    def __init__(self, phase: str = "train"):
        """
        phase: "train" || "test"
        """
        self.__time = {f"{phase.capitalize()} at": datetime.now().strftime("%d/%m/%Y %H:%M:%S")}


    @multimethod
    def write(self,  file: str, log_info: dict, writing_mode: str = "a") -> None:
        """
        This is used for writing multiple time log from model training
        """
        with open(file=file, mode=writing_mode, encoding="UTF-8", errors="ignore") as f:
            f.write(json.dumps(dict(self.__time, **log_info), indent=4))
            f.write(",\n")
        return None

    @multimethod
    def write(self,  file: str, log_info: str, writing_mode: str = "w") -> None:
        """
        This is used for writing one-time log from model testing
        """
        with open(file=file, mode=writing_mode, encoding="UTF-8", errors="ignore") as f:
            f.write(log_info)
        return None


### Utils
import os

from box import Box
from typing import Tuple, Dict, List
# from src.modelling.vgg import get_vgg_model
# from src.modelling.resnet import get_resnet_model

import torch
import torcheval

from torchsummary import summary
from torchvision.datasets import ImageFolder
from torchvision.transforms import InterpolationMode, Compose
from torch.utils.data import DataLoader, random_split, Dataset

from torch.optim import (
    Adam,
    AdamW,
    NAdam,
    RAdam,
    SparseAdam,
    Adadelta,
    Adagrad,
    Adamax,
    ASGD,
    RMSprop,
    Rprop,
    LBFGS,
    SGD
)

from torch.nn.modules import (
    NLLLoss,
    NLLLoss2d,
    CTCLoss,
    KLDivLoss,
    GaussianNLLLoss,
    PoissonNLLLoss,
    L1Loss,
    MSELoss,
    HuberLoss,
    SmoothL1Loss,
    CrossEntropyLoss,
    BCELoss,
    BCEWithLogitsLoss
)

from torcheval.metrics import (
    BinaryAccuracy,
    BinaryF1Score,
    BinaryPrecision,
    BinaryRecall,
    BinaryConfusionMatrix,
    MulticlassAccuracy,
    MulticlassF1Score,
    MulticlassPrecision,
    MulticlassRecall,
    MulticlassConfusionMatrix,
    BinaryBinnedPrecisionRecallCurve,
    MulticlassBinnedPrecisionRecallCurve,
    BinaryPrecisionRecallCurve
)

from torch.optim.lr_scheduler import (
    LambdaLR,
    MultiplicativeLR,
    StepLR,
    MultiStepLR,
    ConstantLR,
    LinearLR,
    ExponentialLR,
    PolynomialLR,
    CosineAnnealingLR,
    CosineAnnealingWarmRestarts,
    ChainedScheduler,
    SequentialLR,
    ReduceLROnPlateau,
    OneCycleLR
)

from torchvision.transforms.v2 import (
    # Color
    ColorJitter,
    Grayscale,
    RandomAdjustSharpness,
    RandomAutocontrast,
    RandomChannelPermutation,
    RandomEqualize,
    RandomGrayscale,
    RandomInvert,
    RandomPhotometricDistort,
    RandomPosterize,
    RandomSolarize,

    # Geometry
    CenterCrop,
    ElasticTransform,
    FiveCrop,
    Pad,
    RandomAffine,
    RandomCrop,
    RandomHorizontalFlip,
    RandomIoUCrop,
    RandomPerspective,
    RandomResize,
    RandomResizedCrop,
    RandomRotation,
    RandomShortestSize,
    RandomVerticalFlip,
    RandomZoomOut,
    Resize,
    ScaleJitter,
    TenCrop,

    # Meta
    ClampBoundingBoxes,
    ConvertBoundingBoxFormat,

    # Misc
    ConvertImageDtype,
    GaussianBlur,
    Identity,
    Lambda,
    LinearTransformation,
    Normalize,
    SanitizeBoundingBoxes,
    ToDtype,

    # Temporal
    UniformTemporalSubsample,

    # Type conversion
    PILToTensor,
    ToImage,
    ToPILImage,
    ToPureTensor
)


__all__ = ["get_dataset", "get_train_val_loader", "get_test_loader", "get_model_summary",
           "init_loss", "init_lr_scheduler", "init_metrics", "init_model", "init_model_optimizer_start_epoch"
           ]


def get_model_summary(model: torch.nn.Module, input_size: Tuple, device: str):
    return summary(model=model, input_size=input_size, device=device)


def get_transformation(transform_dict: Box = None) -> Compose:
    available_transform = {
        # Color
        "ColorJitter": ColorJitter,
        "Grayscale": Grayscale,
        "RandomAdjustSharpness": RandomAdjustSharpness,
        "RandomAutocontrast": RandomAutocontrast,
        "RandomChannelPermutation": RandomChannelPermutation,
        "RandomEqualize": RandomEqualize,
        "RandomGrayscale": RandomGrayscale,
        "RandomInvert": RandomInvert,
        "RandomPhotometricDistort": RandomPhotometricDistort,
        "RandomPosterize": RandomPosterize,
        "RandomSolarize": RandomSolarize,

        # Geometry
        "CenterCrop": CenterCrop,
        "ElasticTransform": ElasticTransform,
        "FiveCrop": FiveCrop,
        "Pad": Pad,
        "RandomAffine": RandomAffine,
        "RandomCrop": RandomCrop,
        "RandomHorizontalFlip": RandomHorizontalFlip,
        "RandomIoUCrop": RandomIoUCrop,
        "RandomPerspective": RandomPerspective,
        "RandomResize": RandomResize,
        "RandomResizedCrop": RandomResizedCrop,
        "RandomRotation": RandomRotation,
        "RandomShortestSize": RandomShortestSize,
        "RandomVerticalFlip": RandomVerticalFlip,
        "RandomZoomOut": RandomZoomOut,
        "Resize": Resize,
        "ScaleJitter": ScaleJitter,
        "TenCrop": TenCrop,

        # Meta
        "ClampBoundingBoxes": ClampBoundingBoxes,
        "ConvertBoundingBoxFormat": ConvertBoundingBoxFormat,

        # Misc
        "ConvertImageDtype": ConvertImageDtype,
        "GaussianBlur": GaussianBlur,
        "Identity": Identity,
        "Lambda": Lambda,
        "LinearTransformation": LinearTransformation,
        "Normalize": Normalize,
        "SanitizeBoundingBoxes": SanitizeBoundingBoxes,
        "ToDtype": ToDtype,

        # Temporal
        "UniformTemporalSubsample": UniformTemporalSubsample,

        # Type conversion
        "PILToTensor": PILToTensor,
        "ToImage": ToImage,
        "ToPILImage": ToPILImage,
        "ToPureTensor": ToPureTensor
    }

    # Took from InterpolationMode of pytorch
    available_interpolation = {
        "NEAREST": InterpolationMode.NEAREST,
        "BILINEAR": InterpolationMode.BILINEAR,
        "BICUBIC": InterpolationMode.BICUBIC,
        # For PIL compatibility
        "BOX": InterpolationMode.BOX,
        "HAMMING": InterpolationMode.HAMMING,
        "LANCZOS": InterpolationMode.LANCZOS,
    }

    available_dtype = {
        "complex64": torch.complex64,
        "complex128": torch.complex128,
        "float16": torch.float16,
        "float32": torch.float32,
        "float64": torch.float64,
        "uint8": torch.uint8,
        "int8": torch.int8,
        "int16": torch.int16,
        "int32": torch.int32,
        "int6": torch.int64
    }

    if transform_dict is not None:
        transform_lst: List[str] = transform_dict.NAME_LIST
        args: Box = transform_dict.ARGS
        # Verify transformation
        for i in range(len(transform_lst)):
            assert transform_lst[i] in available_transform.keys(), "Your selected transform is unavailable"

            # Verify interpolation mode & replace str name to its corresponding func
            if transform_lst[i] in ("Resize", "RandomRotation"):
                assert args[str(i)].interpolation in available_interpolation.keys(), "Your selected interpolation mode in unavailable"
                args[str(i)].interpolation = available_interpolation[args[str(i)].interpolation]

            # Verify dtype & replace str name to its corresponding func
            if transform_lst[i] == "ToDtype":
                assert args[str(i)].dtype in available_dtype.keys(), "Your selected dtype in unavailable"
                args[str(i)].dtype = available_dtype[args[str(i)].dtype]
        compose: Compose = Compose([available_transform[transform_lst[i]](**args[str(i)]) for i in range(len(transform_lst))])
    else:
        compose: Compose = Compose([])
    return compose


def get_dataset(root: str,
                transform: Box = None,
                target_transform: Box = None
                ) -> Dataset:
    """
    root: dataset dir
    input_shape: CHW
    transform: Dict of transformation name and its corresponding args
    target_transform:                     //                          but for labels/ target
    """
    return ImageFolder(root=root,
                       transform=get_transformation(transform_dict=transform),
                       target_transform=get_transformation(transform_dict=target_transform)
                       )


def get_train_val_loader(dataset: Dataset,
                         train_size: float, batch_size: int,
                         seed: int, cuda: bool, num_workers=1
                         ) -> Tuple[DataLoader, DataLoader]:
    train_size = round(len(dataset) * train_size)
    pin_memory = True if cuda is True else False  # Use page-locked or not

    train_set, validation_set = random_split(dataset=dataset,
                                             generator=torch.Generator().manual_seed(seed),
                                             lengths=[train_size, len(dataset) - train_size])

    train_set = DataLoader(dataset=train_set,
                           batch_size=batch_size,
                           shuffle=True,
                           num_workers=num_workers,
                           pin_memory=pin_memory
                           )

    validation_set = DataLoader(dataset=validation_set,
                                batch_size=batch_size,
                                shuffle=True,
                                num_workers=num_workers,
                                pin_memory=pin_memory
                                )
    return train_set, validation_set


def get_test_loader(dataset: Dataset, batch_size: int, cuda: bool, num_workers=1) -> DataLoader:
    # Use page-locked or not
    pin_memory = True if cuda is True else False
    return DataLoader(dataset=dataset,
                      batch_size=batch_size,
                      shuffle=True,
                      num_workers=num_workers,
                      pin_memory=pin_memory
                      )
##########################################################################################################################


def init_loss(name: str, args: Dict) -> torch.nn.Module:
    available_loss = {
        "NLLLoss": NLLLoss, "NLLLoss2d": NLLLoss2d,
        "CTCLoss": CTCLoss, "KLDivLoss": KLDivLoss,
        "GaussianNLLLoss": GaussianNLLLoss, "PoissonNLLLoss": PoissonNLLLoss,
        "CrossEntropyLoss": CrossEntropyLoss, "BCELoss": BCELoss, "BCEWithLogitsLoss": BCEWithLogitsLoss,
        "L1Loss": L1Loss, "MSELoss": MSELoss, "HuberLoss": HuberLoss, "SmoothL1Loss": SmoothL1Loss,
    }
    assert name in available_loss.keys(), "Your selected loss function is unavailable"
    loss: torch.nn.Module = available_loss[name](**args)
    return loss


def init_lr_scheduler(name: str, args: Dict, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler.LRScheduler:
    available_lr_scheduler = {
        "LambdaLR": LambdaLR, "MultiplicativeLR": MultiplicativeLR, "StepLR": StepLR, "MultiStepLR": MultiStepLR,
        "ConstantLR": ConstantLR,
        "LinearLR": LinearLR, "ExponentialLR": ExponentialLR, "PolynomialLR": PolynomialLR,
        "CosineAnnealingLR": CosineAnnealingLR,
        "CosineAnnealingWarmRestarts": CosineAnnealingWarmRestarts, "ChainedScheduler": ChainedScheduler,
        "SequentialLR": SequentialLR,
        "ReduceLROnPlateau": ReduceLROnPlateau, "OneCycleLR": OneCycleLR
    }
    assert name in available_lr_scheduler.keys(), "Your selected lr scheduler is unavailable"
    return available_lr_scheduler[name](optimizer, **args)


def init_metrics(name_lst: List[str], args: Dict, device: str) -> List[torcheval.metrics.Metric]:
    available_metrics = {
        "BinaryAccuracy": BinaryAccuracy,
        "BinaryF1Score": BinaryF1Score,
        "BinaryPrecision": BinaryPrecision,
        "BinaryRecall": BinaryRecall,
        "BinaryConfusionMatrix": BinaryConfusionMatrix,
        "BinaryPrecisionRecallCurve": BinaryPrecisionRecallCurve,
        "BinaryBinnedPrecisionRecallCurve": BinaryBinnedPrecisionRecallCurve,

        "MulticlassAccuracy": MulticlassAccuracy,
        "MulticlassF1Score": MulticlassF1Score,
        "MulticlassPrecision": MulticlassPrecision,
        "MulticlassRecall": MulticlassRecall,
        "MulticlassConfusionMatrix": MulticlassConfusionMatrix,
        "MulticlassBinnedPrecisionRecallCurve": MulticlassBinnedPrecisionRecallCurve
    }

    # check whether metrics available or not
    for metric in name_lst:
        assert metric in available_metrics.keys(), "Your selected metric is unavailable"

    metrics: List[torcheval.metrics.Metric] = []
    for i in range(len(name_lst)):
        metrics.append(available_metrics[name_lst[i]](**args[str(i)]))

    metrics = [metric.to(device) for metric in metrics]
    return metrics


def init_model(device: str, pretrained: bool, base: str,
               name: str, state_dict: dict, **kwargs) -> torch.nn.Module:
    available_bases = {
        "vgg": get_vgg_model,
        "resnet": get_resnet_model
    }
    assert base in available_bases.keys(), "Your selected base is unavailable"
    model: torch.nn.Module = available_bases[base](device, name, pretrained, state_dict, **kwargs)
    return model


def init_optimizer(name: str, model_paras, state_dict: Dict = None, **kwargs) -> torch.optim.Optimizer:
    available_optimizers = {
        "Adam": Adam, "AdamW": AdamW, "NAdam": NAdam, "Adadelta": Adadelta, "Adagrad": Adagrad, "Adamax": Adamax,
        "RAdam": RAdam, "SparseAdam": SparseAdam, "RMSprop": RMSprop, "Rprop": Rprop, "ASGD": ASGD, "LBFGS": LBFGS,
        "SGD": SGD
    }
    assert name in available_optimizers.keys(), "Your selected optimizer is unavailable."

    # init optimizer
    optimizer: torch.optim.Optimizer = available_optimizers[name](model_paras, **kwargs)

    if state_dict is not None:
        optimizer.load_state_dict(state_dict)
    return optimizer


def init_model_optimizer_start_epoch(device: str,
                                     checkpoint_load: bool, checkpoint_path: str, resume_name: str,
                                     optimizer_name: str, optimizer_args: Dict,
                                     model_base: str, model_name: str, model_args: Dict,
                                     pretrained: bool = False
                                     ) -> Tuple[int, torch.nn.Module, torch.optim.Optimizer]:
    model_state_dict = None
    optimizer_state_dict = None
    start_epoch = 1

    if checkpoint_load:
        checkpoint = torch.load(f=os.path.join(checkpoint_path, resume_name), map_location=device)
        start_epoch = checkpoint["epoch"] + 1
        model_state_dict = checkpoint["model_state_dict"]
        optimizer_state_dict = checkpoint["optimizer_state_dict"]

    model: torch.nn.Module = init_model(device=device, pretrained=pretrained, base=model_base,
                                        name=model_name, state_dict=model_state_dict, **model_args
                                        )

    optimizer: torch.optim.Optimizer = init_optimizer(name=optimizer_name, model_paras=model.parameters(),
                                                      state_dict=optimizer_state_dict, **optimizer_args
                                                      )
    return start_epoch, model, optimizer



# Trainer

In [None]:
import os

from box import Box
from tqdm import tqdm
from time import sleep
from typing import List, Dict
# from src.utils.logger import Logger
# from src.utils.early_stopping import EarlyStopper
# from src.utils.utils import init_loss, init_metrics, init_lr_scheduler, init_model_optimizer_start_epoch

import torch
import torcheval

from torch.nn.functional import sigmoid, softmax
from torch.utils.data import DataLoader


class Trainer:
    __options: Box
    __train_log_path: str
    __eval_log_path: str
    __checkpoint_path: str
    __device: str

    __train_loader: DataLoader
    __validation_loader: DataLoader

    __early_stopper: EarlyStopper
    __logger: Logger

    __loss: torch.nn.Module
    __optimizer: torch.optim.Optimizer
    __lr_schedulers: torch.optim.lr_scheduler.LRScheduler
    __start_epoch: int
    __model: torch.nn.Module

    __best_val_loss: float

    def __init__(self, options: Box,
                 train_log_path: str, eval_log_path: str, checkpoint_path: str,
                 train_loader: DataLoader, val_loader: DataLoader
                 ):
        self.__options: Box = options
        self.__train_log_path: str = train_log_path
        self.__eval_log_path: str = eval_log_path
        self.__checkpoint_path: str = checkpoint_path
        self.__device: str = "cuda" if self.__options.MISC.CUDA else "cpu"

        self.__train_loader: DataLoader = train_loader
        self.__validation_loader: DataLoader = val_loader

        self.__early_stopper: EarlyStopper = EarlyStopper(**self.__options.SOLVER.EARLY_STOPPING)
        self.__logger: Logger = Logger()

        self.__loss = init_loss(name=self.__options.SOLVER.LOSS.NAME, args=self.__options.SOLVER.LOSS.ARGS)
        self.__start_epoch, self.__model, self.__optimizer = init_model_optimizer_start_epoch(device=self.__device,
                                                                                              checkpoint_load=self.__options.CHECKPOINT.LOAD,
                                                                                              checkpoint_path=checkpoint_path,
                                                                                              resume_name=self.__options.CHECKPOINT.RESUME_NAME,
                                                                                              optimizer_name=self.__options.SOLVER.OPTIMIZER.NAME,
                                                                                              optimizer_args=self.__options.SOLVER.OPTIMIZER.ARGS,
                                                                                              model_base=self.__options.SOLVER.MODEL.BASE,
                                                                                              model_name=self.__options.SOLVER.MODEL.NAME,
                                                                                              model_args=self.__options.SOLVER.MODEL.ARGS
                                                                                              )
        self.__lr_schedulers: torch.optim.lr_scheduler.LRScheduler = init_lr_scheduler(
            name=self.__options.SOLVER.LR_SCHEDULER.NAME, args=self.__options.SOLVER.LR_SCHEDULER.ARGS,
            optimizer=self.__optimizer)
        self.__best_val_loss: float = self.__get_best_val_loss()

    @classmethod
    def __init_subclass__(cls):
        """Check indispensable args when instantiate Trainer"""
        required_class_variables = [
            "__options", "__train_log_path", "__eval_log_path", "__checkpoint_path", "__train_loader", "__val_loader"
        ]
        for var in required_class_variables:
            if not hasattr(cls, var):
                raise NotImplementedError(
                    f'Class {cls} lacks required `{var}` class attribute'
                )

    # Setter & Getter
    @property
    def model(self):
        return self.__model

    # Public methods
    def train(self, sleep_time: int = None, metric_in_train: bool = False) -> None:
        """
        sleep_time: temporarily cease the training process
        metric_in_train: compute metrics during training phase or not
        """
        print("Start training model ...")

        for epoch in range(self.__start_epoch, self.__start_epoch + self.__options.EPOCH.EPOCHS):
            print("Epoch:", epoch)

            for phase, dataset_loader, log_path in zip(("train", "eval"),
                                                       (self.__train_loader, self.__validation_loader),
                                                       (self.__train_log_path, self.__eval_log_path)):
                # Preliminary setups
                self.__model.train() if phase == "train" else self.__model.eval()
                metrics: List[torcheval.metrics.Metric] = init_metrics(name_lst=self.__options.METRICS.NAME_LIST,
                                                                       args=self.__options.METRICS.ARGS,
                                                                       device=self.__device) if metric_in_train else None

                # Epoch running
                run_epoch_result: Dict = self.__run_epoch(phase=phase, epoch=epoch, dataset_loader=dataset_loader,
                                                          metrics=metrics)

                # Logging
                self.__logger.write(log_path, {**{"epoch": epoch}, **run_epoch_result})

                if phase == "eval":
                    # Save checkpoint
                    if self.__options.CHECKPOINT.SAVE:
                        self.__save_checkpoint(epoch=epoch, val_loss=run_epoch_result["loss"],
                                               save_all=self.__options.CHECKPOINT.SAVE_ALL,
                                               obj={"epoch": epoch, "val_loss": run_epoch_result["loss"],
                                                    "model_state_dict": self.__model.state_dict(),
                                                    "optimizer_state_dict": self.__optimizer.state_dict()
                                                    }
                                               )

                    # Early stopping checking
                    if self.__options.MISC.APPLY_EARLY_STOPPING:
                        if self.__early_stopper.check(val_loss=run_epoch_result["loss"]):
                            exit()

                # Stop program in the meantime
                if sleep_time is not None:
                    sleep(sleep_time)
        return None

    # Private methods
    def __run_epoch(self, phase: str, epoch: int, dataset_loader: DataLoader,
                    metrics: List[torcheval.metrics.Metric] = None) -> Dict:
        """
        phase: "train" || "eval"
        dataset_loader: train_loader || val_loader
        metrics: only available in eval phase

        Notes: loss of last iter is taken as loss of that epoch
        """
        num_class = self.__options.SOLVER.MODEL.ARGS.num_classes
        total_loss = 0

        # Epoch training
        for index, batch in tqdm(enumerate(dataset_loader), total=len(dataset_loader), colour="cyan",
                                 desc=phase.capitalize()):
            imgs, labels = batch[0].type(torch.FloatTensor).to(self.__device), batch[1]

            # reset gradients prior to forward pass
            self.__optimizer.zero_grad()

            with torch.set_grad_enabled(phase == "train"):
                # forward pass
                pred_labels = self.__model(imgs)
                if num_class == 1:
                    # Shape: N1 -> N
                    pred_labels = sigmoid(pred_labels).squeeze(dim=1)
                    labels = labels.type(torch.FloatTensor).to(self.__device)
                else:
                    # Shape: NC -> N
                    pred_labels = softmax(pred_labels, dim=1)

                # Update loss
                mini_batch_loss = self.__loss(pred_labels, labels)

                # backprop + optimize only if in training phase
                if phase == 'train':
                    mini_batch_loss.backward()
                    self.__optimizer.step()
                    self.__lr_schedulers.step(epoch=epoch + index / len(dataset_loader))

                # Update metrics only if eval phase
                if metrics is not None:
                    metrics = [metric.update(pred_labels, labels) for metric in metrics]
            # Accumulate minibatch into total loss
            total_loss += mini_batch_loss.item()

        if metrics is not None:
            metrics_name = self.__options.METRICS.NAME_LIST
            metric_val = [metric.compute().item() for metric in metrics]
            training_result = {**{"loss": total_loss / len(dataset_loader)},
                               **{metric_name: value for metric_name, value in zip(metrics_name, metric_val)}
                               }
        else:
            training_result = {"loss": total_loss / len(dataset_loader)}
        return training_result

    def __save_checkpoint(self, epoch: int, val_loss: float, obj: dict, save_all: bool = False) -> None:
        """
        save_all:
            True: save all trained epoch
            False: save only last and the best trained epoch
        Best_epoch is still saved in either save_all is True or False
        """
        save_name = os.path.join(self.__checkpoint_path, f"epoch_{epoch}.pt")
        torch.save(obj=obj, f=save_name)

        # Save best checkpoint
        if val_loss < self.__best_val_loss:
            save_name = os.path.join(self.__checkpoint_path, f"best_checkpoint.pt")
            torch.save(obj=obj, f=save_name)

            # Update best accuracy
            self.__best_val_loss = val_loss

        if not save_all and epoch - 1 > 0:
            # Remove previous epoch
            os.remove(os.path.join(self.__checkpoint_path, f"epoch_{epoch - 1}.pt"))
        return None

    def __get_best_val_loss(self) -> float:
        if "best_checkpoint.pt" in os.listdir(self.__checkpoint_path):
            return torch.load(f=os.path.join(self.__checkpoint_path, "best_checkpoint.pt"))["val_loss"]
        else:
            return 1e9

# Main

In [None]:
import os
import shutil
from box import Box


import os
import commentjson

from box import Box
# from src.tools.train import Trainer
# from src.tools.inference import inference
# from src.utils.utils import get_train_val_loader, get_test_loader, get_dataset

from torch.utils.data import DataLoader


def train() -> None:
    checkpoint_path = os.path.join(os.getcwd(), "checkpoints", options.SOLVER.MODEL.NAME)
    log_path = os.path.join(os.getcwd(), "logs", options.SOLVER.MODEL.NAME)

    if not os.path.isdir(checkpoint_path):
        os.makedirs(checkpoint_path, 0o777, True)
        print(f"Checkpoint dir for {options.SOLVER.MODEL.NAME} was created.")

    if not os.path.isdir(log_path):
        os.makedirs(log_path, 0o777, True)
        print(f"Log dir checkpoint for {options.SOLVER.MODEL.NAME} was created.")

    train_log_path = os.path.join(log_path, f"training_log.json")
    eval_log_path = os.path.join(log_path, f"eval_log.json")


    train_set = get_dataset(root=os.path.join("/kaggle/input/celeb-a", options.DATA.DATASET_NAME, "train"),
                            transform=options.DATA.TRANSFORM)

    train_loader, val_loader = get_train_val_loader(dataset=train_set,
                                                    train_size=options.DATA.TRAIN_SIZE,
                                                    batch_size=options.DATA.BATCH_SIZE, seed=options.MISC.SEED,
                                                    cuda=options.MISC.CUDA, num_workers=options.DATA.NUM_WORKERS
                                                    )
    print(f"""Train batch: {len(train_loader)}, Validation batch: {len(val_loader)}
Training model {options.SOLVER.MODEL.NAME}
""")

    trainer = Trainer(options=options,
                      train_log_path=train_log_path,
                      eval_log_path=eval_log_path,
                      checkpoint_path=checkpoint_path,
                      train_loader=train_loader,
                      val_loader=val_loader
                      )
    trainer.train(metric_in_train=True)
    print("Training finished")
    return None


def test(option_path: str) -> None:
    for dataset in (["celeb_A", "collected_v3", "collected_v4"]):
        options = Box(commentjson.loads(open(file=option_path, mode="r").read()))
        checkpoint_path = os.path.join(os.getcwd(), "checkpoints", options.MODEL.NAME, options.CHECKPOINT.NAME)

        options.DATA.DATASET_NAME = dataset
        log_path = os.path.join(os.getcwd(), "logs", options.MODEL.NAME, f"testing_log_{dataset}.json")

        test_set = get_dataset(root=os.path.join(os.getcwd(), options.DATA.DATASET_NAME, "test"), transform=options.DATA.TRANSFORM)

        test_loader: DataLoader = get_test_loader(dataset=test_set,
                                                  batch_size=options.DATA.BATCH_SIZE,
                                                  cuda=options.MISC.CUDA,
                                                  num_workers=options.DATA.NUM_WORKERS
                                                  )
        print(f"""Test batch: {len(test_loader)}""")

        inference(options=options, checkpoint_path=checkpoint_path, log_path=log_path, test_loader=test_loader, num_threshold=2)
    return None


def main() -> None:
    train()
    # test(option_path=os.path.join(os.getcwd(), "configs", "inference_config.json"))
    return None


if __name__ == '__main__':
    main()