# Resnet 18

In [1]:
import torch.nn as nn
from torch.nn import functional as F

- residual connection

In [3]:
class Residual(nn.Module):
  def __init__(self, in_channel, out_channel, use_1x1Conv=False, stride=1):
    super.__init__()
    self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, stride= strides)
    self.bn1 = nn.BatchNorm2d(out_channel)
    self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1)
    self.bn2 = nn.BatchNorm2d(out_channel)

    if use_1x1Conv:
      self.conv3 = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=strides)
    else:
      self.conv3 = None
  
  def forward(self, X):
    out = F.relu(self.bn1(self.conv1(X)))
    out = self.bn2(self.conv2(out))
    if self.conv3:
      X = self.conv3(X)
    out = out + X
    return F.relu(out)

In [4]:
def residualBlock(in_channel, out_channel, num_residuals, first_block=False):
  blks = []
  for i in range(num_residuals):
    if i == 0 and not first_block:
      blks.append(Residual(in_channel, out_channel, use_1x1Conv=True, stride=2))
    else:
      blks.append(Residual(out_channel, out_channel))

  return blks

In [7]:
class ResNet(nn.Module):
  def __init__(self, input_channel, n_classes):
    super().__init__()
    self.b1 = nn.Sequential(
        nn.Conv2d(input_channel, 64, kernel_size=7, stride=2 , padding=3),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    )
    self.b2 = nn.Sequential(*residualBlock(64,64,2,first_block=True))
    self.b3 = nn.Sequential(*residualBlock(64,128,2))
    self.b4 = nn.Sequential(*residualBlock(128,256,2))
    self.b5 = nn.Sequential(*residualBlock(256,512,2))
    slef.finalLayer = nn.Sequential(
        nn.AdaptiveAvgPool2d((1,1)),
        nn.Flatten(),
        nn.Linear(512,n_classes)
    )

    self.b1.apply(self.init_weights)
    self.b2.apply(self.init_weights)
    self.b3.apply(self.init_weights)
    self.b4.apply(self.init_weights)
    self.b5.apply(self.init_weights)
    self.finalLayer.apply(self.init_weights)

  def init_weights(self, layer):
      if type(layer) == nn.Conv2d:
          nn.init.kaiming_normal_(layer.weight, mode='fan_out')
      if type(layer) == nn.Linear:
          nn.init.normal_(layer.weight, std=1e-3)
      if type(layer) == nn.BatchNorm2d:
          nn.init.constant_(layer.weight, 1)
          nn.init.constant_(layer.bias, 0)

  def forward(self, X):
      out = self.b1(X)
      out = self.b2(out)
      out = self.b3(out)
      out = self.b4(out)
      out = self.b5(out)
      out = self.finalLayer(out)

      return out