In [2]:
import jax
import jax.numpy as jnp

In [42]:
class compiledmethod:

    def __init__(self, wrapped_method):
        # self : compiledmethod
        # wrapped_method : the class method being decorated
        print(f"__init__ {wrapped_method=} {self=}")
        self.wrapped_method = wrapped_method
        self.docstring = wrapped_method.__doc__

    def __set_name__(self, owner, name):
        # self : compiledmethod
        # owner : parent class that will have `self` as a member
        # name : the name of the attribute that `self` will be
        print(f"set_name {owner=} {name=}")
        self.public_name = name
        self.private_name = '_precompiled_' + name
        if not hasattr(owner, '_compiledmethods'):
            owner._compiledmethods = []
        owner._compiledmethods.append(name)

    def __get__(self, obj, objtype=None):
        # self : compiledmethod
        # obj : instance of parent class that has `self` as a member
        # objtype : class of `obj`
        print(f"get {self.public_name=} {self=} {obj=}")
        result = getattr(obj, self.private_name, None)
        if result is None:

            @jax.jit
            def func(*args, **kwargs):
                return self.wrapped_method(obj, *args, **kwargs)

            print(f"recompile {self.public_name=} {self=} {obj=}")
            result = func
            result.__doc__ = self.docstring
            setattr(obj, self.private_name, result)
        return result

    def __set__(self, obj, value):
        # self : compiledmethod
        # obj : instance of parent class that has `self` as a member
        # value : the new value that is trying to be assigned
        raise AttributeError(f"can't set {self.public_name}")

class Thing:

    def __init__(self):
        self.a = jnp.arange(10)

    @compiledmethod
    def multy(self, y):
        """
        Do a thing.

        Parameters
        ----------
        y : array-like or scalar
        """
        sa = jnp.array(self.a)
        aa = jnp.zeros(12)
        aa = aa.at[:10].add(self.a)
        return aa * y


__init__ wrapped_method=<function Thing.multy at 0x136755040> self=<__main__.compiledmethod object at 0x1367512b0>
set_name owner=<class '__main__.Thing'> name='multy'


In [43]:
t = Thing()

In [44]:
t.multy(100)

get self.public_name='multy' self=<__main__.compiledmethod object at 0x1367512b0> obj=<__main__.Thing object at 0x130f93fa0>
recompile self.public_name='multy' self=<__main__.compiledmethod object at 0x1367512b0> obj=<__main__.Thing object at 0x130f93fa0>


DeviceArray([   0.,  100.,  200.,  300.,  400.,  500.,  600.,  700.,  800.,  900.,    0.,    0.], dtype=float32)

In [45]:
t.multy(100)

get self.public_name='multy' self=<__main__.compiledmethod object at 0x1367512b0> obj=<__main__.Thing object at 0x130f93fa0>


DeviceArray([   0.,  100.,  200.,  300.,  400.,  500.,  600.,  700.,  800.,  900.,    0.,    0.], dtype=float32)

In [46]:
t.multy(10)

get self.public_name='multy' self=<__main__.compiledmethod object at 0x1367512b0> obj=<__main__.Thing object at 0x130f93fa0>


DeviceArray([  0.,  10.,  20.,  30.,  40.,  50.,  60.,  70.,  80.,  90.,   0.,   0.], dtype=float32)

In [47]:
t.multy(jnp.arange(12)+1)

get self.public_name='multy' self=<__main__.compiledmethod object at 0x1367512b0> obj=<__main__.Thing object at 0x130f93fa0>


DeviceArray([  0.,   2.,   6.,  12.,  20.,  30.,  42.,  56.,  72.,  90.,   0.,   0.], dtype=float32)

In [48]:
t.multy?

get self.public_name='multy' self=<__main__.compiledmethod object at 0x1367512b0> obj=None
recompile self.public_name='multy' self=<__main__.compiledmethod object at 0x1367512b0> obj=None
get self.public_name='multy' self=<__main__.compiledmethod object at 0x1367512b0> obj=<__main__.Thing object at 0x130f93fa0>
