Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can we cache inner functions for speed improvement? #4

Closed
leakec opened this issue Mar 6, 2022 · 7 comments
Closed

Can we cache inner functions for speed improvement? #4

leakec opened this issue Mar 6, 2022 · 7 comments
Assignees
Labels
enhancement New feature or request question Further information is requested

Comments

@leakec
Copy link
Owner

leakec commented Mar 6, 2022

In TFC, we often have functions like the following:

f = lambda xi: np.dot(H(x),xi)

where the function only depends on xi, and we want to constantly use the same x value. The function H(x) represents a potentially expensive function here. It would be nice to cache the result of H(x) and use that, rather than having to call the function each time.

Of course, for a simple case like this we can do,

A = H(x)
f = lambda xi: np.dot(A,xi)

but in general, these H(x) bits may be generated by egrad transformations, and so we do not have an explicit view of them. Is there a way we can cache everything except the operations that explicitly depend on xi?

@leakec leakec added enhancement New feature or request question Further information is requested labels Mar 6, 2022
@leakec
Copy link
Owner Author

leakec commented Mar 6, 2022

This script does a nice job highlighting the issue, and is one I'm working with to find a potential solution. I modified the CP::Hint function to print out a statement each time it is run, that way we can track if JAX is calling it or not.

Any calls to f, jf, df, and dfj results in a call to CP::Hint.

import numpy as onp
import jax.numpy as np
from tfc import utfc
from tfc.utils import egrad
from jax import jit

mine = utfc(3,0,5,basis="CP",xf=1.)

H = mine.H
x = mine.x

xi = np.zeros(H(x).shape[1])

dH = egrad(H)
f = lambda xi: np.dot(H(x),xi)
df = lambda xi: np.dot(dH(x),xi)
fj = jit(f,static_argnums=(0,))
dfj = jit(df)

class hasharray(onp.ndarray): 
    def __hash__(self):
        return hash(self.tobytes())
    def __eq__(self,other):
        return other.__hash__() == self.__hash__()

x2 = onp.array(x.tolist()).view(hasharray)
#x2 = hasharray(x)
f2 = lambda x2,xi: np.dot(H(x2),xi)
jf2 = jit(f2, static_argnums=(0,))

At the end to this script, I created a hack (this hack is generally unsafe, so do not use it outside of this script) to make arrays hashable. However, JAX still calls CP::Hint in this case.

To try

  • Maybe if I make the return of the function hashable JAX will hash it?

@leakec
Copy link
Owner Author

leakec commented Mar 6, 2022

Tried using this JAX comment. This allows x as a static_argnums, i.e., it does not re-JIT; however, it still runs all of the H(x) values :(

I think we need a way to make H(x) itself return a hash value. I suppose that static_args are likely only used as "compile-time constants" so we can do if statements and the like with them, i.e., control flow is allowed on these values; however, all functions computed with them are not necessarily compile-time constants. We want jit to recognize the output of H(x) as a compile-time constant, and cache it.

@leakec
Copy link
Owner Author

leakec commented Mar 6, 2022

Opened up a discussion for this on the JAX GitHub here.

@leakec
Copy link
Owner Author

leakec commented Mar 7, 2022

The above discussion proved ultra useful!

Posting this here for now. It is a really rough draft. I plan to make this code cleaner and the API simpler before adding as a tool to TFC. Ultimately, it would be nice to offer this as a standalone, and to offer it as an option for the nlls functions.

Rough draft follows:

from functools import partial

import jax.numpy as np
from tfc import utfc
from tfc.utils import egrad
from jax import jit, ensure_compile_time_eval
from jax.interpreters import partial_eval as pe
from jax import linear_util as lu
from jax._src.api_util import flatten_fun, tree_flatten
from jax.interpreters import ad
from jax.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal,
                            raise_to_shaped)
from jax import core
from copy import copy

mine = utfc(3,0,5,basis="CP",xf=1.)

H = mine.H
x = mine.x
xi = np.ones(H(x).shape[1])

f = lambda x,xi: np.dot(H(x),xi)
df = egrad(f)

def new_jit(*args, remove_args = (), remove_arg_nums = (), **kwargs):
    dark = list(args)
    def wrapper(f):
        if len(remove_args) > 0:
            assert(len(remove_args) == len(remove_arg_nums))
            def get_arg(a, known):
                if known:
                    return pe.PartialVal.known(a)
                else:
                    return pe.PartialVal.unknown(get_aval(a).at_least_vspace())
            for k in remove_arg_nums:
                dark[k] = remove_args[k]
            part_args = tuple((get_arg(a, k in remove_arg_nums) for k,a in enumerate(dark)))

            wrap = lu.wrap_init(f)
            _, in_tree = tree_flatten((args, {}))
            wrap_flat, out_tree = flatten_fun(wrap, in_tree)
            jaxpr, _, const = pe.trace_to_jaxpr(wrap_flat, part_args)

            f_removed = lambda *args: core.eval_jaxpr(jaxpr, const, *args)
            return jit(f_removed)
        else:
            return jit(f)
    return wrapper


def L(x,xi):
    return x*df(x,xi)

Lf = new_jit(x,xi,remove_args=(x,),remove_arg_nums=(0,))(L)

@leakec
Copy link
Owner Author

leakec commented Mar 7, 2022

This is more representative of the features I want. Some kinks I need to still to work out with the out shape:

def jit2(*args, constant_arg_nums = (), **kwargs):

    # Reorder to put knowns first, then unknowns
    order = [k for k in range(len(args))]
    for k in constant_arg_nums:
        order.insert(0, order.pop(k))
    dark = tuple(args[k] for k in order)

    # Store the removed args for later
    num_args_remove = len(constant_arg_nums)

    def wrapper(f_orig):
        if len(constant_arg_nums) > 0:
            # Reordering args so the ones to remove are given first
            # This will allow us to return a function that has completely removed those args
            # Moreover, we do it here so this reordering will be optimized by the compiler
            def f(*args):
                new_args = tuple(args[k] for k in order)
                return f_orig(*new_args)

            # Create the partial args needed by pe.trace_to_jaxpr
            def get_arg(a, unknown):
                if unknown:
                    return pe.PartialVal.unknown(get_aval(a).at_least_vspace())
                else:
                    return pe.PartialVal.known(a)
            part_args = tuple((get_arg(a, k >= num_args_remove) for k,a in enumerate(dark)))

            # Create jaxpr
            wrap = lu.wrap_init(f)
            _, in_tree = tree_flatten((args, {}))
            wrap_flat, out_tree = flatten_fun(wrap, in_tree)
            jaxpr, _, const = pe.trace_to_jaxpr(wrap_flat, part_args)

            f_removed = lambda *args: core.eval_jaxpr(jaxpr, const, *dark[0:num_args_remove], *args)
            return jit(f_removed)
        else:
            return jit(f)
    return wrapper


def L(xi,x):
    return H(x).T+df(x,xi)

Lf = jit2(xi,x,constant_arg_nums=(1,))(L)

Lf(xi)
print("If working, no words beyond this point.")
Lf(xi)

leakec added a commit that referenced this issue Mar 12, 2022
* This is the functionality defined in #4.
@leakec
Copy link
Owner Author

leakec commented Mar 12, 2022

This capability has been added as pe and pejit. I have also added it to NLLS and NLLSClass.
Remaining to-dos:

  • Modify pe and pejit to work with general pytrees.
  • Added options for constant_arg_nums to LS and LSClass.

@leakec
Copy link
Owner Author

leakec commented Mar 12, 2022

Closing as this feature has been added.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request question Further information is requested
Projects
None yet
Development

No branches or pull requests

1 participant