Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consider using autograd.make_vjp() instead of ceviche.jacobian() #1

Closed
ianwilliamson opened this issue Feb 5, 2023 · 4 comments
Closed
Labels
enhancement New feature or request

Comments

@ianwilliamson
Copy link

Hey there, cool package! I wanted to suggest an alternative implementation that may be more efficient.

It looks like javiche is currently using ceviche's jacobian() method, which (I think) may be less efficient if @javiche.jaxit is ever applied to a function whose output is not scalar-valued. The reason is that ceviche's jacobian() method loops over the output basis vectors to construct the Jacobian, and you typically don't need the explicit Jacobian when calculating VJPs. An example of where this would matter is if you just applied @jaxit to a function that returns the field distribution, but performed the loss function calculation in terms of the field distribution in JAX, calling jax.grad or jax.value_and_grad on the combination.

Below is a sketch of a more direct approach that maps autograd's make_vjp() function to JAX's VJP mechanism.

Given an autograd function, f_ag(*args) -> np.ndarray we can wrap it into a function, f(*args) -> jnp.ndarray. This is not ceviche-specific; it can be used for any autograd function with multiple inputs (*args) and a single array output, though it could be generalized to support multiple array outputs as well.

import jax
import jax.numpy as jnp
import numpy as np
import autograd


def as_numpy(x):
  def as_numpy_map(a):
    if isinstance(a, jnp.ndarray):
      return np.asarray(a)
    else:
      return a
  return jax.tree_util.tree_map(as_numpy_map, x)


def as_jax(x):
  def as_jax_map(a):
    if isinstance(a, np.ndarray):
      return jnp.asarray(a)
    else:
      return a
  return jax.tree_util.tree_map(as_jax_map, x)


@jax.custom_vjp
def f(*args):
  return as_jax(f_ag(*as_numpy(args)))


def f_fwd(*args):
  args = as_numpy(args)
  argnums = tuple(i for i, _ in enumerate(args))

  def f_ag_tupled(*args):
    ans = f_ag(*args)
    if isinstance(ans, tuple):
      return autograd.builtins.tuple(ans)
    else:
      return ans

  vjp_f, ans = autograd.make_vjp(f_ag_tupled, argnums)(*args)
  return as_jax(ans), jax.tree_util.Partial(vjp_f)


def f_rev(vjp_f, g):
  g = as_numpy(g)
  return as_jax(vjp_f(g))


f.defvjp(f_fwd, f_rev)

I was thinking about adding something like this to ceviche_challenges, but had not gotten around to it.

@jan-david-fischbach
Copy link
Owner

Thanks for the suggestion!!
I'll have a deeper look soon.
Do you want to open a pull request for proper attribution?
Or would you rather want me to integrate it with jaxit directly?
Regards JD

@ianwilliamson
Copy link
Author

I do not really have time to implement this and open a pull request right now. You're welcome to adapt the sketch above if you have time.

@jan-david-fischbach
Copy link
Owner

Top. I'll do that!

@jan-david-fischbach
Copy link
Owner

A the proposed change has been merged to master I'll close the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants