Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support static arguments for any function transformation #16071

Open
danijar opened this issue May 19, 2023 · 2 comments
Open

Support static arguments for any function transformation #16071

danijar opened this issue May 19, 2023 · 2 comments
Labels
enhancement New feature or request

Comments

@danijar
Copy link

danijar commented May 19, 2023

Following up on #15504, I ended up needing a more general solution because I kept running into the same issue with other function transformations that often don't support static arguments. My solution ended up being this this snippet that adds static argument support to any transformation. A meta transformation! :D

def static_support(transform):
  def new_transform(fun, *args, static=(), **kwargs):
    assert isinstance(static, (list, tuple)), static
    cache = {}
    def new_function(*args2, **kwargs2):
      sta = {k: v for k, v in kwargs2.items() if k in static}
      dyn = {k: v for k, v in kwargs2.items() if k not in static}
      key = hash(tuple(sta.get(n, '_default') for n in static))
      if key not in cache:
        specialized = bind(fun, **sta)
        specialized.__name__ = fun.__name__
        cache[key] = transform(specialized, *args, **kwargs)
      return cache[key](*args2, **dyn)
    return new_function
  return new_transform

fun = static_support(jit)(fun, static=['static_arg'])
fun = static_support(checkify)(fun, static=['static_arg'])
# ...

A simplifying design choice here is that static arguments are required to be passed as keyword arguments and cannot be passed as positional arguments, which I find much easier to specify. When the transformed function is already the output of some other transformation, counting argument names can be painful.

I'm wondering whether including a helper like the one above would be a useful addition, simplification, and unification for JAX? It might simplify the existing JAX transformations, provide a unified API to users, and allow users to use third-party transformations that may not support static arguments. Longer term, users could even be asked to apply the meta transformation themselves instead of repeating the API across the provided transformations.

@danijar danijar added the enhancement New feature or request label May 19, 2023
@danijar danijar changed the title Functionality to support static arguments in any function transformation Support static arguments for any function transformation May 19, 2023
@patrick-kidger
Copy link
Collaborator

You might find Equinox interesting, which does something similar: it has equinox.filter_{jit, grad, ...}, which automatically determines dynamic vs static arguments based on their type. After all, it mostly only makes sense to JIT wrt arrays, to grad wrt floating-point arrays, etc.

@danijar
Copy link
Author

danijar commented May 23, 2023

Thanks Patrick, I've heard of equinox and think it has a lot of interesting ideas. I use Ninjax.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants