In [2]:
import torch
import torch.nn as nn

In [30]:
class BasicBlock(nn.Module):
  def __init__(self,in_channels,out_channels,stride=1):
    super(BasicBlock,self).__init__()

    self.residual_function = nn.Sequential(
        nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=stride,padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1),
        nn.BatchNorm2d(out_channels)
    )

    self.downsample = None # 안에 아무것도 없으면 f(x)=x가 되는 것

    # feature map size가 1/2 되는 지점 , skip connection 하기 위해서는 채널수와 feature map 사이즈를 맞춰야함
    if stride !=1 or in_channels != out_channels: 
        self.downsample = nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=stride),
            nn.BatchNorm2d(out_channels)
        )
  def forward(self,x):
    if self.downsample != None:
      return nn.ReLU(inplace=True)(self.residual_function(x)+self.downsample(x))
    else:
      return nn.ReLU(inplace=True)(self.residual_function(x)+x)

class ResNet(nn.Module):
  def __init__(self,block,num_block,num_classes):
    super(ResNet,self).__init__()

    self.in_channels = 64

    self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
    self.bn1 = nn.BatchNorm2d(64)
    self.relu = nn.ReLU(inplace=True)
    self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)

    self.layer1 = self._make_layer(block,64,num_block[0],1)
    self.layer2 = self._make_layer(block,128,num_block[1],2)
    self.layer3 = self._make_layer(block,256,num_block[1],2)
    self.layer4 = self._make_layer(block,512,num_block[1],2)

    self.avgpool = nn.AdaptiveAvgPool2d((1,1))
    self.fc = nn.Linear(512,num_classes)
  
  def forward(self,x):
    output = self.conv1(x)
    output = self.bn1(output)
    output = self.relu(output)
    output = self.maxpool(output)
    output = self.layer1(output)
    output = self.layer2(output)
    output = self.layer3(output)
    output = self.layer4(output)
    output = self.avgpool(output)
    output = output.reshape(output.size()[0],-1)
    output = self.fc(output)

    return output

  def _make_layer(self,block,out_channels,num_blocks,stride):
    strides = [stride] + [1]*(num_blocks-1)
    layers = []
    for stride in strides:
      layers.append(block(self.in_channels,out_channels,stride))
      self.in_channels=out_channels
    return nn.Sequential(*layers)
  
def resnet34(num_classes=1000):
  return ResNet(BasicBlock,[3,4,6,3],num_classes=num_classes)

def resnet18(num_classes=1000):
  return ResNet(BasicBlock,[2,2,2,2],num_classes=num_classes)

In [31]:
res32 = resnet34(1000)
print(res32)
res18 = resnet18(1000)
print(res18)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (residual_function): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (residual_function): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    

In [32]:
a= torch.randn(10,3,224,224)

In [33]:
output = res18(a)
print(output.size())

torch.Size([10, 1000])
