Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient tape vs. adjoint=True #159

Open
johannespitz opened this issue Nov 27, 2023 · 4 comments
Open

Gradient tape vs. adjoint=True #159

johannespitz opened this issue Nov 27, 2023 · 4 comments
Labels
autodiff Issue with automatic differentiation question The issue author requires information

Comments

@johannespitz
Copy link

Are there any guidelines, when to use the wp.tape or the adjoint=True argument to compute gradients?
There are two examples of a torch.autograd.Function in this repository using different approches.

class ForwardKinematics(torch.autograd.Function):

class TestFunc(torch.autograd.Function):

I tried to expand the test_torch.py with multiple inputs, but I wasn't able to get it to work (reliably!). Usually I get Warp CUDA error 1: invalid argument (/buildAgent/work/a9ae500d09a78409/warp/native/warp.cu:1891) but sometimes the program segfaults instead.
Is there anything I would need to be aware of when using the adjoint=True argument?

Today my college experimented with the wp.tape approach, and we converged on the code pasted below.
We had to add the ctx.x0.grad.zero_() line to make sure that multiple calls to pytorch's backward() are working properly (to please gradcheck())
But even more importantly we had to add .clone(), requires_grad=True) to ensure that warp does not write into the pytroch's gradient buffers directly, since that results in gradients to be 2x of the true gradient if the variable is a leaf in the computation graph!
-> We haven't tested it, but that would suggest there is a bug in the the example_sim_fk_grad_torch example, no?

import numpy as np
import torch
from torch.autograd import gradcheck
import warp as wp

wp.init()
device = "cuda"
torch_device = wp.device_to_torch(device)
wp_device = wp.device_from_torch(torch_device)


@wp.kernel()
def op_kernel(
    x0: wp.array(dtype=wp.float32),
    x1: wp.array(dtype=wp.float32),
    x2: wp.array(dtype=wp.float32),
    y: wp.array(dtype=wp.float32),
):
    i = wp.tid()
    y[i] = x0[i] ** 2.0 * x1[i] ** 2.0 * x2[i] ** 2.0


class WPGradTape(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        x0,
        x1,
        x2,
    ):
        wp.synchronize_device()
        ctx.tape = wp.Tape()
        ctx.x0 = wp.from_torch(x0.clone(), requires_grad=True)
        ctx.x1 = wp.from_torch(x1.clone(), requires_grad=True)
        ctx.x2 = wp.from_torch(x2.clone(), requires_grad=True)
        ctx.y = wp.empty(x1.shape[0], dtype=wp.float32, device=wp_device)
        with ctx.tape:
            wp.launch(
                kernel=op_kernel,
                dim=[len(x1)],
                inputs=[
                    ctx.x0,
                    ctx.x1,
                    ctx.x2,
                ],
                outputs=[ctx.y],
                adjoint=False,
                device=wp_device,
            )
        wp.synchronize_device()
        return wp.to_torch(ctx.y)

    @staticmethod
    def backward(ctx, adj_y):
        wp.synchronize_device()
        ctx.x0.grad.zero_()
        ctx.x1.grad.zero_()
        ctx.x2.grad.zero_()
        ctx.y.grad = wp.from_torch(adj_y).contiguous()
        ctx.tape.backward()
        wp.synchronize_device()
        return (
            wp.to_torch(ctx.tape.gradients[ctx.x0]),
            wp.to_torch(ctx.tape.gradients[ctx.x1]),
            wp.to_torch(ctx.tape.gradients[ctx.x2]),
        )


torch.manual_seed(42)
n = 3
# input data
x0_base = torch.rand(n, dtype=torch.float32, device=torch_device, requires_grad=True)
x1_base = torch.rand(n, dtype=torch.float32, device=torch_device, requires_grad=True)
x2_base = torch.rand(n, dtype=torch.float32, device=torch_device, requires_grad=True)
x0 = x0_base * 1.0
x1 = x1_base * 1.0
x2 = x2_base * 1.0

def fun(x0, x1, x2):
    return x0**2 * x1**2 * x2**2


gradcheck(
    WPGradTape.apply,
    # fun,
    (
        x0_base,
        x1_base,
        x2_base,
    ),
    atol=1e-1,
    rtol=1e-1,
)
@shi-eric shi-eric added question The issue author requires information autodiff Issue with automatic differentiation labels Dec 5, 2023
@mmacklin
Copy link
Collaborator

Hi @johannespitz, in general you can use the adjoint=True flag to manually invoke the backward version of a kernel. This is what the wp.Tape object does, but it also takes care of a few other complications like tracking launches, zeroing gradients, etc.

When you run a backward pass it does accumulate gradients (always adds to existing arrays), this is similar to PyTorch, but it means that indeed you need to make sure they are zero'd somewhere between optimization steps.

I don't think you should need to call wp.synchronize() explicitly. @nvlukasz can you confirm?

@johannespitz
Copy link
Author

Thank's for the reply @mmacklin!
And I would be very interested to hear if/where we really need to call wp.synchronize() @nvlukasz.

Regarding the accumulation of gradients. When we use wp.from_torch() directly, as it is used in

ctx.joint_q = wp.from_torch(joint_q)

instead of creating a new pytorch tensor with .clone() the gradient of leaf nodes in the computation graph will be 2x the true gradient, even when we clear all gradients before the call.
That is because torch expects torch.autograd.Function's to return the gradient and not write it directly into the buffer. Therefore, torch then adds the returned gradient to the gradient that warp already wrote into the buffer (for leafs in the computation graph). Note for intermediate nodes it works only because usually (if retain_graph=False) the gradient buffers of those tensors are not used at all.

@nvlukasz
Copy link
Contributor

CUDA synchronization can be a little tricky, especially when launching work using multiple frameworks that use different scheduling mechanisms under the hood.

Short answer: If you're not explicitly creating and using custom CUDA streams in PyTorch or Warp, and both are targeting the same device, then synchronization is not necessary.

Long answer: By default, PyTorch uses the legacy default stream on each device. This stream is synchronous with respect to other blocking streams on the device, so no explicit synchonization is needed. Warp, by default, uses a blocking stream on each device, so Warp operations will automatically synchronize with PyTorch operations on the same device.

The picture changes if you start using custom streams in PyTorch. Those streams will not automatically synchronize with Warp streams, so manual synchronization will be required. This can be done using wp.synchronize(), wp.synchronize_device(), or wp.synchronize_stream(). These functions synchronize the host with outstanding GPU work, so launching new work will be done after prior work completes. We also support event-based device-side synchronization, which is generally faster because it doesn't sync the host and only ensures that the operations are synchronized on the device. This includes wp.wait_stream() and wp.wait_event(), as well as interop utilities like wp.stream_from_torch() and wp.stream_to_torch().

Note that when capturing CUDA graphs using PyTorch, a non-default stream is used, so synchronization becomes important.

Things can get a little complicated with multi-stream usage and graph capture, so we're working on extended documentation in this area! But in your simple example, the explicit synchronization shouldn't be necessary.

@johannespitz
Copy link
Author

Thank you for the detailed answer regarding the synchronization! @nvlukasz

Though, can either of you comment on the accumulation of the gradients again. @mmacklin
Am I missing something, or is the example code incorrect at the moment?
(Note: Optimizations with the 2x the gradient will likely work just fine, but if someone wants to extend the code they might run into problems.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
autodiff Issue with automatic differentiation question The issue author requires information
Projects
None yet
Development

No branches or pull requests

4 participants