In [1]:
from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

import matplotlib.pyplot as plt
from PIL import Image
import torch.nn as nn
import numpy as np
import os, json
from glob import glob
from tqdm.notebook import tqdm
from joblib import Parallel, delayed

import torch
from torchvision import models, transforms
from torch.autograd import Variable
import torch.nn.functional as F

import timm

In [2]:
def get_input_transform():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])       
    transf = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize
    ])    

    return transf

def get_input_tensors(img):
    transf = get_input_transform()
    return transf(img).unsqueeze(0)

In [10]:
ckpt_fn = "models/20220703-182810-resnet50-224/last.pth.tar"
im_path = "data/orig_o"
out_path = "o_o"

fl_dict = {}
for idx, fl in enumerate(glob(f"{im_path}/*")):
    bn = os.path.basename(fl)
    fl_dict[bn] = idx

In [11]:
model = timm.create_model(
        'resnet50',
        num_classes=300,
        in_chans=3,
        pretrained=True,
        checkpoint_path=ckpt_fn)

target_layers = model.layer4

images = glob(f"{im_path}/**/*.*", recursive=True)

In [None]:
'''GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad'''

cam = GradCAM(model=model, target_layers=target_layers)
eig = EigenCAM(model=model, target_layers=target_layers)

cam_list = [
    ["gradcam", GradCAM(model=model, target_layers=target_layers)],
    ["scorecam", ScoreCAM(model=model, target_layers=target_layers)],
    ["gradcampp", GradCAMPlusPlus(model=model, target_layers=target_layers)],
    ["xgradcam", XGradCAM(model=model, target_layers=target_layers)],
    ["eigencam", EigenCAM(model=model, target_layers=target_layers)],
    ["fullgrad", FullGrad(model=model, target_layers=target_layers)],
]

def process(im_fn, fl_dict, cam_list):
    bn = os.path.basename(im_fn)
    cat = fl_dict[im_fn.split("/")[-2]]
    
    img = Image.open(im_fn)
    img_t = get_input_tensors(img)
    test_img = np.array(img.resize((224, 224)), dtype='float32')
    test_img /= 255

    targets = [ClassifierOutputTarget(cat)]
    
    for cam_str, cam_method in cam_list:
        grayscale_cam = cam_method(input_tensor=img_t, targets=targets)
        grayscale_cam = grayscale_cam[0, :]

        visualization = show_cam_on_image(test_img, grayscale_cam, use_rgb=True)
        plt.imshow(visualization)
        try:
            plt.savefig(f"outs/{out_path}/{cam_str}/{bn}")
        except FileNotFoundError:
            os.makedirs(f"outs/{out_path}/{cam_str}/")
        plt.clf()

Parallel(n_jobs=8)(delayed(process)(im_fn, fl_dict, cam_list) for im_fn in tqdm(images, total=len(list(images))))



  0%|          | 0/4577 [00:00<?, ?it/s]

  7%|▋         | 9/128 [01:39<22:50, 11.51s/it]