In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.optim import Adam, SGD
import torchvision
import torchvision.transforms as transforms

import sys, os, math
import argparse

In [2]:
lr=0.01
data='cifar10'
root='./data/'
model='vgg'
model_out='./checkpoint/cifar10_vgg_ReLU.pth'
resume = False

In [3]:
if data == 'cifar10':
    nclass = 10
    img_width = 32
    transform_train = transforms.Compose([
#         transforms.RandomCrop(32, padding=4),
#         transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])
    trainset = torchvision.datasets.CIFAR10(root=root, train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=8)
    testset = torchvision.datasets.CIFAR10(root=root, train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8)
    
elif data == 'cifar100':
    nclass = 100
    img_width = 32
    transform_train = transforms.Compose([
#         transforms.RandomCrop(32, padding=4),
#         transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])
    trainset = torchvision.datasets.CIFAR100(root=root, train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=8)
    testset = torchvision.datasets.CIFAR100(root=root, train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8)
    
elif data == 'stl10':
    nclass = 10
    img_width = 32
    transform_train = transforms.Compose([
#         transforms.RandomCrop(32, padding=4),
#         transforms.RandomHorizontalFlip(),
        transforms.Resize((img_width,img_width)),
        transforms.ToTensor(),
    ])
    transform_test = transforms.Compose([
        transforms.Resize((img_width,img_width)),
        transforms.ToTensor(),
    ])
    trainset = torchvision.datasets.STL10(root=root, split='train', transform=transform_train, target_transform=None, download=True)
    testset = torchvision.datasets.STL10(root=root, split='test', transform=transform_test, target_transform=None, download=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=8)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
if model == 'vgg':
    from models.vgg import VGG_ReLU
    net = nn.DataParallel(VGG_ReLU('VGG16', nclass, img_width=img_width).cuda())
    
net

DataParallel(
  (module): VGG_ReLU(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace)
      (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (9): ReLU(inplace)
      (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (12): ReLU(inplace)
      (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, c

In [5]:
if resume:
    print(f'==> Resuming from {model_out}')
    net.load_state_dict(torch.load(model_out))

In [6]:
cudnn.benchmark = True

In [7]:
criterion = nn.CrossEntropyLoss()

In [8]:
def train(epoch):
    print('Epoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.cuda(), targets.cuda()
        optimizer.zero_grad()
        outputs, _ = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        pred = torch.max(outputs, dim=1)[1]
        correct += torch.sum(pred.eq(targets)).item()
        total += targets.numel()
    print(f'[TRAIN] Acc: {100.*correct/total:.3f}')

In [9]:
def test(epoch):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs, _ = net(inputs)
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
        print(f'[TEST] Acc: {100.*correct/total:.3f}')

    # Save checkpoint after each epoch
    torch.save(net.state_dict(), model_out)

In [10]:
if data == 'cifar10':
    epochs = [50, 50, 50, 50]
elif data == 'cifar100':
    epochs = [50, 50, 50, 50]
elif data == 'stl10':
    epochs = [50, 50, 50, 50]

In [11]:
count = 0

In [12]:
for epoch in epochs:
    optimizer = Adam(net.parameters(), lr=lr)
    for _ in range(epoch):
        train(count)
        test(count)
        count += 1
    lr /= 10

Epoch: 0
[TRAIN] Acc: 15.818
[TEST] Acc: 19.290
Epoch: 1
[TRAIN] Acc: 31.294
[TEST] Acc: 32.880
Epoch: 2
[TRAIN] Acc: 52.844
[TEST] Acc: 52.390
Epoch: 3
[TRAIN] Acc: 66.700
[TEST] Acc: 60.730
Epoch: 4
[TRAIN] Acc: 74.344
[TEST] Acc: 56.990
Epoch: 5
[TRAIN] Acc: 79.222
[TEST] Acc: 56.180
Epoch: 6
[TRAIN] Acc: 82.560
[TEST] Acc: 76.050
Epoch: 7
[TRAIN] Acc: 85.278
[TEST] Acc: 77.140
Epoch: 8
[TRAIN] Acc: 87.706
[TEST] Acc: 78.840
Epoch: 9
[TRAIN] Acc: 89.770
[TEST] Acc: 79.420
Epoch: 10
[TRAIN] Acc: 91.344
[TEST] Acc: 75.410
Epoch: 11
[TRAIN] Acc: 92.440
[TEST] Acc: 78.890
Epoch: 12
[TRAIN] Acc: 93.806
[TEST] Acc: 76.490
Epoch: 13
[TRAIN] Acc: 94.748
[TEST] Acc: 80.820
Epoch: 14
[TRAIN] Acc: 95.596
[TEST] Acc: 78.200
Epoch: 15
[TRAIN] Acc: 95.768
[TEST] Acc: 82.000
Epoch: 16
[TRAIN] Acc: 96.464
[TEST] Acc: 81.310
Epoch: 17
[TRAIN] Acc: 96.730
[TEST] Acc: 82.620
Epoch: 18
[TRAIN] Acc: 97.566
[TEST] Acc: 83.100
Epoch: 19
[TRAIN] Acc: 97.548
[TEST] Acc: 81.720
Epoch: 20
[TRAIN] Acc: 97.830


[TRAIN] Acc: 100.000
[TEST] Acc: 86.510
Epoch: 166
[TRAIN] Acc: 100.000
[TEST] Acc: 86.520
Epoch: 167
[TRAIN] Acc: 100.000
[TEST] Acc: 86.410
Epoch: 168
[TRAIN] Acc: 100.000
[TEST] Acc: 86.340
Epoch: 169
[TRAIN] Acc: 100.000
[TEST] Acc: 86.500
Epoch: 170
[TRAIN] Acc: 100.000
[TEST] Acc: 86.440
Epoch: 171
[TRAIN] Acc: 100.000
[TEST] Acc: 86.550
Epoch: 172
[TRAIN] Acc: 100.000
[TEST] Acc: 86.530
Epoch: 173
[TRAIN] Acc: 100.000
[TEST] Acc: 86.630
Epoch: 174
[TRAIN] Acc: 100.000
[TEST] Acc: 86.570
Epoch: 175
[TRAIN] Acc: 100.000
[TEST] Acc: 86.370
Epoch: 176
[TRAIN] Acc: 100.000
[TEST] Acc: 86.620
Epoch: 177
[TRAIN] Acc: 100.000
[TEST] Acc: 86.580
Epoch: 178
[TRAIN] Acc: 100.000
[TEST] Acc: 86.460
Epoch: 179
[TRAIN] Acc: 100.000
[TEST] Acc: 86.480
Epoch: 180
[TRAIN] Acc: 100.000
[TEST] Acc: 86.560
Epoch: 181
[TRAIN] Acc: 100.000
[TEST] Acc: 86.530
Epoch: 182
[TRAIN] Acc: 100.000
[TEST] Acc: 86.590
Epoch: 183
[TRAIN] Acc: 100.000
[TEST] Acc: 86.670
Epoch: 184
[TRAIN] Acc: 100.000
[TEST] Acc

In [2]:
lr=0.01
data='stl10'
root='./data/'
model='vgg'
model_out='./checkpoint/stl10_vgg_ReLU.pth'
resume = False

In [12]:
for epoch in epochs:
    optimizer = Adam(net.parameters(), lr=lr)
    for _ in range(epoch):
        train(count)
        test(count)
        count += 1
    lr /= 10

Epoch: 0
[TRAIN] Acc: 10.520
[TEST] Acc: 12.537
Epoch: 1
[TRAIN] Acc: 11.360
[TEST] Acc: 10.725
Epoch: 2
[TRAIN] Acc: 11.120
[TEST] Acc: 12.338
Epoch: 3
[TRAIN] Acc: 12.340
[TEST] Acc: 13.550
Epoch: 4
[TRAIN] Acc: 15.680
[TEST] Acc: 16.250
Epoch: 5
[TRAIN] Acc: 18.020
[TEST] Acc: 16.225
Epoch: 6
[TRAIN] Acc: 22.100
[TEST] Acc: 15.400
Epoch: 7
[TRAIN] Acc: 23.760
[TEST] Acc: 23.137
Epoch: 8
[TRAIN] Acc: 25.460
[TEST] Acc: 17.062
Epoch: 9
[TRAIN] Acc: 25.860
[TEST] Acc: 17.038
Epoch: 10
[TRAIN] Acc: 27.720
[TEST] Acc: 22.488
Epoch: 11
[TRAIN] Acc: 28.540
[TEST] Acc: 19.038
Epoch: 12
[TRAIN] Acc: 28.640
[TEST] Acc: 21.575
Epoch: 13
[TRAIN] Acc: 30.200
[TEST] Acc: 28.875
Epoch: 14
[TRAIN] Acc: 31.180
[TEST] Acc: 31.887
Epoch: 15
[TRAIN] Acc: 34.300
[TEST] Acc: 26.600
Epoch: 16
[TRAIN] Acc: 36.260
[TEST] Acc: 36.513
Epoch: 17
[TRAIN] Acc: 38.140
[TEST] Acc: 35.388
Epoch: 18
[TRAIN] Acc: 39.220
[TEST] Acc: 33.462
Epoch: 19
[TRAIN] Acc: 42.460
[TEST] Acc: 36.725
Epoch: 20
[TRAIN] Acc: 44.300


[TEST] Acc: 60.550
Epoch: 165
[TRAIN] Acc: 100.000
[TEST] Acc: 60.600
Epoch: 166
[TRAIN] Acc: 100.000
[TEST] Acc: 60.675
Epoch: 167
[TRAIN] Acc: 100.000
[TEST] Acc: 60.562
Epoch: 168
[TRAIN] Acc: 100.000
[TEST] Acc: 60.712
Epoch: 169
[TRAIN] Acc: 100.000
[TEST] Acc: 60.638
Epoch: 170
[TRAIN] Acc: 100.000
[TEST] Acc: 60.750
Epoch: 171
[TRAIN] Acc: 100.000
[TEST] Acc: 60.825
Epoch: 172
[TRAIN] Acc: 100.000
[TEST] Acc: 60.850
Epoch: 173
[TRAIN] Acc: 100.000
[TEST] Acc: 60.700
Epoch: 174
[TRAIN] Acc: 100.000
[TEST] Acc: 60.700
Epoch: 175
[TRAIN] Acc: 100.000
[TEST] Acc: 60.663
Epoch: 176
[TRAIN] Acc: 100.000
[TEST] Acc: 60.737
Epoch: 177
[TRAIN] Acc: 100.000
[TEST] Acc: 60.487
Epoch: 178
[TRAIN] Acc: 100.000
[TEST] Acc: 60.737
Epoch: 179
[TRAIN] Acc: 100.000
[TEST] Acc: 60.763
Epoch: 180
[TRAIN] Acc: 100.000
[TEST] Acc: 60.750
Epoch: 181
[TRAIN] Acc: 100.000
[TEST] Acc: 60.700
Epoch: 182
[TRAIN] Acc: 100.000
[TEST] Acc: 60.600
Epoch: 183
[TRAIN] Acc: 100.000
[TEST] Acc: 60.737
Epoch: 184
[

In [2]:
lr=0.01
data='cifar100'
root='./data/'
model='vgg'
model_out='./checkpoint/cifar100_vgg_ReLU.pth'
resume = False

In [12]:
for epoch in epochs:
    optimizer = Adam(net.parameters(), lr=lr)
    for _ in range(epoch):
        train(count)
        test(count)
        count += 1
    lr /= 10

Epoch: 0
[TRAIN] Acc: 1.368
[TEST] Acc: 2.140
Epoch: 1
[TRAIN] Acc: 2.760
[TEST] Acc: 4.120
Epoch: 2
[TRAIN] Acc: 5.608
[TEST] Acc: 5.680
Epoch: 3
[TRAIN] Acc: 8.578
[TEST] Acc: 6.810
Epoch: 4
[TRAIN] Acc: 12.042
[TEST] Acc: 12.300
Epoch: 5
[TRAIN] Acc: 16.670
[TEST] Acc: 15.630
Epoch: 6
[TRAIN] Acc: 23.162
[TEST] Acc: 19.630
Epoch: 7
[TRAIN] Acc: 28.002
[TEST] Acc: 27.200
Epoch: 8
[TRAIN] Acc: 33.262
[TEST] Acc: 27.840
Epoch: 9
[TRAIN] Acc: 37.922
[TEST] Acc: 30.170
Epoch: 10
[TRAIN] Acc: 42.210
[TEST] Acc: 32.640
Epoch: 11
[TRAIN] Acc: 46.446
[TEST] Acc: 35.670
Epoch: 12
[TRAIN] Acc: 50.478
[TEST] Acc: 38.810
Epoch: 13
[TRAIN] Acc: 54.634
[TEST] Acc: 38.920
Epoch: 14
[TRAIN] Acc: 58.306
[TEST] Acc: 39.060
Epoch: 15
[TRAIN] Acc: 61.852
[TEST] Acc: 41.480
Epoch: 16
[TRAIN] Acc: 65.646
[TEST] Acc: 42.990
Epoch: 17
[TRAIN] Acc: 69.024
[TEST] Acc: 41.180
Epoch: 18
[TRAIN] Acc: 72.144
[TEST] Acc: 44.390
Epoch: 19
[TRAIN] Acc: 75.098
[TEST] Acc: 44.820
Epoch: 20
[TRAIN] Acc: 78.286
[TEST] A

[TRAIN] Acc: 99.984
[TEST] Acc: 51.500
Epoch: 167
[TRAIN] Acc: 99.984
[TEST] Acc: 51.160
Epoch: 168
[TRAIN] Acc: 99.984
[TEST] Acc: 51.310
Epoch: 169
[TRAIN] Acc: 99.980
[TEST] Acc: 51.300
Epoch: 170
[TRAIN] Acc: 99.982
[TEST] Acc: 51.410
Epoch: 171
[TRAIN] Acc: 99.984
[TEST] Acc: 51.350
Epoch: 172
[TRAIN] Acc: 99.978
[TEST] Acc: 51.240
Epoch: 173
[TRAIN] Acc: 99.988
[TEST] Acc: 51.270
Epoch: 174
[TRAIN] Acc: 99.978
[TEST] Acc: 51.310
Epoch: 175
[TRAIN] Acc: 99.982
[TEST] Acc: 51.350
Epoch: 176
[TRAIN] Acc: 99.974
[TEST] Acc: 51.310
Epoch: 177
[TRAIN] Acc: 99.978
[TEST] Acc: 51.400
Epoch: 178
[TRAIN] Acc: 99.982
[TEST] Acc: 51.120
Epoch: 179
[TRAIN] Acc: 99.970
[TEST] Acc: 51.280
Epoch: 180
[TRAIN] Acc: 99.988
[TEST] Acc: 51.110
Epoch: 181
[TRAIN] Acc: 99.980
[TEST] Acc: 51.390
Epoch: 182
[TRAIN] Acc: 99.980
[TEST] Acc: 51.380
Epoch: 183
[TRAIN] Acc: 99.976
[TEST] Acc: 51.250
Epoch: 184
[TRAIN] Acc: 99.982
[TEST] Acc: 51.360
Epoch: 185
[TRAIN] Acc: 99.974
[TEST] Acc: 51.350
Epoch: 186
