In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Callable, List, Optional, Type, Union
# from ._api import register_model, Weights, WeightsEnum

In [18]:
device = torch.device('cuda')

if device == 'cuda':
    print(torch.cuda.get_device_name())
else:
    print('CUDA not available')

CUDA not available


In [19]:
class conv3x3(nn.Module):
    def __init__(self, in_c, out_c, kernel = 3, stride=1, padding=1):
        super().__init__()

        self.conv = nn.Conv2d(
            in_c, out_c,
            kernel_size = kernel,
            stride = stride,
            padding = padding
        )

    def forward(self, x):
        return self.conv(x)


class conv1x1(nn.Module):
    def __init__(self, in_c, out_c, kernel = 1, stride=1, padding=1):
        super().__init__()

        self.conv = nn.Conv2d(
            in_c, out_c,
            kernel_size = kernel,
            stride = stride,
            apdding = padding
        )

    def forward(self, x):
        return self.conv(x)

In [20]:
class Bottleneck(nn.Module):
    def __init__(self, in_c, out_c, down_sample = None, stride = 1):
        super().__init__()

        self.expansion = 4                            #what's the utility of it?

        self.conv1 = conv1x1(in_c, out_c, stride = stride, padding = 0)
        self.batch_norm1 = nn.BatchNorm2d(out_c)

        self.conv2 = conv3x3(out_c, out_c, stride = stride, padding = 0)                    #stride values need to be checked
        self.batch_norm2 = nn.BatchNorm2d(out_c)

        self.conv1 = conv3x3(in_c, out_c * self.expansion, stride = stride, padding = 0)
        self.batch_norm3 = nn.BatchNorm2d(out_c * self.expansion)

        self.downsample = down_sample
        self.stride = stride
        self.relu = nn.ReLU()

    def forward(self, x):
        x_cpy = x
        
        x = self.conv1(x)
        x = self.batch_norm1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.batch_norm2(x)
        x = self.relu(x)

        x = self.conv3(x)
        x = self.batch_norm3(x)

        if self.downsample != None:
            x_cpy = self.downsample(x_cpy)

        x += x_cpy
        x = self.relu(x)

        return x       
        

In [21]:
class BasicBlock(nn.Module):
    def __init__(self, in_c, out_c, kernel, stride, padding, norm_layer, groups, base_width):
        super().__init__(BasicBlock, self)

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

        if groups != 1 and base_width != 64:
            raise ValueError('Only groups=1 & width=64 are supported.')

        if dilation>1:
            raise NotImplementedError("Dilation>1 can't be handled.")

        self.conv1 = conv3x3(in_c, out_c, kernel, stride, padding)
        self.bn1 = norm_layer(out_c)
        self.relu = nn.ReLU(inplace = True)
        self.conv2 = conv1x1(out_c, out_c, kernel, stride, padding)
        self.bn2 = norm_layer(out_c)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        x_cpy = x

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)

        if self.downsample != None:
            x_cpy = self.downsample(x_cpy)

        x += x_cpy
        x = self.relu(x)

        return x
        

In [22]:
class resnet34(nn.Module):
    def __init__(self, ):
        super().__init(resnet34, self)

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

        self._norm_layer = norm_layer

        self.in_c = 64
        self.dialation = 1
        
        if replace_stride_with_dilation is None:
            replace_stride_with_dilation = [False, False, False]

        if len(replace_stride_with_dilation) != 3:
            raise ValueError("Length of replace_stride_with_dilation isn't 3")

        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = conv3x3(3, in_c, kernel_size=7, stride=2, padding=3, bias = False)
        self.bn1 = norm_layer(self.in_c)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=3, bias=False)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layer[1], stride=2, dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 128, layer[2], stride=2, dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 128, layer[3], stride=2, dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)


        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode = 'fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, BottleNeck):
                    nn.init.constant_(m.bn3,weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)


    def make_layer(self, block, out_c, blocks, stride, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        prev_dilation = self.dilation

        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.in_c != out_c * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.in_c, planes * block.expansion, stride),
                norm_layer(out_c * block.expansion)
            )

        layers = []
        layers.append(
            block(
                self.in_c, out_c, stride, downsample, self.groups, self.base_width, prev_dilation, norm_layer
            )
        )
        self.in_c = out_c * block.expansion

        for _ in range(1, blocks):
            layers.append(
                block(
                    self.in_c,
                    out_c,
                    groups = self.groups,
                    base_width = self.base_width,
                    dilation = self.dilation,
                    norm_layer = norm_layer
                )
            )
            
        return nn.Sequential(*layers)


    def _forward_impl(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

    def forward(self, x):
        return self._forward_impl(x)

def _resnet(
    block: Type[Union[BasicBlock, Bottleneck]],
    layers: List[int],
    weights: Optional[WeightsEnum],
    progress: bool,
    **kwargs: Any,
) -> ResNet:
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

    model = ResNet(block, layers, **kwargs)

    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))

    return model


_COMMON_META = {
    "min_size": (1, 1),
    "categories": _IMAGENET_CATEGORIES,
}


# class ResNet34_Weights(WeightsEnum):
#     IMAGENET1K_V1 = Weights(
#         url="https://download.pytorch.org/models/resnet34-b627a593.pth",
#         transforms=partial(ImageClassification, crop_size=224),
#         meta={
#             **_COMMON_META,
#             "num_params": 21797672,
#             "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
#             "_metrics": {
#                 "ImageNet-1K": {
#                     "acc@1": 73.314,
#                     "acc@5": 91.420,
#                 }
#             },
#             "_ops": 3.664,
#             "_file_size": 83.275,
#             "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
#         },
#     )
#     DEFAULT = IMAGENET1K_V1



        

NameError: name 'WeightsEnum' is not defined