In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import pandas as pd
from skimage import io, transform
import PIL
import numpy as np
import matplotlib.pyplot as plt
import os

In [2]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.2, 0.2, 0.2))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.2, 0.2, 0.2))
])

trainset = torchvision.datasets.CIFAR10(root='./image_files', train=True,
                                        download=False, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
                                          shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./image_files', train=False,
                                       download=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=200,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


In [3]:
def imshow(img):
    unnormalized_image = img / 2 + 0.5     # unnormalize
    npimg = unnormalized_image.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()
    

In [4]:
class PreactBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(PreactBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=(1,1))
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=(1,1))
        
    def forward(self, x):
        if self.out_channels > self.in_channels:
            shortcut = F.pad(x, (0, 0, 0, 0, 0, (self.out_channels - self.in_channels)))
        else:
            shortcut = x
        res = self.bn1(x)
        res = F.relu(res)
        res = self.conv1(res)
        res = self.bn2(res)
        res = F.relu(res)
        res = self.conv2(res)
        out = res + shortcut
        return out
        

In [5]:
class PreactGroup(nn.Module):
    def __init__(self, in_channels, out_channels, num_blocks):
        super(PreactGroup, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_blocks = num_blocks
        self.blocks = nn.ModuleList([PreactBlock(in_channels, out_channels)])
        for _ in range(1, num_blocks):
            self.blocks.append(PreactBlock(out_channels, out_channels))
            
    def forward(self, x):
        out = x
        for block in self.blocks:
            out = block(out)
        return out

In [6]:
class PreactResNet(nn.Module):
    def __init__(self, start_size, blocks_per_group):
        super(PreactResNet, self).__init__()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv1 = nn.Conv2d(3, start_size, 3, padding=(1,1))
        self.block1 = PreactGroup(start_size, start_size, blocks_per_group)
        self.block2 = PreactGroup(start_size, start_size*2, blocks_per_group)
        self.block3 = PreactGroup(start_size*2, start_size*4, blocks_per_group)
        self.bn = nn.BatchNorm2d(start_size*4)
        self.fc1 = nn.Linear(start_size*4, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool(self.block1(x))
        x = self.pool(self.block2(x))
        x = self.pool(self.block3(x))
        x = F.avg_pool2d(x, x.size(-1))
        x = self.bn(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x


In [11]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = PreactResNet(start_size=32, blocks_per_group=4)
net.to(device)
criterion = nn.CrossEntropyLoss()


In [12]:
for epoch in range(20):  # loop over the dataset multiple times
    lr = 1e-2 * (1.1 ** (-epoch))
    print('Epoch {0}, learning rate {1}'.format(epoch + 1, lr))
    optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9)
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs_cpu, labels_cpu = data
        inputs, labels = inputs_cpu.to(device), labels_cpu.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 100 == 99:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0


Epoch 1, learning rate 0.01
[1,   100] loss: 1.909
[1,   200] loss: 1.692
[1,   300] loss: 1.526
[1,   400] loss: 1.401
[1,   500] loss: 1.319
[1,   600] loss: 1.297
[1,   700] loss: 1.209
[1,   800] loss: 1.158
[1,   900] loss: 1.141
[1,  1000] loss: 1.046
[1,  1100] loss: 0.986
[1,  1200] loss: 1.044
[1,  1300] loss: 0.970
[1,  1400] loss: 0.966
[1,  1500] loss: 0.929
Epoch 2, learning rate 0.00909090909090909
[2,   100] loss: 0.827
[2,   200] loss: 0.880
[2,   300] loss: 0.809
[2,   400] loss: 0.858
[2,   500] loss: 0.844
[2,   600] loss: 0.751
[2,   700] loss: 0.785
[2,   800] loss: 0.801
[2,   900] loss: 0.738
[2,  1000] loss: 0.770
[2,  1100] loss: 0.724
[2,  1200] loss: 0.762
[2,  1300] loss: 0.732
[2,  1400] loss: 0.717
[2,  1500] loss: 0.710
Epoch 3, learning rate 0.008264462809917354
[3,   100] loss: 0.678
[3,   200] loss: 0.678
[3,   300] loss: 0.691
[3,   400] loss: 0.701
[3,   500] loss: 0.633
[3,   600] loss: 0.611
[3,   700] loss: 0.623
[3,   800] loss: 0.643
[3,   900] 

Traceback (most recent call last):
  File "/opt/anaconda3/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/opt/anaconda3/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/opt/anaconda3/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
Traceback (most recent call last):
  File "/opt/anaconda3/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/opt/anaconda3/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/opt/anaconda3/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
  File "/opt/anaconda3/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/opt/anaconda3/lib/python3.7/

KeyboardInterrupt: 

In [13]:
net.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images_cpu, labels_cpu = data
        images, labels = images_cpu.to(device), labels_cpu.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

Accuracy of the network on the 10000 test images: 86 %


In [14]:
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in testloader:
        images_cpu, labels_cpu = data
        images, labels = images_cpu.to(device), labels_cpu.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))


Accuracy of plane : 77 %
Accuracy of   car : 100 %
Accuracy of  bird : 78 %
Accuracy of   cat : 73 %
Accuracy of  deer : 81 %
Accuracy of   dog : 80 %
Accuracy of  frog : 86 %
Accuracy of horse : 100 %
Accuracy of  ship : 88 %
Accuracy of truck : 100 %
