-
Notifications
You must be signed in to change notification settings - Fork 242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradient tape
vs. adjoint=True
#159
Comments
Hi @johannespitz, in general you can use the When you run a backward pass it does accumulate gradients (always adds to existing arrays), this is similar to PyTorch, but it means that indeed you need to make sure they are zero'd somewhere between optimization steps. I don't think you should need to call |
Thank's for the reply @mmacklin! Regarding the accumulation of gradients. When we use
instead of creating a new pytorch tensor with .clone() the gradient of leaf nodes in the computation graph will be 2x the true gradient, even when we clear all gradients before the call.That is because torch expects torch.autograd.Function 's to return the gradient and not write it directly into the buffer. Therefore, torch then adds the returned gradient to the gradient that warp already wrote into the buffer (for leafs in the computation graph). Note for intermediate nodes it works only because usually (if retain_graph=False ) the gradient buffers of those tensors are not used at all.
|
CUDA synchronization can be a little tricky, especially when launching work using multiple frameworks that use different scheduling mechanisms under the hood. Short answer: If you're not explicitly creating and using custom CUDA streams in PyTorch or Warp, and both are targeting the same device, then synchronization is not necessary. Long answer: By default, PyTorch uses the legacy default stream on each device. This stream is synchronous with respect to other blocking streams on the device, so no explicit synchonization is needed. Warp, by default, uses a blocking stream on each device, so Warp operations will automatically synchronize with PyTorch operations on the same device. The picture changes if you start using custom streams in PyTorch. Those streams will not automatically synchronize with Warp streams, so manual synchronization will be required. This can be done using Note that when capturing CUDA graphs using PyTorch, a non-default stream is used, so synchronization becomes important. Things can get a little complicated with multi-stream usage and graph capture, so we're working on extended documentation in this area! But in your simple example, the explicit synchronization shouldn't be necessary. |
Thank you for the detailed answer regarding the synchronization! @nvlukasz Though, can either of you comment on the accumulation of the gradients again. @mmacklin |
Are there any guidelines, when to use the
wp.tape
or theadjoint=True
argument to compute gradients?There are two examples of a
torch.autograd.Function
in this repository using different approches.warp/examples/example_sim_fk_grad_torch.py
Line 29 in 79a56a9
warp/warp/tests/test_torch.py
Line 416 in 79a56a9
I tried to expand the test_torch.py with multiple inputs, but I wasn't able to get it to work (reliably!). Usually I get
Warp CUDA error 1: invalid argument (/buildAgent/work/a9ae500d09a78409/warp/native/warp.cu:1891)
but sometimes the program segfaults instead.Is there anything I would need to be aware of when using the
adjoint=True
argument?Today my college experimented with the
wp.tape
approach, and we converged on the code pasted below.We had to add the
ctx.x0.grad.zero_()
line to make sure that multiple calls to pytorch'sbackward()
are working properly (to pleasegradcheck()
)But even more importantly we had to add
.clone(), requires_grad=True)
to ensure that warp does not write into the pytroch's gradient buffers directly, since that results in gradients to be 2x of the true gradient if the variable is a leaf in the computation graph!-> We haven't tested it, but that would suggest there is a bug in the the example_sim_fk_grad_torch example, no?
The text was updated successfully, but these errors were encountered: