Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement SegGradCAM based on pytorch-grad-cam #94

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
157 changes: 157 additions & 0 deletions pytorch_grad_cam/utils/roi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import numpy as np
import torch
import matplotlib.pyplot as plt
from skimage import measure
import collections

def gui_get_point(image, i=None, j=None):
fig = plt.figure('Input Pick An Point')
scale = np.mean(image.shape[:2])
if len(image.shape)==3:
pImg = plt.imshow(image)
else:
pImg = plt.matshow(image)

pMarker = plt.scatter(j, i, c='r', s=scale, marker='x')
ret = plt.ginput(1)
if ret == None or ret == []:
pass
else:
j, i = ret[0]
j, i= int(j), int(i)
pMarker.remove()
pMarker = plt.scatter(j, i, c='r', s=scale, marker='x')
plt.close(fig)
return i, j

class BaseROI:
def __init__(self, image = None):
self.image = image
self.roi = torch.Tensor([1])
self.fullroi = None
self.i = None
self.j = None

def setROIij(self):
print(f'Shape of ROI:{self.roi.shape}')
self.i = np.where(self.roi == 1)[0]
self.j = np.where(self.roi == 1)[1]
print(f'Lengths of i and j index lists: {len(self.i)}, {len(self.j)}')

def meshgrid(self):
ylist = np.linspace(0, self.image.shape[0], self.image.shape[0])
xlist = np.linspace(0, self.image.shape[1], self.image.shape[1])
return np.meshgrid(xlist, ylist)

def apply_roi(self, output):
return self.roi.to(output.device) * output


class PixelROI(BaseROI):
def __init__(self, i, j, image):
self.image = image
self.roi = torch.zeros((image.shape[-3], image.shape[-2]))
self.roi[i, j] = 1
self.i = i
self.j = j

def pickPixel(self):
self.i, self.j = gui_get_point(self.image, self.i, self.j)
self.roi.zero_()
self.roi[self.i, self.j] = 1
print(f'ROI Point: {self.i},{self.j}')

def filter_connected_components(values, counts, exclude):
selected_indices=[]
selected_counts=[]
selected_values=[]
for i in range(len(values)):
if values[i] != exclude:
selected_indices.append([i])
selected_values.append(values[i])
selected_counts.append(counts[i])
return selected_indices, selected_values, selected_counts

class ClassROI(BaseROI):
def __init__(self, image, pred, cls, background=0):
self.image = image
self.pred = pred
self.cls = cls
self.roi = (pred == cls).reshape(image.shape[-3], image.shape[-2])
self.background = background
print(f'Valid ROI pixels: {torch.sum(self.roi).numpy()} of class {self.cls}')

def connectedComponents(self):
all_labels = measure.label(self.pred)
(values, counts) = np.unique(all_labels, return_counts=True)
print("connectedComponents values, counts: ", values, counts)
return all_labels, values, counts

def largestComponent(self):
all_labels, values, counts = self.connectedComponents()
# find the largest component
selected_indices, selected_values, selected_counts = filter_connected_components(values,
counts,
self.background)
ind = selected_indices[np.argmax(selected_counts)]
print("largestComponent argmax: ", ind)
self.roi = torch.Tensor(all_labels == ind)
print(f'Valid ROI pixels: {torch.sum(self.roi).numpy()} of class {values[ind]}')

def smallestComponent(self):
all_labels, values, counts = self.connectedComponents()
selected_indices, selected_values, selected_counts = filter_connected_components(values,
counts,
self.background)
ind = selected_indices[np.argmin(selected_counts)]
print("smallestComponent argmin: ", ind)
self.roi = torch.Tensor(all_labels == ind)
print(f'Valid ROI pixels: {torch.sum(self.roi).numpy()} of class {values[ind]}')

def pickClass(self):
i, j = gui_get_point(self.pred)
self.cls = self.pred[i, j]
self.roi = (self.pred == self.cls).reshape(self.image.shape[-3], self.image.shape[-2])
print(f'Valid ROI pixels: {torch.sum(self.roi).numpy()} of class {self.cls}')

def pickComponentClass(self):
i, j = gui_get_point(self.pred)
all_labels, values, counts = self.connectedComponents()
ind = all_labels[i, j]
self.cls = all_labels[i, j]
self.roi = torch.Tensor(all_labels == ind).reshape(self.image.shape[-3], self.image.shape[-2])
print(f'Valid ROI pixels: {torch.sum(self.roi).numpy()} of class {self.cls}')

# Get tensor from output of network. Some segmentation network returns more than 1 tensor.
def get_output_tensor(output, verbose=False):
if isinstance(output, torch.Tensor):
return output
elif isinstance(output, collections.OrderedDict):
k = next(iter(output.keys()))
if verbose: print(f'Select "{k}" from dict {output.keys()}')
return output[k]
elif isinstance(output, list):
if verbose: print(f'Select "[0]" from list(n={len(output)})')
return output[0]
else:
raise RuntimeError(f'Unknown type {type(output)}')

class SegModel(torch.nn.Module):
def __init__(self, model, roi=None):
super(SegModel, self).__init__()
self.model = model
self.roi = roi

def forward(self, x):
output = self.model(x) # might be multiple tensors
output = get_output_tensor(output) # Ensure only one tensor

N = output.shape[-3]
if N == 1: # if the original problem is binary using sigmoid, change to one-hot style.
output = torch.log_softmax([-output, output], dim=-3)

if self.roi is not None:
output = self.roi.apply_roi(output)
output = torch.sum(output, dim=(-2, -1))
return output

6 changes: 4 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
dataclasses==0.8
dataclasses
dicom-factory==0.0.3
numpy==1.19.5
Pillow==8.1.1
torch==1.7.1
torchvision==0.8.2
typing-extensions==3.7.4.3
scikit-image
ttach
tqdm
tqdm
opencv-python
156 changes: 156 additions & 0 deletions segcam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import argparse
import cv2
import numpy as np
import torch
import torch.nn
from torchvision import models
import matplotlib.pyplot as plt

from pytorch_grad_cam import GradCAM, \
ScoreCAM, \
GradCAMPlusPlus, \
AblationCAM, \
XGradCAM, \
EigenCAM, \
EigenGradCAM

from pytorch_grad_cam.utils.roi import BaseROI, \
PixelROI, \
ClassROI, \
get_output_tensor, \
SegModel

from pytorch_grad_cam import GuidedBackpropReLUModel
from pytorch_grad_cam.utils.image import show_cam_on_image, \
deprocess_image, \
preprocess_image


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--use-cuda', action='store_true', default=False,
help='Use NVIDIA GPU acceleration')
parser.add_argument('--image-path', type=str, default='./examples/both.png',
help='Input image path')
parser.add_argument('--aug_smooth', action='store_true',
help='Apply test time augmentation to smooth the CAM')
parser.add_argument('--eigen_smooth', action='store_true',
help='Reduce noise by taking the first principle componenet'
'of cam_weights*activations')
parser.add_argument('--method', type=str, default='gradcam',
choices=['gradcam', 'gradcam++', 'scorecam', 'xgradcam',
'ablationcam', 'eigencam', 'eigengradcam'],
help='Can be gradcam/gradcam++/scorecam/xgradcam'
'/ablationcam/eigencam/eigengradcam')
parser.add_argument('--roimode', type=int, default='0')
args = parser.parse_args()
args.use_cuda = args.use_cuda and torch.cuda.is_available()
if args.use_cuda:
print('Using GPU for acceleration')
else:
print('Using CPU for computation')

return args


if __name__ == '__main__':
""" python cam.py -image-path <path_to_image>
Example usage of loading an image, and computing:
1. CAM
2. Guided Back Propagation
3. Combining both
"""

args = get_args()
methods = \
{"gradcam": GradCAM,
"scorecam": ScoreCAM,
"gradcam++": GradCAMPlusPlus,
"ablationcam": AblationCAM,
"xgradcam": XGradCAM,
"eigencam": EigenCAM,
"eigengradcam": EigenGradCAM}

model = models.segmentation.fcn_resnet50(pretrained=True)
model.eval()
# Choose the target layer you want to compute the visualization for.
# Usually this will be the last convolutional layer in the model.
# Some common choices can be:
# Resnet18 and 50: model.layer4[-1]
# VGG, densenet161: model.features[-1]
# mnasnet1_0: model.layers[-1]
# You can print the model to help chose the layer
target_layer = model.backbone.layer4[-1]

rgb_img = cv2.imread(args.image_path, 1)[:, :, ::-1]
rgb_img = np.float32(rgb_img) / 255
input_tensor = preprocess_image(rgb_img, mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

ROIMode = args.roimode
if ROIMode == 0:
## All pixels
segmodel = SegModel(model, roi=BaseROI(rgb_img))
elif ROIMode == 1:
## Single code assigned roi
roi = PixelROI(50, 130, rgb_img)
segmodel = SegModel(model, roi=roi)
elif ROIMode == 2:
## User pick a pixel
roi = PixelROI(50, 130, rgb_img)
## Before or after pass to model, both work
# roi.pickPoint()
segmodel = SegModel(model, roi=roi)
roi.pickPixel()
elif ROIMode == 3:
## Of specific class (GT or predict, depending on what user passes)
pred = torch.argmax(get_output_tensor(model(input_tensor)), -3).squeeze(0)
roi = ClassROI(rgb_img, pred, 12)
# roi.largestComponent()
# roi.smallestComponent()
# roi.pickClass()
roi.pickComponentClass()
segmodel = SegModel(model, roi=roi)


cam = methods[args.method](model=segmodel,
target_layer=target_layer,
use_cuda=args.use_cuda)

# If None, returns the map for the highest scoring category.
# Otherwise, targets the requested category.
target_category = None

# AblationCAM and ScoreCAM have batched implementations.
# You can override the internal batch size for faster computation.
cam.batch_size = 32

grayscale_cam = cam(input_tensor=input_tensor,
target_category=target_category,
aug_smooth=args.aug_smooth,
eigen_smooth=args.eigen_smooth)

# Here grayscale_cam has only one image in the batch
grayscale_cam = grayscale_cam[0, :]

cam_image = show_cam_on_image(rgb_img, grayscale_cam)

gb_model = GuidedBackpropReLUModel(model=segmodel, use_cuda=args.use_cuda)
gb = gb_model(input_tensor, target_category=target_category)

cam_mask = cv2.merge([grayscale_cam, grayscale_cam, grayscale_cam])
cam_gb = deprocess_image(cam_mask * gb)
gb = deprocess_image(gb)

if True:
plt.figure()
plt.imshow(cam_image)
# plt.figure()
# plt.imshow(gb)
plt.figure()
plt.imshow(cam_gb)
plt.show()
else:
cv2.imwrite(f'{args.method}_cam.jpg', cam_image)
cv2.imwrite(f'{args.method}_gb.jpg', gb)
cv2.imwrite(f'{args.method}_cam_gb.jpg', cam_gb)