- https://github.com/ndb796/Deep-Learning-Paper-Review-and-Practice/blob/master/code_practices/ResNet18_MNIST_Train.ipynb

In [1]:
# ResNet18 모델 정의 및 인스턴스 초기화
import torch
import torch.nn as nn 
import torch.nn.functional as F 
import torch.backends.cudnn as cudnn 
import torch.optim as optim 
import os 

from torch.utils.data import DataLoader

In [2]:
# ResNet18을 위해 최대한 간단히 수정한 BasicBlock 클래스 정의 
class BasicBlock(nn.Module):
  def __init__(self, in_planes, planes, stride=1):
    super(BasicBlock, self).__init__()

    # 3x3 필터를 사용(너비와 높이를 줄일 때는 stride 값 조절)
    self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
    self.bn1 = nn.BatchNorm2d(planes) # 배치 정규화(Batch normalization)

    # 3x3 필터를 사용(패딩을 1만큼 주기 때문에 너비와 높이가 동일)
    self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(planes) # 배치 정규화(Batch normalization)

    self.shortcut = nn.Sequential() # identity인 경우
    if stride != 1: # stride가 1이 아니라면, Identity mapping이 아닌 경우
      self.shortcut = nn.Sequential(
        nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
        nn.BatchNorm2d(planes)
      )

  def forward(self, x):
    out = F.relu(self.bn1(self.conv1(x)))
    out = self.bn2(self.conv2(out))
    out += self.shortcut(x) # (핵심) skip connection
    out = F.relu(out)
    return out

In [22]:
# ResNet 클래스 정의
class ResNet(nn.Module):
  def _make_layer(self, block, planes, num_blocks, stride):
    strides = [stride] + [1] * (num_blocks -1)
    layers = []

    for _stride in strides:
      layers.append(block(self.in_planes, planes, _stride))
      self.in_planes = planes # 다음 레이어를 위해 채널 수 변경
    
    return nn.Sequential(*layers)
  
  def __init__(self, block, num_blocks, num_classes=10):
    super(ResNet, self).__init__()
    self.in_planes = 64

    # 64개의 3x3 필터(filter)를 사용
    self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn1 = nn.BatchNorm2d(64)
    self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
    self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
    self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
    self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
    self.linear = nn.Linear(512, num_classes)

  def forward(self, x):
    print(f'x: {x.size()}')
    out = F.relu(self.bn1(self.conv1(x)))
    out = self.layer1(out)
    out = self.layer2(out)
    out = self.layer3(out)
    out = self.layer4(out)
    out = F.avg_pool2d(out, 4)
    out = out.view(out.size(0), -1)
    print(f'out: {out.size()}')
    out = self.linear(out)

    return out

In [23]:
# ResNet 18 함수 정의
def ResNet18():
  return ResNet(BasicBlock, [2, 2, 2, 2])

In [31]:
device = torch.device(
  'cuda' if torch.cuda.is_available() else 'cpu'
)
model = ResNet18().to(device)
x = torch.randn(128, 1, 28, 28).to(device)
output = model(x)
print(output.size())

x: torch.Size([128, 1, 28, 28])
out: torch.Size([128, 512])
torch.Size([128, 10])


In [32]:
from torchsummary import summary 

In [33]:
summary(model, (1, 28, 28), device=device.type)

x: torch.Size([2, 1, 28, 28])
out: torch.Size([2, 512])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 28, 28]             576
       BatchNorm2d-2           [-1, 64, 28, 28]             128
            Conv2d-3           [-1, 64, 28, 28]          36,864
       BatchNorm2d-4           [-1, 64, 28, 28]             128
            Conv2d-5           [-1, 64, 28, 28]          36,864
       BatchNorm2d-6           [-1, 64, 28, 28]             128
        BasicBlock-7           [-1, 64, 28, 28]               0
            Conv2d-8           [-1, 64, 28, 28]          36,864
       BatchNorm2d-9           [-1, 64, 28, 28]             128
           Conv2d-10           [-1, 64, 28, 28]          36,864
      BatchNorm2d-11           [-1, 64, 28, 28]             128
       BasicBlock-12           [-1, 64, 28, 28]               0
           Conv2d-13          [-1, 128, 14, 14]

In [25]:
# 데이터셋 다운로드 및 불러오기
import torchvision 
import torchvision.transforms as transforms 

transform_train = transforms.Compose([
  transforms.ToTensor()
])

transform_test = transforms.Compose([
  transforms.ToTensor()
])

train_dataset = torchvision.datasets.MNIST(root='../../data', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.MNIST(root='../../data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=4)

In [26]:
# 환경 설정 및 학습 함수 정의
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

net = ResNet18()
net = net.to(device)
net = torch.nn.DataParallel(net)
cudnn.benchmark = True 

learning_rate = 0.01
file_name = 'resnet18_mnist.pt'

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0002)


In [27]:
def train(epoch):
  print('\n[ Train epoch: %d ]' % epoch)
  net.train()
  train_loss = 0
  correct = 0
  total = 0

  for batch_idx, (inputs, targets) in enumerate(train_loader):
    inputs, targets = inputs.to(device), targets.to(device)
    optimizer.zero_grad()

    benign_outputs = net(inputs)
    loss = criterion(benign_outputs, targets)
    loss.backward()

    optimizer.step()
    train_loss += loss.item()
    _, predicted = benign_outputs.max(1)

    total += targets.size(0)
    correct += predicted.eq(targets).sum().item()

    if batch_idx % 100 == 0:
      print('\nCurrent batch: ', str(batch_idx))
      print('Current benign train accuracy: ', str(predicted.eq(targets).sum().item() / targets.size(0)))
      print('Current benign train loss: ', loss.item())

  print('\nTotal benign train accuarcy: ', 100. * correct / total)
  print('Total benign train loss: ', train_loss)

In [28]:
def test(epoch):
  print('\n[ Test epoch: %d' % epoch)
  net.eval()
  loss = 0
  correct = 0
  total = 0

  for batch_idx, (inputs, targets) in enumerate(test_loader):
    inputs, targets = inputs.to(device), targets.to(device)
    total += targets.size(0)

    outputs = net(inputs)
    loss += criterion(outputs, targets).item()

    _, predicted = outputs.max(1)
    correct += predicted.eq(targets).sum().item()

  print('\nTest accuarcy: ', 100. * correct / total)
  print('Test average loss: ', loss / total)

  state = {
    'net': net.state_dict()
  }
  if not os.path.isdir('checkpoint'):
    os.mkdir('checkpoint')
  torch.save(state, './checkpoint/'+file_name)
  print('Model Saved')

In [29]:
def adjust_learning_rate(optimizer, epoch):
  lr = learning_rate
  if epoch >= 5:
    lr /= 10
  
  for param_group in optimizer.param_groups:
    param_group['lr'] = lr

In [30]:
# 학습 진행
# MNIST 데이터셋에 대하여 전체 10 epoch로 99.5% test accuracy를 시연할 수 있습니다.
for epoch in range(0, 10):
  adjust_learning_rate(optimizer, epoch)
  train(epoch)
  test(epoch)


[ Train epoch: 0 ]
x: torch.Size([128, 1, 28, 28])
out: torch.Size([128, 512])

Current batch:  0
Current benign train accuracy:  0.140625
Current benign train loss:  2.3682894706726074
x: torch.Size([128, 1, 28, 28])
out: torch.Size([128, 512])
x: torch.Size([128, 1, 28, 28])
out: torch.Size([128, 512])
x: torch.Size([128, 1, 28, 28])
out: torch.Size([128, 512])
x: torch.Size([128, 1, 28, 28])
out: torch.Size([128, 512])
x: torch.Size([128, 1, 28, 28])
out: torch.Size([128, 512])
x: torch.Size([128, 1, 28, 28])
out: torch.Size([128, 512])
x: torch.Size([128, 1, 28, 28])
out: torch.Size([128, 512])
x: torch.Size([128, 1, 28, 28])
out: torch.Size([128, 512])


Traceback (most recent call last):
  File "/Users/gyoungwon-cho/.pyenv/versions/3.9-dev/lib/python3.9/multiprocessing/queues.py", line 251, in _feed
    send_bytes(obj)
  File "/Users/gyoungwon-cho/.pyenv/versions/3.9-dev/lib/python3.9/multiprocessing/connection.py", line 205, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/Users/gyoungwon-cho/.pyenv/versions/3.9-dev/lib/python3.9/multiprocessing/connection.py", line 416, in _send_bytes
    self._send(header + buf)
  File "/Users/gyoungwon-cho/.pyenv/versions/3.9-dev/lib/python3.9/multiprocessing/connection.py", line 373, in _send
    n = write(self._handle, buf)
Traceback (most recent call last):
Traceback (most recent call last):
  File "/Users/gyoungwon-cho/.pyenv/versions/3.9-dev/lib/python3.9/multiprocessing/queues.py", line 251, in _feed
    send_bytes(obj)
  File "/Users/gyoungwon-cho/.pyenv/versions/3.9-dev/lib/python3.9/multiprocessing/connection.py", line 205, in send_bytes
    self._send_bytes(m[offset:o

KeyboardInterrupt: 