# GradCAM

## Imports

In [None]:
# Common imports
import sys
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import os
from PIL import Image

# PyTorch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.models import resnet50

# GradCAM
from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

# My Imports
sys.path.append('../pytorch/')
from configs import Inputs
from train import get_classification_model
#from utils.data import RadiographSexDataset
from utils.data import FullRadiographSexDataset
from utils.augmentations import get_transforms

%load_ext autoreload
%autoreload 2

## Help functions

In [None]:
def denorm(img, means, stdvs):
    means = torch.tensor(means)
    stdvs = torch.tensor(stdvs)
    return means + stdvs*img.squeeze().permute(1, 2, 0)

## Transforms

In [None]:
from torchvision import transforms as T
means = [0.485, 0.456, 0.406]
stdvs = [0.229, 0.224, 0.225]
transform = T.Compose([
                T.Resize((224,224)),
                T.ToTensor(),
                T.Normalize(means, stdvs)
            ])

## Load the data

In [None]:
# model and weights
inputs = Inputs(selected_model='efficientnet-b0')
model = get_classification_model(inputs.model_name, 2)
checkpoint = torch.load('/home/bernardo/github/sex-age-estimation/backup-bia/patch-1/pesos/checkpoint-efficientnet-b0-fold-2-max-acc.pth.tar')
model.load_state_dict(checkpoint['state_dict'])

for i in range(1,31):
    filepath = f'/home/bernardo/datasets/pan-radiographs/splits/{i:02d}.txt'
    with open(filepath) as f:
        for line in f: #ler cada linha do txt
            fname = line.strip().split('/')[2] #retirar o \n
            sex = fname.split('-')[10]
            age = fname.split('-')[-2][1:]
            months = fname.split('-')[-1][1:3] #home/bernardo/datasets/pan-radiographs/1st-set

            if fname.split('-')[0] == 'pan': #separar os arquivos pan e panreport
                fpath = os.path.join(f'/home/bernardo/datasets/pan-radiographs/1st-set/images/{fname}')
           
            else:
                fpath = os.path.join(f'/home/bernardo/datasets/pan-radiographs/2nd-set/images/{fname}')
            im = Image.open(fpath)

val_dataset = FullRadiographSexDataset(
    root_dir=inputs.DATASET_DIR,
    fold_nums=inputs.val_folds,
    transforms=get_transforms(inputs, subset='val') #aqui não tá indo de primeira. coloquei primeiro transforms = transform da celula anterior, mas só funcionou o gradcam com get_transforms
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0
)


## GradCAM

In [None]:
targets = None # uses highest score category

# target_layers = [model.layer4[-1]]# this is the last layer for resnet
#target_layers = [model.classifier[1]] # this is the last layer for efficientnet b0

target_layers = [model.features[-1]] ## this is the last layer for efficientnet b1 e b0
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)

attention = torch.zeros((224,224))  #(224,224) para a b0 e (240,240) para a b1
model = model.cuda()
for idx, (img, label) in tqdm(enumerate(val_dataloader)):
    img, label = img.cuda(), label.cuda()
    preds = model(img)
    prediction = torch.argmax(preds).item()
    ground_truth = label.item()
    
    grayscale_cam = cam(input_tensor=img, targets=None)

     # In this example grayscale_cam has only one image in the batch:                  
    grayscale_cam = grayscale_cam[0, :]

    # Average all:
    attention += grayscale_cam/len(val_dataloader)

    image = denorm(img.cpu(), inputs.MEAN, inputs.STDV).cpu().numpy()
    #visualize = show_cam_on_image(image, attention, use_rgb=True) 
    visualize = show_cam_on_image(image, grayscale_cam, use_rgb=True) #para visualizar uma imagem de cada vez
    plt.imshow(visualize)
    #break # para visualizar apenas 1 imagem de cada vez sem carregar o total de imagens do fold que você escolheu para visualizar (tem diferença)
#para rodar o mapa de atenção médio tem que rodar essa celula sem o break acima, porque ele vai fazer uma varredura com todas imagen da validação

## Grad CAM médio

In [None]:
visualize = show_cam_on_image(image, attention, use_rgb=True)
plt.imshow(visualize)

## Pontos máximos de atenção do GradCAM

In [None]:
# Pontos máximos de atenção sem uma imagem específica 
import numpy as np
media = attention.squeeze().cpu().numpy()
plt.imshow(media)
np.where(media == media.max())

In [None]:
# Predição da rede para essa imagem
prediction = torch.argmax(preds).item()
prediction

In [None]:
# Rótulo verdadeiro da imagem
# Homem = 1
label

## Show GradCAM on the last image

In [None]:
last_image = denorm(img[0].cpu(), means, stdvs).cpu().numpy()
#image = denorm(img.cpu(), inputs.MEAN, inputs.STDV).cpu().numpy()
visualize = show_cam_on_image(last_image, grayscale_cam, use_rgb=True)
plt.imshow(visualize)

## Outras análises do GradCAM

In [None]:
def max_attention(att):
    if isinstance(att, torch.Tensor):
        att = att.detach().numpy()

    y_max, x_max = np.unravel_index(np.argmax(att), att.shape)

    return x_max, y_max
batch_size = 64

In [None]:
# available_cams = [GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad]
available_cams = [GradCAM, GradCAMPlusPlus, XGradCAM, EigenCAM]

attention_maps = []
all_max_attention_points = []
size = len(val_dataloader) // 20 # len(val_dataloader)
for selected_cam in available_cams:
    print('.', end='')
    cam = selected_cam(model=model, target_layers=target_layers, use_cuda=False)

    max_attention_points = []
    attention = torch.zeros((224, 224))

    for idx, (imgs, labels) in enumerate(val_dataloader):
        imgs, labels = imgs.cuda(), labels.cuda()

        # You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
        grayscale_cams = cam(input_tensor=imgs, targets=targets)

        # In this example grayscale_cam has only one image in the batch:
        for grayscale_cam in grayscale_cams:
            # max attention points
            max_attention_points.append(max_attention(grayscale_cam))

            # average all
            attention += grayscale_cam / size / batch_size

        if idx == size - 1:
            break
    
    attention_maps.append(attention)
    all_max_attention_points.append(max_attention_points)

## Pontos de atenção

In [None]:
n = len(attention_maps)
fig, axes = plt.subplots(1, n, figsize=(20, n*20))
for i, (attention, max_attention_points) in enumerate(zip(attention_maps, all_max_attention_points)):
    Xs = [x for x, y in max_attention_points]
    Ys = [y for x, y in max_attention_points]
    axes[i].imshow(attention)
    axes[i].plot(Xs, Ys, 'ro', alpha=5/len(max_attention_points))

In [None]:
plt.figure(figsize=(10, 10))

plt.xlim(0, 224)
plt.ylim(224, 0)

X = [p[0] for p in max_attention_points]
Y = [p[1] for p in max_attention_points]
image = denorm(img.cpu(), inputs.MEAN, inputs.STDV).cpu().numpy()

plt.imshow(image)
# plt.imshow(attention.detach().numpy())

for x in range(7):
    xx = 32*x
    plt.plot([xx, xx], [0, 224], 'g--')

for y in range(7):
    yy = 32*y
    plt.plot([0, 224], [yy, yy], 'b--')

plt.plot(X, Y, 'ro', alpha=10/size/batch_size)

## EigenCAM

In [None]:
target_layers = [model.features[-1]] ## this is the last layer for efficientnet b1 e b0
cam = EigenCAM(model=model, target_layers=target_layers, use_cuda=True)

attention = torch.zeros((224,224))  #(224,224) para a b0 e (240,240) para a b1
model = model.cuda()
for idx, (img, label) in tqdm(enumerate(val_dataloader)):
    img, label = img.cuda(), label.cuda()
    preds = model(img)
    prediction = torch.argmax(preds).item()
    ground_truth = label.item()
    
    grayscale_cam = cam(input_tensor=img, targets=None)

     # In this example grayscale_cam has only one image in the batch:                  
    grayscale_cam = grayscale_cam[0, :]

    # Average all:
    attention += grayscale_cam/len(val_dataloader)

    image = denorm(img.cpu(), inputs.MEAN, inputs.STDV).cpu().numpy()
    #visualize = show_cam_on_image(image, attention, use_rgb=True) 
    visualize = show_cam_on_image(image, grayscale_cam, use_rgb=True) #para visualizar uma imagem de cada vez
    plt.imshow(visualize)
    #break # para visualizar apenas 1 imagem de cada vez sem carregar o total de imagens do fold que você escolheu para visualizar (tem diferença)
#para rodar o mapa de atenção médio tem que rodar essa celula sem o break acima, porque ele vai fazer uma varredura com todas imagen da validação

## EigenCAM médio

In [None]:
visualize = show_cam_on_image(image, attention, use_rgb=True)
plt.imshow(visualize)

## Pontos máximos de atenção do EigenCAM

In [None]:
# Pontos máximos de atenção sem uma imagem específica 
import numpy as np
media = attention.squeeze().cpu().numpy()
plt.imshow(media)
np.where(media == media.max())

## GradCAMPlusPlus

In [None]:
target_layers = [model.features[-1]] ## this is the last layer for efficientnet b1 e b0
cam = GradCAMPlusPlus(model=model, target_layers=target_layers, use_cuda=True)

attention = torch.zeros((224,224))  #(224,224) para a b0 e (240,240) para a b1
model = model.cuda()
for idx, (img, label) in tqdm(enumerate(val_dataloader)):
    img, label = img.cuda(), label.cuda()
    preds = model(img)
    prediction = torch.argmax(preds).item()
    ground_truth = label.item()
    
    grayscale_cam = cam(input_tensor=img, targets=None)

     # In this example grayscale_cam has only one image in the batch:                  
    grayscale_cam = grayscale_cam[0, :]

    # Average all:
    attention += grayscale_cam/len(val_dataloader)

    image = denorm(img.cpu(), inputs.MEAN, inputs.STDV).cpu().numpy()
    #visualize = show_cam_on_image(image, attention, use_rgb=True) 
    visualize = show_cam_on_image(image, grayscale_cam, use_rgb=True) #para visualizar uma imagem de cada vez
    plt.imshow(visualize)
    #break # para visualizar apenas 1 imagem de cada vez sem carregar o total de imagens do fold que você escolheu para visualizar (tem diferença)
#para rodar o mapa de atenção médio tem que rodar essa celula sem o break acima, porque ele vai fazer uma varredura com todas imagen da validação

## GradCAMPlusPlus médio

In [None]:
visualize = show_cam_on_image(image, attention, use_rgb=True)
plt.imshow(visualize)

## Pontos máximos de atenção do GradCAMPlusPlus

In [None]:
# Pontos máximos de atenção sem uma imagem específica 
import numpy as np
media = attention.squeeze().cpu().numpy()
plt.imshow(media)
np.where(media == media.max())

## Aplicando o retângulo preto

In [None]:
img_orig = denorm(img[0].cpu(), Inputs.MEAN, Inputs.STDV)
plt.imshow(img_orig)

In [None]:
#img_orig[100:150, 62:162, :] = 0