## 라이브러리 import

In [3]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
from torchsummary import summary
from math import sqrt

## 공용함수 선언

### 기본 공용함수 선언

In [4]:
def get_mean_std(channel, training_dataset):
  if channel.lower() is 'rgb':
    mean_rgb = [np.mean(x.numpy(), axis=(1, 2)) for x,_ in training_dataset]
    std_rgb = [np.std(x.numpy(), axis=(1, 2)) for x,_ in training_dataset]

    mean_r = np.mean([m[0] for m in mean_rgb])
    mean_g = np.mean([m[1] for m in mean_rgb])
    mean_b = np.mean([m[2] for m in mean_rgb])

    std_r = np.mean([s[0] for s in std_rgb])
    std_g = np.mean([s[1] for s in std_rgb])
    std_b = np.mean([s[2] for s in std_rgb])
    return [mean_r, mean_g, mean_b], [std_r, std_g, std_b]
  else:
    mean = [np.mean(x.numpy(), axis=(1, 2)) for x,_ in training_dataset]
    std = [np.std(x.numpy(), axis=(1, 2)) for x,_ in training_dataset]

    return np.mean([m[0] for m in mean]), np.mean([s[0] for s in std])

In [5]:
def get_device():
    return torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## 모델 구현

### VGG16 Base 클래스

In [6]:
class VGG16Base(nn.Module):
  def __init__(self, size_of_channel):
    super(VGG16Base, self).__init__()
    self.size_of_channel = size_of_channel
    # conv 1
    self.conv_1_1 = nn.Conv2d(size_of_channel, 64, kernel_size=(3, 3), padding=(1, 1))
    self.conv_1_2 = nn.Conv2d(64, 64, kernel_size=(3, 3), padding=(1, 1))
    self.max_pooling_1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
    # conv 2
    self.conv_2_1 = nn.Conv2d(64, 128, kernel_size=(3, 3), padding=(1, 1))
    self.conv_2_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), padding=(1, 1))
    self.max_pooling_2 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
    # conv 3
    self.conv_3_1 = nn.Conv2d(128, 256, kernel_size=(3, 3), padding=(1, 1))
    self.conv_3_2 = nn.Conv2d(256, 256, kernel_size=(3, 3), padding=(1, 1))
    self.conv_3_3 = nn.Conv2d(256, 256, kernel_size=(3, 3), padding=(1, 1))
    self.max_pooling_3 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), ceil_mode=True)
    # conv 4
    self.conv_4_1 = nn.Conv2d(256, 512, kernel_size=(3, 3), padding=(1, 1))
    self.conv_4_2 = nn.Conv2d(512, 512, kernel_size=(3, 3), padding=(1, 1))
    self.conv_4_3 = nn.Conv2d(512, 512, kernel_size=(3, 3), padding=(1, 1))
    self.max_pooling_4 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
    # conv 5
    self.conv_5_1 = nn.Conv2d(512, 512, kernel_size=(3, 3), padding=(1, 1))
    self.conv_5_2 = nn.Conv2d(512, 512, kernel_size=(3, 3), padding=(1, 1))
    self.conv_5_3 = nn.Conv2d(512, 512, kernel_size=(3, 3), padding=(1, 1))
    self.max_pooling_5 = nn.MaxPool2d(kernel_size=(3, 3), stride=(1, 1), padding=1)

    self.conv_6 = nn.Conv2d(512, 1024, kernel_size=(3, 3), padding=6, dilation=6)
    self.conv_7 = nn.Conv2d(1024, 1024, kernel_size=(1, 1))

  def forward(self, x):
    y = F.relu(self.conv_1_1(x))
    y = F.relu(self.conv_1_2(y))
    y = self.max_pooling_1(y)

    y = F.relu(self.conv_2_1(y))
    y = F.relu(self.conv_2_2(y))
    y = self.max_pooling_2(y)

    y = F.relu(self.conv_3_1(y))
    y = F.relu(self.conv_3_2(y))
    y = F.relu(self.conv_3_3(y))
    y = self.max_pooling_3(y)

    y = F.relu(self.conv_4_1(y))
    y = F.relu(self.conv_4_2(y))
    y = F.relu(self.conv_4_3(y))
    conv4_3_output = y
    y = self.max_pooling_4(y)

    y = F.relu(self.conv_5_1(y))
    y = F.relu(self.conv_5_2(y))
    y = F.relu(self.conv_5_3(y))
    y = self.max_pooling_5(y)

    y = self.conv_6(y)
    conv7_output = self.conv_7(y)
    return conv4_3_output, conv7_output

### SSD 클래스 모듈 - Auxiliary Conv

In [7]:
class AuxiliaryConv(nn.Module):
  def __init__(self):
    super(AuxiliaryConv, self).__init__()
    # conv 8
    self.conv_8_1 = nn.Conv2d(1024, 256, kernel_size=(1, 1), padding=0)
    self.conv_8_2 = nn.Conv2d(256, 512, kernel_size=(3, 3), padding=1, stride=2)

    # conv 9
    self.conv_9_1 = nn.Conv2d(512, 128, kernel_size=(1, 1), padding=0)
    self.conv_9_2 = nn.Conv2d(128, 256, kernel_size=(3, 3), padding=1, stride=2)

    # conv 10
    self.conv_10_1 = nn.Conv2d(256, 128, kernel_size=(1, 1), padding=0)
    self.conv_10_2 = nn.Conv2d(128, 256, kernel_size=(3, 3), padding=0)

    # conv 11
    self.conv_11_1 = nn.Conv2d(256, 128, kernel_size=(1, 1), padding=0)
    self.conv_11_2 = nn.Conv2d(128, 256, kernel_size=(3, 3), padding=0)

    self.__init_weights__()
  
  def forward(self, x):
    y = F.relu(self.conv_8_1(x))
    y = F.relu(self.conv_8_2(y))
    conv8_output = y

    y = F.relu(self.conv_9_1(y))
    y = F.relu(self.conv_9_2(y))
    conv9_output = y

    y = F.relu(self.conv_10_1(y))
    y = F.relu(self.conv_10_2(y))
    conv10_output = y

    y = F.relu(self.conv_11_1(y))
    y = F.relu(self.conv_11_2(y))
    conv11_output = y
    return conv8_output, conv9_output, conv10_output, conv11_output
  
  def __init_weights__(self):
    for module in self.children():
      if isinstance(module, nn.Conv2d):
        nn.init.xavier_uniform_(module.weight)
        nn.init.constant_(module.bias, 0)

### SSD 클래스 모듈 - Prediction Conv

In [8]:
PREV_BOXES = {
    'conv4_3': 4,
    'conv7': 6,
    'conv8_2': 6,
    'conv9_2': 6,
    'conv10_2': 4,
    'conv11_2': 4
}
class PredictionConv(nn.Module):
  def __init__(self, number_of_class):
    super(PredictionConv, self).__init__()
    self.number_of_class = number_of_class
    self.localization_conv4_3 = nn.Conv2d(512, (PREV_BOXES['conv4_3'] * 4), kernel_size=(3, 3), padding=1)
    self.localization_conv7 = nn.Conv2d(1024, (PREV_BOXES['conv7'] * 4), kernel_size=(3, 3), padding=1)
    self.localization_conv8_2 = nn.Conv2d(512, (PREV_BOXES['conv8_2'] * 4), kernel_size=(3, 3), padding=1)
    self.localization_conv9_2 = nn.Conv2d(256, (PREV_BOXES['conv9_2'] * 4), kernel_size=(3, 3), padding=1)
    self.localization_conv10_2 = nn.Conv2d(256, (PREV_BOXES['conv10_2'] * 4), kernel_size=(3, 3), padding=1)
    self.localization_conv11_2 = nn.Conv2d(256, (PREV_BOXES['conv11_2'] * 4), kernel_size=(3, 3), padding=1)
    
    self.class_prediction_conv4_3 = nn.Conv2d(512, (PREV_BOXES['conv4_3'] * number_of_class), kernel_size=(3, 3), padding=1)
    self.class_prediction_conv7 = nn.Conv2d(1024, (PREV_BOXES['conv7'] * number_of_class), kernel_size=(3, 3), padding=1)
    self.class_prediction_conv8_2 = nn.Conv2d(512, (PREV_BOXES['conv8_2'] * number_of_class), kernel_size=(3, 3), padding=1)
    self.class_prediction_conv9_2 = nn.Conv2d(256, (PREV_BOXES['conv9_2'] * number_of_class), kernel_size=(3, 3), padding=1)
    self.class_prediction_conv10_2 = nn.Conv2d(256, (PREV_BOXES['conv10_2'] * number_of_class), kernel_size=(3, 3), padding=1)
    self.class_prediction_conv11_2 = nn.Conv2d(256, (PREV_BOXES['conv11_2'] * number_of_class), kernel_size=(3, 3), padding=1)
    
    self.__init_weights__()
  
  def forward(self, features_4_3, features_7, features_8_2, features_9_2, features_10_2, features_11_2):
    batch_size = features_4_3.size(0)
    output_loc_conv_4_3 = self.__predict_box_bounds__(batch_size, self.localization_conv4_3, features_4_3, 4)
    output_loc_conv7 = self.__predict_box_bounds__(batch_size, self.localization_conv7, features_7, 4)
    output_loc_conv_8_2 = self.__predict_box_bounds__(batch_size, self.localization_conv8_2, features_8_2, 4)
    output_loc_conv_9_2 = self.__predict_box_bounds__(batch_size, self.localization_conv9_2, features_9_2, 4)
    output_loc_conv_10_2 = self.__predict_box_bounds__(batch_size, self.localization_conv10_2, features_10_2, 4)
    output_loc_conv_11_2 = self.__predict_box_bounds__(batch_size, self.localization_conv11_2, features_11_2, 4)
    
    output_class_conv_4_3= self.__predict_box_bounds__(batch_size, self.class_prediction_conv4_3, features_4_3, self.number_of_class)
    output_class_conv7 = self.__predict_box_bounds__(batch_size, self.class_prediction_conv7, features_7, self.number_of_class)
    output_class_conv_8_2 = self.__predict_box_bounds__(batch_size, self.class_prediction_conv8_2, features_8_2, self.number_of_class)
    output_class_conv_9_2 = self.__predict_box_bounds__(batch_size, self.class_prediction_conv9_2, features_9_2, self.number_of_class)
    output_class_conv_10_2 = self.__predict_box_bounds__(batch_size, self.class_prediction_conv10_2, features_10_2, self.number_of_class)
    output_class_conv_11_2 = self.__predict_box_bounds__(batch_size, self.class_prediction_conv11_2, features_11_2, self.number_of_class)

    loss = torch.cat([output_loc_conv_4_3, output_loc_conv7, output_loc_conv_8_2,
                      output_loc_conv_9_2, output_loc_conv_10_2, output_loc_conv_11_2], dim=1)
    class_scores = torch.cat([output_class_conv_4_3, output_class_conv7, output_class_conv_8_2,
                      output_class_conv_9_2, output_class_conv_10_2, output_class_conv_11_2], dim=1)
    return loss, class_scores
  
  def __predict_box_bounds__(self, batch_size, conv, input_features, shape):
    output = conv(input_features)
    output = output.permute(0, 2, 3, 1).contiguous()
    output = output.view(batch_size, -1, shape)
    return output
  
  def __init_weights__(self):
    for module in self.children():
      if isinstance(module, nn.Conv2d):
        nn.init.xavier_uniform_(module.weight)
        nn.init.constant_(module.bias, 0)

### SSD 클래스

In [9]:
class MySSD(nn.Module):
  def __init__(self, size_of_channel, number_of_class):
    super(MySSD, self).__init__()
    self.vgg16_base = VGG16Base(size_of_channel)
    self.extra_feature_layers = AuxiliaryConv()
    self.predict_conv = PredictionConv(number_of_class)
  
  def forward(self, x):
    conv4_3_features, conv7_features = self.vgg16_base(x)
    conv8_features, conv9_features, conv10_features, conv11_features = self.extra_feature_layers(conv7_features)
    loss, class_scores = self.predict_conv(conv4_3_features, conv7_features, conv8_features, conv9_features, conv10_features, conv11_features)
    return loss, class_scores

## CIFAR10

In [10]:
device = get_device()

In [11]:
training_dataset_cifar10 = datasets.CIFAR10(
    root=basic_path + '/data', train=True, download=True, transform=transforms.ToTensor(),
)
test_dataset_cifar10 = datasets.CIFAR10(
    root=basic_path + '/data', train=False, download=True, transform=transforms.ToTensor(),
)

rgb_mean, rgb_std = get_mean_std('rgb', training_dataset_cifar10)
rgb_transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.RandomHorizontalFlip(),
  transforms.Resize(300),
  transforms.Normalize(mean=[rgb_mean], std=[rgb_std])
])
training_dataset_cifar10.transform = rgb_transform
test_dataset_cifar10.transform = rgb_transform

Files already downloaded and verified
Files already downloaded and verified


In [12]:
training_dataloader_cifar10 = DataLoader(training_dataset_cifar10, batch_size=64)
test_dataloader_cifar10 = DataLoader(test_dataset_cifar10, batch_size=64)

for X, y in test_dataloader_cifar10:
  print('Shape of X [N, C, H, W]:', X.shape)
  print('Shape of y:', y.shape, y.dtype)
  break

Shape of X [N, C, H, W]: torch.Size([64, 3, 300, 300])
Shape of y: torch.Size([64]) torch.int64


In [25]:
number_of_cifar10_classes = 10
model_cifar10 = MySSD(3, number_of_cifar10_classes)
model_cifar10 = model_cifar10.to(device)

In [26]:
summary(model_cifar10, input_size=(3, 300, 300))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 300, 300]           1,792
            Conv2d-2         [-1, 64, 300, 300]          36,928
         MaxPool2d-3         [-1, 64, 150, 150]               0
            Conv2d-4        [-1, 128, 150, 150]          73,856
            Conv2d-5        [-1, 128, 150, 150]         147,584
         MaxPool2d-6          [-1, 128, 75, 75]               0
            Conv2d-7          [-1, 256, 75, 75]         295,168
            Conv2d-8          [-1, 256, 75, 75]         590,080
            Conv2d-9          [-1, 256, 75, 75]         590,080
        MaxPool2d-10          [-1, 256, 38, 38]               0
           Conv2d-11          [-1, 512, 38, 38]       1,180,160
           Conv2d-12          [-1, 512, 38, 38]       2,359,808
           Conv2d-13          [-1, 512, 38, 38]       2,359,808
        MaxPool2d-14          [-1, 512,