diff --git a/scripts/cam_example.py b/scripts/cam_example.py index 19550005..792cb77d 100644 --- a/scripts/cam_example.py +++ b/scripts/cam_example.py @@ -61,8 +61,13 @@ def main(args): ScoreCAM(model, conv_layer, input_layer), SSCAM(model, conv_layer, input_layer), ISSCAM(model, conv_layer, input_layer)] + # Don't trigger all hooks + for extractor in cam_extractors: + extractor._hooks_enabled = False + fig, axes = plt.subplots(1, len(cam_extractors), figsize=(7, 2)) for idx, extractor in enumerate(cam_extractors): + extractor._hooks_enabled = True model.zero_grad() scores = model(img_tensor.unsqueeze(0)) @@ -73,6 +78,7 @@ def main(args): activation_map = extractor(class_idx, scores).cpu() # Clean data extractor.clear_hooks() + extractor._hooks_enabled = False # Convert it to PIL image # The indexing below means first image in batch heatmap = to_pil_image(activation_map, mode='F') diff --git a/torchcam/cams/cam.py b/torchcam/cams/cam.py index 43d7ce23..0f6af454 100644 --- a/torchcam/cams/cam.py +++ b/torchcam/cams/cam.py @@ -16,9 +16,6 @@ class _CAM: conv_layer: name of the last convolutional layer """ - hook_a: Optional[Tensor] = None - hook_handles: List[torch.utils.hooks.RemovableHandle] = [] - def __init__( self, model: nn.Module, @@ -28,6 +25,9 @@ def __init__( if not hasattr(model, conv_layer): raise ValueError(f"Unable to find submodule {conv_layer} in the model") self.model = model + # Init hooks + self.hook_a: Optional[Tensor] = None + self.hook_handles: List[torch.utils.hooks.RemovableHandle] = [] # Forward hook self.hook_handles.append(self.model._modules.get(conv_layer).register_forward_hook(self._hook_a)) # Enable hooks @@ -150,7 +150,7 @@ def __init__( model: nn.Module, conv_layer: str, fc_layer: str - ): + ) -> None: super().__init__(model, conv_layer) # Softmax weight @@ -219,7 +219,7 @@ def __init__( # Ensure ReLU is applied to CAM before normalization self._relu = True - def _store_input(self, module: nn.Module, input: Tensor): + def _store_input(self, module: nn.Module, input: Tensor) -> None: """Store model input tensor""" if self._hooks_enabled: diff --git a/torchcam/cams/gradcam.py b/torchcam/cams/gradcam.py index db89fda0..0d44100f 100644 --- a/torchcam/cams/gradcam.py +++ b/torchcam/cams/gradcam.py @@ -15,8 +15,6 @@ class _GradCAM(_CAM): conv_layer: name of the last convolutional layer """ - hook_g: Optional[Tensor] = None - def __init__( self, model: torch.nn.Module, @@ -24,6 +22,8 @@ def __init__( ) -> None: super().__init__(model, conv_layer) + # Init hook + self.hook_g: Optional[Tensor] = None # Ensure ReLU is applied before normalization self._relu = True # Model output is used by the extractor