Skip to content

RFC: xpx.compile as an abstraction for torch.compile and jax.jit decorators? #523

@ogrisel

Description

@ogrisel

Has anybody investigated the possibility to allow for an array agnostic way to leverage the torch.compile and jax.jit decorators in array-api-extra?

This might be useful for array API consuming libraries such as SciPy or scikit-learn. For array API namespaces without JIT compiler support, xpx.compile would just result in a noop decorator. For torch and JAX it might, dispatching to an actual JIT compiler could unlock significant speed-ups and memory usage improvements.

However, the parameters of those decorators have many kwargs with seemingly very little overlap:

Maybe xpx.compile could be made to accept arbitrary kwargs scoped by the underlying namespace name without attempting to map common compiler semantics together.

@xpx.compile(
   torch=dict(options={"triton.cudagraphs": True}, fullgraph=True),
   jax=dict(static_argnames=['n']),
)
def some_array_function(array, n):
   ...

I have little experience to tell whether calling those decorators with their default argument is useful or not in practice.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions