In [1]:
import os
import numpy as np
import random
import time

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets

from vgg_cifar import vgg

## Define hyperparameters

In [2]:
class VGG_agrument:
    def __init__(self):
        self.arch = 'vgg19_bn'  #  Choice : 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn'
        self.workers = 4
        self.epochs = 300
        self.start_epoch = 0
        self.batch_size = 128
        self.lr = 0.05
        self.momentum = 0.9
        self.weight_decay = 5e-4
        self.print_freq = -1
        self.resume = ''
        self.evaluate = False
        self.pretrained = False
        self.half = False
        self.cpu = False
        self.save_dir = 'weights/vgg19'
        self.dataset = 'cifar10'  # Choice : 'cifar10' and 'cifar100'
        self.block = 'VGG19'
        self.checkpoint = None

args = VGG_agrument()

In [3]:
block_list = ['VGG19', 'SE_SA_1', 'SEC_SA_1', 'CBAM_1', 'NEW_1']
name_list = ['VGG19 (base)', 'SE (residuel) + SA', 'SE + SA', 'CBAM', 'Our model']

## Set the random seed

In [4]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# 1. Training

## 1.1 Define functions

In [5]:
def train(train_loader, model, criterion, optimizer, epoch, is_cpu, is_half, print_freq):
    """
        Run one train epoch
    """
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to train mode
    model.train()

    if print_freq < 0:
        print_freq = len(train_loader) - 1
    end = time.time()
    for i, (input, target) in enumerate(train_loader):

        # measure data loading time
        data_time.update(time.time() - end)

        if is_cpu == False:
            input = input.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)
        if is_half:
            input = input.half()

        # compute output
        output = model(input)
        loss = criterion(output, target)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        output = output.float()
        loss = loss.float()
        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                      epoch, i, len(train_loader)-1, batch_time=batch_time,
                      data_time=data_time, loss=losses, top1=top1))


def validate(val_loader, model, criterion, is_cpu, is_half, print_freq):
    """
    Run evaluation
    """
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    if print_freq < 0:
        print_freq = len(val_loader) - 1
    end = time.time()
    for i, (input, target) in enumerate(val_loader):
        if is_cpu == False:
            input = input.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)

        if is_half:
            input = input.half()

        # compute output
        with torch.no_grad():
            output = model(input)
            loss = criterion(output, target)

        output = output.float()
        loss = loss.float()

        # measure accuracy and record loss
        prec1 = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % print_freq == 0:
            print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                      i, len(val_loader)-1, batch_time=batch_time, loss=losses,
                      top1=top1))

    print(' * Prec@1 {top1.avg:.3f}'
          .format(top1=top1))


    return top1.avg


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    """
    Save the training model
    """
    torch.save(state, filename)


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def adjust_learning_rate(lr, optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 2 every 30 epochs"""
    new_lr = lr * (0.5 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = new_lr


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

In [6]:
def run_train(args):
    if args.dataset == "cifar10" :
        num_classes = 10
    elif args.dataset == "cifar100" :
        num_classes = 100
    print("dataset : ", args.dataset)
    print("num classes : ", num_classes)

    # Check the save_dir exists or not
    save_path = os.path.join(args.save_dir, args.dataset, args.block)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    model = vgg.__dict__[args.arch](num_classes, args.block)

    model.features = torch.nn.DataParallel(model.features)
    if args.cpu:
        model.cpu()
    else:
        model.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.evaluate, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # cudnn.benchmark = False

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    if args.dataset == "cifar10":
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ]), download=True),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)
    elif args.dataset == "cifar100" :
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100(root='./data', train=True, transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ]), download=True),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100(root='./data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)
    

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss()
    if args.cpu:
        criterion = criterion.cpu()
    else:
        criterion = criterion.cuda()

    if args.half:
        model.half()
        criterion.half()

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)    

    if args.evaluate:
        validate(val_loader, model, criterion, args.cpu, args.half, args.print_freq)
        return

    best_prec1 = 0
    epoch_accuracys = []
    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(args.lr, optimizer, epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args.cpu, args.half, args.print_freq)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion, args.cpu, args.half, args.print_freq)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        if is_best:
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            }, is_best, filename=os.path.join(save_path, 'checkpoint_{}.tar'.format(epoch)))
        
        # Accumulated precisions
        epoch_accuracys.append(prec1)
    
    return epoch_accuracys

## 1.2 Model training

In [7]:
color_list = ['r', 'y', 'b', 'm', 'g']

In [None]:
accuracy_list = []
for bt, bn in zip(block_list, name_list):
    print()
    print('########################################################################################')
    print('Training of "%s"' %bn)
    args.block = bt
    accuracys = run_train(args)
    accuracy_list.append(accuracys)
    print('########################################################################################')


########################################################################################
Training of "VGG19 (base)"
dataset :  cifar10
num classes :  10
Files already downloaded and verified
Epoch: [0][0/390]	Time 11.676 (11.676)	Data 8.716 (8.716)	Loss 2.3046 (2.3046)	Prec@1 14.062 (14.062)
Epoch: [0][390/390]	Time 0.092 (0.088)	Data 0.000 (0.023)	Loss 1.5114 (1.7785)	Prec@1 38.750 (31.120)
Test: [0/78]	Time 6.144 (6.144)	Loss 1.5055 (1.5055)	Prec@1 48.438 (48.438)
Test: [78/78]	Time 0.033 (0.119)	Loss 1.6387 (1.6267)	Prec@1 43.750 (40.680)
 * Prec@1 40.680
Epoch: [1][0/390]	Time 7.596 (7.596)	Data 7.477 (7.477)	Loss 1.4055 (1.4055)	Prec@1 44.531 (44.531)
Epoch: [1][390/390]	Time 0.043 (0.060)	Data 0.000 (0.019)	Loss 1.0492 (1.3557)	Prec@1 67.500 (51.228)
Test: [0/78]	Time 6.135 (6.135)	Loss 1.4818 (1.4818)	Prec@1 49.219 (49.219)
Test: [78/78]	Time 0.016 (0.103)	Loss 1.6218 (1.5587)	Prec@1 50.000 (46.970)
 * Prec@1 46.970
Epoch: [2][0/390]	Time 7.238 (7.238)	Data 7.179 (7.179)	Loss 1

Test: [0/78]	Time 6.587 (6.587)	Loss 0.4055 (0.4055)	Prec@1 85.156 (85.156)
Test: [78/78]	Time 0.007 (0.110)	Loss 0.4481 (0.5050)	Prec@1 81.250 (84.050)
 * Prec@1 84.050
Epoch: [22][0/390]	Time 7.814 (7.814)	Data 7.741 (7.741)	Loss 0.3231 (0.3231)	Prec@1 90.625 (90.625)
Epoch: [22][390/390]	Time 0.025 (0.060)	Data 0.000 (0.020)	Loss 0.3760 (0.4264)	Prec@1 87.500 (86.794)
Test: [0/78]	Time 6.543 (6.543)	Loss 0.3927 (0.3927)	Prec@1 89.844 (89.844)
Test: [78/78]	Time 0.030 (0.124)	Loss 0.3083 (0.4955)	Prec@1 87.500 (83.810)
 * Prec@1 83.810
Epoch: [23][0/390]	Time 8.021 (8.021)	Data 7.926 (7.926)	Loss 0.2974 (0.2974)	Prec@1 89.062 (89.062)
Epoch: [23][390/390]	Time 0.065 (0.136)	Data 0.000 (0.021)	Loss 0.5505 (0.4061)	Prec@1 87.500 (87.372)
Test: [0/78]	Time 6.663 (6.663)	Loss 0.5622 (0.5622)	Prec@1 81.250 (81.250)
Test: [78/78]	Time 0.006 (0.096)	Loss 0.8456 (0.5983)	Prec@1 62.500 (81.170)
 * Prec@1 81.170
Epoch: [24][0/390]	Time 7.901 (7.901)	Data 7.686 (7.686)	Loss 0.3986 (0.3986)	Prec

Test: [0/78]	Time 6.639 (6.639)	Loss 0.4509 (0.4509)	Prec@1 86.719 (86.719)
Test: [78/78]	Time 0.003 (0.111)	Loss 0.3962 (0.4744)	Prec@1 87.500 (86.680)
 * Prec@1 86.680
Epoch: [44][0/390]	Time 7.727 (7.727)	Data 7.639 (7.639)	Loss 0.1708 (0.1708)	Prec@1 93.750 (93.750)
Epoch: [44][390/390]	Time 0.077 (0.099)	Data 0.000 (0.020)	Loss 0.2108 (0.2263)	Prec@1 92.500 (93.016)
Test: [0/78]	Time 6.385 (6.385)	Loss 0.3384 (0.3384)	Prec@1 89.062 (89.062)
Test: [78/78]	Time 0.011 (0.109)	Loss 0.3493 (0.4070)	Prec@1 87.500 (87.710)
 * Prec@1 87.710
Epoch: [45][0/390]	Time 7.698 (7.698)	Data 7.613 (7.613)	Loss 0.1837 (0.1837)	Prec@1 92.969 (92.969)
Epoch: [45][390/390]	Time 0.044 (0.100)	Data 0.000 (0.020)	Loss 0.2670 (0.2206)	Prec@1 91.250 (93.066)
Test: [0/78]	Time 6.959 (6.959)	Loss 0.3039 (0.3039)	Prec@1 89.844 (89.844)
Test: [78/78]	Time 0.031 (0.133)	Loss 0.6868 (0.4231)	Prec@1 87.500 (87.670)
 * Prec@1 87.670
Epoch: [46][0/390]	Time 8.324 (8.324)	Data 8.184 (8.184)	Loss 0.2124 (0.2124)	Prec

Test: [0/78]	Time 6.489 (6.489)	Loss 0.4354 (0.4354)	Prec@1 89.062 (89.062)
Test: [78/78]	Time 0.027 (0.111)	Loss 0.6006 (0.3831)	Prec@1 87.500 (89.860)
 * Prec@1 89.860
Epoch: [66][0/390]	Time 7.710 (7.710)	Data 7.639 (7.639)	Loss 0.0699 (0.0699)	Prec@1 97.656 (97.656)
Epoch: [66][390/390]	Time 0.039 (0.071)	Data 0.000 (0.020)	Loss 0.0442 (0.0884)	Prec@1 100.000 (97.156)
Test: [0/78]	Time 6.514 (6.514)	Loss 0.3072 (0.3072)	Prec@1 89.062 (89.062)
Test: [78/78]	Time 0.010 (0.099)	Loss 0.7390 (0.4393)	Prec@1 87.500 (88.970)
 * Prec@1 88.970
Epoch: [67][0/390]	Time 8.371 (8.371)	Data 8.210 (8.210)	Loss 0.1177 (0.1177)	Prec@1 96.094 (96.094)
Epoch: [67][390/390]	Time 0.095 (0.159)	Data 0.000 (0.022)	Loss 0.1007 (0.0906)	Prec@1 96.250 (97.222)
Test: [0/78]	Time 6.670 (6.670)	Loss 0.3460 (0.3460)	Prec@1 89.062 (89.062)
Test: [78/78]	Time 0.003 (0.112)	Loss 0.7366 (0.3556)	Prec@1 87.500 (90.280)
 * Prec@1 90.280
Epoch: [68][0/390]	Time 7.726 (7.726)	Data 7.665 (7.665)	Loss 0.1318 (0.1318)	Pre

Test: [0/78]	Time 6.887 (6.887)	Loss 0.1584 (0.1584)	Prec@1 95.312 (95.312)
Test: [78/78]	Time 0.013 (0.106)	Loss 0.5321 (0.3484)	Prec@1 75.000 (90.390)
 * Prec@1 90.390
Epoch: [88][0/390]	Time 8.232 (8.232)	Data 8.075 (8.075)	Loss 0.1002 (0.1002)	Prec@1 96.875 (96.875)
Epoch: [88][390/390]	Time 0.044 (0.130)	Data 0.000 (0.021)	Loss 0.0413 (0.1015)	Prec@1 98.750 (96.856)
Test: [0/78]	Time 6.476 (6.476)	Loss 0.2403 (0.2403)	Prec@1 95.312 (95.312)
Test: [78/78]	Time 0.012 (0.110)	Loss 0.6940 (0.3660)	Prec@1 81.250 (90.050)
 * Prec@1 90.050
Epoch: [89][0/390]	Time 7.802 (7.802)	Data 7.703 (7.703)	Loss 0.1067 (0.1067)	Prec@1 97.656 (97.656)
Epoch: [89][390/390]	Time 0.121 (0.122)	Data 0.000 (0.020)	Loss 0.0991 (0.1022)	Prec@1 96.250 (96.734)
Test: [0/78]	Time 6.754 (6.754)	Loss 0.5718 (0.5718)	Prec@1 88.281 (88.281)
Test: [78/78]	Time 0.005 (0.131)	Loss 0.2887 (0.4027)	Prec@1 93.750 (89.280)
 * Prec@1 89.280
Epoch: [90][0/390]	Time 8.242 (8.242)	Data 8.100 (8.100)	Loss 0.1159 (0.1159)	Prec

Test: [0/78]	Time 6.766 (6.766)	Loss 0.3396 (0.3396)	Prec@1 92.188 (92.188)
Test: [78/78]	Time 0.006 (0.130)	Loss 0.5087 (0.3857)	Prec@1 87.500 (91.480)
 * Prec@1 91.480
Epoch: [110][0/390]	Time 8.849 (8.849)	Data 8.746 (8.746)	Loss 0.0970 (0.0970)	Prec@1 96.875 (96.875)
Epoch: [110][390/390]	Time 0.042 (0.120)	Data 0.000 (0.023)	Loss 0.0032 (0.0388)	Prec@1 100.000 (98.790)
Test: [0/78]	Time 7.202 (7.202)	Loss 0.2795 (0.2795)	Prec@1 92.969 (92.969)
Test: [78/78]	Time 0.012 (0.109)	Loss 0.7522 (0.3647)	Prec@1 93.750 (91.630)
 * Prec@1 91.630
Epoch: [111][0/390]	Time 8.375 (8.375)	Data 8.274 (8.274)	Loss 0.0282 (0.0282)	Prec@1 98.438 (98.438)
Epoch: [111][390/390]	Time 0.129 (0.139)	Data 0.000 (0.022)	Loss 0.0704 (0.0425)	Prec@1 97.500 (98.730)
Test: [0/78]	Time 6.878 (6.878)	Loss 0.2916 (0.2916)	Prec@1 93.750 (93.750)
Test: [78/78]	Time 0.025 (0.132)	Loss 0.5944 (0.3552)	Prec@1 87.500 (91.660)
 * Prec@1 91.660
Epoch: [112][0/390]	Time 8.490 (8.490)	Data 8.452 (8.452)	Loss 0.0108 (0.0108

Epoch: [131][390/390]	Time 0.068 (0.076)	Data 0.000 (0.020)	Loss 0.0132 (0.0111)	Prec@1 98.750 (99.660)
Test: [0/78]	Time 6.533 (6.533)	Loss 0.3002 (0.3002)	Prec@1 95.312 (95.312)
Test: [78/78]	Time 0.038 (0.114)	Loss 0.5502 (0.3533)	Prec@1 93.750 (93.030)
 * Prec@1 93.030
Epoch: [132][0/390]	Time 7.898 (7.898)	Data 7.810 (7.810)	Loss 0.0020 (0.0020)	Prec@1 100.000 (100.000)
Epoch: [132][390/390]	Time 0.027 (0.075)	Data 0.000 (0.020)	Loss 0.0019 (0.0113)	Prec@1 100.000 (99.664)
Test: [0/78]	Time 6.573 (6.573)	Loss 0.2437 (0.2437)	Prec@1 94.531 (94.531)
Test: [78/78]	Time 0.005 (0.114)	Loss 0.5493 (0.3840)	Prec@1 93.750 (92.530)
 * Prec@1 92.530
Epoch: [133][0/390]	Time 7.950 (7.950)	Data 7.875 (7.875)	Loss 0.0534 (0.0534)	Prec@1 98.438 (98.438)
Epoch: [133][390/390]	Time 0.030 (0.101)	Data 0.000 (0.021)	Loss 0.0024 (0.0120)	Prec@1 100.000 (99.674)
Test: [0/78]	Time 6.639 (6.639)	Loss 0.2469 (0.2469)	Prec@1 96.094 (96.094)
Test: [78/78]	Time 0.006 (0.105)	Loss 0.7911 (0.3674)	Prec@1 81.

 * Prec@1 92.530
Epoch: [153][0/390]	Time 7.686 (7.686)	Data 7.576 (7.576)	Loss 0.0018 (0.0018)	Prec@1 100.000 (100.000)
Epoch: [153][390/390]	Time 0.027 (0.085)	Data 0.000 (0.020)	Loss 0.0026 (0.0041)	Prec@1 100.000 (99.916)
Test: [0/78]	Time 6.520 (6.520)	Loss 0.2142 (0.2142)	Prec@1 96.094 (96.094)
Test: [78/78]	Time 0.003 (0.095)	Loss 0.5894 (0.3784)	Prec@1 93.750 (92.910)
 * Prec@1 92.910
Epoch: [154][0/390]	Time 7.878 (7.878)	Data 7.817 (7.817)	Loss 0.0009 (0.0009)	Prec@1 100.000 (100.000)
Epoch: [154][390/390]	Time 0.063 (0.094)	Data 0.000 (0.020)	Loss 0.0008 (0.0038)	Prec@1 100.000 (99.910)
Test: [0/78]	Time 6.467 (6.467)	Loss 0.2170 (0.2170)	Prec@1 96.875 (96.875)
Test: [78/78]	Time 0.004 (0.094)	Loss 0.5941 (0.3629)	Prec@1 93.750 (93.030)
 * Prec@1 93.030
Epoch: [155][0/390]	Time 7.942 (7.942)	Data 7.850 (7.850)	Loss 0.0012 (0.0012)	Prec@1 100.000 (100.000)
Epoch: [155][390/390]	Time 0.059 (0.088)	Data 0.000 (0.021)	Loss 0.0064 (0.0037)	Prec@1 100.000 (99.930)
Test: [0/78]	Tim

Test: [78/78]	Time 0.003 (0.107)	Loss 0.6271 (0.4079)	Prec@1 93.750 (93.090)
 * Prec@1 93.090
Epoch: [175][0/390]	Time 8.124 (8.124)	Data 7.865 (7.865)	Loss 0.0025 (0.0025)	Prec@1 100.000 (100.000)
Epoch: [175][390/390]	Time 0.062 (0.103)	Data 0.000 (0.021)	Loss 0.0095 (0.0047)	Prec@1 100.000 (99.890)
Test: [0/78]	Time 6.680 (6.680)	Loss 0.2091 (0.2091)	Prec@1 96.875 (96.875)
Test: [78/78]	Time 0.004 (0.115)	Loss 0.6069 (0.4060)	Prec@1 93.750 (93.080)
 * Prec@1 93.080
Epoch: [176][0/390]	Time 7.792 (7.792)	Data 7.731 (7.731)	Loss 0.0007 (0.0007)	Prec@1 100.000 (100.000)
Epoch: [176][390/390]	Time 0.072 (0.076)	Data 0.000 (0.020)	Loss 0.0017 (0.0047)	Prec@1 100.000 (99.890)
Test: [0/78]	Time 6.618 (6.618)	Loss 0.1920 (0.1920)	Prec@1 96.875 (96.875)
Test: [78/78]	Time 0.021 (0.115)	Loss 0.6181 (0.4069)	Prec@1 93.750 (92.750)
 * Prec@1 92.750
Epoch: [177][0/390]	Time 7.835 (7.835)	Data 7.746 (7.746)	Loss 0.0260 (0.0260)	Prec@1 99.219 (99.219)
Epoch: [177][390/390]	Time 0.028 (0.073)	Data 

Test: [0/78]	Time 6.488 (6.488)	Loss 0.2415 (0.2415)	Prec@1 96.094 (96.094)
Test: [78/78]	Time 0.014 (0.110)	Loss 0.6246 (0.3974)	Prec@1 93.750 (93.420)
 * Prec@1 93.420
Epoch: [197][0/390]	Time 7.717 (7.717)	Data 7.633 (7.633)	Loss 0.0008 (0.0008)	Prec@1 100.000 (100.000)
Epoch: [197][390/390]	Time 0.033 (0.068)	Data 0.000 (0.020)	Loss 0.0094 (0.0015)	Prec@1 100.000 (99.982)
Test: [0/78]	Time 6.450 (6.450)	Loss 0.3075 (0.3075)	Prec@1 94.531 (94.531)
Test: [78/78]	Time 0.008 (0.113)	Loss 0.6159 (0.4008)	Prec@1 93.750 (93.260)
 * Prec@1 93.260
Epoch: [198][0/390]	Time 7.782 (7.782)	Data 7.681 (7.681)	Loss 0.0007 (0.0007)	Prec@1 100.000 (100.000)
Epoch: [198][390/390]	Time 0.030 (0.090)	Data 0.000 (0.020)	Loss 0.0010 (0.0013)	Prec@1 100.000 (99.986)
Test: [0/78]	Time 6.822 (6.822)	Loss 0.2941 (0.2941)	Prec@1 95.312 (95.312)
Test: [78/78]	Time 0.004 (0.101)	Loss 0.6190 (0.3966)	Prec@1 93.750 (93.370)
 * Prec@1 93.370
Epoch: [199][0/390]	Time 8.010 (8.010)	Data 7.926 (7.926)	Loss 0.0012 (0

Epoch: [218][390/390]	Time 0.025 (0.075)	Data 0.001 (0.020)	Loss 0.0013 (0.0010)	Prec@1 100.000 (99.996)
Test: [0/78]	Time 6.543 (6.543)	Loss 0.2612 (0.2612)	Prec@1 96.875 (96.875)
Test: [78/78]	Time 0.023 (0.113)	Loss 0.6297 (0.4038)	Prec@1 93.750 (93.510)
 * Prec@1 93.510
Epoch: [219][0/390]	Time 7.706 (7.706)	Data 7.602 (7.602)	Loss 0.0008 (0.0008)	Prec@1 100.000 (100.000)
Epoch: [219][390/390]	Time 0.024 (0.099)	Data 0.000 (0.020)	Loss 0.0008 (0.0012)	Prec@1 100.000 (99.988)
Test: [0/78]	Time 6.636 (6.636)	Loss 0.2495 (0.2495)	Prec@1 96.875 (96.875)
Test: [78/78]	Time 0.004 (0.106)	Loss 0.6311 (0.4047)	Prec@1 93.750 (93.570)
 * Prec@1 93.570
Epoch: [220][0/390]	Time 8.176 (8.176)	Data 8.112 (8.112)	Loss 0.0008 (0.0008)	Prec@1 100.000 (100.000)
Epoch: [220][390/390]	Time 0.063 (0.103)	Data 0.000 (0.021)	Loss 0.0007 (0.0009)	Prec@1 100.000 (99.996)
Test: [0/78]	Time 6.800 (6.800)	Loss 0.2533 (0.2533)	Prec@1 96.875 (96.875)
Test: [78/78]	Time 0.014 (0.118)	Loss 0.6359 (0.4053)	Prec@1 

 * Prec@1 93.410
Epoch: [240][0/390]	Time 7.914 (7.914)	Data 7.871 (7.871)	Loss 0.0012 (0.0012)	Prec@1 100.000 (100.000)
Epoch: [240][390/390]	Time 0.051 (0.101)	Data 0.000 (0.021)	Loss 0.0007 (0.0009)	Prec@1 100.000 (99.998)
Test: [0/78]	Time 6.396 (6.396)	Loss 0.2462 (0.2462)	Prec@1 96.875 (96.875)
Test: [78/78]	Time 0.013 (0.113)	Loss 0.6177 (0.4079)	Prec@1 93.750 (93.370)
 * Prec@1 93.370
Epoch: [241][0/390]	Time 7.715 (7.715)	Data 7.671 (7.671)	Loss 0.0007 (0.0007)	Prec@1 100.000 (100.000)
Epoch: [241][390/390]	Time 0.062 (0.082)	Data 0.000 (0.020)	Loss 0.0008 (0.0009)	Prec@1 100.000 (99.996)
Test: [0/78]	Time 6.280 (6.280)	Loss 0.2442 (0.2442)	Prec@1 96.875 (96.875)
Test: [78/78]	Time 0.007 (0.110)	Loss 0.6358 (0.4074)	Prec@1 93.750 (93.430)
 * Prec@1 93.430
Epoch: [242][0/390]	Time 7.657 (7.657)	Data 7.592 (7.592)	Loss 0.0008 (0.0008)	Prec@1 100.000 (100.000)
Epoch: [242][390/390]	Time 0.031 (0.073)	Data 0.000 (0.020)	Loss 0.0006 (0.0011)	Prec@1 100.000 (99.992)
Test: [0/78]	Tim

Test: [78/78]	Time 0.024 (0.108)	Loss 0.6395 (0.4085)	Prec@1 93.750 (93.520)
 * Prec@1 93.520
Epoch: [262][0/390]	Time 7.867 (7.867)	Data 7.766 (7.766)	Loss 0.0007 (0.0007)	Prec@1 100.000 (100.000)
Epoch: [262][390/390]	Time 0.094 (0.131)	Data 0.000 (0.020)	Loss 0.0011 (0.0009)	Prec@1 100.000 (99.998)
Test: [0/78]	Time 6.700 (6.700)	Loss 0.2346 (0.2346)	Prec@1 96.875 (96.875)
Test: [78/78]	Time 0.009 (0.101)	Loss 0.6529 (0.4073)	Prec@1 93.750 (93.560)
 * Prec@1 93.560
Epoch: [263][0/390]	Time 7.940 (7.940)	Data 7.890 (7.890)	Loss 0.0012 (0.0012)	Prec@1 100.000 (100.000)
Epoch: [263][390/390]	Time 0.058 (0.109)	Data 0.000 (0.021)	Loss 0.0008 (0.0009)	Prec@1 100.000 (99.998)
Test: [0/78]	Time 6.719 (6.719)	Loss 0.2397 (0.2397)	Prec@1 96.875 (96.875)
Test: [78/78]	Time 0.015 (0.125)	Loss 0.6365 (0.4058)	Prec@1 93.750 (93.520)
 * Prec@1 93.520
Epoch: [264][0/390]	Time 7.891 (7.891)	Data 7.775 (7.775)	Loss 0.0008 (0.0008)	Prec@1 100.000 (100.000)
Epoch: [264][390/390]	Time 0.040 (0.086)	Dat

Test: [0/78]	Time 7.014 (7.014)	Loss 0.2492 (0.2492)	Prec@1 96.875 (96.875)
Test: [78/78]	Time 0.048 (0.122)	Loss 0.6372 (0.4069)	Prec@1 93.750 (93.550)
 * Prec@1 93.550
Epoch: [284][0/390]	Time 7.808 (7.808)	Data 7.685 (7.685)	Loss 0.0007 (0.0007)	Prec@1 100.000 (100.000)
Epoch: [284][390/390]	Time 0.100 (0.152)	Data 0.000 (0.020)	Loss 0.0007 (0.0009)	Prec@1 100.000 (99.996)
Test: [0/78]	Time 6.661 (6.661)	Loss 0.2439 (0.2439)	Prec@1 96.875 (96.875)
Test: [78/78]	Time 0.003 (0.095)	Loss 0.6397 (0.4063)	Prec@1 93.750 (93.620)
 * Prec@1 93.620
Epoch: [285][0/390]	Time 8.099 (8.099)	Data 7.988 (7.988)	Loss 0.0009 (0.0009)	Prec@1 100.000 (100.000)
Epoch: [285][390/390]	Time 0.111 (0.136)	Data 0.000 (0.021)	Loss 0.0014 (0.0009)	Prec@1 100.000 (99.994)
Test: [0/78]	Time 6.732 (6.732)	Loss 0.2433 (0.2433)	Prec@1 96.875 (96.875)
Test: [78/78]	Time 0.026 (0.136)	Loss 0.6349 (0.4095)	Prec@1 93.750 (93.520)
 * Prec@1 93.520
Epoch: [286][0/390]	Time 8.159 (8.159)	Data 8.047 (8.047)	Loss 0.0006 (0

Test: [0/78]	Time 7.041 (7.041)	Loss 0.7904 (0.7904)	Prec@1 75.000 (75.000)
Test: [78/78]	Time 0.006 (0.119)	Loss 1.0450 (0.8661)	Prec@1 75.000 (71.840)
 * Prec@1 71.840
Epoch: [5][0/390]	Time 8.475 (8.475)	Data 8.333 (8.333)	Loss 0.6993 (0.6993)	Prec@1 75.781 (75.781)
Epoch: [5][390/390]	Time 0.110 (0.166)	Data 0.000 (0.022)	Loss 0.8891 (0.7237)	Prec@1 70.000 (76.898)
Test: [0/78]	Time 6.743 (6.743)	Loss 0.8113 (0.8113)	Prec@1 72.656 (72.656)
Test: [78/78]	Time 0.006 (0.131)	Loss 0.6773 (0.7618)	Prec@1 81.250 (76.070)
 * Prec@1 76.070
Epoch: [6][0/390]	Time 8.336 (8.336)	Data 8.258 (8.258)	Loss 0.7647 (0.7647)	Prec@1 77.344 (77.344)
Epoch: [6][390/390]	Time 0.096 (0.118)	Data 0.000 (0.022)	Loss 0.7124 (0.6631)	Prec@1 75.000 (78.706)
Test: [0/78]	Time 6.991 (6.991)	Loss 0.8550 (0.8550)	Prec@1 73.438 (73.438)
Test: [78/78]	Time 0.045 (0.139)	Loss 0.7173 (0.8015)	Prec@1 68.750 (73.660)
 * Prec@1 73.660
Epoch: [7][0/390]	Time 8.287 (8.287)	Data 8.148 (8.148)	Loss 0.7607 (0.7607)	Prec@1 75

Test: [0/78]	Time 6.668 (6.668)	Loss 0.4644 (0.4644)	Prec@1 86.719 (86.719)
Test: [78/78]	Time 0.046 (0.139)	Loss 0.3130 (0.5189)	Prec@1 87.500 (84.660)
 * Prec@1 84.660
Epoch: [27][0/390]	Time 7.735 (7.735)	Data 7.676 (7.676)	Loss 0.4318 (0.4318)	Prec@1 87.500 (87.500)
Epoch: [27][390/390]	Time 0.104 (0.138)	Data 0.000 (0.020)	Loss 0.5855 (0.3881)	Prec@1 82.500 (87.942)
Test: [0/78]	Time 6.797 (6.797)	Loss 0.4631 (0.4631)	Prec@1 82.031 (82.031)
Test: [78/78]	Time 0.006 (0.108)	Loss 0.3462 (0.5350)	Prec@1 93.750 (83.110)
 * Prec@1 83.110
Epoch: [28][0/390]	Time 8.499 (8.499)	Data 8.415 (8.415)	Loss 0.3261 (0.3261)	Prec@1 90.625 (90.625)
Epoch: [28][390/390]	Time 0.101 (0.137)	Data 0.000 (0.022)	Loss 0.4530 (0.3815)	Prec@1 85.000 (88.138)
Test: [0/78]	Time 6.619 (6.619)	Loss 0.4520 (0.4520)	Prec@1 85.156 (85.156)
Test: [78/78]	Time 0.012 (0.126)	Loss 0.6437 (0.5885)	Prec@1 81.250 (81.720)
 * Prec@1 81.720
Epoch: [29][0/390]	Time 7.825 (7.825)	Data 7.781 (7.781)	Loss 0.2888 (0.2888)	Prec

Test: [0/78]	Time 6.447 (6.447)	Loss 0.3402 (0.3402)	Prec@1 91.406 (91.406)
Test: [78/78]	Time 0.011 (0.110)	Loss 0.2816 (0.4019)	Prec@1 81.250 (87.980)
 * Prec@1 87.980
Epoch: [49][0/390]	Time 7.662 (7.662)	Data 7.581 (7.581)	Loss 0.2096 (0.2096)	Prec@1 94.531 (94.531)
Epoch: [49][390/390]	Time 0.051 (0.078)	Data 0.000 (0.020)	Loss 0.2035 (0.2230)	Prec@1 92.500 (93.038)
Test: [0/78]	Time 6.879 (6.879)	Loss 0.3382 (0.3382)	Prec@1 87.500 (87.500)
Test: [78/78]	Time 0.013 (0.101)	Loss 0.4700 (0.3974)	Prec@1 87.500 (88.020)
 * Prec@1 88.020
Epoch: [50][0/390]	Time 7.966 (7.966)	Data 7.865 (7.865)	Loss 0.1137 (0.1137)	Prec@1 96.875 (96.875)
Epoch: [50][390/390]	Time 0.096 (0.135)	Data 0.000 (0.021)	Loss 0.3426 (0.2071)	Prec@1 88.750 (93.458)
Test: [0/78]	Time 7.109 (7.109)	Loss 0.2603 (0.2603)	Prec@1 92.188 (92.188)
Test: [78/78]	Time 0.009 (0.112)	Loss 0.6069 (0.4542)	Prec@1 87.500 (86.940)
 * Prec@1 86.940
Epoch: [51][0/390]	Time 8.607 (8.607)	Data 8.539 (8.539)	Loss 0.1915 (0.1915)	Prec

Test: [0/78]	Time 6.795 (6.795)	Loss 0.1977 (0.1977)	Prec@1 92.188 (92.188)
Test: [78/78]	Time 0.011 (0.128)	Loss 0.6797 (0.3842)	Prec@1 87.500 (89.460)
 * Prec@1 89.460
Epoch: [71][0/390]	Time 8.018 (8.018)	Data 7.826 (7.826)	Loss 0.0998 (0.0998)	Prec@1 96.094 (96.094)
Epoch: [71][390/390]	Time 0.047 (0.183)	Data 0.000 (0.020)	Loss 0.1079 (0.0969)	Prec@1 97.500 (97.098)
Test: [0/78]	Time 6.789 (6.789)	Loss 0.2302 (0.2302)	Prec@1 91.406 (91.406)
Test: [78/78]	Time 0.009 (0.113)	Loss 0.3865 (0.3484)	Prec@1 93.750 (91.040)
 * Prec@1 91.040
Epoch: [72][0/390]	Time 8.290 (8.290)	Data 8.157 (8.157)	Loss 0.0510 (0.0510)	Prec@1 96.875 (96.875)
Epoch: [72][390/390]	Time 0.100 (0.172)	Data 0.000 (0.021)	Loss 0.0959 (0.0948)	Prec@1 97.500 (97.018)
Test: [0/78]	Time 6.707 (6.707)	Loss 0.1395 (0.1395)	Prec@1 93.750 (93.750)
Test: [78/78]	Time 0.018 (0.135)	Loss 0.8178 (0.3803)	Prec@1 81.250 (89.740)
 * Prec@1 89.740
Epoch: [73][0/390]	Time 7.910 (7.910)	Data 7.844 (7.844)	Loss 0.0506 (0.0506)	Prec

Test: [0/78]	Time 5.248 (5.248)	Loss 0.3533 (0.3533)	Prec@1 92.969 (92.969)
Test: [78/78]	Time 0.015 (0.089)	Loss 0.6532 (0.3385)	Prec@1 93.750 (91.990)
 * Prec@1 91.990
Epoch: [93][0/390]	Time 6.425 (6.425)	Data 6.363 (6.363)	Loss 0.0073 (0.0073)	Prec@1 100.000 (100.000)
Epoch: [93][390/390]	Time 0.022 (0.058)	Data 0.000 (0.017)	Loss 0.0080 (0.0330)	Prec@1 100.000 (98.974)
Test: [0/78]	Time 5.288 (5.288)	Loss 0.2869 (0.2869)	Prec@1 93.750 (93.750)
Test: [78/78]	Time 0.009 (0.090)	Loss 0.6381 (0.3604)	Prec@1 93.750 (92.160)
 * Prec@1 92.160
Epoch: [94][0/390]	Time 6.390 (6.390)	Data 6.327 (6.327)	Loss 0.0897 (0.0897)	Prec@1 96.875 (96.875)
Epoch: [94][390/390]	Time 0.023 (0.080)	Data 0.001 (0.016)	Loss 0.0076 (0.0326)	Prec@1 100.000 (98.984)
Test: [0/78]	Time 5.276 (5.276)	Loss 0.3945 (0.3945)	Prec@1 92.188 (92.188)
Test: [78/78]	Time 0.009 (0.084)	Loss 0.6821 (0.3666)	Prec@1 87.500 (91.950)
 * Prec@1 91.950
Epoch: [95][0/390]	Time 6.746 (6.746)	Data 6.483 (6.483)	Loss 0.0057 (0.0057)	

Test: [0/78]	Time 5.362 (5.362)	Loss 0.2671 (0.2671)	Prec@1 92.188 (92.188)
Test: [78/78]	Time 0.009 (0.086)	Loss 0.3453 (0.4221)	Prec@1 93.750 (90.550)
 * Prec@1 90.550
Epoch: [115][0/390]	Time 6.756 (6.756)	Data 6.410 (6.410)	Loss 0.0162 (0.0162)	Prec@1 99.219 (99.219)
Epoch: [115][390/390]	Time 0.047 (0.083)	Data 0.001 (0.017)	Loss 0.0737 (0.0400)	Prec@1 97.500 (98.772)
Test: [0/78]	Time 5.156 (5.156)	Loss 0.3826 (0.3826)	Prec@1 89.844 (89.844)
Test: [78/78]	Time 0.010 (0.088)	Loss 0.6899 (0.3811)	Prec@1 93.750 (91.490)
 * Prec@1 91.490
Epoch: [116][0/390]	Time 6.692 (6.692)	Data 6.521 (6.521)	Loss 0.0170 (0.0170)	Prec@1 100.000 (100.000)
Epoch: [116][390/390]	Time 0.052 (0.062)	Data 0.000 (0.017)	Loss 0.0037 (0.0476)	Prec@1 100.000 (98.562)
Test: [0/78]	Time 5.273 (5.273)	Loss 0.2122 (0.2122)	Prec@1 93.750 (93.750)
Test: [78/78]	Time 0.004 (0.089)	Loss 0.7768 (0.3471)	Prec@1 81.250 (91.880)
 * Prec@1 91.880
Epoch: [117][0/390]	Time 6.377 (6.377)	Data 6.311 (6.311)	Loss 0.0235 (0.02

Epoch: [136][390/390]	Time 0.054 (0.057)	Data 0.001 (0.018)	Loss 0.0400 (0.0116)	Prec@1 98.750 (99.678)
Test: [0/78]	Time 5.410 (5.410)	Loss 0.2776 (0.2776)	Prec@1 93.750 (93.750)
Test: [78/78]	Time 0.005 (0.091)	Loss 0.7276 (0.3913)	Prec@1 87.500 (92.200)
 * Prec@1 92.200
Epoch: [137][0/390]	Time 9.214 (9.214)	Data 9.152 (9.152)	Loss 0.0026 (0.0026)	Prec@1 100.000 (100.000)
Epoch: [137][390/390]	Time 0.024 (0.066)	Data 0.000 (0.024)	Loss 0.0428 (0.0139)	Prec@1 97.500 (99.596)
Test: [0/78]	Time 5.318 (5.318)	Loss 0.1736 (0.1736)	Prec@1 93.750 (93.750)
Test: [78/78]	Time 0.011 (0.090)	Loss 0.6286 (0.3565)	Prec@1 93.750 (92.600)
 * Prec@1 92.600
Epoch: [138][0/390]	Time 6.475 (6.475)	Data 6.409 (6.409)	Loss 0.0080 (0.0080)	Prec@1 99.219 (99.219)
Epoch: [138][390/390]	Time 0.042 (0.084)	Data 0.000 (0.017)	Loss 0.0009 (0.0135)	Prec@1 100.000 (99.624)
Test: [0/78]	Time 5.627 (5.627)	Loss 0.2701 (0.2701)	Prec@1 92.188 (92.188)
Test: [78/78]	Time 0.010 (0.089)	Loss 0.6376 (0.3386)	Prec@1 93.7

## 1.3 Display the accuracy of models

In [None]:
ax = plt.figure()
ax.gca().xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
for ba, bn, bc in zip(accuracy_list, name_list, color_list):
    plt.plot(ba, label = bn, color = bc, marker='o')
#plt.xlim([0, 5])      # X축의 범위: [xmin, xmax]
#plt.ylim([0, 20])     # Y축의 범위: [ymin, ymax]
plt.title('Accuracy of image models')
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")
plt.legend()
plt.show()

In [None]:
plt.figure()
for ba, bn, bc in zip(accuracy_list, name_list, color_list):
    plt.plot(ba, label = bn, color = bc, marker='o')
#plt.xlim([0, 5])      # X축의 범위: [xmin, xmax]
#plt.ylim([0, 20])     # Y축의 범위: [ymin, ymax]
plt.title('Accuracy of image models')
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")
plt.legend()
plt.show()

# 2. Inference

In [None]:
import cv2

from pytorch_grad_cam import GradCAM
from pytorch_grad_cam import GuidedBackpropReLUModel
from pytorch_grad_cam.utils.image import show_cam_on_image, deprocess_image

## 2.1 Set class names of datasets

In [None]:
# Label and its index for CIFAR10
# https://www.cs.toronto.edu/~kriz/cifar.html
class_cifar10 = {0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer',
                 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}

# Label and its index for CIFAR100
# https://huggingface.co/datasets/cifar100
class_cifar100 = {0: 'apple', 1: 'aquarium_fish', 2: 'baby', 3: 'bear', 4: 'beaver', 5: 'bed', 6: 'bee', 7: 'beetle',
                  8: 'bicycle', 9: 'bottle', 10: 'bowl', 11: 'boy', 12: 'bridge', 13: 'bus', 14: 'butterfly',
                  15: 'camel', 16: 'can', 17: 'castle', 18: 'caterpillar', 19: 'cattle', 20: 'chair', 21: 'chimpanzee',
                  22: 'clock', 23: 'cloud', 24: 'cockroach', 25: 'couch', 26: 'cra', 27: 'crocodile', 28: 'cup',
                  29: 'dinosaur', 30: 'dolphin', 31: 'elephant', 32: 'flatfish', 33: 'forest', 34: 'fox', 35: 'girl',
                  36: 'hamster', 37: 'house', 38: 'kangaroo', 39: 'keyboard', 40: 'lamp', 41: 'lawn_mower',
                  42: 'leopard', 43: 'lion', 44: 'lizard', 45: 'lobster', 46: 'man', 47: 'maple_tree', 48: 'motorcycle',
                  49: 'mountain', 50: 'mouse', 51: 'mushroom', 52: 'oak_tree', 53: 'orange', 54: 'orchid', 55: 'otter',
                  56: 'palm_tree', 57: 'pear', 58: 'pickup_truck', 59: 'pine_tree', 60: 'plain', 61: 'plate',
                  62: 'poppy', 63: 'porcupine', 64: 'possum', 65: 'rabbit', 66: 'raccoon', 67: 'ray', 68: 'road',
                  69: 'rocket', 70: 'rose', 71: 'sea', 72: 'seal', 73: 'shark', 74: 'shrew', 75: 'skunk',
                  76: 'skyscraper', 77: 'snail', 78: 'snake', 79: 'spider', 80: 'squirrel', 81: 'streetcar',
                  82: 'sunflower', 83: 'sweet_pepper', 84: 'table', 85: 'tank', 86: 'telephone', 87: 'television',
                  88: 'tiger', 89: 'tractor', 90: 'train', 91: 'trout', 92: 'tulip', 93: 'turtle', 94: 'wardrobe',
                  95: 'whale', 96: 'willow_tree', 97: 'wolf', 98: 'woman', 99: 'worm'}

## 2.2 Define functions

In [None]:
def run_inference(args):
    #############################################
    # Load dataset
    #############################################
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    if args.dataset == "cifar100":
        num_classes = 100
        classes = class_cifar100
        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100(root='./data', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])),
            batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True)
    else:  # default dataset is CIFAR10
        num_classes = 10
        classes = class_cifar10
        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])),
            batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True)

    print("dataset :", args.dataset)
    print("checkpoint :", args.checkpoint)

    #############################################
    # Load model
    #############################################
    model = vgg.__dict__[args.arch](num_classes, args.block)

    cam_layers = [model.features[52]]

    model.features = torch.nn.DataParallel(model.features)
    if args.cpu:
        model.cpu()
    else:
        model.cuda()
    checkpoint = torch.load(args.checkpoint)
    model.load_state_dict(checkpoint['state_dict'])

    #############################################
    # Evaluate model
    #############################################
    dataiter = iter(val_loader)
    images, labels = next(dataiter)
    print("images shape : ", images.shape)
    #img = torchvision.utils.make_grid(images)
    #images = images / 2 + 0.5     # unnormalize
    #npimg = images.numpy()
    #print("npimg shape : ", npimg.shape)
    # torchvision.utils.save_image(images, "gradCAM_seed%d_input.jpg" % seed, nrow=4, normalize=True, range=(-1, 1))
    print("input gt labels : ")
    np_labels = labels.detach().cpu()
    print([classes[int(np_labels[j])] for j in range(args.batch_size)])
    output = model(images)
    maxk = 1
    pred = output.topk(maxk, 1, True, True)
    # print("pred : ", pred)
    print("pred labels : ")
    np_indices = pred.indices.detach().cpu()
    print([classes[int(np_indices[j][0])] for j in range(args.batch_size)])

    #############################################
    # Create CAM
    #############################################
    cam = GradCAM(model=model, target_layers=cam_layers, use_cuda=False if args.cpu else True)
    gb_model = GuidedBackpropReLUModel(model=model, use_cuda=False if args.cpu else True)

    grayscale_cams = cam(input_tensor=images)

    original_img = None
    final_cam = None
    for idx, grayscale_cam in enumerate(grayscale_cams):
        tensor_img = images[idx]

        rgb_img = deprocess_image(tensor_img.permute(1, 2, 0).numpy()) / 255.0
        # print(rgb_img)
        cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)

        # cam_image is RGB encoded whereas "cv2.imwrite" requires BGR encoding.
        cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)

        if final_cam is None:
            original_img = rgb_img
            final_cam = cam_image
        else:
            original_img = cv2.hconcat([original_img, rgb_img])
            final_cam = cv2.hconcat([final_cam, cam_image])
    
    fig = plt.figure()
    fig.add_subplot(2, 1, 1)
    plt.imshow(original_img)
    plt.title("Original Image")
    
    fig.add_subplot(2, 1, 2)
    plt.imshow(final_cam)
    plt.title("GradCam")
    
    plt.show()

In [None]:
def best_checkpoint(checkpoint_path):
    file_list = os.listdir(checkpoint_path)    
    
    for file_name in sorted(file_list, reverse=True):
        if file_name.startswith('checkpoint_'):
            return checkpoint_path + '/' + file_name
    return ''

## 2.3 Analysis of inference

In [None]:
weight_list = [best_checkpoint('%s/%s/%s' % (args.save_dir, args.dataset, block)) for block in block_list]
weight_list

In [None]:
args.batch_size = 4  # Sample images for inferencing

for bt, bn, bw in zip(block_list, name_list, weight_list):
    print()
    print('########################################################################################')
    print('Inference of "%s"' %bn)
    args.block = bt
    args.checkpoint = bw
    run_inference(args)
    print('########################################################################################')