In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = 'cuda' if torch.cuda.is_available else 'cpu'
IMG_SIZE = 224

In [2]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        
        self.residual = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
        )
        
        self.shortcut = nn.Sequential()
        if stride != 1:  # 각 층의 첫번째 블록(stride=2) 일 때는 1x1 컨볼루션 적용해서 차원 맞춰줌
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        
    def forward(self, x):
        return F.relu(self.residual(x) + self.shortcut(x))

In [3]:
class ResNet(nn.Module):
    def __init__(self, in_channels, num_class):
        super(ResNet, self).__init__()
        
        self.conv_layer_1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        self.conv_layer_2 = self.make_conv_layer(in_channels=64, out_channels=64, num_blocks=2, first_stride=1)   
        self.conv_layer_3 = self.make_conv_layer(in_channels=64, out_channels=128, num_blocks=2, first_stride=2)
        self.conv_layer_4 = self.make_conv_layer(in_channels=128, out_channels=256, num_blocks=2, first_stride=2)
        self.conv_layer_5 = self.make_conv_layer(in_channels=256, out_channels=512, num_blocks=2, first_stride=2)
        
        self.fc_layer = nn.Linear(512, num_class)
    
    def forward(self, x):
        x = self.conv_layer_1(x)
        x = self.conv_layer_2(x)  # 56 * 56 * 64
        x = self.conv_layer_3(x)  # 28 * 28 * 128
        x = self.conv_layer_4(x)  # 14 * 14 * 256
        x = self.conv_layer_5(x)  #  7 *  7 * 512
        x = F.avg_pool2d(x, 7)    #  1 *  1 * 512
        x = x.view(x.size(0), -1) # flatten
        x = self.fc_layer(x)      # fully connected layer
        return x
    
    def make_conv_layer(self, in_channels, out_channels, num_blocks, first_stride):  # 레이어를 거쳤을 때 차원의 변화, 레이어의 블록 개수
        
        layers = []
        for stride in [first_stride] + [1] * (num_blocks  - 1):  # layer2 빼고 나머지 층들은 첫번째 블록에서 stride 가 2
            layers.append(BasicBlock(in_channels=in_channels, out_channels=out_channels, stride=stride))  # layer 를 구성하는 블록들을 모두 만듦
            in_channels = out_channels  # 첫번째 블록에서 차원 수가 바뀌고 나머지 블록에서는 그대로 유지
            
        return nn.Sequential(*layers)


In [4]:
model = ResNet(in_channels=3, num_class=10).to(device)
x = torch.randn(10, 3, IMG_SIZE, IMG_SIZE).to(device)
model(x)

tensor([[-0.1674,  0.2047,  0.5747,  0.0848, -0.2590, -1.4141,  0.0249,  0.0279,
          0.9875,  0.3798],
        [-0.0854,  0.4315,  0.5267,  0.1610, -0.1972, -1.3579, -0.1614,  0.0467,
          1.0248,  0.3960],
        [-0.2657,  0.3580,  0.6110,  0.1961, -0.2925, -1.3087, -0.1232,  0.0716,
          1.1121,  0.3374],
        [-0.0781,  0.4652,  0.4923,  0.1827, -0.1818, -1.3926, -0.0498,  0.1606,
          1.0780,  0.5255],
        [-0.0219,  0.4431,  0.6013,  0.0537, -0.1950, -1.3376, -0.1097,  0.0094,
          1.0720,  0.3729],
        [-0.1575,  0.3927,  0.5286,  0.0448, -0.2207, -1.4736,  0.0104,  0.0328,
          1.0741,  0.4644],
        [-0.0610,  0.3712,  0.4117,  0.1535, -0.1969, -1.2586, -0.1222,  0.0631,
          1.1849,  0.4746],
        [-0.1268,  0.4474,  0.5318,  0.1427, -0.3674, -1.3568, -0.1783,  0.1475,
          1.1886,  0.4684],
        [-0.1611,  0.4957,  0.5890,  0.1434, -0.0644, -1.5604, -0.2058,  0.0920,
          1.1791,  0.5571],
        [-0.0925,  