Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: #5

Closed
Ema1997 opened this issue Nov 6, 2020 · 11 comments

Comments

@Ema1997
Copy link

Ema1997 commented Nov 6, 2020

I encountered with a runtime error when I tried to search for an architecture based on your code.

/opt/conda/conda-bld/pytorch_1565272279342/work/torch/csrc/autograd/python_anomaly_mode.cpp:57: UserWarning: Traceback of forward call that caused the error:
  File "tools/train.py", line 300, in <module>
    main()
  File "tools/train.py", line 259, in main
    est=model_est, local_rank=args.local_rank)
  File "/opt/tiger/cream/lib/core/train.py", line 55, in train_epoch
    output = model(input, random_cand)
  File "/home/tiger/.conda/envs/Cream/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tiger/.conda/envs/Cream/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 442, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/home/tiger/.conda/envs/Cream/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/tiger/cream/lib/models/structures/supernet.py", line 121, in forward
    x = self.forward_features(x, architecture)
  File "/opt/tiger/cream/lib/models/structures/supernet.py", line 113, in forward_features
    x = blocks[arch](x)
  File "/home/tiger/.conda/envs/Cream/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tiger/.conda/envs/Cream/lib/python3.6/site-packages/timm/models/efficientnet_blocks.py", line 133, in forward
    x = self.bn1(x)
  File "/home/tiger/.conda/envs/Cream/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/tiger/.conda/envs/Cream/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py", line 81, in forward
    exponential_average_factor, self.eps)
  File "/home/tiger/.conda/envs/Cream/lib/python3.6/site-packages/torch/nn/functional.py", line 1656, in batch_norm
    training, momentum, eps, torch.backends.cudnn.enabled

Traceback (most recent call last):
  File "tools/train.py", line 300, in <module>
    main()
  File "tools/train.py", line 259, in main
    est=model_est, local_rank=args.local_rank)
  File "/opt/tiger/cream/lib/core/train.py", line 67, in train_epoch
    loss.backward()
  File "/home/tiger/.conda/envs/Cream/lib/python3.6/site-packages/torch/tensor.py", line 118, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/tiger/.conda/envs/Cream/lib/python3.6/site-packages/torch/autograd/__init__.py", line 93, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [320]] is at version 2507; expected version 2506 instead. Hint: the backtr
ace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

I tried to locate the source of the error, and I find that whenever the code update the meta network or add the kd_loss to the final loss the error above appears.
How can I fix this problem?

@Z7zuqer
Copy link
Collaborator

Z7zuqer commented Nov 6, 2020

Hi,

Thanks for your kind question!

Could you share the environment variables with us? We haven't encounter such problems using this code base.

Really appreciate that you could reply.

Best,
Hao.

@Ema1997
Copy link
Author

Ema1997 commented Nov 6, 2020

Package Version
future 0.18.2
numpy 1.17.0
opencv-python 4.0.1.24
Pillow 6.1.0
ptflops 0.6.2
tensorboard 2.3.0
tensorboard-plugin-wit 1.7.0
tensorboardX 1.2
thop 0.0.31.post2005241907
timm 0.1.20
torch 1.2.0
torchvision 0.2.1
yacs 0.1.8

thank you very much

@macn3388
Copy link

macn3388 commented Nov 9, 2020

same problem.

@Z7zuqer
Copy link
Collaborator

Z7zuqer commented Nov 25, 2020

Hi,

We have carefully checked the source codes and environments, this bug is from torch.dist.distributed. We thought apex was not required before. However, due to the implemention in torch DDP, we could not train the supernet in SPOS mechaism.

Thus, to solve this bug, it's necessary to run over apex package. You should install apex before supernet training. We would fix installation steps in README.md.

Thanks.
Hao.

@Z7zuqer Z7zuqer closed this as completed Nov 25, 2020
@penghouwen
Copy link
Member

@Ema1997 @macn3388 Would you check whether the issue has been solved? Thanks.

@cswaynecool
Copy link

The same error occurs, when using apex.

@cswaynecool
Copy link

Adding "for name, param in model.named_parameters(recurse=True): param.grad = None" at the beginning of update_student_weights_only solves my problem. It is caused by optimizer.step(), which changes the parameters of meta network.

@penghouwen
Copy link
Member

Adding "for name, param in model.named_parameters(recurse=True): param.grad = None" at the beginning of update_student_weights_only solves my problem. It is caused by optimizer.step(), which changes the parameters of meta network.

In our experience, if the installation strictly follows the README, this issue should not occur.

@Z7zuqer
Copy link
Collaborator

Z7zuqer commented Dec 18, 2020

Adding "for name, param in model.named_parameters(recurse=True): param.grad = None" at the beginning of update_student_weights_only solves my problem. It is caused by optimizer.step(), which changes the parameters of meta network.

HI,

Could you share your environment variables with us?

We have tested the codes. When using apex(installed following REAME), it should not occur.

Best,
Hao.

@jonsnows
Copy link

jonsnows commented Aug 5, 2021

Adding "for name, param in model.named_parameters(recurse=True): param.grad = None" at the beginning of update_student_weights_only solves my problem. It is caused by optimizer.step(), which changes the parameters of meta network.

hello i want to ask where you add the code? i ocuur the same problem after i have installed apex using pip.

@Z7zuqer
Copy link
Collaborator

Z7zuqer commented Aug 5, 2021

Adding "for name, param in model.named_parameters(recurse=True): param.grad = None" at the beginning of update_student_weights_only solves my problem. It is caused by optimizer.step(), which changes the parameters of meta network.

hello i want to ask where you add the code? i ocuur the same problem after i have installed apex using pip.

Hi,

You should install apex with cpp extension and cuda extension as indicated in this URL

python ./apex/setup.py install --cpp_ext --cuda_ext

Or you could add the above codes as SPOS did: Set the grad to None in each training iteration.

Best,
Hao.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants