In [1]:
import os

import numpy as np

import torch
import torchvision

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [2]:
BATCH_SIZE = 10

In [3]:
base_path = os.environ['HOME'] + '/cifar100/test'
if os.path.isdir(base_path):
  pass
else:
  os.makedirs(base_path)

In [4]:
writer = SummaryWriter(log_dir=base_path + '/runs/test')

In [5]:
transform = transforms.Compose(
    [transforms.Pad(padding=(2, 2, 2, 2)), 
     transforms.RandomCrop(size=32),
     torchvision.transforms.RandomHorizontalFlip(p=0.5),
     torchvision.transforms.Resize(size=[224, 224]),
     transforms.ToTensor(),
     transforms.Normalize(
       mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
train = torchvision.datasets.CIFAR100(
  root=base_path + '/data', train=True, download=True, transform=transform) 
test = torchvision.datasets.CIFAR100(
  root=base_path + '/data', train=False, download=True, transform=transform) 

trainlist = torch.utils.data.random_split(train, [40000, 10000])
train, val = trainlist[0], trainlist[1]

trainloader = torch.utils.data.DataLoader(
  train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
valloader = torch.utils.data.DataLoader(
  val, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(
  test, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to /Users/shinjisoo/cifar100/test/data/cifar-100-python.tar.gz


100.0%

Extracting /Users/shinjisoo/cifar100/test/data/cifar-100-python.tar.gz to /Users/shinjisoo/cifar100/test/data
Files already downloaded and verified


In [6]:
# ResNet
class ResNet34(nn.Module):
  """
  input tensor shape. (batch, 3, 224, 224)
  output logit shape. (batch, 100)
  """
  def __init__(self):
    super(ResNet34, self).__init__()
    self.conv1 = nn.Conv2d(
      in_channels=3, out_channels=64, kernel_size=(7, 7), stride=2)
    # batch x 64 x 109 x 109
    self.maxpool = nn.MaxPool2d(kernel_size=(3, 3), stride=2)
    # batch x 64 x 54 x 54
 
    
    self.idt2 = nn.Identity()
    # batch x 64 x 54 x 54
    self.conv2list = [
      nn.Conv2d(64, 64, (3, 3), padding=1) for x in range(2)]
    # batch x 64 x 54 x 54
    
    self.conv2to3 = nn.Conv2d(64, 128, (3, 3), padding=1)
    # batch x 128 x 54 x 54 
    
    self.idt3 = nn.Identity()
    # batch x 128 x 54 x 54
    self.conv3list = [
      nn.Conv2d(128, 128, (3, 3), padding=1) for x in range(2)]
    # batch x 128 x 54 x 54
    self.conv3to4 = nn.Conv2d(128, 256, (3, 3), padding=1)
    # batch x 256 x 54 x 54
    
    self.idt4 = nn.Identity()
    # batch x 256 x 54 x 54
    self.conv4list = [
      nn.Conv2d(256, 256, (3, 3), padding=1) for x in range(2)]
    # batch x 256 x 54 x 54
    self.conv4to5 = nn.Conv2d(256, 512, (3, 3))
    # batch x 512 x 54 x 54
    
    self.idt5 = nn.Identity()
    # batch x 512 x 54 x 54
    self.conv5list = [
      nn.Conv2d(512, 512, (3, 3), padding=1) for x in range(2)]
    # batch x 512 x 54 x 54

    self.batchnorm1 = nn.BatchNorm2d(num_features=64)
    self.batchnorm2 = nn.BatchNorm2d(num_features=128)
    self.batchnorm3 = nn.BatchNorm2d(num_features=256)
    self.batchnorm4 = nn.BatchNorm2d(num_features=512)
    
    # global average pooling 에서 나오는 최종 output은 [batch, 512]가 되도록
    self.globalavgpool = nn.AvgPool2d(kernel_size=(52, 52)) 
    self.fc = nn.Linear(512, 100)
    self.softmax = nn.Softmax()
    
  def forward(self, x):
    def _list_to_layer(x, layer_list, batchnorm_obj):
      for val in layer_list:
        x = val(x)
        x = batchnorm_obj(x)
        x = F.relu(x)
      return x
    # 1st conv, maxpool
    x = self.conv1(x)
    x = F.relu(self.batchnorm1(x))
    x = self.maxpool(x)

    # 2nd residual block
    idt_2_1 = self.idt2(x)
    x = _list_to_layer(x, self.conv2list, self.batchnorm1)
    x = idt_2_1 + x
    idt_2_2 = self.idt2(x)
    x = _list_to_layer(x, self.conv2list, self.batchnorm1)
    x = idt_2_2 + x
    idt_2_3 = self.idt2(x)
    x = _list_to_layer(x, self.conv2list, self.batchnorm1)
    x = idt_2_3 + x
    x = self.conv2to3(x)
    
    # 3rd redsidual block
    idt_3_1 = self.idt3(x)
    x = _list_to_layer(x, self.conv3list, self.batchnorm2)
    x = idt_3_1 + x
    idt_3_2 = self.idt3(x)
    x = _list_to_layer(x, self.conv3list, self.batchnorm2)
    x = idt_3_2 + x
    idt_3_3 = self.idt3(x)
    x = _list_to_layer(x, self.conv3list, self.batchnorm2)
    x = idt_3_3 + x
    idt_3_4 = self.idt3(x)
    x = _list_to_layer(x, self.conv3list, self.batchnorm2)
    x = idt_3_4 + x
    x = self.conv3to4(x)
    
    # 4th residual block
    idt_4_1 = self.idt4(x)
    x = _list_to_layer(x, self.conv4list, self.batchnorm3)
    x = idt_4_1 + x
    idt_4_2 = self.idt4(x)
    x = _list_to_layer(x, self.conv4list, self.batchnorm3)
    x = idt_4_2 + x
    idt_4_3 = self.idt4(x)
    x = _list_to_layer(x, self.conv4list, self.batchnorm3)
    x = idt_4_3 + x
    idt_4_4 = self.idt4(x)
    x = _list_to_layer(x, self.conv4list, self.batchnorm3)
    x = idt_4_4 + x
    idt_4_5 = self.idt4(x)
    x = _list_to_layer(x, self.conv4list, self.batchnorm3)
    x = idt_4_5 + x
    idt_4_6 = self.idt4(x)
    x = _list_to_layer(x, self.conv4list, self.batchnorm3)
    x = idt_4_6 + x
    x = self.conv4to5(x)
    
    # 5th residual block
    idt_5_1 = self.idt5(x)
    x = _list_to_layer(x, self.conv5list, self.batchnorm4)
    x = idt_5_1 + x
    idt_5_2 = self.idt5(x)
    x = _list_to_layer(x, self.conv5list, self.batchnorm4)
    x = idt_5_2 + x
    idt_5_3 = self.idt5(x)
    x = _list_to_layer(x, self.conv5list, self.batchnorm4)
    x = idt_5_3 + x
    
    x = self.globalavgpool(x).view([-1, 512])
    x = self.fc(x)
    x = self.softmax(x)
    
    return x

In [7]:
resnet = ResNet34()

In [49]:
resnet

ResNet34(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2))
  (maxpool): MaxPool2d(kernel_size=(3, 3), stride=2, padding=0, dilation=1, ceil_mode=False)
  (idt2): Identity()
  (conv2to3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (idt3): Identity()
  (conv3to4): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (idt4): Identity()
  (conv4to5): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1))
  (idt5): Identity()
  (batchnorm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (batchnorm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (batchnorm3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (batchnorm4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (globalavgpool): AvgPool2d(kernel_size=(52, 52), stride=(52, 52), padding=0)
  (fc): Linear(in_features=512, out_features=100, bias=Tru

In [None]:
hypothesis = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet.parameters())

In [None]:
for epoch in range(10):
  running_loss = 0.
  for i, data in enumerate(trainloader):
    inputs, labels = data
    optimizer.zero_grad()
    logits = resnet(inputs)
    loss = hypothesis(logits, labels)
    loss.backward()
    optimizer.step()
    
    running_loss += loss.item()
    
    if i % 10 == 9:
      val_inputs, val_labels = next(iter(valloader))
      val_logits = resnet(val_inputs)
      val_loss = hypothesis(val_logits, val_labels)
      correct_case = (val_logits.max(dim=1)[1] == val_labels).sum()
      all_case = val_labels.shape[0]
      accuracy = (correct_case/all_case)*100
      print(
        "Epoch: {} / step: {}".format(epoch, i), 
        "train loss:  {0:.3f}".format(loss), "\t",
        "Epoch: {} / step: {}".format(epoch, i),
        "validation accuracy: {0:.3f}%".format(accuracy))
      writer.add_scalar(
        'training loss', loss.item(), epoch*len(trainloader)+i)
      writer.add_scalar(
        'validation loss', val_loss.item(), epoch*len(trainloader)+i)
      writer.add_scalar(
        'accuracy', accuracy, epoch*len(trainloader)+i)
      
      torch.save({'epoch': epoch, 'model_state_dict': resnet.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'loss': loss}, base_path + "/resnet34_{}.pt".format(epoch))
      running_loss = 0.
print("Finish")