Skip to content

Commit

Permalink
fix: Fixed hooking mechanism (#23)
Browse files Browse the repository at this point in the history
* fix: Fixed hook inheritance

* refactor: Optimized memory usage and hooks

* style: Fixed typing

* refactor: Removed unused import
  • Loading branch information
frgfm committed Nov 12, 2020
1 parent 9a31c57 commit 5bd8a8c
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
6 changes: 6 additions & 0 deletions scripts/cam_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,13 @@ def main(args):
ScoreCAM(model, conv_layer, input_layer), SSCAM(model, conv_layer, input_layer),
ISSCAM(model, conv_layer, input_layer)]

# Don't trigger all hooks
for extractor in cam_extractors:
extractor._hooks_enabled = False

fig, axes = plt.subplots(1, len(cam_extractors), figsize=(7, 2))
for idx, extractor in enumerate(cam_extractors):
extractor._hooks_enabled = True
model.zero_grad()
scores = model(img_tensor.unsqueeze(0))

Expand All @@ -73,6 +78,7 @@ def main(args):
activation_map = extractor(class_idx, scores).cpu()
# Clean data
extractor.clear_hooks()
extractor._hooks_enabled = False
# Convert it to PIL image
# The indexing below means first image in batch
heatmap = to_pil_image(activation_map, mode='F')
Expand Down
10 changes: 5 additions & 5 deletions torchcam/cams/cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ class _CAM:
conv_layer: name of the last convolutional layer
"""

hook_a: Optional[Tensor] = None
hook_handles: List[torch.utils.hooks.RemovableHandle] = []

def __init__(
self,
model: nn.Module,
Expand All @@ -28,6 +25,9 @@ def __init__(
if not hasattr(model, conv_layer):
raise ValueError(f"Unable to find submodule {conv_layer} in the model")
self.model = model
# Init hooks
self.hook_a: Optional[Tensor] = None
self.hook_handles: List[torch.utils.hooks.RemovableHandle] = []
# Forward hook
self.hook_handles.append(self.model._modules.get(conv_layer).register_forward_hook(self._hook_a))
# Enable hooks
Expand Down Expand Up @@ -150,7 +150,7 @@ def __init__(
model: nn.Module,
conv_layer: str,
fc_layer: str
):
) -> None:

super().__init__(model, conv_layer)
# Softmax weight
Expand Down Expand Up @@ -219,7 +219,7 @@ def __init__(
# Ensure ReLU is applied to CAM before normalization
self._relu = True

def _store_input(self, module: nn.Module, input: Tensor):
def _store_input(self, module: nn.Module, input: Tensor) -> None:
"""Store model input tensor"""

if self._hooks_enabled:
Expand Down
4 changes: 2 additions & 2 deletions torchcam/cams/gradcam.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ class _GradCAM(_CAM):
conv_layer: name of the last convolutional layer
"""

hook_g: Optional[Tensor] = None

def __init__(
self,
model: torch.nn.Module,
conv_layer: str
) -> None:

super().__init__(model, conv_layer)
# Init hook
self.hook_g: Optional[Tensor] = None
# Ensure ReLU is applied before normalization
self._relu = True
# Model output is used by the extractor
Expand Down

0 comments on commit 5bd8a8c

Please sign in to comment.