In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import os
import glob
import PIL
from PIL import Image
from torch.utils import data as D
from torch.utils.data.sampler import SubsetRandomSampler
import random
import torchsummary

In [2]:
print(torch.__version__)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

1.0.0
cuda:0


In [3]:
batch_size = 64
validation_ratio = 0.1
random_seed = 10
initial_lr = 0.1
num_epoch = 100

In [26]:
transform_train = transforms.Compose([
        transforms.Resize(256),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])

transform_validation = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])


transform_test = transforms.Compose([
        transforms.Resize(224),     
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)

validset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_validation)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)


num_train = len(trainset)
indices = list(range(num_train))
split = int(np.floor(validation_ratio * num_train))

np.random.seed(random_seed)
np.random.shuffle(indices)

train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

train_loader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, sampler=train_sampler, num_workers=0
)

valid_loader = torch.utils.data.DataLoader(
    validset, batch_size=batch_size, sampler=valid_sampler, num_workers=0
)

test_loader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=False, num_workers=0
)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [27]:
class conv_bn_relu(nn.Module):
    def __init__(self, nin, nout, kernel_size, stride, padding, bias=False):
        super(conv_bn_relu, self).__init__()
        self.conv = nn.Conv2d(nin, nout, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
        self.batch_norm = nn.BatchNorm2d(nout)
        self.relu = nn.ReLU(True)

    def forward(self, x):
        out = self.conv(x)
        out = self.batch_norm(out)
        out = self.relu(out)

        return out

In [28]:
class Transition_layer(nn.Sequential):
  def __init__(self, nin, theta=1):    
      super(Transition_layer, self).__init__()
      
      self.add_module('conv_1x1', conv_bn_relu(nin=nin, nout=int(nin*theta), kernel_size=1, stride=1, padding=0, bias=False))
      self.add_module('avg_pool_2x2', nn.AvgPool2d(kernel_size=2, stride=2, padding=0))

In [29]:
class StemBlock(nn.Module):
    def __init__(self):
        super(StemBlock, self).__init__()
        
        self.conv_3x3_first = conv_bn_relu(nin=3, nout=32, kernel_size=3, stride=2, padding=1, bias=False)
        
        self.conv_1x1_left = conv_bn_relu(nin=32, nout=16, kernel_size=1, stride=1, padding=0, bias=False)
        self.conv_3x3_left = conv_bn_relu(nin=16, nout=32, kernel_size=3, stride=2, padding=1, bias=False)
        
        self.max_pool_right = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        
        self.conv_1x1_last = conv_bn_relu(nin=64, nout=32, kernel_size=1, stride=1, padding=0, bias=False)

    def forward(self, x):
        out_first = self.conv_3x3_first(x)
        
        out_left = self.conv_1x1_left(out_first)
        out_left = self.conv_3x3_left(out_left)
        
        out_right = self.max_pool_right(out_first)
        
        out_middle = torch.cat((out_left, out_right), 1)
        
        out_last = self.conv_1x1_last(out_middle)
                
        return out_last

In [30]:
class dense_layer(nn.Module):
  def __init__(self, nin, growth_rate, drop_rate=0.2):    
      super(dense_layer, self).__init__()
      
      self.dense_left_way = nn.Sequential()
      
      self.dense_left_way.add_module('conv_1x1', conv_bn_relu(nin=nin, nout=growth_rate*2, kernel_size=1, stride=1, padding=0, bias=False))
      self.dense_left_way.add_module('conv_3x3', conv_bn_relu(nin=growth_rate*2, nout=growth_rate//2, kernel_size=3, stride=1, padding=1, bias=False))
            
      self.dense_right_way = nn.Sequential()
      
      self.dense_right_way.add_module('conv_1x1', conv_bn_relu(nin=nin, nout=growth_rate*2, kernel_size=1, stride=1, padding=0, bias=False))
      self.dense_right_way.add_module('conv_3x3_1', conv_bn_relu(nin=growth_rate*2, nout=growth_rate//2, kernel_size=3, stride=1, padding=1, bias=False))
      self.dense_right_way.add_module('conv_3x3 2', conv_bn_relu(nin=growth_rate//2, nout=growth_rate//2, kernel_size=3, stride=1, padding=1, bias=False))
      
      self.drop_rate = drop_rate
      
  def forward(self, x):
      left_output = self.dense_left_way(x)
      right_output = self.dense_right_way(x)

      if self.drop_rate > 0:
          left_output = F.dropout(left_output, p=self.drop_rate, training=self.training)
          right_output = F.dropout(right_output, p=self.drop_rate, training=self.training)
          
      dense_layer_output = torch.cat((x, left_output, right_output), 1)
            
      return dense_layer_output

In [31]:
class DenseBlock(nn.Sequential):
  def __init__(self, nin, num_dense_layers, growth_rate, drop_rate=0.0):
      super(DenseBlock, self).__init__()
                        
      for i in range(num_dense_layers):
          nin_dense_layer = nin + growth_rate * i
          self.add_module('dense_layer_%d' % i, dense_layer(nin=nin_dense_layer, growth_rate=growth_rate, drop_rate=drop_rate))

In [32]:
class PeleeNet(nn.Module):
    def __init__(self, growth_rate=32, num_dense_layers=[3,4,8,6], theta=1, drop_rate=0.0, num_classes=10):
        super(PeleeNet, self).__init__()
        
        assert len(num_dense_layers) == 4
        
        self.features = nn.Sequential()
        self.features.add_module('StemBlock', StemBlock())
        
        nin_transition_layer = 32
        
        for i in range(len(num_dense_layers)):
            self.features.add_module('DenseBlock_%d' % (i+1), DenseBlock(nin=nin_transition_layer, num_dense_layers=num_dense_layers[i], growth_rate=growth_rate, drop_rate=0.0))
            nin_transition_layer +=  num_dense_layers[i] * growth_rate
            
            if i == len(num_dense_layers) - 1:
                self.features.add_module('Transition_layer_%d' % (i+1), conv_bn_relu(nin=nin_transition_layer, nout=int(nin_transition_layer*theta), kernel_size=1, stride=1, padding=0, bias=False))
            else:
                self.features.add_module('Transition_layer_%d' % (i+1), Transition_layer(nin=nin_transition_layer, theta=1))
        
        self.linear = nn.Linear(nin_transition_layer, num_classes)
        
    def forward(self, x):
        stage_output = self.features(x)
        
        global_avg_pool_output = F.adaptive_avg_pool2d(stage_output, (1, 1))  
        global_avg_pool_output_flat = global_avg_pool_output.view(global_avg_pool_output.size(0), -1)
                
        output = self.linear(global_avg_pool_output_flat)
        
        return output

In [33]:
net = PeleeNet()

In [34]:
net.to(device)

PeleeNet(
  (features): Sequential(
    (StemBlock): StemBlock(
      (conv_3x3_first): conv_bn_relu(
        (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (batch_norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
      )
      (conv_1x1_left): conv_bn_relu(
        (conv): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (batch_norm): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
      )
      (conv_3x3_left): conv_bn_relu(
        (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (batch_norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
      )
      (max_pool_right): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (conv_1x1_last): conv_bn_relu(
   

In [35]:
torchsummary.summary(net, (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 112, 112]             864
       BatchNorm2d-2         [-1, 32, 112, 112]              64
              ReLU-3         [-1, 32, 112, 112]               0
      conv_bn_relu-4         [-1, 32, 112, 112]               0
            Conv2d-5         [-1, 16, 112, 112]             512
       BatchNorm2d-6         [-1, 16, 112, 112]              32
              ReLU-7         [-1, 16, 112, 112]               0
      conv_bn_relu-8         [-1, 16, 112, 112]               0
            Conv2d-9           [-1, 32, 56, 56]           4,608
      BatchNorm2d-10           [-1, 32, 56, 56]              64
             ReLU-11           [-1, 32, 56, 56]               0
     conv_bn_relu-12           [-1, 32, 56, 56]               0
        MaxPool2d-13           [-1, 32, 56, 56]               0
           Conv2d-14           [-1, 32,

     BatchNorm2d-125           [-1, 16, 28, 28]              32
            ReLU-126           [-1, 16, 28, 28]               0
    conv_bn_relu-127           [-1, 16, 28, 28]               0
     dense_layer-128          [-1, 192, 28, 28]               0
          Conv2d-129           [-1, 64, 28, 28]          12,288
     BatchNorm2d-130           [-1, 64, 28, 28]             128
            ReLU-131           [-1, 64, 28, 28]               0
    conv_bn_relu-132           [-1, 64, 28, 28]               0
          Conv2d-133           [-1, 16, 28, 28]           9,216
     BatchNorm2d-134           [-1, 16, 28, 28]              32
            ReLU-135           [-1, 16, 28, 28]               0
    conv_bn_relu-136           [-1, 16, 28, 28]               0
          Conv2d-137           [-1, 64, 28, 28]          12,288
     BatchNorm2d-138           [-1, 64, 28, 28]             128
            ReLU-139           [-1, 64, 28, 28]               0
    conv_bn_relu-140           [-1, 64, 

            ReLU-253           [-1, 16, 14, 14]               0
    conv_bn_relu-254           [-1, 16, 14, 14]               0
          Conv2d-255           [-1, 16, 14, 14]           2,304
     BatchNorm2d-256           [-1, 16, 14, 14]              32
            ReLU-257           [-1, 16, 14, 14]               0
    conv_bn_relu-258           [-1, 16, 14, 14]               0
     dense_layer-259          [-1, 384, 14, 14]               0
          Conv2d-260           [-1, 64, 14, 14]          24,576
     BatchNorm2d-261           [-1, 64, 14, 14]             128
            ReLU-262           [-1, 64, 14, 14]               0
    conv_bn_relu-263           [-1, 64, 14, 14]               0
          Conv2d-264           [-1, 16, 14, 14]           9,216
     BatchNorm2d-265           [-1, 16, 14, 14]              32
            ReLU-266           [-1, 16, 14, 14]               0
    conv_bn_relu-267           [-1, 16, 14, 14]               0
          Conv2d-268           [-1, 64, 

    conv_bn_relu-381             [-1, 64, 7, 7]               0
          Conv2d-382             [-1, 16, 7, 7]           9,216
     BatchNorm2d-383             [-1, 16, 7, 7]              32
            ReLU-384             [-1, 16, 7, 7]               0
    conv_bn_relu-385             [-1, 16, 7, 7]               0
          Conv2d-386             [-1, 16, 7, 7]           2,304
     BatchNorm2d-387             [-1, 16, 7, 7]              32
            ReLU-388             [-1, 16, 7, 7]               0
    conv_bn_relu-389             [-1, 16, 7, 7]               0
     dense_layer-390            [-1, 576, 7, 7]               0
          Conv2d-391             [-1, 64, 7, 7]          36,864
     BatchNorm2d-392             [-1, 64, 7, 7]             128
            ReLU-393             [-1, 64, 7, 7]               0
    conv_bn_relu-394             [-1, 64, 7, 7]               0
          Conv2d-395             [-1, 16, 7, 7]           9,216
     BatchNorm2d-396             [-1, 16

In [36]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=initial_lr, momentum=0.9)
learning_rate_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epoch)

for epoch in range(num_epoch):  
    learning_rate_scheduler.step()
    running_loss = 0.0
    
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        
        optimizer.step()
        
        running_loss += loss.item()
                
        show_period = 100
        if i % show_period ==  show_period-1:    # print every "show_period" mini-batches
            print('[%d, %5d/50000] loss: %.7f, lr: %.7f' %
                  (epoch + 1, (i + 1)*batch_size, running_loss / show_period, learning_rate_scheduler.get_lr()[0]))
            running_loss = 0.0
            
    # validation part
    correct = 0
    total = 0
    for i, data in enumerate(valid_loader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = net(inputs)
        
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
    print('[%d epoch] Accuracy of the network on the validation images: %d %%' % 
          (epoch + 1, 100 * correct / total)
         )

print('Finished Training')

[1,  6400/50000] loss: 2.0496658, lr: 0.1000000
[1, 12800/50000] loss: 1.8068226, lr: 0.1000000
[1, 19200/50000] loss: 1.6786690, lr: 0.1000000
[1, 25600/50000] loss: 1.5848522, lr: 0.1000000
[1, 32000/50000] loss: 1.5190959, lr: 0.1000000
[1, 38400/50000] loss: 1.4203354, lr: 0.1000000
[1, 44800/50000] loss: 1.3563626, lr: 0.1000000
[1 epoch] Accuracy of the network on the validation images: 51 %
[2,  6400/50000] loss: 1.2689225, lr: 0.0999753
[2, 12800/50000] loss: 1.1836578, lr: 0.0999753
[2, 19200/50000] loss: 1.1128522, lr: 0.0999753
[2, 25600/50000] loss: 1.0888133, lr: 0.0999753
[2, 32000/50000] loss: 1.0325753, lr: 0.0999753
[2, 38400/50000] loss: 0.9965671, lr: 0.0999753
[2, 44800/50000] loss: 0.9671038, lr: 0.0999753
[2 epoch] Accuracy of the network on the validation images: 67 %
[3,  6400/50000] loss: 0.9065573, lr: 0.0999013
[3, 12800/50000] loss: 0.8651605, lr: 0.0999013
[3, 19200/50000] loss: 0.8328243, lr: 0.0999013
[3, 25600/50000] loss: 0.8386372, lr: 0.0999013
[3, 32

[21, 12800/50000] loss: 0.1413518, lr: 0.0904508
[21, 19200/50000] loss: 0.1452274, lr: 0.0904508
[21, 25600/50000] loss: 0.1475782, lr: 0.0904508
[21, 32000/50000] loss: 0.1456229, lr: 0.0904508
[21, 38400/50000] loss: 0.1574638, lr: 0.0904508
[21, 44800/50000] loss: 0.1524778, lr: 0.0904508
[21 epoch] Accuracy of the network on the validation images: 88 %
[22,  6400/50000] loss: 0.1222139, lr: 0.0895078
[22, 12800/50000] loss: 0.1358480, lr: 0.0895078
[22, 19200/50000] loss: 0.1337867, lr: 0.0895078
[22, 25600/50000] loss: 0.1315738, lr: 0.0895078
[22, 32000/50000] loss: 0.1434432, lr: 0.0895078
[22, 38400/50000] loss: 0.1534713, lr: 0.0895078
[22, 44800/50000] loss: 0.1387539, lr: 0.0895078
[22 epoch] Accuracy of the network on the validation images: 89 %
[23,  6400/50000] loss: 0.1525227, lr: 0.0885257
[23, 12800/50000] loss: 0.1344277, lr: 0.0885257
[23, 19200/50000] loss: 0.1197590, lr: 0.0885257
[23, 25600/50000] loss: 0.1392264, lr: 0.0885257
[23, 32000/50000] loss: 0.1197026, 

[41, 12800/50000] loss: 0.0315345, lr: 0.0654508
[41, 19200/50000] loss: 0.0282899, lr: 0.0654508
[41, 25600/50000] loss: 0.0343004, lr: 0.0654508
[41, 32000/50000] loss: 0.0250815, lr: 0.0654508
[41, 38400/50000] loss: 0.0263588, lr: 0.0654508
[41, 44800/50000] loss: 0.0288901, lr: 0.0654508
[41 epoch] Accuracy of the network on the validation images: 90 %
[42,  6400/50000] loss: 0.0240217, lr: 0.0639496
[42, 12800/50000] loss: 0.0222370, lr: 0.0639496
[42, 19200/50000] loss: 0.0224188, lr: 0.0639496
[42, 25600/50000] loss: 0.0290490, lr: 0.0639496
[42, 32000/50000] loss: 0.0304564, lr: 0.0639496
[42, 38400/50000] loss: 0.0314656, lr: 0.0639496
[42, 44800/50000] loss: 0.0291475, lr: 0.0639496
[42 epoch] Accuracy of the network on the validation images: 90 %
[43,  6400/50000] loss: 0.0342173, lr: 0.0624345
[43, 12800/50000] loss: 0.0291759, lr: 0.0624345
[43, 19200/50000] loss: 0.0262700, lr: 0.0624345
[43, 25600/50000] loss: 0.0262725, lr: 0.0624345
[43, 32000/50000] loss: 0.0269235, 

[61, 12800/50000] loss: 0.0046097, lr: 0.0345492
[61, 19200/50000] loss: 0.0058113, lr: 0.0345492
[61, 25600/50000] loss: 0.0059015, lr: 0.0345492
[61, 32000/50000] loss: 0.0028465, lr: 0.0345492
[61, 38400/50000] loss: 0.0056427, lr: 0.0345492
[61, 44800/50000] loss: 0.0050590, lr: 0.0345492
[61 epoch] Accuracy of the network on the validation images: 91 %
[62,  6400/50000] loss: 0.0046700, lr: 0.0330631
[62, 12800/50000] loss: 0.0051078, lr: 0.0330631
[62, 19200/50000] loss: 0.0036939, lr: 0.0330631
[62, 25600/50000] loss: 0.0061979, lr: 0.0330631
[62, 32000/50000] loss: 0.0055176, lr: 0.0330631
[62, 38400/50000] loss: 0.0034038, lr: 0.0330631
[62, 44800/50000] loss: 0.0041574, lr: 0.0330631
[62 epoch] Accuracy of the network on the validation images: 91 %
[63,  6400/50000] loss: 0.0077219, lr: 0.0315938
[63, 12800/50000] loss: 0.0056349, lr: 0.0315938
[63, 19200/50000] loss: 0.0057499, lr: 0.0315938
[63, 25600/50000] loss: 0.0045835, lr: 0.0315938
[63, 32000/50000] loss: 0.0062511, 

[81, 12800/50000] loss: 0.0024244, lr: 0.0095492
[81, 19200/50000] loss: 0.0029955, lr: 0.0095492
[81, 25600/50000] loss: 0.0029514, lr: 0.0095492
[81, 32000/50000] loss: 0.0020355, lr: 0.0095492
[81, 38400/50000] loss: 0.0017162, lr: 0.0095492
[81, 44800/50000] loss: 0.0011705, lr: 0.0095492
[81 epoch] Accuracy of the network on the validation images: 91 %
[82,  6400/50000] loss: 0.0017971, lr: 0.0086460
[82, 12800/50000] loss: 0.0017039, lr: 0.0086460
[82, 19200/50000] loss: 0.0013340, lr: 0.0086460
[82, 25600/50000] loss: 0.0023235, lr: 0.0086460
[82, 32000/50000] loss: 0.0034754, lr: 0.0086460
[82, 38400/50000] loss: 0.0014668, lr: 0.0086460
[82, 44800/50000] loss: 0.0019544, lr: 0.0086460
[82 epoch] Accuracy of the network on the validation images: 91 %
[83,  6400/50000] loss: 0.0016734, lr: 0.0077836
[83, 12800/50000] loss: 0.0012158, lr: 0.0077836
[83, 19200/50000] loss: 0.0022578, lr: 0.0077836
[83, 25600/50000] loss: 0.0025178, lr: 0.0077836
[83, 32000/50000] loss: 0.0011063, 

In [37]:
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))

correct = 0
total = 0

with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
                
        for i in range(labels.shape[0]):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))            
            
for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i])) 

Accuracy of the network on the 10000 test images: 90 %
Accuracy of plane : 92 %
Accuracy of   car : 95 %
Accuracy of  bird : 87 %
Accuracy of   cat : 81 %
Accuracy of  deer : 92 %
Accuracy of   dog : 84 %
Accuracy of  frog : 93 %
Accuracy of horse : 93 %
Accuracy of  ship : 95 %
Accuracy of truck : 93 %
