In [4]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

In [5]:
transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])

transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])

data_folder = "./data" #FIXME

trainset = torchvision.datasets.CIFAR10(root=data_folder, train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR10(root=data_folder, train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=4)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz
Files already downloaded and verified


In [6]:
class Conv2d_partial(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True, partial=False):
        super(Conv2d_partial, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
     
        self.partial = partial
        
    def forward(self, input):
        if self.partial:
            self.padding = 0

            pad_val = (self.kernel_size[0] - 1) // 2
            if pad_val > 0:
                if (self.kernel_size[0] - self.stride[0]) % 2 == 0:
                    pad_top = pad_val
                    pad_bottom = pad_val
                    pad_left = pad_val
                    pad_right = pad_val
                else:
                    pad_top = pad_val
                    pad_bottom = self.kernel_size[0] - self.stride[0] - pad_top
                    pad_left = pad_val
                    pad_right = self.kernel_size[0] - self.stride[0] - pad_left
                
                p0 = torch.ones_like(input) 
                p0 = p0.sum()
                                
                input = F.pad(input, (pad_left, pad_right, pad_top, pad_bottom) , mode='constant', value=0)
                
                p1 = torch.ones_like(input) 
                p1 = p1.sum()

                ratio = torch.div(p1, p0 + 1e-8) 
                input = torch.mul(input, ratio)  
            
        return F.conv2d(input, self.weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)
    
    

In [7]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = Conv2d_partial(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False, partial=True)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = Conv2d_partial(planes, planes, kernel_size=3, stride=1, padding=1, bias=False, partial=True)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                Conv2d_partial(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False, partial=True),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

In [8]:
class ResNet(nn.Module):
#class ResNet(object):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = Conv2d_partial(3, 64, kernel_size=3, stride=1, padding=1, bias=False, partial=True)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [9]:
net = ResNet(BasicBlock, [2, 2, 2, 2], 10) #ResNet-18

In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [11]:
net.to(device)

ResNet(
  (conv1): Conv2d_partial(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d_partial(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d_partial(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d_partial(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d_partial(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNor

In [12]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9)

In [13]:
for epoch in range(300):  
    if epoch == 150:
        optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
    elif epoch == 250:
        optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        
        show_period = 250
        
        if i % show_period == show_period-1:    # print every "show_period" mini-batches
            print('[%d, %5d] loss: %.7f' %
                  (epoch + 1, i + 1, running_loss / show_period))
            running_loss = 0.0

print('Finished Training')

[1,   250] loss: 1.7497658
[2,   250] loss: 1.1681479
[3,   250] loss: 0.8984321
[4,   250] loss: 0.7268849
[5,   250] loss: 0.6202962
[6,   250] loss: 0.5275766
[7,   250] loss: 0.4630039
[8,   250] loss: 0.4190814


KeyboardInterrupt: 

In [69]:
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

Accuracy of the network on the 10000 test images: 93 %


In [70]:
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
                
        for i in range(labels.shape[0]):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))

Accuracy of plane : 93 %
Accuracy of   car : 95 %
Accuracy of  bird : 91 %
Accuracy of   cat : 87 %
Accuracy of  deer : 94 %
Accuracy of   dog : 89 %
Accuracy of  frog : 95 %
Accuracy of horse : 96 %
Accuracy of  ship : 94 %
Accuracy of truck : 95 %
