In [13]:
# Train ResNet-18 on CIFAR-10
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from torch import Tensor
from typing import Optional, Callable
from torchsummary import summary

In [14]:
# settings
seed = 42

torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

if torch.cuda.is_available():
  torch.cuda.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)

use_cuda = True
device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu")

# Hyperparameters
batch_size = 1000
learning_rate = 0.1
epochs = 20

In [15]:
# dataset loader
data_path = '../data/'

train_data_cifar10 = datasets.CIFAR10(data_path, train=True, download=True, transform=transforms.ToTensor())
test_data_cifar10 = datasets.CIFAR10(data_path, train=False, download=True, transform=transforms.ToTensor())

len(train_data_cifar10), len(test_data_cifar10)

Files already downloaded and verified
Files already downloaded and verified


(50000, 10000)

In [16]:
# data normalization
train_data_cifar10_mean = train_data_cifar10.data.mean(axis=(0,1,2))/255
train_data_cifar10_std = train_data_cifar10.data.std(axis=(0,1,2))/255

test_data_cifar10_mean = test_data_cifar10.data.mean(axis=(0,1,2))/255
test_data_cifar10_std = test_data_cifar10.data.std(axis=(0,1,2))/255

print(train_data_cifar10_mean, train_data_cifar10_std)
print(test_data_cifar10_mean, test_data_cifar10_std)

[0.49139968 0.48215841 0.44653091] [0.24703223 0.24348513 0.26158784]
[0.49421428 0.48513139 0.45040909] [0.24665252 0.24289226 0.26159238]


In [17]:
# data loader

# training data
train_data_cifar10.transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(train_data_cifar10_mean, train_data_cifar10_std)
])
train_loader = torch.utils.data.DataLoader(train_data_cifar10, batch_size=batch_size, shuffle=True)

# test data
test_data_cifar10.transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(train_data_cifar10_mean, train_data_cifar10_std)
])
test_loader = torch.utils.data.DataLoader(test_data_cifar10, batch_size=batch_size, shuffle=True)


In [18]:
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()

        # BatchNorm에 bias가 포함되어 있으므로, conv2d는 bias=False로 설정합니다.
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels * BasicBlock.expansion),
        )

        # identity mapping, input과 output의 feature map size, filter 수가 동일한 경우 사용.
        self.shortcut = nn.Sequential()

        self.relu = nn.ReLU()

        # projection mapping using 1x1conv
        if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * BasicBlock.expansion)
            )

    def forward(self, x):
        x = self.residual_function(x) + self.shortcut(x)
        x = self.relu(x)
        return x


class BottleNeck(nn.Module):
    expansion = 4
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()

        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels * BottleNeck.expansion),
        )

        self.shortcut = nn.Sequential()

        self.relu = nn.ReLU()

        if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels*BottleNeck.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels*BottleNeck.expansion)
            )
            
    def forward(self, x):
        x = self.residual_function(x) + self.shortcut(x)
        x = self.relu(x)
        return x

In [19]:
class ResNet(nn.Module):
    def __init__(self, block, num_block, num_classes=10, init_weights=True):
        super().__init__()

        self.in_channels=64

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
        self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
        self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
        self.conv5_x = self._make_layer(block, 512, num_block[3], 2)

        self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        # weights inittialization
        if init_weights:
            self._initialize_weights()

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion

        return nn.Sequential(*layers)

    def forward(self,x):
        output = self.conv1(x)
        output = self.conv2_x(output)
        x = self.conv3_x(output)
        x = self.conv4_x(x)
        x = self.conv5_x(x)
        x = self.avg_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

    # define weight initialization function
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

def resnet18():
    return ResNet(BasicBlock, [2,2,2,2])

In [20]:
model = resnet18().to(device)
x = torch.randn(3, 3, 32, 32).to(device)
output = model(x)
print(output.size())

torch.Size([3, 10])


In [21]:
summary(model, (3, 32, 32), device=device.type)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 16, 16]           9,408
       BatchNorm2d-2           [-1, 64, 16, 16]             128
              ReLU-3           [-1, 64, 16, 16]               0
         MaxPool2d-4             [-1, 64, 8, 8]               0
            Conv2d-5             [-1, 64, 8, 8]          36,864
       BatchNorm2d-6             [-1, 64, 8, 8]             128
              ReLU-7             [-1, 64, 8, 8]               0
            Conv2d-8             [-1, 64, 8, 8]          36,864
       BatchNorm2d-9             [-1, 64, 8, 8]             128
             ReLU-10             [-1, 64, 8, 8]               0
       BasicBlock-11             [-1, 64, 8, 8]               0
           Conv2d-12             [-1, 64, 8, 8]          36,864
      BatchNorm2d-13             [-1, 64, 8, 8]             128
             ReLU-14             [-1, 6

In [22]:
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4, nesterov=True)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=3, threshold=0.001)

# training
for epoch in range(epochs):
    for i, data in enumerate(train_loader):
        # get the inputs
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()
        
        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    # print statistics
    print('Epoch: {}/{}, Loss: {:.3f}'.format(epoch+1, epochs, loss.item()))
    lr_scheduler.step(loss.item())

print('Finished Training')


Epoch: 1/20, Loss: 1.996
Epoch: 2/20, Loss: 1.576
Epoch: 3/20, Loss: 1.394
Epoch: 4/20, Loss: 1.124
Epoch: 5/20, Loss: 1.055
Epoch: 6/20, Loss: 0.799
Epoch: 7/20, Loss: 0.772
Epoch: 8/20, Loss: 0.699
Epoch: 9/20, Loss: 0.701
Epoch: 10/20, Loss: 0.557
Epoch: 11/20, Loss: 0.582
Epoch: 12/20, Loss: 0.534
Epoch: 13/20, Loss: 0.561
Epoch: 14/20, Loss: 0.529
Epoch: 15/20, Loss: 0.581
Epoch: 16/20, Loss: 0.584
Epoch: 17/20, Loss: 0.469
Epoch: 18/20, Loss: 0.566
Epoch: 19/20, Loss: 0.513
Epoch: 20/20, Loss: 0.580
Finished Training


In [25]:
# test data
test_loader = torch.utils.data.DataLoader(test_data_cifar10, batch_size=batch_size, shuffle=True)

# test model
def test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:            
            data, target = data.to(device=device), target.to(device=device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            
    test_loss /= len(test_loader.dataset)
    test_accuracy = 100. * correct / len(test_loader.dataset)
    return test_loss, test_accuracy

test_loss, test_accuracy = test(model, test_loader)
print('Test loss: {:.4f}, Accuracy: {:.2f}%'.format(test_loss, test_accuracy))


Test loss: 0.9874, Accuracy: 68.74%


In [24]:
# save model
torch.save(model.state_dict(), '../model/ResNet-18.pt')