Skip to content

grad returns complex conjugate of the gradient #9110

@campsd

Description

@campsd

I have noticed that automatic differentiation with JAX's grad function returns the complex conjugate of the gradient as compared to autograd in PyTorch. This is illustrated by running the example on the webpage:

def f(z):
  x, y = jnp.real(z), jnp.imag(z)
  return x**2 + y**2

z = 3. + 4j
grad(f)(z)
# >>> DeviceArray(6.-8.j, dtype=complex64)

Running the same test in PyTorch gives:

import torch
z = torch.tensor([3. + 4j], requires_grad=True)
f = torch.real(z)**2 +  torch.imag(z)**2
f.backward()
z.grad
# >>> tensor([6.+8.j])

I couldn't determine (based on my initial reading of the documentation) if this is actually the intended behavior or not. However, in order to get a gradient descent optimizer working for complex data in FunFact with JAX as a computational backend, we have to explicitly take the conjugate of the gradient, see this PR This behavior at least appears to be somewhat counterintuitive.

I tested this with jax[cpu] version 0.2.26.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions