Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up solver instantiation #5

Open
hgrecco opened this issue Oct 26, 2020 · 6 comments
Open

Speed up solver instantiation #5

hgrecco opened this issue Oct 26, 2020 · 6 comments

Comments

@hgrecco
Copy link
Owner

hgrecco commented Oct 26, 2020

Instantiating the solver takes quite long due to compilation. It would be interesting to improve de code (without loosing clarity) so that numba is able to use the compiled code cache.

@hgrecco
Copy link
Owner Author

hgrecco commented Nov 4, 2020

With current Numba (0.51.2) the following code:

import numba as nb

@nb.njit(cache=True)
def func(t, y):
    return y

@nb.njit(cache=True)
def stepper(f, a, b):
    return 3 * f(a, b)

print(stepper(func, 4., 2.))

fails with the following error:

Traceback (most recent call last):
[...]
TypeError: cannot pickle 'weakref' object

due to cache=True.

If this works in the near future, instantiation time can drop dramatically as compilation will be done once. This might require some code reorganization.

For example, right now the stepper is organized in two layers:

def step_builder(*outer_args):
    """Build a stepper.

    This outer function should only contains attributes
    associated with the solver class not with the solver instance.
    """

    @numba.njit
    def _step(*inner_args):
        """Perform a single step.

        This inner function should only contains attributes
        associated with the solver instance not with the solver class.
        """

        # code to step

    return _step

The step_builder is called at class instantiation, _step at each step. This separation makes the code very clear and organized. But it might not work and the code will need to be flattened to something like:

@numba.njit
def _step(*inner_args, *outer_args):
    """Perform a single step.

    This outer function should only contains attributes
    associated with the solver class not with the solver instance.
    """

@Illviljan
Copy link

The failing example appears to work now:

import numba as nb

@nb.njit(cache=True)
def func(t, y):
    return y

@nb.njit(cache=True)
def stepper(f, a, b):
    return 3 * f(a, b)

print(f"{nb.__version__ = }")
print(stepper(func, 4., 2.))
nb.__version__ = '0.56.4'
6.0

@maurosilber
Copy link
Collaborator

But it is still producing a different cached version of stepper on each run:

> python cache.py && ls __pycache__/cache.stepper*.nbc | wc -l
1
> python cache.py && ls __pycache__/cache.stepper*.nbc | wc -l
2

Environment:

libllvm11                 11.1.0               hfa12f05_5    conda-forge
llvmlite                  0.39.1          py310h1e34944_1    conda-forge
numba                     0.56.4          py310h3124f1e_0    conda-forge
python                    3.10.8          h3ba56d0_0_cpython    conda-forge

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants