<a href="https://colab.research.google.com/github/haeniKim/metaverse-academy/blob/master/DL/230703_ResNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ResNet

* CNN 모델 중에서 많이 사용되는 모델

In [None]:
import torch
import torch.nn as nn

In [None]:
# ResNet의 Basic Block
# input(저장) - 저장 합성곱 3 * 3 - 배치정규화 - ReLu - 합성곱 3 * 3 - 배치정규화 - skip connection(input으로 들어온 것과 더하는 과정)

class BasicBlock(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_sizes = 3):
    super(BasicBlock, self).__init__()

    self.c1 = nn.Conv2d(in_channels, out_channels, kernel_size= kernel_sizes, padding = 1)
    self.c2 = nn.Conv2d(out_channels, out_channels, kernel_size=kernel_sizes, padding = 1)

    #input의 3채널을 64에 맞춤.
    self.downsample = nn.Conv2d(in_channels, out_channels, kernel_size= 1) #정보를 그대로 저장, 정보 압축하지 않으려면 필터를 사용하지 않고, kernel_size 는 1로 설정

    self.bn1 = nn.BatchNorm2d(num_features = out_channels)
    self.bn2 = nn.BatchNorm2d(num_features = out_channels)

    self.relu = nn.ReLU()

  def forward(self, x):
    x_ = x #입력으로 들어온 값 저장

    x = self.c1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.c2(x)
    x = self.bn2(x)

    #skip connection
    x_ = self.downsample(x_)
    x += x_
    x = self.relu(x)

    return x


* pytorch는 tensorflow와 달리 input이 자동으로 들어가지 않고, 값을 정해줘야 함.

In [None]:
class ResNet(nn.Module): #원래 resnet은 17개 블록이 돌아감.  nn.Module -> 상
  def __init__(self, num_classes = 10):
    super(ResNet, self).__init__()

    self.b1 = BasicBlock(in_channels = 3, out_channels= 64)
    self.b2 = BasicBlock(in_channels = 64, out_channels = 128)
    self.b3 = BasicBlock(in_channels = 128, out_channels = 256)

    self.pool = nn.AvgPool2d(kernel_size = 2, stride = 2) #이미지 반으로 줄임.

    #분류기
    self.fc1 = nn.Linear(in_features = 4096, out_features = 2048)
    self.fc2 = nn.Linear(in_features = 2048, out_features= 512)
    self.fc3 = nn.Linear(in_features = 512, out_features = num_classes)

    self.relu = nn.ReLU()

  def forward(self, x):
    x = self.b1(x)
    x = self.pool(x)

    x = self.b2(x)
    x = self.pool(x)

    x = self.b3(x)
    x = self.pool(x)

    #평탄화
    x = torch.flatten(x, start_dim = 1)

    #분류
    x = self.fc1(x)
    x = self.relu(x)
    x = self.fc2(x)
    x = self.relu(x)
    x = self.fc3(x)

    return x

In [None]:
model = ResNet(10)

In [None]:
from torchsummary import summary

summary(model, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,792
       BatchNorm2d-2           [-1, 64, 32, 32]             128
              ReLU-3           [-1, 64, 32, 32]               0
            Conv2d-4           [-1, 64, 32, 32]          36,928
       BatchNorm2d-5           [-1, 64, 32, 32]             128
            Conv2d-6           [-1, 64, 32, 32]             256
              ReLU-7           [-1, 64, 32, 32]               0
        BasicBlock-8           [-1, 64, 32, 32]               0
         AvgPool2d-9           [-1, 64, 16, 16]               0
           Conv2d-10          [-1, 128, 16, 16]          73,856
      BatchNorm2d-11          [-1, 128, 16, 16]             256
             ReLU-12          [-1, 128, 16, 16]               0
           Conv2d-13          [-1, 128, 16, 16]         147,584
      BatchNorm2d-14          [-1, 128,

### 팁

* torchsummary의 summary를 이용해서 in_features 구하기
* 예 `summary(model, (3, 32, 32))`

* data load

In [None]:
from torch.utils.data.dataloader import DataLoader
from torch.optim.adam import Adam
from torchvision.datasets.cifar import CIFAR10
from torchvision.transforms import Compose, RandomCrop, RandomHorizontalFlip, Normalize
import torchvision.transforms as T

In [None]:
transforms = Compose([
    RandomCrop((32, 32), padding = 4),
    RandomHorizontalFlip(p = 0.5),
    T.ToTensor(),
    Normalize(mean=(0.4914,0.4822,0.4465),std=(0.247,0.243,0.261))
  ])

In [None]:
train_data = CIFAR10(root = './', train = True, download = True, transform=transforms)
test_data = CIFAR10(root = './', train = False, download = True, transform = transforms)

train_loader = DataLoader(train_data, batch_size = 32, shuffle = True)
test_loader = DataLoader(test_data, batch_size = 32, shuffle = False)

device = "cuda" if torch.cuda.is_available() else 'cpu'

model = ResNet(10)

model.to(device)

Files already downloaded and verified
Files already downloaded and verified


ResNet(
  (b1): BasicBlock(
    (c1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (c2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (downsample): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1))
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (b2): BasicBlock(
    (c1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (c2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (downsample): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (b3): BasicBlock(
    (c1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))


In [None]:
# SummaryWriter 인스턴스 생성
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [None]:
from tqdm import tqdm

lr = 1e-4

optim = Adam(model.parameters(), lr = lr)

for epoch in range(30):
  for data, label in tqdm(train_loader):
    optim.zero_grad()

    preds = model(data.to(device))

    loss = nn.CrossEntropyLoss()(preds, label.to(device))
    loss.backward()
    optim.step()

  print(f'epoch : {epoch+1}, loss : {loss.item()}')
  #writer.add_scalar("Loss/train", loss.item(), epoch+1)

torch.save('cifar-resnet.pth')

100%|██████████| 1563/1563 [24:12<00:00,  1.08it/s]


epoch : 1, loss : 1.4115791320800781


100%|██████████| 1563/1563 [23:32<00:00,  1.11it/s]


epoch : 2, loss : 1.0136336088180542


100%|██████████| 1563/1563 [23:21<00:00,  1.12it/s]


epoch : 3, loss : 0.47634297609329224


100%|██████████| 1563/1563 [23:19<00:00,  1.12it/s]


epoch : 4, loss : 0.4536926746368408


100%|██████████| 1563/1563 [23:22<00:00,  1.11it/s]


epoch : 5, loss : 0.6580432057380676


100%|██████████| 1563/1563 [23:25<00:00,  1.11it/s]


epoch : 6, loss : 0.4621102213859558


100%|██████████| 1563/1563 [23:23<00:00,  1.11it/s]


epoch : 7, loss : 0.337603360414505


100%|██████████| 1563/1563 [23:24<00:00,  1.11it/s]


epoch : 8, loss : 0.6295395493507385


100%|██████████| 1563/1563 [23:24<00:00,  1.11it/s]


epoch : 9, loss : 0.4576990604400635


100%|██████████| 1563/1563 [23:26<00:00,  1.11it/s]


epoch : 10, loss : 0.5037466287612915


100%|██████████| 1563/1563 [23:26<00:00,  1.11it/s]


epoch : 11, loss : 0.48238715529441833


100%|██████████| 1563/1563 [23:25<00:00,  1.11it/s]


epoch : 12, loss : 0.555030107498169


 53%|█████▎    | 825/1563 [12:23<11:15,  1.09it/s]

<중요>

* batch normalization을 사용하는 이유
* skip connection을 하는 이유
* down sampling을 하는 이유