In [None]:
import os
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary
from PIL import Image
import glob
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pathlib import Path
import math
from ipywidgets import interact

In [None]:
class _BaseWrapper():
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.handlers = []

    def forward(self, images):
        self.image_shape = images.shape[2:]
        self.logits = self.model(images)
        self.probs = F.softmax(self.logits, dim=1)
        return self.probs.sort(dim=1, descending=True)

    def backward(self, ids):
        one_hot = F.one_hot(ids, self.logits.shape[-1])
        one_hot = one_hot.squeeze(-2)
        self.model.zero_grad()
        self.logits.backward(gradient=one_hot, retain_graph=True)
        # gradient는 해당 index에 대해서만 미분을 통한 backpropagation 진행
        # 확인하고 싶은 class에 대해서 featuremap 영향을 확인가능

    def generate(self):
        raise NotImplementedError


class GradCAM(_BaseWrapper):
    def __init__(self, model, layers=None):
        super().__init__(model)
        self.feature_map = {}
        self.grad_map = {}
        self.layers = layers

        def save_fmaps(key):
            def forward_hook(module, input, output):
                self.feature_map[key]=output.detach()

            return forward_hook

        def save_grads(key):
            def backward_hook(modeul, grad_in, grad_out):
                self.grad_map[key] = grad_out[0].detach()

            return backward_hook

        for name, module in self.model.named_modules():
            if self.layers is None or name in self.layers:
                self.handlers.append(module.register_forward_hook(save_fmaps(name)))
                self.handlers.append(module.register_backward_hook(save_grads(name)))

    def findLayers(self, layers, target_layer):
        if target_layer in layers.keys():
            return layers[target_layer]
        else:
            raise ValueError(f"{target_layer} not exists")

    def generate(self, target_layer):
        feature_maps = self.findLayers(self.feature_map, target_layer)
        print(feature_maps.size())
        grad_maps = self.findLayers(self.grad_map, target_layer)
        weights = F.adaptive_avg_pool2d(grad_maps, 1)
        print(weights.size())
        
        grad_cam = torch.mul(feature_maps, weights).sum(dim=1, keepdim=True)
        grad_cam = F.relu(grad_cam)
        grad_cam = F.interpolate(grad_cam, self.image_shape, mode="bilinear", align_corners=False)
        B, C, W, H = grad_cam.shape

        grad_cam = grad_cam.view(B, -1)
        grad_cam -= grad_cam.min(dim=1, keepdim=True)[0]
        grad_cam /= grad_cam.max(dim=1, keepdim=True)[0]
        grad_cam = grad_cam.view(B, C, W, H)

        return grad_cam

In [None]:
def readImages(img_path):
    return glob.glob(os.path.join(img_path, "*.*"))
    

def load_images(img_path, transform):
    img_list = readImages(img_path)

    images = []
    raw_images = []
    
    for fname in img_list:
        img = np.array(Image.open(fname))
        img_tf = transform(img)
        
        images.append(img_tf)
        raw_images.append(img)

    if len(images)>1:
        images = torch.stack(images)
    else:
        images = images[0].unsqueeze(0)

    return images, raw_images

def gradCam(model, img_path, transform=None, device="cuda", target_class=0, target_layers :list =[]):
    model.to(device)
    model.eval()
    
    images, raw_images = load_images(img_path, transform)
    images = images.to(device)
    # images는 B,C,W,H로 된 tensor, raw_images는 리스트로 된 numpy 이미지 배열

    grad_cam = GradCAM(model)
    probs, ids = grad_cam.forward(images)

    target_ids = torch.LongTensor([[target_class]]*len(images)).to(device)

    grad_cam.backward(target_ids)

    check = []
    # 설정한 layer가 여러개라면 모두 저장

    for target_layer in target_layers:
        print(f"generating Grad-CAM : {target_layer}")
        regions = grad_cam.generate(target_layer)
        check.append(regions)
    
    return check, raw_images

In [None]:
def gauss(x,a,b,c):
    return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a)

def colorize(x):
    if x.dim() == 2:
        x = torch.unsqueeze(x, 0)
    if x.dim() == 3:
        cl = torch.zeros([3, x.size(1), x.size(2)])
        cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3)
        cl[1] = gauss(x,1,.5,.3)
        cl[2] = gauss(x,1,.2,.3)
        cl[cl.gt(1)] = 1
    elif x.dim() == 4:
        cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)])
        cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3)
        cl[:,1,:,:] = gauss(x,1,.5,.3)
        cl[:,2,:,:] = gauss(x,1,.2,.3)
    return cl

In [None]:
model = "모델명"
src_path = "이미지가 있는 폴더명"
transform_ = "이미지 transform 메소드"
target_layers = ["확인하고 싶은 layer"]
target_class = "확인하고 싶은 class"
device = "cpu 또는 cuda"

check, original = gradCam(model, src_path, transform_, device, target_class, target_layers)
# check (list) : 각 layer들의 feature map
# original (list) : 이미지 

In [None]:
# 수정할 필요 X
@interact(idx = (0, check[0].shape[0]-1))
def showImg(idx):
    mask = check[0][idx].squeeze().cpu().detach()
    mask = colorize(mask)
    H, W = mask.shape[-2], mask.shape[-1]

    original_img = torch.tensor(original[idx]).permute(2, 0, 1)
    original_img = transforms.Resize((H,W))(original_img).permute(1,2,0)

    fig, axes = plt.subplots(1, 3, figsize=(18,6))
    fig.set_facecolor("black")
    plt.suptitle(f"GradCAM class : {target_class}", c="white", fontsize=30, y=1.1)
    axes[0].imshow(original_img)
    axes[1].imshow(mask.permute(1,2,0))
    axes[2].imshow(original_img)
    axes[2].imshow(mask.permute(1,2,0), alpha=0.5)
    
    axes[0].set_title("original_image", c="white")
    axes[1].set_title(f'{target_layers[0]}',c="white")
    axes[2].set_title("mixed_image",c="white")
    plt.grid(False)
    for i in range(3):
        axes[i].set_xticks([])
        axes[i].set_yticks([])
    
    for loc in ["top","bottom","left", "right"]:
        for i in range(3):
            axes[i].spines[loc].set_visible(False)
    
    plt.show()

# 예시 코드

In [None]:
from google.colab import drive
drive.mount("/content/gdrive")

In [None]:
model = torchvision.models.resnet18(pretrained=True)
src_path = "/content/gdrive/MyDrive/test/dataset/test"
transform_ = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((512,512)),
        transforms.Normalize((0.56, 0.524, 0.501), (0.258, 0.265, 0.267)),
    ])
target_layers = ["layer4"]
target_class = 282 # imagenet 기준
device = "cpu"

check, original = gradCam(model, src_path, transform_, device, target_class, target_layers=["layer4"])


In [None]:
# 수정할 필요 X
@interact(idx = (0, check[0].shape[0]-1))
def showImg(idx):
    mask = check[0][idx].squeeze().cpu().detach()
    mask = colorize(mask)
    H, W = mask.shape[-2], mask.shape[-1]

    original_img = torch.tensor(original[idx]).permute(2, 0, 1)
    original_img = transforms.Resize((H,W))(original_img).permute(1,2,0)

    fig, axes = plt.subplots(1, 3, figsize=(18,6))
    fig.set_facecolor("black")
    plt.suptitle(f"GradCAM class : {target_class}", c="white", fontsize=30, y=1.1)
    axes[0].imshow(original_img)
    axes[1].imshow(mask.permute(1,2,0))
    axes[2].imshow(original_img)
    axes[2].imshow(mask.permute(1,2,0), alpha=0.5)
    
    axes[0].set_title("original_image", c="white")
    axes[1].set_title(f'{target_layers[0]}',c="white")
    axes[2].set_title("mixed_image",c="white")
    plt.grid(False)
    for i in range(3):
        axes[i].set_xticks([])
        axes[i].set_yticks([])
    
    for loc in ["top","bottom","left", "right"]:
        for i in range(3):
            axes[i].spines[loc].set_visible(False)
    
    plt.show()