In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os

In [29]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

In [30]:
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [31]:
def ResNet18():
    return ResNet(BasicBlock, [2,2,2,2])

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

In [3]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [4]:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [36]:
net = ResNet18()

In [37]:
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

In [38]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

In [39]:
# Training
def train(epoch):
    print('\nTrain Epoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        print('Loss: %.3f | Acc: %.3f%% (%d/%d)' % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))



In [40]:
def test(epoch):
    print('\nTest Epoch: %d' % epoch)
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            print('Loss: %.3f | Acc: %.3f%% (%d/%d)' % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
            
            # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc


In [41]:
for epoch in range(start_epoch, start_epoch+10):
    train(epoch)
    test(epoch)


Train Epoch: 0
Loss: 2.427 | Acc: 3.125% (4/128)
Loss: 2.760 | Acc: 7.812% (20/256)
Loss: 3.075 | Acc: 9.115% (35/384)
Loss: 3.431 | Acc: 10.547% (54/512)
Loss: 3.811 | Acc: 10.625% (68/640)
Loss: 3.864 | Acc: 11.719% (90/768)
Loss: 3.738 | Acc: 12.054% (108/896)
Loss: 3.929 | Acc: 12.012% (123/1024)
Loss: 3.911 | Acc: 12.240% (141/1152)
Loss: 3.887 | Acc: 11.953% (153/1280)
Loss: 3.888 | Acc: 12.997% (183/1408)
Loss: 3.837 | Acc: 12.695% (195/1536)
Loss: 3.790 | Acc: 12.560% (209/1664)
Loss: 3.760 | Acc: 12.891% (231/1792)
Loss: 3.726 | Acc: 12.917% (248/1920)
Loss: 3.688 | Acc: 12.842% (263/2048)
Loss: 3.637 | Acc: 13.051% (284/2176)
Loss: 3.593 | Acc: 13.108% (302/2304)
Loss: 3.543 | Acc: 13.199% (321/2432)
Loss: 3.504 | Acc: 13.477% (345/2560)
Loss: 3.480 | Acc: 13.430% (361/2688)
Loss: 3.456 | Acc: 13.459% (379/2816)
Loss: 3.447 | Acc: 13.553% (399/2944)
Loss: 3.406 | Acc: 13.477% (414/3072)
Loss: 3.386 | Acc: 13.469% (431/3200)
Loss: 3.348 | Acc: 13.522% (450/3328)
Loss: 3.324 |

Loss: 2.156 | Acc: 25.606% (6850/26752)
Loss: 2.154 | Acc: 25.625% (6888/26880)
Loss: 2.152 | Acc: 25.644% (6926/27008)
Loss: 2.151 | Acc: 25.634% (6956/27136)
Loss: 2.149 | Acc: 25.682% (7002/27264)
Loss: 2.148 | Acc: 25.683% (7035/27392)
Loss: 2.146 | Acc: 25.687% (7069/27520)
Loss: 2.145 | Acc: 25.720% (7111/27648)
Loss: 2.143 | Acc: 25.760% (7155/27776)
Loss: 2.141 | Acc: 25.796% (7198/27904)
Loss: 2.139 | Acc: 25.853% (7247/28032)
Loss: 2.138 | Acc: 25.902% (7294/28160)
Loss: 2.136 | Acc: 25.930% (7335/28288)
Loss: 2.134 | Acc: 25.996% (7387/28416)
Loss: 2.132 | Acc: 26.026% (7429/28544)
Loss: 2.130 | Acc: 26.081% (7478/28672)
Loss: 2.128 | Acc: 26.149% (7531/28800)
Loss: 2.127 | Acc: 26.179% (7573/28928)
Loss: 2.125 | Acc: 26.194% (7611/29056)
Loss: 2.124 | Acc: 26.216% (7651/29184)
Loss: 2.123 | Acc: 26.262% (7698/29312)
Loss: 2.121 | Acc: 26.315% (7747/29440)
Loss: 2.119 | Acc: 26.336% (7787/29568)
Loss: 2.118 | Acc: 26.360% (7828/29696)
Loss: 2.116 | Acc: 26.378% (7867/29824)


Loss: 1.540 | Acc: 43.760% (1094/2500)
Loss: 1.550 | Acc: 43.115% (1121/2600)
Loss: 1.550 | Acc: 43.074% (1163/2700)
Loss: 1.551 | Acc: 42.929% (1202/2800)
Loss: 1.557 | Acc: 42.931% (1245/2900)
Loss: 1.553 | Acc: 42.767% (1283/3000)
Loss: 1.550 | Acc: 42.935% (1331/3100)
Loss: 1.549 | Acc: 43.031% (1377/3200)
Loss: 1.548 | Acc: 43.121% (1423/3300)
Loss: 1.549 | Acc: 43.118% (1466/3400)
Loss: 1.550 | Acc: 43.143% (1510/3500)
Loss: 1.547 | Acc: 43.056% (1550/3600)
Loss: 1.548 | Acc: 43.054% (1593/3700)
Loss: 1.548 | Acc: 43.026% (1635/3800)
Loss: 1.545 | Acc: 43.154% (1683/3900)
Loss: 1.546 | Acc: 43.175% (1727/4000)
Loss: 1.545 | Acc: 43.268% (1774/4100)
Loss: 1.545 | Acc: 43.214% (1815/4200)
Loss: 1.544 | Acc: 43.163% (1856/4300)
Loss: 1.542 | Acc: 43.182% (1900/4400)
Loss: 1.539 | Acc: 43.311% (1949/4500)
Loss: 1.538 | Acc: 43.239% (1989/4600)
Loss: 1.535 | Acc: 43.319% (2036/4700)
Loss: 1.534 | Acc: 43.479% (2087/4800)
Loss: 1.534 | Acc: 43.449% (2129/4900)
Loss: 1.532 | Acc: 43.560

Loss: 1.554 | Acc: 42.367% (7321/17280)
Loss: 1.554 | Acc: 42.360% (7374/17408)
Loss: 1.554 | Acc: 42.376% (7431/17536)
Loss: 1.553 | Acc: 42.408% (7491/17664)
Loss: 1.553 | Acc: 42.395% (7543/17792)
Loss: 1.553 | Acc: 42.394% (7597/17920)
Loss: 1.553 | Acc: 42.387% (7650/18048)
Loss: 1.552 | Acc: 42.430% (7712/18176)
Loss: 1.552 | Acc: 42.411% (7763/18304)
Loss: 1.552 | Acc: 42.405% (7816/18432)
Loss: 1.552 | Acc: 42.446% (7878/18560)
Loss: 1.550 | Acc: 42.514% (7945/18688)
Loss: 1.550 | Acc: 42.501% (7997/18816)
Loss: 1.551 | Acc: 42.478% (8047/18944)
Loss: 1.551 | Acc: 42.455% (8097/19072)
Loss: 1.551 | Acc: 42.443% (8149/19200)
Loss: 1.551 | Acc: 42.441% (8203/19328)
Loss: 1.552 | Acc: 42.439% (8257/19456)
Loss: 1.551 | Acc: 42.453% (8314/19584)
Loss: 1.551 | Acc: 42.482% (8374/19712)
Loss: 1.551 | Acc: 42.500% (8432/19840)
Loss: 1.551 | Acc: 42.513% (8489/19968)
Loss: 1.550 | Acc: 42.516% (8544/20096)
Loss: 1.549 | Acc: 42.539% (8603/20224)
Loss: 1.549 | Acc: 42.527% (8655/20352)


Loss: 1.486 | Acc: 45.250% (19519/43136)
Loss: 1.486 | Acc: 45.269% (19585/43264)
Loss: 1.486 | Acc: 45.278% (19647/43392)
Loss: 1.486 | Acc: 45.283% (19707/43520)
Loss: 1.485 | Acc: 45.322% (19782/43648)
Loss: 1.484 | Acc: 45.340% (19848/43776)
Loss: 1.484 | Acc: 45.356% (19913/43904)
Loss: 1.483 | Acc: 45.365% (19975/44032)
Loss: 1.483 | Acc: 45.383% (20041/44160)
Loss: 1.483 | Acc: 45.385% (20100/44288)
Loss: 1.482 | Acc: 45.391% (20161/44416)
Loss: 1.482 | Acc: 45.414% (20229/44544)
Loss: 1.482 | Acc: 45.433% (20296/44672)
Loss: 1.481 | Acc: 45.449% (20361/44800)
Loss: 1.480 | Acc: 45.473% (20430/44928)
Loss: 1.480 | Acc: 45.479% (20491/45056)
Loss: 1.479 | Acc: 45.492% (20555/45184)
Loss: 1.479 | Acc: 45.516% (20624/45312)
Loss: 1.479 | Acc: 45.506% (20678/45440)
Loss: 1.479 | Acc: 45.521% (20743/45568)
Loss: 1.479 | Acc: 45.531% (20806/45696)
Loss: 1.478 | Acc: 45.550% (20873/45824)
Loss: 1.478 | Acc: 45.548% (20930/45952)
Loss: 1.478 | Acc: 45.556% (20992/46080)
Loss: 1.478 | Ac

Loss: 1.323 | Acc: 51.449% (3622/7040)
Loss: 1.325 | Acc: 51.395% (3684/7168)
Loss: 1.326 | Acc: 51.288% (3742/7296)
Loss: 1.325 | Acc: 51.387% (3815/7424)
Loss: 1.326 | Acc: 51.390% (3881/7552)
Loss: 1.324 | Acc: 51.458% (3952/7680)
Loss: 1.323 | Acc: 51.562% (4026/7808)
Loss: 1.323 | Acc: 51.537% (4090/7936)
Loss: 1.322 | Acc: 51.600% (4161/8064)
Loss: 1.324 | Acc: 51.562% (4224/8192)
Loss: 1.324 | Acc: 51.575% (4291/8320)
Loss: 1.324 | Acc: 51.574% (4357/8448)
Loss: 1.322 | Acc: 51.726% (4436/8576)
Loss: 1.324 | Acc: 51.643% (4495/8704)
Loss: 1.324 | Acc: 51.687% (4565/8832)
Loss: 1.323 | Acc: 51.763% (4638/8960)
Loss: 1.324 | Acc: 51.750% (4703/9088)
Loss: 1.322 | Acc: 51.780% (4772/9216)
Loss: 1.322 | Acc: 51.787% (4839/9344)
Loss: 1.322 | Acc: 51.858% (4912/9472)
Loss: 1.321 | Acc: 51.938% (4986/9600)
Loss: 1.320 | Acc: 51.963% (5055/9728)
Loss: 1.318 | Acc: 52.100% (5135/9856)
Loss: 1.318 | Acc: 52.113% (5203/9984)
Loss: 1.316 | Acc: 52.116% (5270/10112)
Loss: 1.315 | Acc: 52.10

Loss: 1.265 | Acc: 54.307% (18004/33152)
Loss: 1.265 | Acc: 54.327% (18080/33280)
Loss: 1.264 | Acc: 54.352% (18158/33408)
Loss: 1.264 | Acc: 54.359% (18230/33536)
Loss: 1.263 | Acc: 54.370% (18303/33664)
Loss: 1.263 | Acc: 54.365% (18371/33792)
Loss: 1.263 | Acc: 54.369% (18442/33920)
Loss: 1.263 | Acc: 54.376% (18514/34048)
Loss: 1.263 | Acc: 54.407% (18594/34176)
Loss: 1.262 | Acc: 54.437% (18674/34304)
Loss: 1.261 | Acc: 54.467% (18754/34432)
Loss: 1.261 | Acc: 54.476% (18827/34560)
Loss: 1.261 | Acc: 54.460% (18891/34688)
Loss: 1.261 | Acc: 54.463% (18962/34816)
Loss: 1.260 | Acc: 54.490% (19041/34944)
Loss: 1.260 | Acc: 54.522% (19122/35072)
Loss: 1.260 | Acc: 54.511% (19188/35200)
Loss: 1.260 | Acc: 54.532% (19265/35328)
Loss: 1.260 | Acc: 54.549% (19341/35456)
Loss: 1.259 | Acc: 54.567% (19417/35584)
Loss: 1.258 | Acc: 54.606% (19501/35712)
Loss: 1.258 | Acc: 54.621% (19576/35840)
Loss: 1.258 | Acc: 54.615% (19644/35968)
Loss: 1.258 | Acc: 54.635% (19721/36096)
Loss: 1.258 | Ac

Loss: 1.242 | Acc: 57.014% (4162/7300)
Loss: 1.238 | Acc: 57.027% (4220/7400)
Loss: 1.241 | Acc: 56.880% (4266/7500)
Loss: 1.238 | Acc: 56.947% (4328/7600)
Loss: 1.238 | Acc: 56.974% (4387/7700)
Loss: 1.239 | Acc: 56.949% (4442/7800)
Loss: 1.239 | Acc: 57.025% (4505/7900)
Loss: 1.239 | Acc: 57.013% (4561/8000)
Loss: 1.241 | Acc: 56.975% (4615/8100)
Loss: 1.242 | Acc: 56.976% (4672/8200)
Loss: 1.243 | Acc: 56.952% (4727/8300)
Loss: 1.246 | Acc: 56.810% (4772/8400)
Loss: 1.246 | Acc: 56.753% (4824/8500)
Loss: 1.246 | Acc: 56.756% (4881/8600)
Loss: 1.247 | Acc: 56.713% (4934/8700)
Loss: 1.248 | Acc: 56.682% (4988/8800)
Loss: 1.249 | Acc: 56.640% (5041/8900)
Loss: 1.250 | Acc: 56.622% (5096/9000)
Loss: 1.250 | Acc: 56.549% (5146/9100)
Loss: 1.251 | Acc: 56.598% (5207/9200)
Loss: 1.250 | Acc: 56.591% (5263/9300)
Loss: 1.251 | Acc: 56.564% (5317/9400)
Loss: 1.251 | Acc: 56.621% (5379/9500)
Loss: 1.251 | Acc: 56.625% (5436/9600)
Loss: 1.250 | Acc: 56.619% (5492/9700)
Loss: 1.252 | Acc: 56.551

Loss: 1.071 | Acc: 61.900% (14341/23168)
Loss: 1.071 | Acc: 61.890% (14418/23296)
Loss: 1.071 | Acc: 61.915% (14503/23424)
Loss: 1.070 | Acc: 61.918% (14583/23552)
Loss: 1.070 | Acc: 61.930% (14665/23680)
Loss: 1.070 | Acc: 61.925% (14743/23808)
Loss: 1.069 | Acc: 61.953% (14829/23936)
Loss: 1.070 | Acc: 61.956% (14909/24064)
Loss: 1.070 | Acc: 61.946% (14986/24192)
Loss: 1.069 | Acc: 61.961% (15069/24320)
Loss: 1.069 | Acc: 61.981% (15153/24448)
Loss: 1.069 | Acc: 61.979% (15232/24576)
Loss: 1.069 | Acc: 61.974% (15310/24704)
Loss: 1.070 | Acc: 61.960% (15386/24832)
Loss: 1.070 | Acc: 61.947% (15462/24960)
Loss: 1.070 | Acc: 61.922% (15535/25088)
Loss: 1.071 | Acc: 61.933% (15617/25216)
Loss: 1.070 | Acc: 61.936% (15697/25344)
Loss: 1.070 | Acc: 61.954% (15781/25472)
Loss: 1.069 | Acc: 61.969% (15864/25600)
Loss: 1.070 | Acc: 61.975% (15945/25728)
Loss: 1.069 | Acc: 61.974% (16024/25856)
Loss: 1.069 | Acc: 61.957% (16099/25984)
Loss: 1.069 | Acc: 61.956% (16178/26112)
Loss: 1.068 | Ac

Loss: 1.029 | Acc: 63.482% (30959/48768)
Loss: 1.028 | Acc: 63.494% (31046/48896)
Loss: 1.028 | Acc: 63.504% (31132/49024)
Loss: 1.028 | Acc: 63.517% (31220/49152)
Loss: 1.027 | Acc: 63.531% (31308/49280)
Loss: 1.027 | Acc: 63.542% (31395/49408)
Loss: 1.027 | Acc: 63.548% (31479/49536)
Loss: 1.027 | Acc: 63.547% (31560/49664)
Loss: 1.027 | Acc: 63.548% (31642/49792)
Loss: 1.027 | Acc: 63.554% (31726/49920)
Loss: 1.027 | Acc: 63.570% (31785/50000)

Test Epoch: 3
Loss: 0.963 | Acc: 66.000% (66/100)
Loss: 0.962 | Acc: 64.500% (129/200)
Loss: 0.990 | Acc: 64.000% (192/300)
Loss: 0.969 | Acc: 64.500% (258/400)
Loss: 0.986 | Acc: 63.200% (316/500)
Loss: 0.955 | Acc: 64.333% (386/600)
Loss: 0.983 | Acc: 63.857% (447/700)
Loss: 1.017 | Acc: 62.625% (501/800)
Loss: 1.017 | Acc: 62.444% (562/900)
Loss: 1.002 | Acc: 63.400% (634/1000)
Loss: 0.999 | Acc: 63.636% (700/1100)
Loss: 0.999 | Acc: 63.917% (767/1200)
Loss: 0.997 | Acc: 64.077% (833/1300)
Loss: 1.006 | Acc: 63.429% (888/1400)
Loss: 1.003 

Loss: 0.926 | Acc: 67.713% (8754/12928)
Loss: 0.925 | Acc: 67.678% (8836/13056)
Loss: 0.923 | Acc: 67.734% (8930/13184)
Loss: 0.923 | Acc: 67.758% (9020/13312)
Loss: 0.924 | Acc: 67.731% (9103/13440)
Loss: 0.924 | Acc: 67.733% (9190/13568)
Loss: 0.925 | Acc: 67.655% (9266/13696)
Loss: 0.926 | Acc: 67.636% (9350/13824)
Loss: 0.924 | Acc: 67.646% (9438/13952)
Loss: 0.925 | Acc: 67.578% (9515/14080)
Loss: 0.926 | Acc: 67.546% (9597/14208)
Loss: 0.926 | Acc: 67.557% (9685/14336)
Loss: 0.926 | Acc: 67.575% (9774/14464)
Loss: 0.925 | Acc: 67.612% (9866/14592)
Loss: 0.926 | Acc: 67.602% (9951/14720)
Loss: 0.925 | Acc: 67.645% (10044/14848)
Loss: 0.926 | Acc: 67.615% (10126/14976)
Loss: 0.927 | Acc: 67.565% (10205/15104)
Loss: 0.927 | Acc: 67.535% (10287/15232)
Loss: 0.927 | Acc: 67.546% (10375/15360)
Loss: 0.927 | Acc: 67.530% (10459/15488)
Loss: 0.927 | Acc: 67.476% (10537/15616)
Loss: 0.926 | Acc: 67.511% (10629/15744)
Loss: 0.925 | Acc: 67.566% (10724/15872)
Loss: 0.925 | Acc: 67.569% (108

Loss: 0.894 | Acc: 68.435% (26542/38784)
Loss: 0.893 | Acc: 68.460% (26639/38912)
Loss: 0.894 | Acc: 68.471% (26731/39040)
Loss: 0.893 | Acc: 68.474% (26820/39168)
Loss: 0.893 | Acc: 68.485% (26912/39296)
Loss: 0.893 | Acc: 68.486% (27000/39424)
Loss: 0.893 | Acc: 68.495% (27091/39552)
Loss: 0.892 | Acc: 68.533% (27194/39680)
Loss: 0.892 | Acc: 68.544% (27286/39808)
Loss: 0.892 | Acc: 68.560% (27380/39936)
Loss: 0.892 | Acc: 68.560% (27468/40064)
Loss: 0.892 | Acc: 68.553% (27553/40192)
Loss: 0.892 | Acc: 68.534% (27633/40320)
Loss: 0.892 | Acc: 68.537% (27722/40448)
Loss: 0.892 | Acc: 68.543% (27812/40576)
Loss: 0.891 | Acc: 68.556% (27905/40704)
Loss: 0.891 | Acc: 68.583% (28004/40832)
Loss: 0.891 | Acc: 68.586% (28093/40960)
Loss: 0.890 | Acc: 68.582% (28179/41088)
Loss: 0.891 | Acc: 68.575% (28264/41216)
Loss: 0.891 | Acc: 68.576% (28352/41344)
Loss: 0.890 | Acc: 68.596% (28448/41472)
Loss: 0.890 | Acc: 68.596% (28536/41600)
Loss: 0.890 | Acc: 68.618% (28633/41728)
Loss: 0.890 | Ac

Loss: 0.840 | Acc: 70.477% (1714/2432)
Loss: 0.836 | Acc: 70.469% (1804/2560)
Loss: 0.835 | Acc: 70.424% (1893/2688)
Loss: 0.839 | Acc: 70.312% (1980/2816)
Loss: 0.833 | Acc: 70.652% (2080/2944)
Loss: 0.832 | Acc: 70.833% (2176/3072)
Loss: 0.833 | Acc: 70.812% (2266/3200)
Loss: 0.831 | Acc: 70.944% (2361/3328)
Loss: 0.831 | Acc: 71.007% (2454/3456)
Loss: 0.839 | Acc: 70.843% (2539/3584)
Loss: 0.838 | Acc: 70.851% (2630/3712)
Loss: 0.833 | Acc: 70.885% (2722/3840)
Loss: 0.836 | Acc: 70.892% (2813/3968)
Loss: 0.833 | Acc: 70.996% (2908/4096)
Loss: 0.833 | Acc: 70.999% (2999/4224)
Loss: 0.830 | Acc: 71.048% (3092/4352)
Loss: 0.832 | Acc: 70.982% (3180/4480)
Loss: 0.831 | Acc: 70.833% (3264/4608)
Loss: 0.830 | Acc: 70.840% (3355/4736)
Loss: 0.826 | Acc: 70.929% (3450/4864)
Loss: 0.827 | Acc: 70.933% (3541/4992)
Loss: 0.831 | Acc: 70.742% (3622/5120)
Loss: 0.827 | Acc: 70.884% (3720/5248)
Loss: 0.826 | Acc: 70.852% (3809/5376)
Loss: 0.825 | Acc: 70.821% (3898/5504)
Loss: 0.827 | Acc: 70.739

Loss: 0.804 | Acc: 71.486% (20405/28544)
Loss: 0.804 | Acc: 71.495% (20499/28672)
Loss: 0.804 | Acc: 71.524% (20599/28800)
Loss: 0.804 | Acc: 71.519% (20689/28928)
Loss: 0.803 | Acc: 71.558% (20792/29056)
Loss: 0.803 | Acc: 71.553% (20882/29184)
Loss: 0.804 | Acc: 71.544% (20971/29312)
Loss: 0.804 | Acc: 71.515% (21054/29440)
Loss: 0.804 | Acc: 71.517% (21146/29568)
Loss: 0.804 | Acc: 71.525% (21240/29696)
Loss: 0.804 | Acc: 71.526% (21332/29824)
Loss: 0.804 | Acc: 71.528% (21424/29952)
Loss: 0.804 | Acc: 71.549% (21522/30080)
Loss: 0.803 | Acc: 71.551% (21614/30208)
Loss: 0.803 | Acc: 71.552% (21706/30336)
Loss: 0.803 | Acc: 71.563% (21801/30464)
Loss: 0.803 | Acc: 71.591% (21901/30592)
Loss: 0.803 | Acc: 71.582% (21990/30720)
Loss: 0.803 | Acc: 71.580% (22081/30848)
Loss: 0.803 | Acc: 71.594% (22177/30976)
Loss: 0.803 | Acc: 71.579% (22264/31104)
Loss: 0.803 | Acc: 71.590% (22359/31232)
Loss: 0.803 | Acc: 71.598% (22453/31360)
Loss: 0.802 | Acc: 71.624% (22553/31488)
Loss: 0.802 | Ac

Loss: 0.871 | Acc: 70.553% (2681/3800)
Loss: 0.869 | Acc: 70.615% (2754/3900)
Loss: 0.866 | Acc: 70.800% (2832/4000)
Loss: 0.864 | Acc: 70.805% (2903/4100)
Loss: 0.867 | Acc: 70.762% (2972/4200)
Loss: 0.865 | Acc: 70.860% (3047/4300)
Loss: 0.866 | Acc: 70.841% (3117/4400)
Loss: 0.864 | Acc: 70.889% (3190/4500)
Loss: 0.866 | Acc: 70.848% (3259/4600)
Loss: 0.866 | Acc: 70.787% (3327/4700)
Loss: 0.869 | Acc: 70.688% (3393/4800)
Loss: 0.865 | Acc: 70.857% (3472/4900)
Loss: 0.867 | Acc: 70.700% (3535/5000)
Loss: 0.864 | Acc: 70.882% (3615/5100)
Loss: 0.863 | Acc: 70.846% (3684/5200)
Loss: 0.859 | Acc: 70.962% (3761/5300)
Loss: 0.859 | Acc: 71.000% (3834/5400)
Loss: 0.859 | Acc: 70.982% (3904/5500)
Loss: 0.860 | Acc: 71.036% (3978/5600)
Loss: 0.860 | Acc: 71.000% (4047/5700)
Loss: 0.856 | Acc: 71.086% (4123/5800)
Loss: 0.859 | Acc: 71.017% (4190/5900)
Loss: 0.859 | Acc: 70.933% (4256/6000)
Loss: 0.857 | Acc: 70.934% (4327/6100)
Loss: 0.858 | Acc: 70.839% (4392/6200)
Loss: 0.859 | Acc: 70.825

Loss: 0.683 | Acc: 76.244% (14346/18816)
Loss: 0.684 | Acc: 76.235% (14442/18944)
Loss: 0.685 | Acc: 76.227% (14538/19072)
Loss: 0.684 | Acc: 76.255% (14641/19200)
Loss: 0.684 | Acc: 76.283% (14744/19328)
Loss: 0.684 | Acc: 76.295% (14844/19456)
Loss: 0.683 | Acc: 76.292% (14941/19584)
Loss: 0.683 | Acc: 76.278% (15036/19712)
Loss: 0.682 | Acc: 76.316% (15141/19840)
Loss: 0.682 | Acc: 76.332% (15242/19968)
Loss: 0.681 | Acc: 76.344% (15342/20096)
Loss: 0.680 | Acc: 76.380% (15447/20224)
Loss: 0.680 | Acc: 76.361% (15541/20352)
Loss: 0.681 | Acc: 76.333% (15633/20480)
Loss: 0.680 | Acc: 76.359% (15736/20608)
Loss: 0.681 | Acc: 76.292% (15820/20736)
Loss: 0.681 | Acc: 76.294% (15918/20864)
Loss: 0.682 | Acc: 76.296% (16016/20992)
Loss: 0.683 | Acc: 76.255% (16105/21120)
Loss: 0.683 | Acc: 76.275% (16207/21248)
Loss: 0.683 | Acc: 76.296% (16309/21376)
Loss: 0.683 | Acc: 76.288% (16405/21504)
Loss: 0.683 | Acc: 76.294% (16504/21632)
Loss: 0.682 | Acc: 76.324% (16608/21760)
Loss: 0.682 | Ac

Loss: 0.685 | Acc: 76.162% (33828/44416)
Loss: 0.686 | Acc: 76.163% (33926/44544)
Loss: 0.685 | Acc: 76.173% (34028/44672)
Loss: 0.685 | Acc: 76.181% (34129/44800)
Loss: 0.685 | Acc: 76.175% (34224/44928)
Loss: 0.685 | Acc: 76.176% (34322/45056)
Loss: 0.685 | Acc: 76.202% (34431/45184)
Loss: 0.685 | Acc: 76.203% (34529/45312)
Loss: 0.685 | Acc: 76.202% (34626/45440)
Loss: 0.685 | Acc: 76.209% (34727/45568)
Loss: 0.685 | Acc: 76.197% (34819/45696)
Loss: 0.685 | Acc: 76.205% (34920/45824)
Loss: 0.685 | Acc: 76.206% (35018/45952)
Loss: 0.685 | Acc: 76.220% (35122/46080)
Loss: 0.685 | Acc: 76.210% (35215/46208)
Loss: 0.685 | Acc: 76.211% (35313/46336)
Loss: 0.685 | Acc: 76.220% (35415/46464)
Loss: 0.684 | Acc: 76.243% (35523/46592)
Loss: 0.684 | Acc: 76.259% (35628/46720)
Loss: 0.684 | Acc: 76.266% (35729/46848)
Loss: 0.683 | Acc: 76.271% (35829/46976)
Loss: 0.683 | Acc: 76.274% (35928/47104)
Loss: 0.683 | Acc: 76.281% (36029/47232)
Loss: 0.683 | Acc: 76.299% (36135/47360)
Loss: 0.683 | Ac

Loss: 0.635 | Acc: 78.173% (6504/8320)
Loss: 0.634 | Acc: 78.208% (6607/8448)
Loss: 0.636 | Acc: 78.137% (6701/8576)
Loss: 0.635 | Acc: 78.159% (6803/8704)
Loss: 0.634 | Acc: 78.159% (6903/8832)
Loss: 0.633 | Acc: 78.225% (7009/8960)
Loss: 0.633 | Acc: 78.235% (7110/9088)
Loss: 0.631 | Acc: 78.277% (7214/9216)
Loss: 0.631 | Acc: 78.243% (7311/9344)
Loss: 0.629 | Acc: 78.315% (7418/9472)
Loss: 0.631 | Acc: 78.229% (7510/9600)
Loss: 0.629 | Acc: 78.269% (7614/9728)
Loss: 0.628 | Acc: 78.348% (7722/9856)
Loss: 0.628 | Acc: 78.365% (7824/9984)
Loss: 0.629 | Acc: 78.323% (7920/10112)
Loss: 0.631 | Acc: 78.271% (8015/10240)
Loss: 0.631 | Acc: 78.289% (8117/10368)
Loss: 0.630 | Acc: 78.296% (8218/10496)
Loss: 0.629 | Acc: 78.351% (8324/10624)
Loss: 0.629 | Acc: 78.320% (8421/10752)
Loss: 0.631 | Acc: 78.235% (8512/10880)
Loss: 0.630 | Acc: 78.307% (8620/11008)
Loss: 0.629 | Acc: 78.412% (8732/11136)
Loss: 0.631 | Acc: 78.347% (8825/11264)
Loss: 0.631 | Acc: 78.283% (8918/11392)
Loss: 0.630 | 

Loss: 0.627 | Acc: 78.415% (26799/34176)
Loss: 0.627 | Acc: 78.425% (26903/34304)
Loss: 0.627 | Acc: 78.410% (26998/34432)
Loss: 0.628 | Acc: 78.403% (27096/34560)
Loss: 0.627 | Acc: 78.413% (27200/34688)
Loss: 0.627 | Acc: 78.421% (27303/34816)
Loss: 0.628 | Acc: 78.403% (27397/34944)
Loss: 0.627 | Acc: 78.407% (27499/35072)
Loss: 0.627 | Acc: 78.403% (27598/35200)
Loss: 0.627 | Acc: 78.394% (27695/35328)
Loss: 0.627 | Acc: 78.387% (27793/35456)
Loss: 0.628 | Acc: 78.364% (27885/35584)
Loss: 0.628 | Acc: 78.369% (27987/35712)
Loss: 0.628 | Acc: 78.379% (28091/35840)
Loss: 0.628 | Acc: 78.378% (28191/35968)
Loss: 0.628 | Acc: 78.377% (28291/36096)
Loss: 0.628 | Acc: 78.371% (28389/36224)
Loss: 0.627 | Acc: 78.389% (28496/36352)
Loss: 0.627 | Acc: 78.385% (28595/36480)
Loss: 0.627 | Acc: 78.387% (28696/36608)
Loss: 0.627 | Acc: 78.394% (28799/36736)
Loss: 0.627 | Acc: 78.377% (28893/36864)
Loss: 0.628 | Acc: 78.352% (28984/36992)
Loss: 0.628 | Acc: 78.349% (29083/37120)
Loss: 0.628 | Ac

Loss: 0.692 | Acc: 76.691% (6212/8100)
Loss: 0.691 | Acc: 76.683% (6288/8200)
Loss: 0.692 | Acc: 76.627% (6360/8300)
Loss: 0.693 | Acc: 76.560% (6431/8400)
Loss: 0.693 | Acc: 76.565% (6508/8500)
Loss: 0.694 | Acc: 76.512% (6580/8600)
Loss: 0.694 | Acc: 76.460% (6652/8700)
Loss: 0.695 | Acc: 76.455% (6728/8800)
Loss: 0.695 | Acc: 76.449% (6804/8900)
Loss: 0.696 | Acc: 76.444% (6880/9000)
Loss: 0.696 | Acc: 76.451% (6957/9100)
Loss: 0.694 | Acc: 76.533% (7041/9200)
Loss: 0.697 | Acc: 76.419% (7107/9300)
Loss: 0.697 | Acc: 76.372% (7179/9400)
Loss: 0.697 | Acc: 76.379% (7256/9500)
Loss: 0.697 | Acc: 76.365% (7331/9600)
Loss: 0.696 | Acc: 76.412% (7412/9700)
Loss: 0.697 | Acc: 76.337% (7481/9800)
Loss: 0.698 | Acc: 76.293% (7553/9900)
Loss: 0.698 | Acc: 76.250% (7625/10000)

Train Epoch: 8
Loss: 0.472 | Acc: 83.594% (107/128)
Loss: 0.450 | Acc: 83.594% (214/256)
Loss: 0.492 | Acc: 81.510% (313/384)
Loss: 0.500 | Acc: 81.445% (417/512)
Loss: 0.523 | Acc: 80.469% (515/640)
Loss: 0.524 | Acc:

Loss: 0.585 | Acc: 79.880% (19120/23936)
Loss: 0.584 | Acc: 79.908% (19229/24064)
Loss: 0.584 | Acc: 79.919% (19334/24192)
Loss: 0.585 | Acc: 79.905% (19433/24320)
Loss: 0.585 | Acc: 79.896% (19533/24448)
Loss: 0.585 | Acc: 79.895% (19635/24576)
Loss: 0.584 | Acc: 79.906% (19740/24704)
Loss: 0.584 | Acc: 79.905% (19842/24832)
Loss: 0.584 | Acc: 79.904% (19944/24960)
Loss: 0.584 | Acc: 79.891% (20043/25088)
Loss: 0.584 | Acc: 79.886% (20144/25216)
Loss: 0.585 | Acc: 79.869% (20242/25344)
Loss: 0.585 | Acc: 79.868% (20344/25472)
Loss: 0.585 | Acc: 79.867% (20446/25600)
Loss: 0.586 | Acc: 79.839% (20541/25728)
Loss: 0.586 | Acc: 79.846% (20645/25856)
Loss: 0.585 | Acc: 79.861% (20751/25984)
Loss: 0.585 | Acc: 79.848% (20850/26112)
Loss: 0.586 | Acc: 79.851% (20953/26240)
Loss: 0.586 | Acc: 79.839% (21052/26368)
Loss: 0.585 | Acc: 79.842% (21155/26496)
Loss: 0.586 | Acc: 79.826% (21253/26624)
Loss: 0.586 | Acc: 79.822% (21354/26752)
Loss: 0.586 | Acc: 79.792% (21448/26880)
Loss: 0.586 | Ac

Loss: 0.584 | Acc: 79.932% (39595/49536)
Loss: 0.585 | Acc: 79.925% (39694/49664)
Loss: 0.585 | Acc: 79.924% (39796/49792)
Loss: 0.584 | Acc: 79.930% (39901/49920)
Loss: 0.584 | Acc: 79.938% (39969/50000)

Test Epoch: 8
Loss: 0.721 | Acc: 74.000% (74/100)
Loss: 0.658 | Acc: 75.500% (151/200)
Loss: 0.645 | Acc: 76.333% (229/300)
Loss: 0.632 | Acc: 77.500% (310/400)
Loss: 0.632 | Acc: 77.200% (386/500)
Loss: 0.601 | Acc: 77.500% (465/600)
Loss: 0.624 | Acc: 76.429% (535/700)
Loss: 0.676 | Acc: 75.125% (601/800)
Loss: 0.689 | Acc: 75.444% (679/900)
Loss: 0.674 | Acc: 76.000% (760/1000)
Loss: 0.672 | Acc: 76.091% (837/1100)
Loss: 0.672 | Acc: 75.917% (911/1200)
Loss: 0.667 | Acc: 76.077% (989/1300)
Loss: 0.664 | Acc: 76.286% (1068/1400)
Loss: 0.659 | Acc: 76.400% (1146/1500)
Loss: 0.662 | Acc: 76.438% (1223/1600)
Loss: 0.660 | Acc: 76.765% (1305/1700)
Loss: 0.665 | Acc: 76.556% (1378/1800)
Loss: 0.668 | Acc: 76.632% (1456/1900)
Loss: 0.676 | Acc: 76.300% (1526/2000)
Loss: 0.679 | Acc: 76.1

Loss: 0.542 | Acc: 80.980% (11091/13696)
Loss: 0.542 | Acc: 81.004% (11198/13824)
Loss: 0.541 | Acc: 81.021% (11304/13952)
Loss: 0.542 | Acc: 81.009% (11406/14080)
Loss: 0.542 | Acc: 81.011% (11510/14208)
Loss: 0.541 | Acc: 81.027% (11616/14336)
Loss: 0.541 | Acc: 81.029% (11720/14464)
Loss: 0.542 | Acc: 81.010% (11821/14592)
Loss: 0.542 | Acc: 81.012% (11925/14720)
Loss: 0.541 | Acc: 81.008% (12028/14848)
Loss: 0.542 | Acc: 80.976% (12127/14976)
Loss: 0.542 | Acc: 80.959% (12228/15104)
Loss: 0.542 | Acc: 80.948% (12330/15232)
Loss: 0.542 | Acc: 80.983% (12439/15360)
Loss: 0.542 | Acc: 80.992% (12544/15488)
Loss: 0.542 | Acc: 80.975% (12645/15616)
Loss: 0.542 | Acc: 80.971% (12748/15744)
Loss: 0.543 | Acc: 80.916% (12843/15872)
Loss: 0.544 | Acc: 80.894% (12943/16000)
Loss: 0.545 | Acc: 80.872% (13043/16128)
Loss: 0.545 | Acc: 80.893% (13150/16256)
Loss: 0.544 | Acc: 80.908% (13256/16384)
Loss: 0.545 | Acc: 80.862% (13352/16512)
Loss: 0.546 | Acc: 80.811% (13447/16640)
Loss: 0.545 | Ac

Loss: 0.554 | Acc: 80.850% (31771/39296)
Loss: 0.554 | Acc: 80.849% (31874/39424)
Loss: 0.555 | Acc: 80.848% (31977/39552)
Loss: 0.555 | Acc: 80.854% (32083/39680)
Loss: 0.555 | Acc: 80.861% (32189/39808)
Loss: 0.555 | Acc: 80.872% (32297/39936)
Loss: 0.555 | Acc: 80.878% (32403/40064)
Loss: 0.555 | Acc: 80.872% (32504/40192)
Loss: 0.554 | Acc: 80.883% (32612/40320)
Loss: 0.554 | Acc: 80.892% (32719/40448)
Loss: 0.554 | Acc: 80.895% (32824/40576)
Loss: 0.554 | Acc: 80.909% (32933/40704)
Loss: 0.553 | Acc: 80.919% (33041/40832)
Loss: 0.553 | Acc: 80.930% (33149/40960)
Loss: 0.553 | Acc: 80.938% (33256/41088)
Loss: 0.553 | Acc: 80.952% (33365/41216)
Loss: 0.553 | Acc: 80.960% (33472/41344)
Loss: 0.553 | Acc: 80.949% (33571/41472)
Loss: 0.553 | Acc: 80.950% (33675/41600)
Loss: 0.553 | Acc: 80.953% (33780/41728)
Loss: 0.553 | Acc: 80.954% (33884/41856)
Loss: 0.553 | Acc: 80.964% (33992/41984)
Loss: 0.553 | Acc: 80.951% (34090/42112)
Loss: 0.553 | Acc: 80.956% (34196/42240)
Loss: 0.553 | Ac

In [42]:
print(best_acc)

76.89


# Vanilla ResNet18

In [18]:
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=16, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [19]:
from torchvision import models
res_net = models.resnet50(pretrained=True)
#for param in res_net.parameters():
    #param.requires_grad = False
num_ftrs = res_net.fc.in_features
res_net.fc = nn.Linear(num_ftrs, len(classes))

res_net = res_net.to(device)

criterion = nn.CrossEntropyLoss().to(device)

# Observe that only parameters of final layer are being optimized as
# opposed to before.
optimizer_conv = optim.SGD(res_net.fc.parameters(), lr=0.01, momentum=0.9, weight_decay = 1e-4)

from torch.optim import lr_scheduler
# Decay LR by a factor of 0.1 every 20 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=20, gamma=0.1)

In [20]:
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import copy

'''
data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
'''
dataloaders = {'train': trainloader, 'val': testloader}
dataset_sizes = {'train': len(trainset), 'val': len(testset)}

In [21]:


def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [22]:
model_conv = train_model(res_net, criterion, optimizer_conv,
                         exp_lr_scheduler, num_epochs=20)

Epoch 0/19
----------
train Loss: 1.1761 Acc: 0.6894
val Loss: 1.0712 Acc: 0.7396

Epoch 1/19
----------
train Loss: 1.0553 Acc: 0.7313
val Loss: 1.1787 Acc: 0.7201

Epoch 2/19
----------
train Loss: 1.0392 Acc: 0.7381
val Loss: 0.8413 Acc: 0.7867

Epoch 3/19
----------
train Loss: 1.0118 Acc: 0.7440
val Loss: 0.8195 Acc: 0.7842

Epoch 4/19
----------
train Loss: 0.9916 Acc: 0.7489
val Loss: 0.8519 Acc: 0.7770

Epoch 5/19
----------
train Loss: 0.9643 Acc: 0.7528
val Loss: 0.7477 Acc: 0.7944

Epoch 6/19
----------
train Loss: 0.9643 Acc: 0.7548
val Loss: 1.2058 Acc: 0.7299

Epoch 7/19
----------
train Loss: 0.9360 Acc: 0.7623
val Loss: 0.7980 Acc: 0.7929

Epoch 8/19
----------
train Loss: 0.9115 Acc: 0.7648
val Loss: 0.7101 Acc: 0.8032

Epoch 9/19
----------
train Loss: 0.9006 Acc: 0.7658
val Loss: 0.7995 Acc: 0.7906

Epoch 10/19
----------
train Loss: 0.8844 Acc: 0.7683
val Loss: 0.7553 Acc: 0.7923

Epoch 11/19
----------
train Loss: 0.8723 Acc: 0.7701
val Loss: 0.8466 Acc: 0.7814

Ep