New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Can we cache inner functions for speed improvement? #4
Comments
This script does a nice job highlighting the issue, and is one I'm working with to find a potential solution. I modified the Any calls to import numpy as onp
import jax.numpy as np
from tfc import utfc
from tfc.utils import egrad
from jax import jit
mine = utfc(3,0,5,basis="CP",xf=1.)
H = mine.H
x = mine.x
xi = np.zeros(H(x).shape[1])
dH = egrad(H)
f = lambda xi: np.dot(H(x),xi)
df = lambda xi: np.dot(dH(x),xi)
fj = jit(f,static_argnums=(0,))
dfj = jit(df)
class hasharray(onp.ndarray):
def __hash__(self):
return hash(self.tobytes())
def __eq__(self,other):
return other.__hash__() == self.__hash__()
x2 = onp.array(x.tolist()).view(hasharray)
#x2 = hasharray(x)
f2 = lambda x2,xi: np.dot(H(x2),xi)
jf2 = jit(f2, static_argnums=(0,)) At the end to this script, I created a hack (this hack is generally unsafe, so do not use it outside of this script) to make arrays hashable. However, JAX still calls To try
|
Tried using this JAX comment. This allows I think we need a way to make |
Opened up a discussion for this on the JAX GitHub here. |
The above discussion proved ultra useful! Posting this here for now. It is a really rough draft. I plan to make this code cleaner and the API simpler before adding as a tool to TFC. Ultimately, it would be nice to offer this as a standalone, and to offer it as an option for the Rough draft follows: from functools import partial
import jax.numpy as np
from tfc import utfc
from tfc.utils import egrad
from jax import jit, ensure_compile_time_eval
from jax.interpreters import partial_eval as pe
from jax import linear_util as lu
from jax._src.api_util import flatten_fun, tree_flatten
from jax.interpreters import ad
from jax.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal,
raise_to_shaped)
from jax import core
from copy import copy
mine = utfc(3,0,5,basis="CP",xf=1.)
H = mine.H
x = mine.x
xi = np.ones(H(x).shape[1])
f = lambda x,xi: np.dot(H(x),xi)
df = egrad(f)
def new_jit(*args, remove_args = (), remove_arg_nums = (), **kwargs):
dark = list(args)
def wrapper(f):
if len(remove_args) > 0:
assert(len(remove_args) == len(remove_arg_nums))
def get_arg(a, known):
if known:
return pe.PartialVal.known(a)
else:
return pe.PartialVal.unknown(get_aval(a).at_least_vspace())
for k in remove_arg_nums:
dark[k] = remove_args[k]
part_args = tuple((get_arg(a, k in remove_arg_nums) for k,a in enumerate(dark)))
wrap = lu.wrap_init(f)
_, in_tree = tree_flatten((args, {}))
wrap_flat, out_tree = flatten_fun(wrap, in_tree)
jaxpr, _, const = pe.trace_to_jaxpr(wrap_flat, part_args)
f_removed = lambda *args: core.eval_jaxpr(jaxpr, const, *args)
return jit(f_removed)
else:
return jit(f)
return wrapper
def L(x,xi):
return x*df(x,xi)
Lf = new_jit(x,xi,remove_args=(x,),remove_arg_nums=(0,))(L) |
This is more representative of the features I want. Some kinks I need to still to work out with the out shape: def jit2(*args, constant_arg_nums = (), **kwargs):
# Reorder to put knowns first, then unknowns
order = [k for k in range(len(args))]
for k in constant_arg_nums:
order.insert(0, order.pop(k))
dark = tuple(args[k] for k in order)
# Store the removed args for later
num_args_remove = len(constant_arg_nums)
def wrapper(f_orig):
if len(constant_arg_nums) > 0:
# Reordering args so the ones to remove are given first
# This will allow us to return a function that has completely removed those args
# Moreover, we do it here so this reordering will be optimized by the compiler
def f(*args):
new_args = tuple(args[k] for k in order)
return f_orig(*new_args)
# Create the partial args needed by pe.trace_to_jaxpr
def get_arg(a, unknown):
if unknown:
return pe.PartialVal.unknown(get_aval(a).at_least_vspace())
else:
return pe.PartialVal.known(a)
part_args = tuple((get_arg(a, k >= num_args_remove) for k,a in enumerate(dark)))
# Create jaxpr
wrap = lu.wrap_init(f)
_, in_tree = tree_flatten((args, {}))
wrap_flat, out_tree = flatten_fun(wrap, in_tree)
jaxpr, _, const = pe.trace_to_jaxpr(wrap_flat, part_args)
f_removed = lambda *args: core.eval_jaxpr(jaxpr, const, *dark[0:num_args_remove], *args)
return jit(f_removed)
else:
return jit(f)
return wrapper
def L(xi,x):
return H(x).T+df(x,xi)
Lf = jit2(xi,x,constant_arg_nums=(1,))(L)
Lf(xi)
print("If working, no words beyond this point.")
Lf(xi) |
* This is the functionality defined in #4.
This capability has been added as
|
Closing as this feature has been added. |
In TFC, we often have functions like the following:
where the function only depends on
xi
, and we want to constantly use the samex
value. The functionH(x)
represents a potentially expensive function here. It would be nice to cache the result ofH(x)
and use that, rather than having to call the function each time.Of course, for a simple case like this we can do,
but in general, these
H(x)
bits may be generated byegrad
transformations, and so we do not have an explicit view of them. Is there a way we can cache everything except the operations that explicitly depend onxi
?The text was updated successfully, but these errors were encountered: