diff --git a/README.md b/README.md
index 6a964524e..e525dfb7c 100644
--- a/README.md
+++ b/README.md
@@ -57,6 +57,10 @@ The aim is also to serve as a benchmark of algorithms and metrics for research o
| -----------------|-----------------------|
|
|
|
+| Semantic Segmentation (3D) |
+| -------------------------- |
+|
|
+
## Explaining similarity to other images / embeddings
diff --git a/examples/multiorgan_segmentation.gif b/examples/multiorgan_segmentation.gif
new file mode 100644
index 000000000..66b510dbd
Binary files /dev/null and b/examples/multiorgan_segmentation.gif differ
diff --git a/pytorch_grad_cam/base_cam.py b/pytorch_grad_cam/base_cam.py
index f08cd12bb..4b2850a98 100644
--- a/pytorch_grad_cam/base_cam.py
+++ b/pytorch_grad_cam/base_cam.py
@@ -1,21 +1,25 @@
+from typing import Callable, List, Optional, Tuple
+
import numpy as np
import torch
import ttach as tta
-from typing import Callable, List, Tuple, Optional
+
from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients
-from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
from pytorch_grad_cam.utils.image import scale_cam_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
+from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
class BaseCAM:
- def __init__(self,
- model: torch.nn.Module,
- target_layers: List[torch.nn.Module],
- reshape_transform: Callable = None,
- compute_input_gradient: bool = False,
- uses_gradients: bool = True,
- tta_transforms: Optional[tta.Compose] = None) -> None:
+ def __init__(
+ self,
+ model: torch.nn.Module,
+ target_layers: List[torch.nn.Module],
+ reshape_transform: Callable = None,
+ compute_input_gradient: bool = False,
+ uses_gradients: bool = True,
+ tta_transforms: Optional[tta.Compose] = None,
+ ) -> None:
self.model = model.eval()
self.target_layers = target_layers
@@ -34,63 +38,64 @@ def __init__(self,
else:
self.tta_transforms = tta_transforms
- self.activations_and_grads = ActivationsAndGradients(
- self.model, target_layers, reshape_transform)
+ self.activations_and_grads = ActivationsAndGradients(self.model, target_layers, reshape_transform)
""" Get a vector of weights for every channel in the target layer.
Methods that return weights channels,
will typically need to only implement this function. """
- def get_cam_weights(self,
- input_tensor: torch.Tensor,
- target_layers: List[torch.nn.Module],
- targets: List[torch.nn.Module],
- activations: torch.Tensor,
- grads: torch.Tensor) -> np.ndarray:
+ def get_cam_weights(
+ self,
+ input_tensor: torch.Tensor,
+ target_layers: List[torch.nn.Module],
+ targets: List[torch.nn.Module],
+ activations: torch.Tensor,
+ grads: torch.Tensor,
+ ) -> np.ndarray:
raise Exception("Not Implemented")
- def get_cam_image(self,
- input_tensor: torch.Tensor,
- target_layer: torch.nn.Module,
- targets: List[torch.nn.Module],
- activations: torch.Tensor,
- grads: torch.Tensor,
- eigen_smooth: bool = False) -> np.ndarray:
-
- weights = self.get_cam_weights(input_tensor,
- target_layer,
- targets,
- activations,
- grads)
- weighted_activations = weights[:, :, None, None] * activations
+ def get_cam_image(
+ self,
+ input_tensor: torch.Tensor,
+ target_layer: torch.nn.Module,
+ targets: List[torch.nn.Module],
+ activations: torch.Tensor,
+ grads: torch.Tensor,
+ eigen_smooth: bool = False,
+ ) -> np.ndarray:
+ weights = self.get_cam_weights(input_tensor, target_layer, targets, activations, grads)
+ # 2D conv
+ if len(activations.shape) == 4:
+ weighted_activations = weights[:, :, None, None] * activations
+ # 3D conv
+ elif len(activations.shape) == 5:
+ weighted_activations = weights[:, :, None, None, None] * activations
+ else:
+ raise ValueError(f"Invalid activation shape. Get {len(activations.shape)}.")
+
if eigen_smooth:
cam = get_2d_projection(weighted_activations)
else:
cam = weighted_activations.sum(axis=1)
return cam
- def forward(self,
- input_tensor: torch.Tensor,
- targets: List[torch.nn.Module],
- eigen_smooth: bool = False) -> np.ndarray:
-
+ def forward(
+ self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool = False
+ ) -> np.ndarray:
input_tensor = input_tensor.to(self.device)
if self.compute_input_gradient:
- input_tensor = torch.autograd.Variable(input_tensor,
- requires_grad=True)
+ input_tensor = torch.autograd.Variable(input_tensor, requires_grad=True)
self.outputs = outputs = self.activations_and_grads(input_tensor)
if targets is None:
target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1)
- targets = [ClassifierOutputTarget(
- category) for category in target_categories]
+ targets = [ClassifierOutputTarget(category) for category in target_categories]
if self.uses_gradients:
self.model.zero_grad()
- loss = sum([target(output)
- for target, output in zip(targets, outputs)])
+ loss = sum([target(output) for target, output in zip(targets, outputs)])
loss.backward(retain_graph=True)
# In most of the saliency attribution papers, the saliency is
@@ -102,25 +107,24 @@ def forward(self,
# This gives you more flexibility in case you just want to
# use all conv layers for example, all Batchnorm layers,
# or something else.
- cam_per_layer = self.compute_cam_per_layer(input_tensor,
- targets,
- eigen_smooth)
+ cam_per_layer = self.compute_cam_per_layer(input_tensor, targets, eigen_smooth)
return self.aggregate_multi_layers(cam_per_layer)
- def get_target_width_height(self,
- input_tensor: torch.Tensor) -> Tuple[int, int]:
- width, height = input_tensor.size(-1), input_tensor.size(-2)
- return width, height
+ def get_target_width_height(self, input_tensor: torch.Tensor) -> Tuple[int, int]:
+ if len(input_tensor.shape) == 4:
+ width, height = input_tensor.size(-1), input_tensor.size(-2)
+ return width, height
+ elif len(input_tensor.shape) == 5:
+ depth, width, height = input_tensor.size(-1), input_tensor.size(-2), input_tensor.size(-3)
+ return depth, width, height
+ else:
+ raise ValueError("Invalid input_tensor shape. Only 2D or 3D images are supported.")
def compute_cam_per_layer(
- self,
- input_tensor: torch.Tensor,
- targets: List[torch.nn.Module],
- eigen_smooth: bool) -> np.ndarray:
- activations_list = [a.cpu().data.numpy()
- for a in self.activations_and_grads.activations]
- grads_list = [g.cpu().data.numpy()
- for g in self.activations_and_grads.gradients]
+ self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool
+ ) -> np.ndarray:
+ activations_list = [a.cpu().data.numpy() for a in self.activations_and_grads.activations]
+ grads_list = [g.cpu().data.numpy() for g in self.activations_and_grads.gradients]
target_size = self.get_target_width_height(input_tensor)
cam_per_target_layer = []
@@ -134,36 +138,26 @@ def compute_cam_per_layer(
if i < len(grads_list):
layer_grads = grads_list[i]
- cam = self.get_cam_image(input_tensor,
- target_layer,
- targets,
- layer_activations,
- layer_grads,
- eigen_smooth)
+ cam = self.get_cam_image(input_tensor, target_layer, targets, layer_activations, layer_grads, eigen_smooth)
cam = np.maximum(cam, 0)
scaled = scale_cam_image(cam, target_size)
cam_per_target_layer.append(scaled[:, None, :])
return cam_per_target_layer
- def aggregate_multi_layers(
- self,
- cam_per_target_layer: np.ndarray) -> np.ndarray:
+ def aggregate_multi_layers(self, cam_per_target_layer: np.ndarray) -> np.ndarray:
cam_per_target_layer = np.concatenate(cam_per_target_layer, axis=1)
cam_per_target_layer = np.maximum(cam_per_target_layer, 0)
result = np.mean(cam_per_target_layer, axis=1)
return scale_cam_image(result)
- def forward_augmentation_smoothing(self,
- input_tensor: torch.Tensor,
- targets: List[torch.nn.Module],
- eigen_smooth: bool = False) -> np.ndarray:
+ def forward_augmentation_smoothing(
+ self, input_tensor: torch.Tensor, targets: List[torch.nn.Module], eigen_smooth: bool = False
+ ) -> np.ndarray:
cams = []
for transform in self.tta_transforms:
augmented_tensor = transform.augment_image(input_tensor)
- cam = self.forward(augmented_tensor,
- targets,
- eigen_smooth)
+ cam = self.forward(augmented_tensor, targets, eigen_smooth)
# The ttach library expects a tensor of size BxCxHxW
cam = cam[:, None, :, :]
@@ -178,19 +172,18 @@ def forward_augmentation_smoothing(self,
cam = np.mean(np.float32(cams), axis=0)
return cam
- def __call__(self,
- input_tensor: torch.Tensor,
- targets: List[torch.nn.Module] = None,
- aug_smooth: bool = False,
- eigen_smooth: bool = False) -> np.ndarray:
-
+ def __call__(
+ self,
+ input_tensor: torch.Tensor,
+ targets: List[torch.nn.Module] = None,
+ aug_smooth: bool = False,
+ eigen_smooth: bool = False,
+ ) -> np.ndarray:
# Smooth the CAM result with test time augmentation
if aug_smooth is True:
- return self.forward_augmentation_smoothing(
- input_tensor, targets, eigen_smooth)
+ return self.forward_augmentation_smoothing(input_tensor, targets, eigen_smooth)
- return self.forward(input_tensor,
- targets, eigen_smooth)
+ return self.forward(input_tensor, targets, eigen_smooth)
def __del__(self):
self.activations_and_grads.release()
@@ -202,6 +195,5 @@ def __exit__(self, exc_type, exc_value, exc_tb):
self.activations_and_grads.release()
if isinstance(exc_value, IndexError):
# Handle IndexError here...
- print(
- f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}")
+ print(f"An exception occurred in CAM with block: {exc_type}. Message: {exc_value}")
return True
diff --git a/pytorch_grad_cam/grad_cam.py b/pytorch_grad_cam/grad_cam.py
index b03f69e16..17eb6ac57 100644
--- a/pytorch_grad_cam/grad_cam.py
+++ b/pytorch_grad_cam/grad_cam.py
@@ -1,4 +1,5 @@
import numpy as np
+
from pytorch_grad_cam.base_cam import BaseCAM
@@ -18,4 +19,14 @@ def get_cam_weights(self,
target_category,
activations,
grads):
- return np.mean(grads, axis=(2, 3))
+ # 2D image
+ if len(grads.shape) == 4:
+ return np.mean(grads, axis=(2, 3))
+
+ # 3D image
+ elif len(grads.shape) == 5:
+ return np.mean(grads, axis=(2, 3, 4))
+
+ else:
+ raise ValueError("Invalid grads shape."
+ "Shape of grads should be 4 (2D image) or 5 (3D image).")
\ No newline at end of file
diff --git a/pytorch_grad_cam/utils/image.py b/pytorch_grad_cam/utils/image.py
index 187b37e51..372de93de 100644
--- a/pytorch_grad_cam/utils/image.py
+++ b/pytorch_grad_cam/utils/image.py
@@ -1,12 +1,14 @@
-import matplotlib
-from matplotlib import pyplot as plt
-from matplotlib.lines import Line2D
+import math
+from typing import Dict, List
+
import cv2
+import matplotlib
import numpy as np
import torch
+from matplotlib import pyplot as plt
+from matplotlib.lines import Line2D
+from scipy.ndimage import zoom
from torchvision.transforms import Compose, Normalize, ToTensor
-from typing import List, Dict
-import math
def preprocess_image(
@@ -163,7 +165,11 @@ def scale_cam_image(cam, target_size=None):
img = img - np.min(img)
img = img / (1e-7 + np.max(img))
if target_size is not None:
+ if len(img.shape) > 3:
+ img = zoom(np.float32(img), [(t_s/i_s) for i_s, t_s in zip(img.shape, target_size[::-1])])
+ else:
img = cv2.resize(np.float32(img), target_size)
+
result.append(img)
result = np.float32(result)