-
Notifications
You must be signed in to change notification settings - Fork 16
Closed
Description
I am not sure how to define a JIT decorator which support multiple libraries. Maybe like this?
from collections.abc import Mapping, Sequence
from functools import cache, wraps
from types import ModuleType
from typing import Any, Callable, ParamSpec, TypeVar
from array_api_compat import array_namespace
from frozendict import frozendict
P = ParamSpec("P")
T = TypeVar("T")
def get_jit_decorator(
module: ModuleType,
/,
*,
args: Mapping[ModuleType, Sequence[Any]] | None = None,
kwargs: Mapping[ModuleType, Mapping[str, Any]] | None = None,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
args = args or {}
kwargs = kwargs or {}
print(module.__name__)
if "numpy" in module.__name__:
import numba
jit = numba.jit
elif "torch" in module.__name__:
import torch
jit = torch.jit.script
else:
jit = getattr(module, "jit", lambda x: x)
@wraps(jit)
def inner(f: Callable[P, T]) -> Callable[P, T]:
return jit(f, *args.get(module, []), **kwargs.get(module, {}))
return inner # type: ignore[return-value]
def jit(
f: Callable[P, T],
/,
*,
args: Mapping[ModuleType, Sequence[Any]] | None = None,
kwargs: Mapping[ModuleType, Mapping[str, Any]] | None = None,
) -> Callable[P, T]:
args = frozendict(args or {})
kwargs = frozendict(kwargs or {})
get_jit_decorator_cache = cache(get_jit_decorator)
@wraps(f)
def inner(*args_inner: P.args, **kwargs_inner: P.kwargs) -> T:
try:
xp = array_namespace(*args_inner)
except TypeError as e:
if e.args[0] == "Unrecognized array input":
return f(*args_inner, **kwargs_inner)
raise
return get_jit_decorator_cache(xp, args=args, kwargs=kwargs)(f)(
*args_inner, **kwargs_inner
)
return innerMetadata
Metadata
Assignees
Labels
No labels