Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: 'cannot register a hook on a tensor that doesn't require gradient' #195

Closed
dayunyan opened this issue Dec 21, 2022 · 7 comments
Assignees
Labels
module: methods Related to torchcam.methods type: improvement New feature or request
Milestone

Comments

@dayunyan
Copy link

Bug description

When I loaded XGradCAM for the same model twice in a row, I got an error 'cannot register a hook on a tensor that doesn't require gradient'.
Here is my code example
`

ERROR:---------------------------------------------------------------------------
`

Code snippet to reproduce the bug

import torch
import torch.nn as nn
from torchvision import models
from torchvision.models import resnet18,resnet50
from torchcam.methods import SmoothGradCAMpp,GradCAM,ISCAM,XGradCAM

model = models.resnet50()
model.eval()
model_1 = XGradCAM(model)
model_2 = XGradCAM(model)

Error traceback

RuntimeError                              Traceback (most recent call last)
Cell In[51], line 11
      9 model.eval()
     10 model_1 = XGradCAM(model)
---> 11 model_2 = XGradCAM(model)

File /opt/conda/envs/pytorch/lib/python3.8/site-packages/torchcam/methods/gradient.py:34, in _GradCAM.__init__(self, model, target_layer, input_shape, **kwargs)
     26 def __init__(
     27     self,
     28     model: nn.Module,
   (...)
     31     **kwargs: Any,
     32 ) -> None:
---> 34     super().__init__(model, target_layer, input_shape, **kwargs)
     35     # Ensure ReLU is applied before normalization
     36     self._relu = True

File /opt/conda/envs/pytorch/lib/python3.8/site-packages/torchcam/methods/core.py:53, in _CAM.__init__(self, model, target_layer, input_shape, enable_hooks)
     48     target_names = [
     49         self._resolve_layer_name(layer) if isinstance(layer, nn.Module) else layer for layer in target_layer
     50     ]
     51 elif target_layer is None:
     52     # If the layer is not specified, try automatic resolution
---> 53     target_name = locate_candidate_layer(model, input_shape)
     54     # Warn the user of the choice
     55     if isinstance(target_name, str):

File /opt/conda/envs/pytorch/lib/python3.8/site-packages/torchcam/methods/_utils.py:43, in locate_candidate_layer(mod, input_shape)
     41 # forward empty
     42 with torch.no_grad():
---> 43     _ = mod(torch.zeros((1, *input_shape), device=next(mod.parameters()).data.device))
     45 # Remove all temporary hooks
     46 for handle in hook_handles:

File ~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1148, in Module._call_impl(self, *input, **kwargs)
   1145     bw_hook = hooks.BackwardHook(self, full_backward_hooks)
   1146     input = bw_hook.setup_input_hook(input)
-> 1148 result = forward_call(*input, **kwargs)
   1149 if _global_forward_hooks or self._forward_hooks:
   1150     for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):

File ~/.local/lib/python3.8/site-packages/torchvision/models/resnet.py:285, in ResNet.forward(self, x)
    284 def forward(self, x: Tensor) -> Tensor:
--> 285     return self._forward_impl(x)

File ~/.local/lib/python3.8/site-packages/torchvision/models/resnet.py:276, in ResNet._forward_impl(self, x)
    274 x = self.layer2(x)
    275 x = self.layer3(x)
--> 276 x = self.layer4(x)
    278 x = self.avgpool(x)
    279 x = torch.flatten(x, 1)

File ~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1151, in Module._call_impl(self, *input, **kwargs)
   1149 if _global_forward_hooks or self._forward_hooks:
   1150     for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
-> 1151         hook_result = hook(self, input, result)
   1152         if hook_result is not None:
   1153             result = hook_result

File /opt/conda/envs/pytorch/lib/python3.8/site-packages/torchcam/methods/gradient.py:50, in _GradCAM._hook_g(self, module, input, output, idx)
     48 """Gradient hook"""
     49 if self._hooks_enabled:
---> 50     self.hook_handles.append(output.register_hook(partial(self._store_grad, idx=idx)))

File ~/.local/lib/python3.8/site-packages/torch/_tensor.py:430, in Tensor.register_hook(self, hook)
    428     return handle_torch_function(Tensor.register_hook, (self,), self, hook)
    429 if not self.requires_grad:
--> 430     raise RuntimeError("cannot register a hook on a tensor that "
    431                        "doesn't require gradient")
    432 if self._backward_hooks is None:
    433     self._backward_hooks = OrderedDict()

RuntimeError: cannot register a hook on a tensor that doesn't require gradient

Environment

TorchCAM version: 0.3.2
PyTorch version: 1.12.1+cu102

OS: Ubuntu 18.04.2 LTS

Python version: 3.8.15
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration: GPU 0: Tesla V100-SXM2-32GB
Nvidia driver version: 455.23.05
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.2
/usr/local/cuda-10.0/targets/x86_64-linux/lib/libcudnn.so.7.6.2

@frgfm frgfm added bug Something isn't working module: methods Related to torchcam.methods labels Dec 21, 2022
@frgfm frgfm added this to the 0.3.3 milestone Dec 21, 2022
@frgfm
Copy link
Owner

frgfm commented Dec 21, 2022

Hello @dayunyan 👋

Thanks for reporting this!
So this is an advanced topic actually, here's what's happening under the hood:

  • when you create a CAM extractor, two things happen in terms of hooks
    • if no target layer is passed, it will set temp forward hooks & make a forward pass to select one by default
    • once this is resolved, it will set & enable hooks necessary for the CAM computation. For gradient based methods, that means backward hooks 😅
  • so now that a first extractor is set
    • by default, its hooks are enabled (waiting for a model forward)
    • creating the second extractor without passing explicitly the target layer will trigger the layer resolution mechanism. However the forward pass of this resolution is run without checking gradients. But because you already have the first extractor & its gradient hooks enabled, that trigger the error 🙃

Now you have two solutions:

  1. Disable the hooks at the right moment
from torchvision.models import resnet50
from torchcam.methods import XGradCAM

model = resnet50().eval()
# Disable CAM computation hooks
extractor_1 = XGradCAM(model, enable_hooks=False)
extractor_2 = XGradCAM(model)
# Re-enable them
extractor_1._hooks_enabled = True
  1. Pass the target layer explicitly to avoid the problem
from torchvision.models import resnet50
from torchcam.methods import XGradCAM

model = resnet50().eval()
extractor_1 = XGradCAM(model)
# Use the layer resolution of the first extractor to avoid the double resolution
extractor_2 = XGradCAM(model, target_layer=extractor_1.target_names)

I would highly recommend the second one which is more efficient / faster :)

I'm trying to think about a way to prevent this behavior, but that means I should disable hooks automatically before attempting to locate a layer 🤯

One question though: why would you create two identical CAM extractors for the same model? 🤔

Anyway, I hope this helped!

@frgfm frgfm added type: improvement New feature or request awaiting response and removed bug Something isn't working labels Dec 21, 2022
@dayunyan
Copy link
Author

Thank you very much! What an excellent answer! @frgfm

As for your question, in the project source code, I want to extract the CAM in the process of model training and use it as a standard to calculate a loss. In my mind, since I updated the parameters in the model using loss.backward() & optimizer.step(), I should create a new extractor for the updated model before the next training loop, so I ran into this error.
The code structure looks like the following

begin train:
    for batch, [data, labels] in enumerate(dataloder):
        extractor = XGradCAM(model) # The error always occurs the second time the extractor is created
        # Input data and extract the CAM
        ........
        # Get the loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Now I wonder if I only need to create an extractor once outside the loop, and as the parameters of the model are updated, the parameters inside the extractor are updated as well.

I'm not sure if you understand what I mean, but anyway, I really appreciate your answer!

@frgfm
Copy link
Owner

frgfm commented Dec 21, 2022

You're very much welcome 😄
I started this project initially simply to implement research papers about DL interpretability to understand them, so the design might differ a bit from other neighbouring efforts!

About your question:

  • every time you create an extractor, it creates hooks (I haven't implemented an auto removal when it's dereferenced, I opened Remove hooks once we're done with CAM extractors #197 for that)
  • so I suggest creating the extractor once outside the loop
  • about the CAM computation, two things to consider
    1. the backprop of the loss does populate the gradients. But that's the partial derivative of the loss from a mathematical point of view, which increase with inaccuracies of the model.
    2. the CAM computation backprops a synthetic loss if you will (one-hot vector on the class you're interested in multiplied by output logits). This happens when extractor(class_idx, out) is called.

Now, here are my suggestions:

  • if you want to use the gradient of the cam computation for param updating, bear in mind that for the optimization/param update to be relevant, the gradient needs to represent a variable that increases with errors. So if class_idx is the target, then use maximize=True in your optimizer (cf. https://pytorch.org/docs/stable/generated/torch.optim.Adam.html). That will reverse the sign of the update for the parameters.
from torch.optim import Adam
from torchcam.methods import LayerCAM

model = ...
optimizer = Adam(model.parameters(), maximize=True)
extractor = LayerCAM(model)

# Epoch loop
for x, target in dataloader:
    optimizer.zero_grad()
    out = model(x)
    # Backprop the CAM grads
    extractor(target.numpy().tolist(), out)
    optimizer.step()

Please understand that this sounds highly experimental, I cannot vouch for the outcome

  • if you want to use the gradient of the loss for the CAM computation, you'll have to manually modify torchcam since the automatic backprop is enforced. Sorry 🙃

Hope this isn't too obscure 😅

@dayunyan
Copy link
Author

To be honest, it's a little obscure 😂 But I will try to understand.

Thanks very very very very much again!🥰

@frgfm
Copy link
Owner

frgfm commented Dec 21, 2022

Feel free to ask if you have specific questions later on, I'll try my best to answer :)

@niniack
Copy link

niniack commented Mar 26, 2023

Hi @frgfm

I'm not sure if this is the best spot to ask this question, but it seems somewhat relevant to this topic.

somewhere outside the training loop:

model = resnet50()
cam_extractor = GradCAM(model, target_layer='conv1')

This is a simplified excerpt from my training loop:

optimizer.zero_grad()

# Forward pass
outputs = model(inputs)

# Batch processing
cams = cam_extractor(outputs.argmax(dim=1).tolist(), outputs, retain_graph=True)

#Grab the first from list (we only have one target layer)
cams = cams[0]

# Simplified but nothing wild happening here
custom_loss = custom_loss(cams)

loss = criterion(outputs, labels)
loss += custom_loss
loss.backward()
optimizer.step()

This doesn't work too well for me and I suspect that I'm losing the computational graph somewhere? When I do:

print(cams.requires_grad)

I get a False. So, I'm not really able to go backward (I think?). Setting cams.requires_grad=True doesn't seem like the right answer either.

Am I on the right track? I posted this here I am also trying to use the CAM for updating parameters.

@frgfm
Copy link
Owner

frgfm commented Mar 26, 2023

Hi @niniack 👋

Actually, I think this is unrelated to this topic, but let's check this:

  • first thing, for me to debug, it's much easier when I have a minimal snippet to reproduce the problem on my end. I understand you might not want to share all of that and that's quite fine :) However "This doesn't work too well for me" is not very helpful haha what do you mean? NaNs? a uniform CAM?
  • I don't have the context here, but without the training loop, I can only try to guess: you are correct that what you pass to a loss function/criterion needs to require gradient. The goal is to compute the derivative of parameters in regards to the loss. I have no clue what custom_loss does, but CAMs by nature don't need to propagate the gradients (i.e. I didn't ensure the operations done to compute the CAMs were preserving the gradient flow)
  • That means you can easily use them as targets or any other tensors that doesn't require gradients. I imagine that will be easier than making CAMs backpropable

I hope this helps a bit!
Cheers ✌️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: methods Related to torchcam.methods type: improvement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants