Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce redundant module linking #3228

Merged
merged 6 commits into from
Oct 17, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 18 additions & 16 deletions numba/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,14 +1021,15 @@ def native_lowering_stage(targetctx, library, interp, typemap, restype,
interp, typemap, restype, calltypes, mangler=targetctx.mangler,
inline=flags.forceinline, noalias=flags.noalias)

lower = lowering.Lower(targetctx, library, fndesc, interp)
lower.lower()
if not flags.no_cpython_wrapper:
lower.create_cpython_wrapper(flags.release_gil)
env = lower.env
call_helper = lower.call_helper
has_dynamic_globals = lower.has_dynamic_globals
del lower
with targetctx.push_code_library(library):
lower = lowering.Lower(targetctx, library, fndesc, interp)
lower.lower()
if not flags.no_cpython_wrapper:
lower.create_cpython_wrapper(flags.release_gil)
env = lower.env
call_helper = lower.call_helper
has_dynamic_globals = lower.has_dynamic_globals
del lower

if flags.no_compile:
return _LowerResult(fndesc, call_helper, cfunc=None, env=env,
Expand All @@ -1047,14 +1048,15 @@ def py_lowering_stage(targetctx, library, interp, flags):
fndesc = funcdesc.PythonFunctionDescriptor.from_object_mode_function(
interp
)
lower = pylowering.PyLower(targetctx, library, fndesc, interp)
lower.lower()
if not flags.no_cpython_wrapper:
lower.create_cpython_wrapper()
env = lower.env
call_helper = lower.call_helper
has_dynamic_globals = lower.has_dynamic_globals
del lower
with targetctx.push_code_library(library):
lower = pylowering.PyLower(targetctx, library, fndesc, interp)
lower.lower()
if not flags.no_cpython_wrapper:
lower.create_cpython_wrapper()
env = lower.env
call_helper = lower.call_helper
has_dynamic_globals = lower.has_dynamic_globals
del lower

if flags.no_compile:
return _LowerResult(fndesc, call_helper, cfunc=None, env=env,
Expand Down
3 changes: 2 additions & 1 deletion numba/npyufunc/array_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,8 @@ def _lower_array_expr(lowerer, expr):
# division (issue #1223).
flags = compiler.Flags()
flags.set('error_model', 'numpy')
cres = context.compile_subroutine_no_cache(builder, impl, inner_sig, flags=flags)
cres = context.compile_subroutine(builder, impl, inner_sig, flags=flags,
caching=False)

# Create kernel subclass calling our native function
from ..targets import npyimpl
Expand Down
54 changes: 43 additions & 11 deletions numba/targets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import sys
from itertools import permutations, takewhile
from contextlib import contextmanager

import numpy as np

Expand Down Expand Up @@ -241,6 +242,7 @@ def __init__(self, typing_context):
self.special_ops = {}
self.cached_internal_func = {}
self._pid = None
self._codelib_stack = []

self.data_model_manager = datamodel.default_manager

Expand Down Expand Up @@ -776,7 +778,8 @@ def get_dummy_value(self):
def get_dummy_type(self):
return GENERIC_POINTER

def compile_subroutine_no_cache(self, builder, impl, sig, locals={}, flags=None):
def _compile_subroutine_no_cache(self, builder, impl, sig, locals={},
flags=None):
"""
Invoke the compiler to compile a function to be used inside a
nopython function, but without generating code to call that
Expand Down Expand Up @@ -804,23 +807,36 @@ def compile_subroutine_no_cache(self, builder, impl, sig, locals={}, flags=None)
codegen.add_linking_library(cres.library)
return cres

def compile_subroutine(self, builder, impl, sig, locals={}):
def compile_subroutine(self, builder, impl, sig, locals={}, flags=None,
caching=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps add what caching does to the docstring for this function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added docstring

"""
Compile the function *impl* for the given *sig* (in nopython mode).
Return a placeholder object that's callable from another Numba
function.

If *caching* evaluates True, the function keeps the compiled function
for reuse in *.cached_internal_func*.
"""
cache_key = (impl.__code__, sig, type(self.error_model))
if impl.__closure__:
# XXX This obviously won't work if a cell's value is
# unhashable.
cache_key += tuple(c.cell_contents for c in impl.__closure__)
ty = self.cached_internal_func.get(cache_key)
if ty is None:
cres = self.compile_subroutine_no_cache(builder, impl, sig,
locals=locals)
if not caching:
cached = None
else:
if impl.__closure__:
# XXX This obviously won't work if a cell's value is
# unhashable.
cache_key += tuple(c.cell_contents for c in impl.__closure__)
cached = self.cached_internal_func.get(cache_key)
if cached is None:
cres = self._compile_subroutine_no_cache(builder, impl, sig,
locals=locals,
flags=flags)
lib = cres.library
ty = types.NumbaFunction(cres.fndesc, sig)
self.cached_internal_func[cache_key] = ty
self.cached_internal_func[cache_key] = ty, lib

ty, lib = self.cached_internal_func[cache_key]
# Allow inlining the function inside callers.
self.active_code_library.add_linking_library(lib)
return ty

def compile_internal(self, builder, impl, sig, args, locals={}):
Expand Down Expand Up @@ -1064,6 +1080,22 @@ def create_module(self, name):
"""
return lc.Module(name)

@property
def active_code_library(self):
"""Get the active code library
"""
return self._codelib_stack[-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could someone end up calling this before the _codelib_stack contains at least one entry? Worth catching it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If that happens, it is an internal error.


@contextmanager
def push_code_library(self, lib):
"""Push the active code library for the context
"""
self._codelib_stack.append(lib)
try:
yield
finally:
self._codelib_stack.pop()


class _wrap_impl(object):
"""
Expand Down
161 changes: 83 additions & 78 deletions numba/tests/test_compile_cache.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from __future__ import division

import numba.unittest_support as unittest
from contextlib import contextmanager

import llvmlite.llvmpy.core as lc

import numpy as np

from numba import compiler, types, typing
from numba import types, typing
from numba.targets import callconv, cpu


Expand All @@ -16,21 +15,24 @@ class TestCompileCache(unittest.TestCase):
checking the state of the cache when it is used by the CPUContext.
'''

@contextmanager
def _context_builder_sig_args(self):
typing_context = typing.Context()
context = cpu.CPUContext(typing_context)
module = lc.Module("test_module")
lib = context.codegen().create_library('testing')
with context.push_code_library(lib):
module = lc.Module("test_module")

sig = typing.signature(types.int32, types.int32)
llvm_fnty = context.call_conv.get_function_type(sig.return_type,
sig.args)
function = module.get_or_insert_function(llvm_fnty, name='test_fn')
args = context.call_conv.get_arguments(function)
assert function.is_declaration
entry_block = function.append_basic_block('entry')
builder = lc.Builder(entry_block)
sig = typing.signature(types.int32, types.int32)
llvm_fnty = context.call_conv.get_function_type(sig.return_type,
sig.args)
function = module.get_or_insert_function(llvm_fnty, name='test_fn')
args = context.call_conv.get_arguments(function)
assert function.is_declaration
entry_block = function.append_basic_block('entry')
builder = lc.Builder(entry_block)

return context, builder, sig, args
yield context, builder, sig, args

def test_cache(self):
def times2(i):
Expand All @@ -39,39 +41,40 @@ def times2(i):
def times3(i):
return i*3

context, builder, sig, args = self._context_builder_sig_args()

# Ensure the cache is empty to begin with
self.assertEqual(0, len(context.cached_internal_func))

# After one compile, it should contain one entry
context.compile_internal(builder, times2, sig, args)
self.assertEqual(1, len(context.cached_internal_func))

# After a second compilation of the same thing, it should still contain
# one entry
context.compile_internal(builder, times2, sig, args)
self.assertEqual(1, len(context.cached_internal_func))

# After compilation of another function, the cache should have grown by
# one more.
context.compile_internal(builder, times3, sig, args)
self.assertEqual(2, len(context.cached_internal_func))

sig2 = typing.signature(types.float64, types.float64)
llvm_fnty2 = context.call_conv.get_function_type(sig2.return_type,
sig2.args)
function2 = builder.module.get_or_insert_function(llvm_fnty2,
name='test_fn_2')
args2 = context.call_conv.get_arguments(function2)
assert function2.is_declaration
entry_block2 = function2.append_basic_block('entry')
builder2 = lc.Builder(entry_block2)

# Ensure that the same function with a different signature does not
# reuse an entry from the cache in error
context.compile_internal(builder2, times3, sig2, args2)
self.assertEqual(3, len(context.cached_internal_func))
with self._context_builder_sig_args() as (
context, builder, sig, args,
):
# Ensure the cache is empty to begin with
self.assertEqual(0, len(context.cached_internal_func))

# After one compile, it should contain one entry
context.compile_internal(builder, times2, sig, args)
self.assertEqual(1, len(context.cached_internal_func))

# After a second compilation of the same thing, it should still contain
# one entry
context.compile_internal(builder, times2, sig, args)
self.assertEqual(1, len(context.cached_internal_func))

# After compilation of another function, the cache should have grown by
# one more.
context.compile_internal(builder, times3, sig, args)
self.assertEqual(2, len(context.cached_internal_func))

sig2 = typing.signature(types.float64, types.float64)
llvm_fnty2 = context.call_conv.get_function_type(sig2.return_type,
sig2.args)
function2 = builder.module.get_or_insert_function(llvm_fnty2,
name='test_fn_2')
args2 = context.call_conv.get_arguments(function2)
assert function2.is_declaration
entry_block2 = function2.append_basic_block('entry')
builder2 = lc.Builder(entry_block2)

# Ensure that the same function with a different signature does not
# reuse an entry from the cache in error
context.compile_internal(builder2, times3, sig2, args2)
self.assertEqual(3, len(context.cached_internal_func))

def test_closures(self):
"""
Expand All @@ -82,19 +85,20 @@ def f(z):
return y + z
return f

context, builder, sig, args = self._context_builder_sig_args()

# Closures with distinct cell contents must each be compiled.
clo11 = make_closure(1, 1)
clo12 = make_closure(1, 2)
clo22 = make_closure(2, 2)
res1 = context.compile_internal(builder, clo11, sig, args)
self.assertEqual(1, len(context.cached_internal_func))
res2 = context.compile_internal(builder, clo12, sig, args)
self.assertEqual(2, len(context.cached_internal_func))
# Same cell contents as above (first parameter isn't captured)
res3 = context.compile_internal(builder, clo22, sig, args)
self.assertEqual(2, len(context.cached_internal_func))
with self._context_builder_sig_args() as (
context, builder, sig, args,
):
# Closures with distinct cell contents must each be compiled.
clo11 = make_closure(1, 1)
clo12 = make_closure(1, 2)
clo22 = make_closure(2, 2)
res1 = context.compile_internal(builder, clo11, sig, args)
self.assertEqual(1, len(context.cached_internal_func))
res2 = context.compile_internal(builder, clo12, sig, args)
self.assertEqual(2, len(context.cached_internal_func))
# Same cell contents as above (first parameter isn't captured)
res3 = context.compile_internal(builder, clo22, sig, args)
self.assertEqual(2, len(context.cached_internal_func))

def test_error_model(self):
"""
Expand All @@ -108,25 +112,26 @@ def inv(x):
def compile_inv(context):
return context.compile_subroutine(builder, inv, inv_sig)

context, builder, sig, args = self._context_builder_sig_args()

py_error_model = callconv.create_error_model('python', context)
np_error_model = callconv.create_error_model('numpy', context)

py_context1 = context.subtarget(error_model=py_error_model)
py_context2 = context.subtarget(error_model=py_error_model)
np_context = context.subtarget(error_model=np_error_model)

# Note the parent context's cache is shared by subtargets
self.assertEqual(0, len(context.cached_internal_func))
# Compiling with the same error model reuses the same cache slot
compile_inv(py_context1)
self.assertEqual(1, len(context.cached_internal_func))
compile_inv(py_context2)
self.assertEqual(1, len(context.cached_internal_func))
# Compiling with another error model creates a new cache slot
compile_inv(np_context)
self.assertEqual(2, len(context.cached_internal_func))
with self._context_builder_sig_args() as (
context, builder, sig, args,
):
py_error_model = callconv.create_error_model('python', context)
np_error_model = callconv.create_error_model('numpy', context)

py_context1 = context.subtarget(error_model=py_error_model)
py_context2 = context.subtarget(error_model=py_error_model)
np_context = context.subtarget(error_model=np_error_model)

# Note the parent context's cache is shared by subtargets
self.assertEqual(0, len(context.cached_internal_func))
# Compiling with the same error model reuses the same cache slot
compile_inv(py_context1)
self.assertEqual(1, len(context.cached_internal_func))
compile_inv(py_context2)
self.assertEqual(1, len(context.cached_internal_func))
# Compiling with another error model creates a new cache slot
compile_inv(np_context)
self.assertEqual(2, len(context.cached_internal_func))


if __name__ == '__main__':
Expand Down