-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
I have noticed that automatic differentiation with JAX's grad function returns the complex conjugate of the gradient as compared to autograd in PyTorch. This is illustrated by running the example on the webpage:
def f(z):
x, y = jnp.real(z), jnp.imag(z)
return x**2 + y**2
z = 3. + 4j
grad(f)(z)
# >>> DeviceArray(6.-8.j, dtype=complex64)Running the same test in PyTorch gives:
import torch
z = torch.tensor([3. + 4j], requires_grad=True)
f = torch.real(z)**2 + torch.imag(z)**2
f.backward()
z.grad
# >>> tensor([6.+8.j])I couldn't determine (based on my initial reading of the documentation) if this is actually the intended behavior or not. However, in order to get a gradient descent optimizer working for complex data in FunFact with JAX as a computational backend, we have to explicitly take the conjugate of the gradient, see this PR This behavior at least appears to be somewhat counterintuitive.
I tested this with jax[cpu] version 0.2.26.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working