In [1]:
from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

import matplotlib.pyplot as plt
from PIL import Image
import torch.nn as nn
import numpy as np
import os, json
from glob import glob
from tqdm.notebook import tqdm
from joblib import Parallel, delayed

import torch
from torchvision import models, transforms
from torch.autograd import Variable
import torch.nn.functional as F

import timm


In [2]:
def get_input_transform():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])       
    transf = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize
    ])    

    return transf

def get_input_tensors(img):
    transf = get_input_transform()
    return transf(img).unsqueeze(0)

In [None]:
ckpt_fn = "models/orig_20220627-235055-seresnext50_32x4d-224/last.pth.tar"
im_path = "data/orig_o"

fl_dict = {}
for idx, fl in enumerate(glob(f"{im_path}/*")):
    bn = os.path.basename(fl)
    fl_dict[bn] = idx

In [3]:
ckpt_fn = "models/planted_20220627-235814-seresnext50_32x4d-224/last.pth.tar"
im_path = "data/pl_o"
out_path = "pl_pl"

fl_dict = {}
for idx, fl in enumerate(glob(f"{im_path}/*")):
    bn = os.path.basename(fl)
    fl_dict[bn] = idx

In [10]:
ckpt_fn = "models/bl_20220703-220635-resnet50-224/last.pth.tar"
im_path = "data/bl"
out_path = "bl_bl"

fl_dict = {}
for idx, fl in enumerate(glob(f"{im_path}/*")):
    bn = os.path.basename(fl)
    fl_dict[bn] = idx

In [9]:
model = timm.create_model(
        'resnet50',
        num_classes=300,
        in_chans=3,
        pretrained=True,
        checkpoint_path=ckpt_fn)

target_layers = model.layer4

images = glob(f"{im_path}/**/*.*", recursive=True)

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50_ram-a26f946b.pth" to /Users/dax/.cache/torch/hub/checkpoints/resnet50_ram-a26f946b.pth


In [12]:
cam = GradCAM(model=model, target_layers=target_layers)
eig = EigenCAM(model=model, target_layers=target_layers)
def process(im_fn, fl_dict):
    bn = os.path.basename(im_fn)
    cat = fl_dict[im_fn.split("/")[-2]]
    
    img = Image.open(im_fn)
    img_t = get_input_tensors(img)
    test_img = np.array(img.resize((224, 224)), dtype='float32')
    test_img /= 255

    targets = [ClassifierOutputTarget(cat)]

    grayscale_cam = cam(input_tensor=img_t, targets=targets)
    grayscale_cam = grayscale_cam[0, :]

    visualization = show_cam_on_image(test_img, grayscale_cam, use_rgb=True)
    plt.imshow(visualization)
    try:
        plt.savefig(f"outs/{out_path}/gradcam/{bn}")
    except FileNotFoundError:
        os.makedirs(f"outs/{out_path}/gradcam/")
    plt.clf()
    
    grayscale_eig = eig(input_tensor=img_t, targets=targets)
    grayscale_eig = grayscale_eig[0, :]

    visualization = show_cam_on_image(test_img, grayscale_eig, use_rgb=True)
    plt.imshow(visualization)
    try:
        plt.savefig(f"outs/{out_path}/eigencam/{bn}")
    except FileNotFoundError:
        os.makedirs(f"outs/{out_path}/eigencam/")
    plt.clf()
    
Parallel(n_jobs=8)(delayed(process)(im_fn, fl_dict) for im_fn in tqdm(images, total=len(list(images))))

  0%|          | 0/4230 [00:00<?, ?it/s]

[None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,

In [6]:
abl = AblationCAM(model=model, target_layers=target_layers)
for im_fn in tqdm(images, total=len(list(images))):
    bn = os.path.basename(im_fn)
    cat = fl_dict[im_fn.split("/")[-2]]
    
    img = Image.open(im_fn)
    img_t = get_input_tensors(img)
    test_img = np.array(img.resize((224, 224)), dtype='float32')
    test_img /= 255

    targets = [ClassifierOutputTarget(cat)]
    
    grayscale_abl = abl(input_tensor=img_t, targets=targets)
    grayscale_abl = grayscale_abl[0, :]

    visualization = show_cam_on_image(test_img, grayscale_abl, use_rgb=True)
    plt.imshow(visualization)
    try:
        plt.savefig(f"outs/{out_path}/ablationcam/{bn}")
    except FileNotFoundError:
        os.makedirs(f"outs/{out_path}/ablationcam/")
    plt.clf()

  0%|          | 0/7859 [00:00<?, ?it/s]


  0%|                                                    | 0/64 [00:00<?, ?it/s][A
  2%|▋                                           | 1/64 [00:05<05:17,  5.04s/it][A
  3%|█▍                                          | 2/64 [00:10<05:10,  5.00s/it][A
  5%|██                                          | 3/64 [00:15<05:05,  5.01s/it][A
  6%|██▊                                         | 4/64 [00:20<05:01,  5.02s/it][A
  8%|███▍                                        | 5/64 [00:25<05:07,  5.21s/it][A
  9%|████▏                                       | 6/64 [00:31<05:07,  5.31s/it][A
 11%|████▊                                       | 7/64 [00:36<05:12,  5.49s/it][A
 12%|█████▌                                      | 8/64 [00:42<05:11,  5.57s/it][A
 14%|██████▏                                     | 9/64 [00:48<05:08,  5.60s/it][A
 16%|██████▋                                    | 10/64 [00:53<04:56,  5.48s/it][A
 17%|███████▍                                   | 11/64 [00:58<04:47,  5.43

100%|███████████████████████████████████████████| 64/64 [05:50<00:00,  5.48s/it][A

  0%|                                                    | 0/64 [00:00<?, ?it/s][A
  2%|▋                                           | 1/64 [00:04<05:14,  5.00s/it][A
  3%|█▍                                          | 2/64 [00:09<05:06,  4.94s/it][A
  5%|██                                          | 3/64 [00:14<04:59,  4.92s/it][A
  6%|██▊                                         | 4/64 [00:19<04:54,  4.90s/it][A
  8%|███▍                                        | 5/64 [00:24<04:50,  4.92s/it][A
  9%|████▏                                       | 6/64 [00:29<04:44,  4.91s/it][A
 11%|████▊                                       | 7/64 [00:34<04:39,  4.90s/it][A
 12%|█████▌                                      | 8/64 [00:39<04:34,  4.90s/it][A
 14%|██████▏                                     | 9/64 [00:44<04:29,  4.90s/it][A
 16%|██████▋                                    | 10/64 [00:49<04:25,  4.91

 98%|██████████████████████████████████████████▎| 63/64 [05:42<00:05,  5.33s/it][A
100%|███████████████████████████████████████████| 64/64 [05:47<00:00,  5.43s/it][A

  0%|                                                    | 0/64 [00:00<?, ?it/s][A
  2%|▋                                           | 1/64 [00:05<05:36,  5.34s/it][A
  3%|█▍                                          | 2/64 [00:10<05:20,  5.17s/it][A
  5%|██                                          | 3/64 [00:15<05:11,  5.11s/it][A
  6%|██▊                                         | 4/64 [00:20<05:06,  5.10s/it][A
  8%|███▍                                        | 5/64 [00:25<05:00,  5.09s/it][A
  9%|████▏                                       | 6/64 [00:30<04:54,  5.08s/it][A
 11%|████▊                                       | 7/64 [00:35<04:47,  5.05s/it][A
 12%|█████▌                                      | 8/64 [00:40<04:44,  5.08s/it][A
 14%|██████▏                                     | 9/64 [00:45<04:40,  5.10

 97%|█████████████████████████████████████████▋ | 62/64 [05:34<00:10,  5.48s/it][A
 98%|██████████████████████████████████████████▎| 63/64 [05:40<00:05,  5.47s/it][A
100%|███████████████████████████████████████████| 64/64 [05:45<00:00,  5.40s/it][A

  0%|                                                    | 0/64 [00:00<?, ?it/s][A
  2%|▋                                           | 1/64 [00:05<05:32,  5.28s/it][A
  3%|█▍                                          | 2/64 [00:10<05:26,  5.27s/it][A
  5%|██                                          | 3/64 [00:15<05:22,  5.28s/it][A
  6%|██▊                                         | 4/64 [00:21<05:16,  5.27s/it][A
  8%|███▍                                        | 5/64 [00:26<05:10,  5.26s/it][A
  9%|████▏                                       | 6/64 [00:31<05:05,  5.26s/it][A
 11%|████▊                                       | 7/64 [00:36<04:59,  5.26s/it][A
 12%|█████▌                                      | 8/64 [00:42<04:54,  5.26

 95%|████████████████████████████████████████▉  | 61/64 [05:37<00:16,  5.54s/it][A
 97%|█████████████████████████████████████████▋ | 62/64 [05:43<00:11,  5.54s/it][A
 98%|██████████████████████████████████████████▎| 63/64 [05:48<00:05,  5.53s/it][A
100%|███████████████████████████████████████████| 64/64 [05:54<00:00,  5.53s/it][A

  0%|                                                    | 0/64 [00:00<?, ?it/s][A
  2%|▋                                           | 1/64 [00:05<05:36,  5.34s/it][A
  3%|█▍                                          | 2/64 [00:10<05:29,  5.31s/it][A
  5%|██                                          | 3/64 [00:15<05:22,  5.29s/it][A
  6%|██▊                                         | 4/64 [00:21<05:17,  5.29s/it][A
  8%|███▍                                        | 5/64 [00:26<05:11,  5.28s/it][A
  9%|████▏                                       | 6/64 [00:31<05:06,  5.28s/it][A
 11%|████▊                                       | 7/64 [00:37<05:01,  5.28

 94%|████████████████████████████████████████▎  | 60/64 [05:31<00:22,  5.52s/it][A
 95%|████████████████████████████████████████▉  | 61/64 [05:36<00:16,  5.53s/it][A
 97%|█████████████████████████████████████████▋ | 62/64 [05:42<00:11,  5.52s/it][A
 98%|██████████████████████████████████████████▎| 63/64 [05:47<00:05,  5.51s/it][A
100%|███████████████████████████████████████████| 64/64 [05:53<00:00,  5.52s/it][A

  0%|                                                    | 0/64 [00:00<?, ?it/s][A
  2%|▋                                           | 1/64 [00:05<05:34,  5.30s/it][A
  3%|█▍                                          | 2/64 [00:10<05:28,  5.29s/it][A
  5%|██                                          | 3/64 [00:15<05:23,  5.30s/it][A
  6%|██▊                                         | 4/64 [00:21<05:17,  5.29s/it][A
  8%|███▍                                        | 5/64 [00:26<05:11,  5.28s/it][A
  9%|████▏                                       | 6/64 [00:31<05:06,  5.28

 92%|███████████████████████████████████████▋   | 59/64 [05:25<00:27,  5.55s/it][A
 94%|████████████████████████████████████████▎  | 60/64 [05:30<00:22,  5.53s/it][A
 95%|████████████████████████████████████████▉  | 61/64 [05:36<00:16,  5.53s/it][A
 97%|█████████████████████████████████████████▋ | 62/64 [05:41<00:11,  5.52s/it][A
 98%|██████████████████████████████████████████▎| 63/64 [05:47<00:05,  5.52s/it][A
100%|███████████████████████████████████████████| 64/64 [05:52<00:00,  5.52s/it][A

  0%|                                                    | 0/64 [00:00<?, ?it/s][A
  2%|▋                                           | 1/64 [00:05<05:33,  5.29s/it][A
  3%|█▍                                          | 2/64 [00:10<05:26,  5.27s/it][A
  5%|██                                          | 3/64 [00:15<05:21,  5.27s/it][A
  6%|██▊                                         | 4/64 [00:21<05:17,  5.29s/it][A
  8%|███▍                                        | 5/64 [00:26<05:11,  5.28

 91%|██████████████████████████████████████▉    | 58/64 [05:19<00:32,  5.50s/it][A
 92%|███████████████████████████████████████▋   | 59/64 [05:24<00:27,  5.49s/it][A
 94%|████████████████████████████████████████▎  | 60/64 [05:30<00:21,  5.49s/it][A
 95%|████████████████████████████████████████▉  | 61/64 [05:35<00:16,  5.49s/it][A
 97%|█████████████████████████████████████████▋ | 62/64 [05:41<00:10,  5.49s/it][A
 98%|██████████████████████████████████████████▎| 63/64 [05:46<00:05,  5.49s/it][A
100%|███████████████████████████████████████████| 64/64 [05:52<00:00,  5.50s/it][A

  0%|                                                    | 0/64 [00:00<?, ?it/s][A
  2%|▋                                           | 1/64 [00:05<05:32,  5.28s/it][A
  3%|█▍                                          | 2/64 [00:10<05:26,  5.27s/it][A
  5%|██                                          | 3/64 [00:15<05:22,  5.28s/it][A
  6%|██▊                                         | 4/64 [00:21<05:16,  5.27

 89%|██████████████████████████████████████▎    | 57/64 [05:14<00:38,  5.53s/it][A
 91%|██████████████████████████████████████▉    | 58/64 [05:19<00:33,  5.52s/it][A
 92%|███████████████████████████████████████▋   | 59/64 [05:25<00:27,  5.51s/it][A
 94%|████████████████████████████████████████▎  | 60/64 [05:30<00:22,  5.51s/it][A
 95%|████████████████████████████████████████▉  | 61/64 [05:36<00:16,  5.55s/it][A
 97%|█████████████████████████████████████████▋ | 62/64 [05:41<00:11,  5.54s/it][A
 98%|██████████████████████████████████████████▎| 63/64 [05:47<00:05,  5.53s/it][A
100%|███████████████████████████████████████████| 64/64 [05:52<00:00,  5.51s/it][A

  0%|                                                    | 0/64 [00:00<?, ?it/s][A
  2%|▋                                           | 1/64 [00:05<05:32,  5.28s/it][A
  3%|█▍                                          | 2/64 [00:10<05:27,  5.28s/it][A
  5%|██                                          | 3/64 [00:15<05:21,  5.27

 88%|█████████████████████████████████████▋     | 56/64 [05:08<00:44,  5.51s/it][A
 89%|██████████████████████████████████████▎    | 57/64 [05:13<00:38,  5.51s/it][A
 91%|██████████████████████████████████████▉    | 58/64 [05:19<00:33,  5.50s/it][A
 92%|███████████████████████████████████████▋   | 59/64 [05:24<00:27,  5.50s/it][A
 94%|████████████████████████████████████████▎  | 60/64 [05:30<00:22,  5.50s/it][A
 95%|████████████████████████████████████████▉  | 61/64 [05:35<00:16,  5.50s/it][A
 97%|█████████████████████████████████████████▋ | 62/64 [05:41<00:11,  5.51s/it][A
 98%|██████████████████████████████████████████▎| 63/64 [05:46<00:05,  5.50s/it][A
100%|███████████████████████████████████████████| 64/64 [05:52<00:00,  5.50s/it][A

  0%|                                                    | 0/64 [00:00<?, ?it/s][A
  2%|▋                                           | 1/64 [00:05<05:32,  5.27s/it][A
  3%|█▍                                          | 2/64 [00:10<05:26,  5.26

 86%|████████████████████████████████████▉      | 55/64 [05:04<00:49,  5.51s/it][A
 88%|█████████████████████████████████████▋     | 56/64 [05:09<00:44,  5.51s/it][A
 89%|██████████████████████████████████████▎    | 57/64 [05:15<00:38,  5.51s/it][A
 91%|██████████████████████████████████████▉    | 58/64 [05:20<00:33,  5.53s/it][A
 92%|███████████████████████████████████████▋   | 59/64 [05:26<00:27,  5.53s/it][A
 94%|████████████████████████████████████████▎  | 60/64 [05:31<00:22,  5.53s/it][A
 95%|████████████████████████████████████████▉  | 61/64 [05:37<00:16,  5.54s/it][A
 97%|█████████████████████████████████████████▋ | 62/64 [05:43<00:11,  5.56s/it][A
 98%|██████████████████████████████████████████▎| 63/64 [05:48<00:05,  5.56s/it][A
100%|███████████████████████████████████████████| 64/64 [05:54<00:00,  5.53s/it][A

  0%|                                                    | 0/64 [00:00<?, ?it/s][A
  2%|▋                                           | 1/64 [00:05<05:35,  5.33

 84%|████████████████████████████████████▎      | 54/64 [04:58<00:55,  5.52s/it][A
 86%|████████████████████████████████████▉      | 55/64 [05:04<00:49,  5.54s/it][A
 88%|█████████████████████████████████████▋     | 56/64 [05:09<00:44,  5.53s/it][A
 89%|██████████████████████████████████████▎    | 57/64 [05:15<00:38,  5.53s/it][A
 91%|██████████████████████████████████████▉    | 58/64 [05:20<00:33,  5.53s/it][A
 92%|███████████████████████████████████████▋   | 59/64 [05:26<00:27,  5.53s/it][A
 94%|████████████████████████████████████████▎  | 60/64 [05:32<00:22,  5.53s/it][A
 95%|████████████████████████████████████████▉  | 61/64 [05:37<00:16,  5.53s/it][A
 97%|█████████████████████████████████████████▋ | 62/64 [05:43<00:11,  5.52s/it][A
 98%|██████████████████████████████████████████▎| 63/64 [05:48<00:05,  5.52s/it][A
100%|███████████████████████████████████████████| 64/64 [05:54<00:00,  5.53s/it][A

  0%|                                                    | 0/64 [00:00<?, ?

 83%|███████████████████████████████████▌       | 53/64 [04:53<01:00,  5.54s/it][A
 84%|████████████████████████████████████▎      | 54/64 [04:59<00:55,  5.54s/it][A
 86%|████████████████████████████████████▉      | 55/64 [05:04<00:49,  5.54s/it][A
 88%|█████████████████████████████████████▋     | 56/64 [05:10<00:44,  5.54s/it][A
 89%|██████████████████████████████████████▎    | 57/64 [05:15<00:38,  5.54s/it][A
 91%|██████████████████████████████████████▉    | 58/64 [05:21<00:33,  5.54s/it][A
 92%|███████████████████████████████████████▋   | 59/64 [05:26<00:27,  5.54s/it][A
 94%|████████████████████████████████████████▎  | 60/64 [05:32<00:22,  5.54s/it][A
 95%|████████████████████████████████████████▉  | 61/64 [05:37<00:16,  5.54s/it][A
 97%|█████████████████████████████████████████▋ | 62/64 [05:43<00:11,  5.55s/it][A
 98%|██████████████████████████████████████████▎| 63/64 [05:49<00:05,  5.61s/it][A
100%|███████████████████████████████████████████| 64/64 [05:54<00:00,  5.54s

<Figure size 432x288 with 0 Axes>