diff --git a/README.md b/README.md index 6a964524e..e525dfb7c 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,10 @@ The aim is also to serve as a benchmark of algorithms and metrics for research o | -----------------|-----------------------| | | | +| Semantic Segmentation (3D) | +| -------------------------- | +| | + ## Explaining similarity to other images / embeddings diff --git a/examples/multiorgan_segmentation.gif b/examples/multiorgan_segmentation.gif new file mode 100644 index 000000000..66b510dbd Binary files /dev/null and b/examples/multiorgan_segmentation.gif differ diff --git a/pytorch_grad_cam/base_cam.py b/pytorch_grad_cam/base_cam.py index f08cd12bb..4b2850a98 100644 --- a/pytorch_grad_cam/base_cam.py +++ b/pytorch_grad_cam/base_cam.py @@ -1,21 +1,25 @@ +from typing import Callable, List, Optional, Tuple + import numpy as np import torch import ttach as tta -from typing import Callable, List, Tuple, Optional + from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients -from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection from pytorch_grad_cam.utils.image import scale_cam_image from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget +from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection class BaseCAM: - def __init__(self, - model: torch.nn.Module, - target_layers: List[torch.nn.Module], - reshape_transform: Callable = None, - compute_input_gradient: bool = False, - uses_gradients: bool = True, - tta_transforms: Optional[tta.Compose] = None) -> None: + def __init__( + self, + model: torch.nn.Module, + target_layers: List[torch.nn.Module], + reshape_transform: Callable = None, + compute_input_gradient: bool = False, + uses_gradients: bool = True, + tta_transforms: Optional[tta.Compose] = None, + ) -> None: self.model = model.eval() self.target_layers = target_layers @@ -34,63 +38,64 @@ def __init__(self, else: self.tta_transforms = tta_transforms - self.activations_and_grads = ActivationsAndGradients( - self.model, target_layers, reshape_transform) + self.activations_and_grads = ActivationsAndGradients(self.model, target_layers, reshape_transform) """ Get a vector of weights for every channel in the target layer. Methods that return weights channels, will typically need to only implement this function. """ - def get_cam_weights(self, - input_tensor: torch.Tensor, - target_layers: List[torch.nn.Module], - targets: List[torch.nn.Module], - activations: torch.Tensor, - grads: torch.Tensor) -> np.ndarray: + def get_cam_weights( + self, + input_tensor: torch.Tensor, + target_layers: List[torch.nn.Module], + targets: List[torch.nn.Module], + activations: torch.Tensor, + grads: torch.Tensor, + ) -> np.ndarray: raise Exception("Not Implemented") - def get_cam_image(self, - input_tensor: torch.Tensor, - target_layer: torch.nn.Module, - targets: List[torch.nn.Module], - activations: torch.Tensor, - grads: torch.Tensor, - eigen_smooth: bool = False) -> np.ndarray: - - weights = self.get_cam_weights(input_tensor, - target_layer, - targets, - activations, - grads) - weighted_activations = weights[:, :, None, None] * activations + def get_cam_image( + self, + input_tensor: torch.Tensor, + target_layer: torch.nn.Module, + targets: List[torch.nn.Module], + activations: torch.Tensor, + grads: torch.Tensor, + eigen_smooth: bool = False, + ) -> np.ndarray: + weights = self.get_cam_weights(input_tensor, target_layer, targets, activations, grads) + # 2D conv + if len(activations.shape) == 4: + weighted_activations = weights[:, :, None, None] * activations + # 3D conv + elif len(activations.shape) == 5: + weighted_activations = weights[:, :, None, None, None] * activations + else: + raise ValueError(f"Invalid activation shape. Get {len(activations.shape)}.") + if eigen_smooth: cam = get_2d_projection(weighted_activations) else: cam = weighted_activations.sum(axis=1) return cam - def forward(self, - input_tensor: torch.Tensor, - targets: List[torch.nn.Module], - eigen_smooth: bool = False) -> np.ndarray: - + def forward( + self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool = False + ) -> np.ndarray: input_tensor = input_tensor.to(self.device) if self.compute_input_gradient: - input_tensor = torch.autograd.Variable(input_tensor, - requires_grad=True) + input_tensor = torch.autograd.Variable(input_tensor, requires_grad=True) self.outputs = outputs = self.activations_and_grads(input_tensor) if targets is None: target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1) - targets = [ClassifierOutputTarget( - category) for category in target_categories] + targets = [ClassifierOutputTarget(category) for category in target_categories] if self.uses_gradients: self.model.zero_grad() - loss = sum([target(output) - for target, output in zip(targets, outputs)]) + loss = sum([target(output) for target, output in zip(targets, outputs)]) loss.backward(retain_graph=True) # In most of the saliency attribution papers, the saliency is @@ -102,25 +107,24 @@ def forward(self, # This gives you more flexibility in case you just want to # use all conv layers for example, all Batchnorm layers, # or something else. - cam_per_layer = self.compute_cam_per_layer(input_tensor, - targets, - eigen_smooth) + cam_per_layer = self.compute_cam_per_layer(input_tensor, targets, eigen_smooth) return self.aggregate_multi_layers(cam_per_layer) - def get_target_width_height(self, - input_tensor: torch.Tensor) -> Tuple[int, int]: - width, height = input_tensor.size(-1), input_tensor.size(-2) - return width, height + def get_target_width_height(self, input_tensor: torch.Tensor) -> Tuple[int, int]: + if len(input_tensor.shape) == 4: + width, height = input_tensor.size(-1), input_tensor.size(-2) + return width, height + elif len(input_tensor.shape) == 5: + depth, width, height = input_tensor.size(-1), input_tensor.size(-2), input_tensor.size(-3) + return depth, width, height + else: + raise ValueError("Invalid input_tensor shape. Only 2D or 3D images are supported.") def compute_cam_per_layer( - self, - input_tensor: torch.Tensor, - targets: List[torch.nn.Module], - eigen_smooth: bool) -> np.ndarray: - activations_list = [a.cpu().data.numpy() - for a in self.activations_and_grads.activations] - grads_list = [g.cpu().data.numpy() - for g in self.activations_and_grads.gradients] + self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool + ) -> np.ndarray: + activations_list = [a.cpu().data.numpy() for a in self.activations_and_grads.activations] + grads_list = [g.cpu().data.numpy() for g in self.activations_and_grads.gradients] target_size = self.get_target_width_height(input_tensor) cam_per_target_layer = [] @@ -134,36 +138,26 @@ def compute_cam_per_layer( if i < len(grads_list): layer_grads = grads_list[i] - cam = self.get_cam_image(input_tensor, - target_layer, - targets, - layer_activations, - layer_grads, - eigen_smooth) + cam = self.get_cam_image(input_tensor, target_layer, targets, layer_activations, layer_grads, eigen_smooth) cam = np.maximum(cam, 0) scaled = scale_cam_image(cam, target_size) cam_per_target_layer.append(scaled[:, None, :]) return cam_per_target_layer - def aggregate_multi_layers( - self, - cam_per_target_layer: np.ndarray) -> np.ndarray: + def aggregate_multi_layers(self, cam_per_target_layer: np.ndarray) -> np.ndarray: cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1) cam_per_target_layer = np.maximum(cam_per_target_layer, 0) result = np.mean(cam_per_target_layer, axis=1) return scale_cam_image(result) - def forward_augmentation_smoothing(self, - input_tensor: torch.Tensor, - targets: List[torch.nn.Module], - eigen_smooth: bool = False) -> np.ndarray: + def forward_augmentation_smoothing( + self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool = False + ) -> np.ndarray: cams = [] for transform in self.tta_transforms: augmented_tensor = transform.augment_image(input_tensor) - cam = self.forward(augmented_tensor, - targets, - eigen_smooth) + cam = self.forward(augmented_tensor, targets, eigen_smooth) # The ttach library expects a tensor of size BxCxHxW cam = cam[:, None, :, :] @@ -178,19 +172,18 @@ def forward_augmentation_smoothing(self, cam = np.mean(np.float32(cams), axis=0) return cam - def __call__(self, - input_tensor: torch.Tensor, - targets: List[torch.nn.Module] = None, - aug_smooth: bool = False, - eigen_smooth: bool = False) -> np.ndarray: - + def __call__( + self, + input_tensor: torch.Tensor, + targets: List[torch.nn.Module] = None, + aug_smooth: bool = False, + eigen_smooth: bool = False, + ) -> np.ndarray: # Smooth the CAM result with test time augmentation if aug_smooth is True: - return self.forward_augmentation_smoothing( - input_tensor, targets, eigen_smooth) + return self.forward_augmentation_smoothing(input_tensor, targets, eigen_smooth) - return self.forward(input_tensor, - targets, eigen_smooth) + return self.forward(input_tensor, targets, eigen_smooth) def __del__(self): self.activations_and_grads.release() @@ -202,6 +195,5 @@ def __exit__(self, exc_type, exc_value, exc_tb): self.activations_and_grads.release() if isinstance(exc_value, IndexError): # Handle IndexError here... - print( - f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}") + print(f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}") return True diff --git a/pytorch_grad_cam/grad_cam.py b/pytorch_grad_cam/grad_cam.py index b03f69e16..17eb6ac57 100644 --- a/pytorch_grad_cam/grad_cam.py +++ b/pytorch_grad_cam/grad_cam.py @@ -1,4 +1,5 @@ import numpy as np + from pytorch_grad_cam.base_cam import BaseCAM @@ -18,4 +19,14 @@ def get_cam_weights(self, target_category, activations, grads): - return np.mean(grads, axis=(2, 3)) + # 2D image + if len(grads.shape) == 4: + return np.mean(grads, axis=(2, 3)) + + # 3D image + elif len(grads.shape) == 5: + return np.mean(grads, axis=(2, 3, 4)) + + else: + raise ValueError("Invalid grads shape." + "Shape of grads should be 4 (2D image) or 5 (3D image).") \ No newline at end of file diff --git a/pytorch_grad_cam/utils/image.py b/pytorch_grad_cam/utils/image.py index 187b37e51..372de93de 100644 --- a/pytorch_grad_cam/utils/image.py +++ b/pytorch_grad_cam/utils/image.py @@ -1,12 +1,14 @@ -import matplotlib -from matplotlib import pyplot as plt -from matplotlib.lines import Line2D +import math +from typing import Dict, List + import cv2 +import matplotlib import numpy as np import torch +from matplotlib import pyplot as plt +from matplotlib.lines import Line2D +from scipy.ndimage import zoom from torchvision.transforms import Compose, Normalize, ToTensor -from typing import List, Dict -import math def preprocess_image( @@ -163,7 +165,11 @@ def scale_cam_image(cam, target_size=None): img = img - np.min(img) img = img / (1e-7 + np.max(img)) if target_size is not None: + if len(img.shape) > 3: + img = zoom(np.float32(img), [(t_s/i_s) for i_s, t_s in zip(img.shape, target_size[::-1])]) + else: img = cv2.resize(np.float32(img), target_size) + result.append(img) result = np.float32(result)