In [1]:
import os
import numpy as np
import random

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets

from resnet_cifar.resnet import *

## Define hyperparameters

In [2]:
class ResNet_agrument:
    def __init__(self):
        self.workers = 4
        self.epochs = 200
        self.batch_size = 128
        self.lr = 0.1
        self.resume = ''        
        self.cpu = False
        self.save_dir = 'weights/resnet18'
        self.dataset = 'cifar100'  # Choice : 'cifar10' and 'cifar100'
        self.block = 'RESNET'
        self.checkpoint = None

args = ResNet_agrument()

In [3]:
block_list = ['RESNET', 'SE_SA_1', 'SEC_SA_1', 'CBAM_1', 'NEW_1']
name_list = ['ResNet (base)', 'SE (residuel) + SA', 'SE + SA', 'CBAM', 'Our model']

## Set the random seed

In [4]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# 1. Training

## 1.1 Define functions

In [5]:
def train(epoch, trainloader, net, criterion, optimizer, device):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    last_idx = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        last_idx = batch_idx

        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    
    print(' - Train : Loss: %.3f | Acc: %.3f%% (=%d/%d)' % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))


def test(epoch, testloader, net, criterion, best_acc, device):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
    print(' - Test : Loss: %.3f | Acc: %.3f%% (=%d/%d)' % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
            
    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        save_path = os.path.join(args.save_dir, args.dataset, args.block)
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        torch.save(state, os.path.join(save_path, 'checkpoint_{}.pth'.format(epoch)))
        best_acc = acc

    return best_acc, acc

In [6]:
def run_train(args):
    print("dataset :", args.dataset)
    print("weight folder :", args.save_dir)

    best_acc = 0  # best test accuracy
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    # Data
    print('==> Preparing data..')
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    # classes = ('plane', 'car', 'bird', 'cat', 'deer',
    #            'dog', 'frog', 'horse', 'ship', 'truck')

    if args.dataset == 'cifar10':
        trainset = datasets.CIFAR10(
            root='./data', train=True, download=True, transform=transform_train)
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)

        testset = datasets.CIFAR10(
            root='./data', train=False, download=False, transform=transform_test)
        testloader = torch.utils.data.DataLoader(
            testset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
    else:
        trainset = datasets.CIFAR100(
            root='./data', train=True, download=True, transform=transform_train)
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)

        testset = datasets.CIFAR100(
            root='./data', train=False, download=False, transform=transform_test)
        testloader = torch.utils.data.DataLoader(
            testset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)

    device = 'cuda' if torch.cuda.is_available() and not args.cpu else 'cpu'

    # Model
    print('==> Building model..')
    net = ResNet18(block=args.block, num_classes=100 if args.dataset == 'cifar100' else 10)

    # print("model : ", net)
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net)
        # cudnn.benchmark = True

    if args.resume:
        # Load checkpoint.
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            net.load_state_dict(checkpoint['net'])
            best_acc = checkpoint['acc']
            start_epoch = checkpoint['epoch']
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=args.lr,
                          momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

    epoch_accuracys = []
    for idx in range(start_epoch, start_epoch+args.epochs):
        train(idx, trainloader, net, criterion, optimizer, device)
        best_acc, cur_acc = test(idx, testloader, net, criterion, best_acc, device)        
        scheduler.step()
        epoch_accuracys.append(cur_acc)
    
    return epoch_accuracys

## 1.2 Model training

In [7]:
color_list = ['r', 'y', 'b', 'm', 'g']

In [None]:
accuracy_list = []
for bt, bn in zip(block_list, name_list):
    print()
    print('########################################################################################')
    print('Training of "%s"' %bn)
    args.block = bt
    accuracys = run_train(args)
    accuracy_list.append(accuracys)
    print('########################################################################################')


########################################################################################
Training of "ResNet (base)"
dataset : cifar100
weight folder : weights/resnet18
==> Preparing data..
Files already downloaded and verified
==> Building model..

Epoch: 0
 - Train : Loss: 3.964 | Acc: 8.654% (=4327/50000)
 - Test : Loss: 3.624 | Acc: 12.820% (=1282/10000)
Saving..

Epoch: 1
 - Train : Loss: 3.393 | Acc: 17.150% (=8575/50000)
 - Test : Loss: 3.284 | Acc: 18.830% (=1883/10000)
Saving..

Epoch: 2
 - Train : Loss: 2.954 | Acc: 25.438% (=12719/50000)
 - Test : Loss: 2.828 | Acc: 28.990% (=2899/10000)
Saving..

Epoch: 3
 - Train : Loss: 2.526 | Acc: 33.904% (=16952/50000)
 - Test : Loss: 2.500 | Acc: 35.040% (=3504/10000)
Saving..

Epoch: 4
 - Train : Loss: 2.167 | Acc: 41.428% (=20714/50000)
 - Test : Loss: 2.207 | Acc: 41.280% (=4128/10000)
Saving..

Epoch: 5
 - Train : Loss: 1.937 | Acc: 46.834% (=23417/50000)
 - Test : Loss: 2.197 | Acc: 42.720% (=4272/10000)
Saving..

Epoch: 6
 - Tr

 - Test : Loss: 1.479 | Acc: 61.050% (=6105/10000)

Epoch: 68
 - Train : Loss: 0.788 | Acc: 76.238% (=38119/50000)
 - Test : Loss: 1.442 | Acc: 62.620% (=6262/10000)

Epoch: 69
 - Train : Loss: 0.784 | Acc: 76.270% (=38135/50000)
 - Test : Loss: 1.441 | Acc: 61.480% (=6148/10000)

Epoch: 70
 - Train : Loss: 0.778 | Acc: 76.434% (=38217/50000)
 - Test : Loss: 1.417 | Acc: 62.340% (=6234/10000)

Epoch: 71
 - Train : Loss: 0.765 | Acc: 76.832% (=38416/50000)
 - Test : Loss: 1.472 | Acc: 61.540% (=6154/10000)

Epoch: 72
 - Train : Loss: 0.757 | Acc: 77.134% (=38567/50000)
 - Test : Loss: 1.552 | Acc: 60.260% (=6026/10000)

Epoch: 73
 - Train : Loss: 0.760 | Acc: 77.230% (=38615/50000)
 - Test : Loss: 1.494 | Acc: 61.620% (=6162/10000)

Epoch: 74
 - Train : Loss: 0.745 | Acc: 77.502% (=38751/50000)
 - Test : Loss: 1.455 | Acc: 62.130% (=6213/10000)

Epoch: 75
 - Train : Loss: 0.736 | Acc: 77.746% (=38873/50000)
 - Test : Loss: 1.399 | Acc: 63.960% (=6396/10000)
Saving..

Epoch: 76
 - Train 

 - Test : Loss: 1.236 | Acc: 70.900% (=7090/10000)
Saving..

Epoch: 138
 - Train : Loss: 0.135 | Acc: 96.216% (=48108/50000)
 - Test : Loss: 1.191 | Acc: 72.080% (=7208/10000)
Saving..

Epoch: 139
 - Train : Loss: 0.138 | Acc: 96.160% (=48080/50000)
 - Test : Loss: 1.254 | Acc: 70.620% (=7062/10000)

Epoch: 140
 - Train : Loss: 0.119 | Acc: 96.720% (=48360/50000)
 - Test : Loss: 1.235 | Acc: 70.520% (=7052/10000)

Epoch: 141
 - Train : Loss: 0.116 | Acc: 96.910% (=48455/50000)
 - Test : Loss: 1.172 | Acc: 72.350% (=7235/10000)
Saving..

Epoch: 142
 - Train : Loss: 0.102 | Acc: 97.412% (=48706/50000)
 - Test : Loss: 1.249 | Acc: 70.710% (=7071/10000)

Epoch: 143
 - Train : Loss: 0.099 | Acc: 97.500% (=48750/50000)
 - Test : Loss: 1.218 | Acc: 71.320% (=7132/10000)

Epoch: 144
 - Train : Loss: 0.089 | Acc: 97.796% (=48898/50000)
 - Test : Loss: 1.176 | Acc: 72.460% (=7246/10000)
Saving..

Epoch: 145
 - Train : Loss: 0.079 | Acc: 98.146% (=49073/50000)
 - Test : Loss: 1.174 | Acc: 72.290%

 - Train : Loss: 2.612 | Acc: 31.980% (=15990/50000)
 - Test : Loss: 2.625 | Acc: 32.780% (=3278/10000)
Saving..

Epoch: 4
 - Train : Loss: 2.271 | Acc: 39.240% (=19620/50000)
 - Test : Loss: 2.277 | Acc: 39.490% (=3949/10000)
Saving..

Epoch: 5
 - Train : Loss: 2.030 | Acc: 44.548% (=22274/50000)
 - Test : Loss: 2.350 | Acc: 38.290% (=3829/10000)

Epoch: 6
 - Train : Loss: 1.852 | Acc: 48.884% (=24442/50000)
 - Test : Loss: 2.074 | Acc: 44.350% (=4435/10000)
Saving..

Epoch: 7
 - Train : Loss: 1.720 | Acc: 52.098% (=26049/50000)
 - Test : Loss: 2.182 | Acc: 43.630% (=4363/10000)

Epoch: 8
 - Train : Loss: 1.618 | Acc: 54.704% (=27352/50000)
 - Test : Loss: 1.880 | Acc: 49.220% (=4922/10000)
Saving..

Epoch: 9
 - Train : Loss: 1.540 | Acc: 56.506% (=28253/50000)
 - Test : Loss: 1.870 | Acc: 50.140% (=5014/10000)
Saving..

Epoch: 10
 - Train : Loss: 1.475 | Acc: 57.980% (=28990/50000)
 - Test : Loss: 1.928 | Acc: 50.150% (=5015/10000)
Saving..

Epoch: 11
 - Train : Loss: 1.423 | Acc: 59

 - Train : Loss: 0.744 | Acc: 77.496% (=38748/50000)
 - Test : Loss: 1.397 | Acc: 63.810% (=6381/10000)

Epoch: 74
 - Train : Loss: 0.734 | Acc: 77.504% (=38752/50000)
 - Test : Loss: 1.427 | Acc: 62.990% (=6299/10000)

Epoch: 75
 - Train : Loss: 0.721 | Acc: 78.152% (=39076/50000)
 - Test : Loss: 1.628 | Acc: 59.640% (=5964/10000)

Epoch: 76
 - Train : Loss: 0.707 | Acc: 78.486% (=39243/50000)
 - Test : Loss: 1.582 | Acc: 60.470% (=6047/10000)

Epoch: 77
 - Train : Loss: 0.703 | Acc: 78.488% (=39244/50000)
 - Test : Loss: 1.629 | Acc: 59.960% (=5996/10000)

Epoch: 78
 - Train : Loss: 0.704 | Acc: 78.720% (=39360/50000)
 - Test : Loss: 1.466 | Acc: 63.130% (=6313/10000)

Epoch: 79
 - Train : Loss: 0.691 | Acc: 78.928% (=39464/50000)
 - Test : Loss: 1.695 | Acc: 59.050% (=5905/10000)

Epoch: 80
 - Train : Loss: 0.691 | Acc: 78.952% (=39476/50000)
 - Test : Loss: 1.566 | Acc: 60.810% (=6081/10000)

Epoch: 81
 - Train : Loss: 0.673 | Acc: 79.332% (=39666/50000)
 - Test : Loss: 1.649 | Acc

## Display the accuracy of models

In [None]:
ax = plt.figure()
ax.gca().xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
for ba, bn, bc in zip(accuracy_list, name_list, color_list):
    plt.plot(ba, label = bn, color = bc, marker='o')
#plt.xlim([0, 5])      # X축의 범위: [xmin, xmax]
#plt.ylim([0, 20])     # Y축의 범위: [ymin, ymax]
plt.title('Accuracy of image models')
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")
plt.legend()
plt.show()

# 2. Inference

In [None]:
import cv2

from pytorch_grad_cam import GradCAM
from pytorch_grad_cam import GuidedBackpropReLUModel
from pytorch_grad_cam.utils.image import show_cam_on_image, deprocess_image

## 2.1 Set class names of datasets

In [None]:
# Label and its index for CIFAR10
# https://www.cs.toronto.edu/~kriz/cifar.html
class_cifar10 = {0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer',
                 5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}

# Label and its index for CIFAR100
# https://huggingface.co/datasets/cifar100
class_cifar100 = {0: 'apple', 1: 'aquarium_fish', 2: 'baby', 3: 'bear', 4: 'beaver', 5: 'bed', 6: 'bee', 7: 'beetle',
                  8: 'bicycle', 9: 'bottle', 10: 'bowl', 11: 'boy', 12: 'bridge', 13: 'bus', 14: 'butterfly',
                  15: 'camel', 16: 'can', 17: 'castle', 18: 'caterpillar', 19: 'cattle', 20: 'chair', 21: 'chimpanzee',
                  22: 'clock', 23: 'cloud', 24: 'cockroach', 25: 'couch', 26: 'cra', 27: 'crocodile', 28: 'cup',
                  29: 'dinosaur', 30: 'dolphin', 31: 'elephant', 32: 'flatfish', 33: 'forest', 34: 'fox', 35: 'girl',
                  36: 'hamster', 37: 'house', 38: 'kangaroo', 39: 'keyboard', 40: 'lamp', 41: 'lawn_mower',
                  42: 'leopard', 43: 'lion', 44: 'lizard', 45: 'lobster', 46: 'man', 47: 'maple_tree', 48: 'motorcycle',
                  49: 'mountain', 50: 'mouse', 51: 'mushroom', 52: 'oak_tree', 53: 'orange', 54: 'orchid', 55: 'otter',
                  56: 'palm_tree', 57: 'pear', 58: 'pickup_truck', 59: 'pine_tree', 60: 'plain', 61: 'plate',
                  62: 'poppy', 63: 'porcupine', 64: 'possum', 65: 'rabbit', 66: 'raccoon', 67: 'ray', 68: 'road',
                  69: 'rocket', 70: 'rose', 71: 'sea', 72: 'seal', 73: 'shark', 74: 'shrew', 75: 'skunk',
                  76: 'skyscraper', 77: 'snail', 78: 'snake', 79: 'spider', 80: 'squirrel', 81: 'streetcar',
                  82: 'sunflower', 83: 'sweet_pepper', 84: 'table', 85: 'tank', 86: 'telephone', 87: 'television',
                  88: 'tiger', 89: 'tractor', 90: 'train', 91: 'trout', 92: 'tulip', 93: 'turtle', 94: 'wardrobe',
                  95: 'whale', 96: 'willow_tree', 97: 'wolf', 98: 'woman', 99: 'worm'}

## 2.2 Define functions

In [None]:
def run_inference(args):
    #############################################
    # Load dataset
    #############################################
    normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
    if args.dataset == "cifar100":
        num_classes = 100
        classes = class_cifar100
        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100(root='./data', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])),
            batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True)
    else:  # default dataset is CIFAR10
        num_classes = 10
        classes = class_cifar10
        val_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])),
            batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True)

    print("dataset :", args.dataset)
    print("checkpoint :", args.checkpoint)

    #############################################
    # Load model
    #############################################
    model = ResNet18(block=args.block, num_classes=num_classes)
    # print(model.layer4)

    cam_layers = [model.layer4]

    device = 'cuda' if torch.cuda.is_available() and not args.cpu else 'cpu'

    model = model.to(device)
    if device == 'cuda':
        model = torch.nn.DataParallel(model)

    checkpoint = torch.load(args.checkpoint)
    model.load_state_dict(checkpoint['net'])

    #############################################
    # Evaluate model
    #############################################
    dataiter = iter(val_loader)
    images, labels = next(dataiter)
    print("images shape : ", images.shape)
    #img = torchvision.utils.make_grid(images)
    #images = images / 2 + 0.5     # unnormalize
    #npimg = images.numpy()
    #print("npimg shape : ", npimg.shape)
    torchvision.utils.save_image(images, "gradCAM_seed%d_input.jpg" % seed, nrow=4, normalize=True, range=(-1, 1))
    print("input gt labels : ")
    np_labels = labels.detach().cpu()
    print([classes[int(np_labels[j])] for j in range(args.batch_size)])
    output = model(images)
    maxk = 1
    pred = output.topk(maxk, 1, True, True)
    # print("pred : ", pred)
    print("pred labels : ")
    np_indices = pred.indices.detach().cpu()
    print([classes[int(np_indices[j][0])] for j in range(args.batch_size)])

    #############################################
    # Create CAM
    #############################################
    cam = GradCAM(model=model, target_layers=cam_layers, use_cuda=False if device == 'cpu' else True)
    gb_model = GuidedBackpropReLUModel(model=model, use_cuda=False if device == 'cpu' else True)

    grayscale_cams = cam(input_tensor=images)

    original_img = None
    final_cam = None
    for idx, grayscale_cam in enumerate(grayscale_cams):
        tensor_img = images[idx]

        rgb_img = deprocess_image(tensor_img.permute(1, 2, 0).numpy()) / 255.0
        # print(rgb_img)
        cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)

        # cam_image is RGB encoded whereas "cv2.imwrite" requires BGR encoding.
        cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)

        if final_cam is None:
            original_img = rgb_img
            final_cam = cam_image
        else:
            original_img = cv2.hconcat([original_img, rgb_img])
            final_cam = cv2.hconcat([final_cam, cam_image])
        
    fig = plt.figure()
    fig.add_subplot(2, 1, 1)
    plt.imshow(original_img)
    plt.title("Original Image")
    
    fig.add_subplot(2, 1, 2)
    plt.imshow(final_cam)
    plt.title("GradCam")
    
    plt.show()

In [None]:
def best_checkpoint(checkpoint_path):
    file_list = os.listdir(checkpoint_path)    
    
    for file_name in sorted(file_list, reverse=True):
        if file_name.startswith('checkpoint_'):
            return checkpoint_path + '/' + file_name
    return ''

## 2.3 Analysis of inference

In [None]:
weight_list = [best_checkpoint('%s/%s/%s' % (args.save_dir, args.dataset, block)) for block in block_list]
weight_list

In [None]:
args.batch_size = 4  # Sample images for inferencing

for bt, bn, bw in zip(block_list, name_list, weight_list):
    print()
    print('########################################################################################')
    print('Inference of "%s"' %bn)
    args.block = bt
    args.checkpoint = bw
    run_inference(args)
    print('########################################################################################')