In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

import sys
from pytorch_grad_cam import GradCAM, EigenCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from PIL import Image

import numpy as np

import torchvision.models as models
import os

from torchvision.models import resnet18

import tqdm

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
def get_all_png(img_path):
    all_png_paths = []
    
    for dirpath,_,filenames in os.walk(img_path):
        for f in filenames:
            if f.endswith(".png"):
                all_png_paths.append(os.path.abspath(os.path.join(dirpath, f)))
    
    return sorted(all_png_paths)

In [None]:
def do_gradcam(img_original, model, layers):
    cam = GradCAM(model=model, target_layers=layers, use_cuda=True)
    
    _toTensor = transforms.ToTensor()
    _input = _toTensor(img_original).to(device)
    
    _input = _input.unsqueeze(0)

    grayscale_cam = cam(input_tensor=_input)
    
    grayscale_cam = grayscale_cam[0, :]

    grad_img = show_cam_on_image(np.asarray(img_original) / 255, grayscale_cam, use_rgb=True)
    
    return Image.fromarray(grad_img)

In [None]:
def do_grad_cam(model, layers, inTensor, outPath):
    model.eval()

    img_grad = do_gradcam(inTensor, model, layers)

    img_original = inTensor.resize((300,300))
    img_grad = img_grad.resize((300,300))

    final_image = np.concatenate((img_original, img_grad), axis=1)

    final_image = Image.fromarray(final_image)

    final_image.show()
    input()

    final_image.save(outPath)
    
    return

In [None]:
model = resnet18(pretrained=True)

layers = [
    model.layer1[-1],
    model.layer2[-1],
    model.layer3[-1],
    model.layer4[-1]
]

In [None]:
transform = transforms.Compose( [transforms.ToTensor()])
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testLoader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=1)

In [None]:
for i, t_img in enumerate(tqdm.tqdm(testLoader)):
    transform_PIL = transforms.ToPILImage()
    do_grad_cam(model, layers, transform_PIL(t_img[0][0]), "./output/" + str(i) + ".png")