In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse


In [4]:
import pickle
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

In [24]:
cifar_path = '.data/cifar-10-batches-py'

In [25]:
data_1 = unpickle(cifar_path)

In [26]:
type(data_1)

dict

In [27]:
data_1.keys()

dict_keys([b'batch_label', b'labels', b'data', b'filenames'])

In [28]:
data_1[b'batch_label'],data_1[b'data']

(b'training batch 1 of 5', array([[ 59,  43,  50, ..., 140,  84,  72],
        [154, 126, 105, ..., 139, 142, 144],
        [255, 253, 253, ...,  83,  83,  84],
        ...,
        [ 71,  60,  74, ...,  68,  69,  68],
        [250, 254, 211, ..., 215, 255, 254],
        [ 62,  61,  60, ..., 130, 130, 131]], dtype=uint8))

In [29]:
# transform
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4), # size of 32*32, padding 4 px
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [46]:
data_path = './data'
trainset = torchvision.datasets.CIFAR10(
    root=data_path, train=True, download=False, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(
    root=data_path, train=False, download=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

In [36]:
# res net model
import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [37]:
net = ResNet([2, 2, 2, 2]) # res18

In [44]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
lr = 0.1
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=lr,
                      momentum=0.9, weight_decay=5e-4)


# Training
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        print(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            print(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))


    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc

In [45]:
for epoch in range(start_epoch, start_epoch+200):
    train(epoch)
    test(epoch)


Epoch: 0
0 391 Loss: 2.805 | Acc: 14.844% (19/128)
1 391 Loss: 3.090 | Acc: 15.625% (40/256)
2 391 Loss: 3.424 | Acc: 15.365% (59/384)
3 391 Loss: 3.842 | Acc: 15.625% (80/512)
4 391 Loss: 4.140 | Acc: 13.906% (89/640)
5 391 Loss: 4.702 | Acc: 13.411% (103/768)
6 391 Loss: 4.913 | Acc: 13.504% (121/896)
7 391 Loss: 4.864 | Acc: 12.988% (133/1024)
8 391 Loss: 4.831 | Acc: 12.413% (143/1152)
9 391 Loss: 4.895 | Acc: 12.188% (156/1280)
10 391 Loss: 4.834 | Acc: 11.719% (165/1408)
11 391 Loss: 4.874 | Acc: 11.654% (179/1536)
12 391 Loss: 4.770 | Acc: 11.418% (190/1664)
13 391 Loss: 4.920 | Acc: 11.272% (202/1792)
14 391 Loss: 4.863 | Acc: 11.406% (219/1920)
15 391 Loss: 4.954 | Acc: 11.279% (231/2048)
16 391 Loss: 4.909 | Acc: 11.213% (244/2176)
17 391 Loss: 4.829 | Acc: 11.328% (261/2304)
18 391 Loss: 4.710 | Acc: 11.883% (289/2432)
19 391 Loss: 4.613 | Acc: 12.031% (308/2560)
20 391 Loss: 4.516 | Acc: 11.979% (322/2688)
21 391 Loss: 4.428 | Acc: 11.896% (335/2816)
22 391 Loss: 4.351 | A

176 391 Loss: 2.432 | Acc: 20.401% (4622/22656)
177 391 Loss: 2.429 | Acc: 20.462% (4662/22784)
178 391 Loss: 2.426 | Acc: 20.539% (4706/22912)
179 391 Loss: 2.425 | Acc: 20.556% (4736/23040)
180 391 Loss: 2.422 | Acc: 20.584% (4769/23168)
181 391 Loss: 2.419 | Acc: 20.639% (4808/23296)
182 391 Loss: 2.417 | Acc: 20.654% (4838/23424)
183 391 Loss: 2.414 | Acc: 20.724% (4881/23552)
184 391 Loss: 2.411 | Acc: 20.764% (4917/23680)
185 391 Loss: 2.409 | Acc: 20.775% (4946/23808)
186 391 Loss: 2.407 | Acc: 20.776% (4973/23936)
187 391 Loss: 2.404 | Acc: 20.844% (5016/24064)
188 391 Loss: 2.402 | Acc: 20.883% (5052/24192)
189 391 Loss: 2.399 | Acc: 20.913% (5086/24320)
190 391 Loss: 2.396 | Acc: 20.963% (5125/24448)
191 391 Loss: 2.394 | Acc: 20.996% (5160/24576)
192 391 Loss: 2.391 | Acc: 21.041% (5198/24704)
193 391 Loss: 2.388 | Acc: 21.102% (5240/24832)
194 391 Loss: 2.386 | Acc: 21.114% (5270/24960)
195 391 Loss: 2.384 | Acc: 21.122% (5299/25088)
196 391 Loss: 2.381 | Acc: 21.153% (5334

347 391 Loss: 2.142 | Acc: 25.507% (11362/44544)
348 391 Loss: 2.141 | Acc: 25.537% (11408/44672)
349 391 Loss: 2.139 | Acc: 25.567% (11454/44800)
350 391 Loss: 2.138 | Acc: 25.588% (11496/44928)
351 391 Loss: 2.138 | Acc: 25.617% (11542/45056)
352 391 Loss: 2.137 | Acc: 25.622% (11577/45184)
353 391 Loss: 2.136 | Acc: 25.631% (11614/45312)
354 391 Loss: 2.135 | Acc: 25.651% (11656/45440)
355 391 Loss: 2.133 | Acc: 25.682% (11703/45568)
356 391 Loss: 2.133 | Acc: 25.698% (11743/45696)
357 391 Loss: 2.132 | Acc: 25.729% (11790/45824)
358 391 Loss: 2.130 | Acc: 25.755% (11835/45952)
359 391 Loss: 2.129 | Acc: 25.766% (11873/46080)
360 391 Loss: 2.128 | Acc: 25.779% (11912/46208)
361 391 Loss: 2.128 | Acc: 25.805% (11957/46336)
362 391 Loss: 2.127 | Acc: 25.826% (12000/46464)
363 391 Loss: 2.126 | Acc: 25.854% (12046/46592)
364 391 Loss: 2.125 | Acc: 25.878% (12090/46720)
365 391 Loss: 2.124 | Acc: 25.911% (12139/46848)
366 391 Loss: 2.123 | Acc: 25.928% (12180/46976)
367 391 Loss: 2.123 

Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/Cellar/python/3.7.7/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/usr/local/Cellar/python/3.7.7/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/local/Cellar/python/3.7.7/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/local/Cellar/python/3.7.7/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
  File "/usr/local/Cellar/python/3.7.7/Frameworks/Python.framework/Versions/3.7/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/usr/l

KeyboardInterrupt: 