In [2]:
import jax.numpy as jnp
import jax
import numpy as np
import equinox

In [6]:
jnp.zeros(tuple())

Array(0., dtype=float32)

In [25]:
from dataclasses import dataclass
import jax
import jax.numpy as jnp

@dataclass(frozen=True)
class MyModule:
    scale: np.ndarray

    def __hash__(self):
        return id(self)

    def __call__(self, x):
        return x * self.scale

# Use like this
module = MyModule(scale=np.array([2., 3., 4.]))
x = jnp.array([1., 2., 3.])
out = jax.jit(module)(x)

In [None]:
from functools import partial


@partial(jax.jit, static_argnames=("scale"))
def func(x: jnp.ndarray, scale: np.ndarray) -> jnp.ndarray:
    return x * scale

<class '__main__.MyModule'>


In [26]:
import functools
from typing import Callable


class custom_jit:
    def __init__(self, func, static_kwargnames=None):
        if static_kwargnames is None:
            static_kwargnames = []
        self.static_kwargnames = static_kwargnames
        self.func = func
        self.jitted_func: Callable | None = None
        functools.update_wrapper(self, func)

    def __call__(self, *args, **kwargs):
        if self.jitted_func is None:
            static_kwargs = {k: kwargs.pop(k) for k in self.static_kwargnames}
            def _func(*args, **kwargs):
                return self.func(*args, **kwargs, **static_kwargs)
            self.jitted_func = jax.jit(_func)
        else:
            kwargs = {k: kwargs[k] for k in kwargs if k not in self.static_kwargnames}
        return self.jitted_func(*args, **kwargs)

In [55]:
@functools.partial(custom_jit, static_kwargnames=("scale",))
def func(x: jnp.ndarray, scale: np.ndarray) -> jnp.ndarray:
    print("jitting")
    return x * scale

In [None]:
arr_1 = np.array([3., 2., 3.])
arr_2 = jnp.array([7., 2., 3.])

func(arr_1, scale=arr_2)

np.ndarray

Array([21.,  4.,  9.], dtype=float32)