-
-
Notifications
You must be signed in to change notification settings - Fork 191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RuntimeError: 'cannot register a hook on a tensor that doesn't require gradient' #195
Comments
Hello @dayunyan 👋 Thanks for reporting this!
Now you have two solutions:
from torchvision.models import resnet50
from torchcam.methods import XGradCAM
model = resnet50().eval()
# Disable CAM computation hooks
extractor_1 = XGradCAM(model, enable_hooks=False)
extractor_2 = XGradCAM(model)
# Re-enable them
extractor_1._hooks_enabled = True
from torchvision.models import resnet50
from torchcam.methods import XGradCAM
model = resnet50().eval()
extractor_1 = XGradCAM(model)
# Use the layer resolution of the first extractor to avoid the double resolution
extractor_2 = XGradCAM(model, target_layer=extractor_1.target_names) I would highly recommend the second one which is more efficient / faster :) I'm trying to think about a way to prevent this behavior, but that means I should disable hooks automatically before attempting to locate a layer 🤯 One question though: why would you create two identical CAM extractors for the same model? 🤔 Anyway, I hope this helped! |
Thank you very much! What an excellent answer! @frgfm As for your question, in the project source code, I want to extract the CAM in the process of model training and use it as a standard to calculate a loss. In my mind, since I updated the parameters in the model using loss.backward() & optimizer.step(), I should create a new extractor for the updated model before the next training loop, so I ran into this error.
Now I wonder if I only need to create an extractor once outside the loop, and as the parameters of the model are updated, the parameters inside the extractor are updated as well. I'm not sure if you understand what I mean, but anyway, I really appreciate your answer! |
You're very much welcome 😄 About your question:
Now, here are my suggestions:
from torch.optim import Adam
from torchcam.methods import LayerCAM
model = ...
optimizer = Adam(model.parameters(), maximize=True)
extractor = LayerCAM(model)
# Epoch loop
for x, target in dataloader:
optimizer.zero_grad()
out = model(x)
# Backprop the CAM grads
extractor(target.numpy().tolist(), out)
optimizer.step() Please understand that this sounds highly experimental, I cannot vouch for the outcome
Hope this isn't too obscure 😅 |
To be honest, it's a little obscure 😂 But I will try to understand. Thanks very very very very much again!🥰 |
Feel free to ask if you have specific questions later on, I'll try my best to answer :) |
Hi @frgfm I'm not sure if this is the best spot to ask this question, but it seems somewhat relevant to this topic. somewhere outside the training loop: model = resnet50()
cam_extractor = GradCAM(model, target_layer='conv1') This is a simplified excerpt from my training loop: optimizer.zero_grad()
# Forward pass
outputs = model(inputs)
# Batch processing
cams = cam_extractor(outputs.argmax(dim=1).tolist(), outputs, retain_graph=True)
#Grab the first from list (we only have one target layer)
cams = cams[0]
# Simplified but nothing wild happening here
custom_loss = custom_loss(cams)
loss = criterion(outputs, labels)
loss += custom_loss
loss.backward()
optimizer.step() This doesn't work too well for me and I suspect that I'm losing the computational graph somewhere? When I do: print(cams.requires_grad) I get a Am I on the right track? I posted this here I am also trying to use the CAM for updating parameters. |
Hi @niniack 👋 Actually, I think this is unrelated to this topic, but let's check this:
I hope this helps a bit! |
Bug description
When I loaded XGradCAM for the same model twice in a row, I got an error 'cannot register a hook on a tensor that doesn't require gradient'.
Here is my code example
`
ERROR:
---------------------------------------------------------------------------`
Code snippet to reproduce the bug
Error traceback
Environment
TorchCAM version: 0.3.2
PyTorch version: 1.12.1+cu102
OS: Ubuntu 18.04.2 LTS
Python version: 3.8.15
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration: GPU 0: Tesla V100-SXM2-32GB
Nvidia driver version: 455.23.05
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.2
/usr/local/cuda-10.0/targets/x86_64-linux/lib/libcudnn.so.7.6.2
The text was updated successfully, but these errors were encountered: