Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for 3D Conv-Net #466

Merged
merged 5 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 15 additions & 5 deletions pytorch_grad_cam/base_cam.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Callable, List, Tuple

import numpy as np
import torch
import ttach as tta
from typing import Callable, List, Tuple

from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients
from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection
from pytorch_grad_cam.utils.image import scale_cam_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection


class BaseCAM:
Expand Down Expand Up @@ -47,12 +49,14 @@ def get_cam_image(self,
grads: torch.Tensor,
eigen_smooth: bool = False) -> np.ndarray:


weights = self.get_cam_weights(input_tensor,
target_layer,
targets,
activations,
grads)
weighted_activations = weights[:, :, None, None] * activations
w_shape = (slice(None), slice(None)) + (None,) * (len(activations.shape)-2)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is a bit less straight forward to understand.
Can you please explain what's going on here?
Do you think there is a way to rewrite it to be more clear ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That line does exactly the same thing as

# 2D conv
if len(activations.shape) == 4:
  weighted_activations = weights[:, :, None, None] * activations

# 3D conv
elif len(activations.shape) == 5:   
  weighted_activations = weights[:, :, None, None, None] * activations

But I think you are right: it does lack some readability.
I will rewrite the code here.

weighted_activations = weights[w_shape] * activations
if eigen_smooth:
cam = get_2d_projection(weighted_activations)
else:
Expand Down Expand Up @@ -99,8 +103,14 @@ def forward(self,

def get_target_width_height(self,
input_tensor: torch.Tensor) -> Tuple[int, int]:
width, height = input_tensor.size(-1), input_tensor.size(-2)
return width, height
if len(input_tensor.shape) == 4:
width, height = input_tensor.size(-1), input_tensor.size(-2)
return width, height
elif len(input_tensor.shape) == 5:
depth, width, height = input_tensor.size(-1), input_tensor.size(-2), input_tensor.size(-3)
return depth, width, height
else:
raise ValueError("Invalid input_tensor shape. Only 2D or 3D images are supported.")

def compute_cam_per_layer(
self,
Expand Down
13 changes: 12 additions & 1 deletion pytorch_grad_cam/grad_cam.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np

from pytorch_grad_cam.base_cam import BaseCAM


Expand All @@ -19,4 +20,14 @@ def get_cam_weights(self,
target_category,
activations,
grads):
return np.mean(grads, axis=(2, 3))
# 2D image
if len(grads.shape) == 4:
return np.mean(grads, axis=(2, 3))

# 3D image
elif len(grads.shape) == 5:
return np.mean(grads, axis=(2, 3, 4))

else:
raise ValueError("Invalid grads shape."
"Shape of grads should be 4 (2D image) or 5 (3D image).")
14 changes: 8 additions & 6 deletions pytorch_grad_cam/utils/image.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import matplotlib
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
import math
from typing import Dict, List

import cv2
import matplotlib
import numpy as np
import torch
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
from scipy.ndimage import zoom
from torchvision.transforms import Compose, Normalize, ToTensor
from typing import List, Dict
import math


def preprocess_image(
Expand Down Expand Up @@ -163,7 +165,7 @@ def scale_cam_image(cam, target_size=None):
img = img - np.min(img)
img = img / (1e-7 + np.max(img))
if target_size is not None:
img = cv2.resize(img, target_size)
img = zoom(img, [(t_s/i_s) for i_s, t_s in zip(img.shape, target_size[::-1])])
result.append(img)
result = np.float32(result)

Expand Down