In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import os
import glob
import PIL
from PIL import Image
from torch.utils import data as D
from torch.utils.data.sampler import SubsetRandomSampler
import random

In [3]:
batch_size = 16
validation_ratio = 0.1
random_seed = 10

In [4]:
transform_train = transforms.Compose([
        transforms.Resize(224),
        transforms.RandomCrop(224, padding=28),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])

transform_validation = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])


transform_test = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)

validset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_validation)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)

#trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
#                                          shuffle=True, num_workers=0)

num_train = len(trainset)
indices = list(range(num_train))
split = int(np.floor(validation_ratio * num_train))

np.random.seed(random_seed)
np.random.shuffle(indices)

train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)

train_loader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, sampler=train_sampler, num_workers=0
)

valid_loader = torch.utils.data.DataLoader(
    validset, batch_size=batch_size, sampler=valid_sampler, num_workers=0
)

test_loader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=False, num_workers=0
)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

initial_lr = 0.045

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [5]:
class depthwise_conv(nn.Module):
    def __init__(self, nin, kernel_size, padding, bias=False, stride=1):
        super(depthwise_conv, self).__init__()
        self.depthwise = nn.Conv2d(nin, nin, kernel_size=kernel_size, stride=stride, padding=padding, groups=nin, bias=bias)

    def forward(self, x):
        out = self.depthwise(x)
        return out

In [6]:
class dw_block(nn.Module):
    def __init__(self, nin, kernel_size, padding=1, bias=False, stride=1):
        super(dw_block, self).__init__()
        self.dw_block = nn.Sequential(
            depthwise_conv(nin, kernel_size, padding, bias, stride),
            nn.BatchNorm2d(nin),
            nn.ReLU(True)
        )
    def forward(self, x):
        out = self.dw_block(x)
        return out

In [7]:
class one_by_one_block(nn.Module):
    def __init__(self, nin, nout, padding=1, bias=False, stride=1):
        super(one_by_one_block, self).__init__()
        self.one_by_one_block = nn.Sequential(
            nn.Conv2d(nin, nout, kernel_size=1, stride=stride, padding=padding, bias=bias),
            nn.BatchNorm2d(nout),
            nn.ReLU(True)
        )
    def forward(self, x):
        out = self.one_by_one_block(x)
        return out

In [8]:
class MobileNet(nn.Module):
    def __init__(self, input_channel, num_classes=10):
        super(MobileNet, self).__init__()
        
        self.network = nn.Sequential(
            nn.Conv2d(input_channel, 32, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            
            dw_block(32, kernel_size=3),
            one_by_one_block(32, 64),
            
            dw_block(64, kernel_size=3, stride=2),
            one_by_one_block(64, 128),
            
            dw_block(128, kernel_size=3),
            one_by_one_block(128, 128),
            
            dw_block(128, kernel_size=3, stride=2),
            one_by_one_block(128, 256),
            
            dw_block(256, kernel_size=3),
            one_by_one_block(256, 256),
            
            dw_block(256, kernel_size=3, stride=2),
            one_by_one_block(256, 512),
            
            # 5 times 
            dw_block(512, kernel_size=3),
            one_by_one_block(512, 512),
            dw_block(512, kernel_size=3),
            one_by_one_block(512, 512),
            dw_block(512, kernel_size=3),
            one_by_one_block(512, 512),
            dw_block(512, kernel_size=3),
            one_by_one_block(512, 512),
            dw_block(512, kernel_size=3),
            one_by_one_block(512, 512),
            
            dw_block(512, kernel_size=3, stride=2),
            one_by_one_block(512, 1024),
            
            dw_block(1024, kernel_size=3, stride=2),
            one_by_one_block(1024, 1024),
        )
                
        self.linear = nn.Linear(1024, num_classes)
        
    def forward(self, x):
        body_output = self.network(x)
        
        avg_pool_output = F.adaptive_avg_pool2d(body_output, (1, 1))
        avg_pool_flat = avg_pool_output.view(avg_pool_output.size(0), -1)

        output = self.linear(avg_pool_flat)
        
        return output

In [9]:
net = MobileNet(3, 10) 

In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [11]:
net.to(device)

MobileNet(
  (network): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): dw_block(
      (dw_block): Sequential(
        (0): depthwise_conv(
          (depthwise): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        )
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
      )
    )
    (4): one_by_one_block(
      (one_by_one_block): Sequential(
        (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace)
      )
    )
    (5): dw_block(
      (dw_block): Sequential(
        (0): depthwise_conv(
          (depthwise): Conv2d(64, 64, kernel_size=(3, 3), stride

In [36]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=initial_lr, momentum=0.9)

for epoch in range(100):  
    if epoch == 0:
        lr = initial_lr
    elif epoch % 2 == 0 and epoch != 0:
        lr *= 0.94
        optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9)
    
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        """
        running_loss += loss.item()
        
        show_period = 250
        if i % show_period == show_period-1:    # print every "show_period" mini-batches
            print('[%d, %5d] loss: %.7f' %
                  (epoch + 1, i + 1, running_loss / show_period))
            running_loss = 0.0
        """
        
    total = 0
    correct = 0
    for i, data in enumerate(valid_loader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = net(inputs)
        
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
    print('[%d epoch] Accuracy of the network on the validation images: %d %%' % 
          (epoch, 100 * correct / total)
         )

print('Finished Training')

[0 epoch] Accuracy of the network on the validation images: 60 %
[1 epoch] Accuracy of the network on the validation images: 65 %
[2 epoch] Accuracy of the network on the validation images: 71 %
[3 epoch] Accuracy of the network on the validation images: 72 %
[4 epoch] Accuracy of the network on the validation images: 78 %
[5 epoch] Accuracy of the network on the validation images: 77 %
[6 epoch] Accuracy of the network on the validation images: 79 %
[7 epoch] Accuracy of the network on the validation images: 82 %
[8 epoch] Accuracy of the network on the validation images: 82 %
[9 epoch] Accuracy of the network on the validation images: 83 %
[10 epoch] Accuracy of the network on the validation images: 83 %
[11 epoch] Accuracy of the network on the validation images: 84 %
[12 epoch] Accuracy of the network on the validation images: 85 %
[13 epoch] Accuracy of the network on the validation images: 86 %
[14 epoch] Accuracy of the network on the validation images: 86 %
[15 epoch] Accuracy 

KeyboardInterrupt: 

In [62]:
dataiter = iter(test_loader)
images, labels = dataiter.next()
images, labels = images.to(device), labels.to(device)

# print images
#imshow(torchvision.utils.make_grid(images))
#print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(64)))
outputs = net(images)
_, predicted = torch.max(outputs, 1)

print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
                              for j in range(64)))

Predicted:    cat  ship  ship plane  frog  frog   car  frog   cat   car plane truck   dog horse truck  ship   dog horse  ship  frog horse plane  deer truck   dog   cat  deer plane truck  frog  frog   dog  deer   dog truck   cat  deer   car truck   dog  deer  frog   dog  frog plane truck   cat truck horse  frog truck  ship   cat   cat  ship  ship horse   cat   dog  frog horse   dog  frog   cat


In [None]:
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

In [None]:
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
                
        for i in range(labels.shape[0]):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))