In [1]:
import jax.numpy as jnp
import numpy as np
from functools import partial, lru_cache
from jax import jit, tree_util, block_until_ready
import time
from jax.scipy.special import factorial

In [8]:
@tree_util.register_pytree_node_class
class TestClass:
    def __init__(self, a):
        self.a = a
        self.b = a @ a @ a @ a @ a
    
    @jit
    def b_eval(self):
        a = self.a
        return a @ a @ a @ a @ a

    @lru_cache(maxsize=None)
    @jit
    def use_b_non_cache(self):
        return self.b_eval() + 1

    def use_b_cache(self):
        return self.b + 1

    ''' JAX prereqs '''
    
    def tree_flatten(self):
        children = (
            self.a,
            self.b,
        )
        aux_data = {}
        return children, aux_data

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(a=children[0])

In [9]:
a_test = np.random.random((5000, 5000))
t = TestClass(a_test)

In [10]:
# Timing repeated b_eval()
for i in range(5):
    time1 = time.time()
    b_test = block_until_ready(t.use_b_non_cache())
    time2 = time.time()
    print('time', time2-time1, 'b[0,0]', b_test[0, 0])
    

time 4.4910888671875 b[0,0] 19741354000000.0
time 1.430511474609375e-05 b[0,0] 19741354000000.0
time 8.821487426757812e-06 b[0,0] 19741354000000.0
time 7.152557373046875e-06 b[0,0] 19741354000000.0
time 6.4373016357421875e-06 b[0,0] 19741354000000.0


In [17]:
# Testing repeated t.b (cache reading)
for i in range(5):
    time1 = time.time()
    b_test = block_until_ready(t.use_b_cache())
    time2 = time.time()
    print('time', time2-time1, 'b[0,0]', b_test[0, 0])
    

time 0.36299991607666016 b[0,0] 19354532862541.535
time 0.16297221183776855 b[0,0] 19354532862541.535
time 0.25149011611938477 b[0,0] 19354532862541.535
time 0.3062417507171631 b[0,0] 19354532862541.535
time 0.32625699043273926 b[0,0] 19354532862541.535


In [None]:
# Testing repeated use of t.b