In [None]:
import torch
import torch.nn as nn
import torchvision

In [None]:
model = torchvision.models.resnet18(pretrained=False)

In [36]:
class Resnet18(nn.Module):
    
    def __init__(self, bottleneck_connection_channel=1):
        """
        bottleneck_connection_channel: connection channel for VOneBlock
        """
        super(Resnet18, self).__init__()
        
        self.customized_layer = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
            nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        )
        
        self.origin_layer = nn.Sequential(
            *list(model.children())[4:-1]
        )
            
        self.fc_layer = nn.Sequential(
            nn.Linear(in_features=512, out_features=10, bias=True)
        )
    
    def forward(self, x):
        out = self.customized_layer(x)
        out = self.origin_layer(out)
        out = nn.Flatten(out)
        out = nn.fc_layer(out)
            
        return out

In [37]:
resnet18 = Resnet18(bottleneck_connection_channel=1)

Resnet18(
  (customized_layer): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (origin_layer): Sequential(
    (0): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum