-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
v1.4.0 no longer seems to support backward()
with the inputs
parameter referencing a sub-module's parameters
#233
Comments
inputs
specified
inputs
specifiedbackward()
with the inputs
parameter specified
Changing |
backward()
with the inputs
parameter specifiedbackward()
with the inputs
parameter referencing a sub-module's parameters
Hi Cas, thanks for your detailed description and the code snippet. One main difference between
Best, |
As BackPACK was originally designed to work with |
Hi Felix, Thank you for the quick response and the proposed workaround. First of all: I actually think I made an error somewhere when I tried changing I am using With
With
The hooks do not seem to be called at all for the module in this case. If I change
You specified that BackPACK was originally designed to work with |
Hi,
There's no special treatment of the first hierarchy when extending a model.
From the I would recommend to try the above workaround. Let me know if it works. |
The above workaround doesn't seem to work. |
Hi @ngonthier, can you describe in more detail how/why the workaround does not seem to work? |
Hi, |
Hi, that indeed sounds like unintended behavior. |
Hi @ngonthier and @f-dangel, Sorry for not replying any sooner. I believe @ngonthier his observation is correct. See the minimal working example below: from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential
from backpack import backpack, extend
from backpack.extensions import BatchGrad
from backpack.utils.examples import load_one_batch_mnist
X, y = load_one_batch_mnist(batch_size=512)
l1 = Linear(784, 128)
l1.requires_grad = False
l2 = Linear(128, 10)
model = Sequential(Flatten(), l1, l2)
lossfunc = CrossEntropyLoss()
model = extend(model, debug=True)
lossfunc = extend(lossfunc, debug=True)
loss = lossfunc(model(X), y)
with backpack(BatchGrad(), debug=True):
loss.backward()
# This should fail for the first layer, right? It doesn't!
for name, param in model.named_parameters():
print(name)
print(".grad.shape: ", param.grad.shape)
print(".grad_batch.shape: ", param.grad_batch.shape) This is the
|
Hi, thanks for providing a script to reproduce the issue. I think you're incorrectly setting The correct way to disable gradients is for p in l1.parameters():
p.requires_grad = False instead of l1.requires_grad = False |
You're right! I was under the impression that this would recursively disable grad for all parameters... 👀 With the suggested change it does work. I also checked if the output is the same for these two methods in version 1.3.0 once seeded and that is indeed the case:
I think that leaves me with a last question before closing the issue: Should there be a more informative error / warning on Backpack's side when using the |
I'm not sure how one would detect that |
No I'm not sure. I'm not familiar enough with the backpack codebase I'm afraid... I'll close this issue then. Thank you for thinking along these last couple of weeks. 🙂 👍 |
I am playing around with the DomainBed repository. I noticed that for the implementation of Fishr, they specifically install version
1.3.0
and I was wondering why.After a bit of experimentation, it seems that it is no longer possible to use
backward(inputs=...)
whereinputs
is a submodule. I adjusted the example from your documentation to replicate the issue:With
backpack-for-pytorch==1.4.0
, this givenWith
backpack-for-pytorch==1.3.0
, this prints the expected output:I tried going through the git history of this repository to identify what changed between these two versions, but I have not managed to pin down the change that caused this. I was wondering whether this is intentional or a bug.
The text was updated successfully, but these errors were encountered: