In [15]:
# ResNet 
# 스킵 커넥션을 사용해서 VGG에 비해 훨씬 많은 층을 쌓는다
# nn.Module을 이용해서 신경망 내부의 데이터 흐름을 제어한다.
# 기울기 소실 : 은닉층이 깊어지면 입력층에 가까운 가중치들의 기울기가 0에 가까워지는 현상
# 배치 정규화 : 배치간의 차이를 정규화... 더 안정되게 학습
# 스킵 커넥션 : 은닉층을 거치지 않은 입력값과 은닉층의 결과를 더하는 구조

In [16]:
import torch
import torch.nn as nn


class BasicBlock(nn.Module):
  def __init__(self,in_channels, out_channels, kernel_size=3) -> None:
    super(BasicBlock, self).__init__()
    # 합성곱 정의
    self.c1 = nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,padding=1)
    self.c2 = nn.Conv2d(out_channels,out_channels,kernel_size=kernel_size,padding=1)
    self.downsample = nn.Conv2d(in_channels,out_channels,kernel_size=1) # 스킵커넥션을 위해서 입력과 출력의 채널 개수를 맞춤
    # 배치정규화
    self.bn1 = nn.BatchNorm2d(num_features=out_channels)
    self.bn2 = nn.BatchNorm2d(num_features=out_channels)
    
    self.relu = nn.ReLU()

  def forward(self,x):
    x_ = x  # 스킵커넥션을 위해 초기 입력 저장

    x = self.c1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.c2(x)
    x = self.bn2(x)

    # 합성곱의 결과와 입력의 채널 수를 맞춤
    x_ = self.downsample(x_)

    x += x_  # 합성곱층의 결과와 저장했던 입력값을 더함(스킵 커넥션)
    x = self.relu(x)    
    return x

In [17]:
# 입력 ->[기본블럭 평균폴링] -> [기본블럭 평균폴링]->[기본블럭 평균폴링] ->평탄화 ->분류기 ->출력
# 32x32 이미지가 4x4 가 될대까지 반복
# 32*32(입력) ->기본블럭(평균폴링)(16)->기본블럭(평균폴링)(8)->기본블럭(평균폴링)(4)
# 블럭을 3번

In [18]:
class ResNet(nn.Module):
  def __init__(self,num_classes=10) -> None:
    super(ResNet,self).__init__()
    # 기본블럭
    self.b1 = BasicBlock(in_channels=3, out_channels=64)
    self.b2 = BasicBlock(in_channels=64, out_channels=128)
    self.b3 = BasicBlock(in_channels=128, out_channels=256)

    # 폴링은 평균폴링, 평균값은 폴링의 커널안에 포함된 모든 픽셀의 정보를 담고있어서 맥스보다는 유리함
    self.pool = nn.AvgPool2d(kernel_size=2,stride=2)

    #분류기
    self.fc1 = nn.Linear(in_features=256*4*4, out_features=2048)
    self.fc2 = nn.Linear(in_features=2048, out_features=512)
    self.fc3 = nn.Linear(in_features=512, out_features=num_classes)

    self.relu = nn.ReLU()
  def forward(self, x):
    x = self.b1(x)
    x = self.pool(x)
    x = self.b2(x)
    x = self.pool(x)
    x = self.b3(x)
    x = self.pool(x)
    x = torch.flatten(x,start_dim=1)  # 평탄화
    # 분류기로 값을 예측
    x = self.fc1(x)
    x = self.relu(x)
    x = self.fc2(x)
    x = self.relu(x)
    x = self.fc3(x)  
    return x  

In [19]:
# 데이터 전처리
# 라이브러리 로드
import tqdm
from torchvision.datasets.cifar import CIFAR10
from torchvision.transforms import Compose, ToTensor, Resize
from torchvision.transforms import RandomHorizontalFlip,RandomCrop,Normalize
from torch.utils.data.dataloader import DataLoader
from torch.optim.adam import Adam

In [20]:
# training_data = CIFAR10(root = './', train=True, download=True, transform=ToTensor)
# rgb_m = training_data.data.mean(axis=(0,1,2)) / 255
# rgb_s = training_data.data.std(axis=(0,1,2)) / 255
# rgb_m, rgb_s

In [21]:
transforms = Compose([    
    RandomCrop((32,32),padding=4),
    RandomHorizontalFlip(p=0.5),
    ToTensor(),    
    Normalize(mean = (0.49139968, 0.48215841, 0.44653091), std=(0.24703223, 0.24348513, 0.26158784) )
])

In [22]:
# 데이터 로더정의
training_data = CIFAR10(root = './', train=True, download=True, transform=transforms)
test_data = CIFAR10(root = './', train=False, download=True, transform=transforms)

train_loader = DataLoader(training_data, batch_size=64,shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [23]:
device = "cuda" if torch.cuda.is_available() else 'cpu'
model = ResNet(num_classes=10)
model.to(device)

ResNet(
  (b1): BasicBlock(
    (c1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (c2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (downsample): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1))
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (b2): BasicBlock(
    (c1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (c2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (downsample): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
  )
  (b3): BasicBlock(
    (c1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))


In [24]:
a,b = next(iter(train_loader))
a.shape

torch.Size([64, 3, 32, 32])

In [25]:
# lr = 1e-4
# optim = Adam(model.parameters(), lr=lr)
# data, label = next(iter(train_loader))
# optim.zero_grad()
# preds = model(data.to(device))
# loss = nn.CrossEntropyLoss()(preds, label.to(device))
# loss.backward()
# optim.step()

In [None]:
# 학습 루프 정의
lr = 1e-4
optim = Adam(model.parameters(), lr=lr)
for epoch in range(30):
  it = tqdm.tqdm(train_loader)
  for data, label in it:    
    optim.zero_grad()
    preds = model(data.to(device))
    loss = nn.CrossEntropyLoss()(preds, label.to(device))
    loss.backward()
    optim.step()
    
    it.set_description(f"epoch:{epoch+1} loss:{loss.item()}")


epoch:1 loss:1.0946922302246094: 100%|██████████| 782/782 [00:45<00:00, 17.36it/s]
epoch:2 loss:1.149419903755188: 100%|██████████| 782/782 [00:44<00:00, 17.77it/s]
epoch:3 loss:0.8546673059463501: 100%|██████████| 782/782 [00:43<00:00, 17.95it/s]
epoch:4 loss:0.6087036728858948: 100%|██████████| 782/782 [00:44<00:00, 17.66it/s]
epoch:5 loss:0.4584844708442688: 100%|██████████| 782/782 [00:44<00:00, 17.67it/s]
epoch:6 loss:0.23697921633720398: 100%|██████████| 782/782 [00:44<00:00, 17.63it/s]
epoch:7 loss:0.37208831310272217: 100%|██████████| 782/782 [00:44<00:00, 17.71it/s]
epoch:8 loss:0.6237072944641113: 100%|██████████| 782/782 [00:44<00:00, 17.77it/s]
epoch:9 loss:0.21885251998901367: 100%|██████████| 782/782 [00:44<00:00, 17.49it/s]
epoch:10 loss:0.2856753468513489: 100%|██████████| 782/782 [00:44<00:00, 17.68it/s]
epoch:11 loss:0.7609376907348633: 100%|██████████| 782/782 [00:44<00:00, 17.72it/s]
epoch:12 loss:0.23405851423740387: 100%|██████████| 782/782 [00:43<00:00, 17.79it/s

In [None]:
torch.save(model.state_dict(),"/content/drive/MyDrive/Colab Notebooks/CIFAR10_ResNet.pth")

In [None]:
model.load_state_dict(torch.load('/content/drive/MyDrive/Colab Notebooks/CIFAR10_ResNet.pth',map_location = device))
num_corr = 0
with torch.no_grad():
  for data, label in test_loader:
    output = model(data.to(device))
    preds = output.data.max(1)[1]
    corr = preds.eq(label.to(device).data).sum().item()
    num_corr  += corr
print(f"accuracy : {num_corr / len(test_data)}")    