In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse

from models.resnet import ResNet18
from utils.util import progress_bar

In [12]:
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
    cudnn.benchmark = True
    print("GPU is available and trained on GPU")

GPU is available and trained on GPU


In [13]:
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

In [19]:
# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=256, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified


In [20]:
# Model
print('==> Building model..')
net = ResNet18()
net = net.to(device)
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    net = torch.nn.DataParallel(net)

==> Building model..
Let's use 2 GPUs!


In [21]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1,
                      momentum=0.9, weight_decay=5e-4)
step_lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[100, 150], gamma=0.1)

In [22]:
# Training
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        print("Outside: input size", inputs.size(),
          "output_size", outputs.size())
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))


def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc

In [23]:
for epoch in range(start_epoch, start_epoch+2):
    train(epoch)
    test(epoch)
    step_lr_scheduler.step()


Epoch: 0
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10])
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10])| Tot: 2ms | Loss: 2.405 | Acc: 9.766% (25/25 1/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 97ms | Loss: 2.704 | Acc: 8.984% (46/51 2/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 186ms | Loss: 3.003 | Acc: 10.938% (84/76 3/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 263ms | Loss: 3.199 | Acc: 11.914% (122/102 4/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 337ms | Loss: 3.363 | Acc: 12.734% (163/128 5/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 415ms | Loss: 3.406 | Acc: 12.044% (185/153 6/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) 

Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 4s409ms | Loss: 2.178 | Acc: 22.501% (3341/1484 58/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 4s483ms | Loss: 2.171 | Acc: 22.709% (3430/1510 59/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 4s558ms | Loss: 2.165 | Acc: 22.832% (3507/1536 60/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 4s633ms | Loss: 2.160 | Acc: 22.964% (3586/1561 61/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 4s707ms | Loss: 2.154 | Acc: 23.072% (3662/1587 62/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 4s783ms | Loss: 2.149 | Acc: 23.165% (3736/1612 63/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 4s855ms | Loss: 2.142 | Acc: 23.285% (3815/1638 

Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 8s679ms | Loss: 1.956 | Acc: 28.670% (8367/2918 114/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 8s753ms | Loss: 1.954 | Acc: 28.750% (8464/2944 115/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 8s827ms | Loss: 1.952 | Acc: 28.832% (8562/2969 116/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 8s902ms | Loss: 1.949 | Acc: 28.966% (8676/2995 117/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 8s976ms | Loss: 1.946 | Acc: 29.092% (8788/3020 118/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 9s54ms | Loss: 1.944 | Acc: 29.175% (8888/3046 119/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 9s128ms | Loss: 1.941 | Acc: 29.255% (8987/

Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 12s909ms | Loss: 1.844 | Acc: 32.491% (14140/4352 170/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 12s987ms | Loss: 1.842 | Acc: 32.545% (14247/4377 171/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 13s63ms | Loss: 1.840 | Acc: 32.604% (14356/4403 172/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 13s137ms | Loss: 1.838 | Acc: 32.668% (14468/4428 173/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 13s212ms | Loss: 1.837 | Acc: 32.707% (14569/4454 174/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 13s287ms | Loss: 1.836 | Acc: 32.757% (14675/4480 175/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 13s362ms | Loss: 1.834 | Acc: 3

Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 2s296ms | Loss: 1.498 | Acc: 44.504% (3304/742 29/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 2s378ms | Loss: 1.499 | Acc: 44.609% (3426/768 30/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 2s454ms | Loss: 1.496 | Acc: 44.708% (3548/793 31/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 2s528ms | Loss: 1.494 | Acc: 44.824% (3672/819 32/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 2s602ms | Loss: 1.493 | Acc: 44.827% (3787/844 33/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 2s676ms | Loss: 1.493 | Acc: 44.841% (3903/870 34/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 2s750ms | Loss: 1.490 | Acc: 45.067% (4038/896 35/196 

Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 6s738ms | Loss: 1.434 | Acc: 47.252% (10403/2201 86/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 6s814ms | Loss: 1.432 | Acc: 47.311% (10537/22 87/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 6s890ms | Loss: 1.431 | Acc: 47.359% (10669/2252 88/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 6s965ms | Loss: 1.430 | Acc: 47.389% (10797/22 89/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 7s40ms | Loss: 1.429 | Acc: 47.378% (10916/2304 90/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 7s114ms | Loss: 1.428 | Acc: 47.420% (11047/2329 91/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 7s189ms | Loss: 1.429 | Acc: 47.401% (11164/235

Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 11s102ms | Loss: 1.382 | Acc: 49.133% (17861/3635 142/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 11s181ms | Loss: 1.382 | Acc: 49.148% (17992/3660 143/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 11s255ms | Loss: 1.381 | Acc: 49.186% (18132/3686 144/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 11s334ms | Loss: 1.380 | Acc: 49.221% (18271/3712 145/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 11s407ms | Loss: 1.380 | Acc: 49.192% (18386/3737 146/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 11s488ms | Loss: 1.379 | Acc: 49.221% (18523/3763 147/196 
Outside: input size torch.Size([256, 3, 32, 32]) output_size torch.Size([256, 10]) Tot: 11s562ms | Loss: 1.378 | Acc: 

Saving..
