Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable inlining of first-class function when it is statically known to be a dispatcher #9077

Merged
merged 15 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/upcoming_changes/9077.improvement.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
First-class function improvements
---------------------------------

Passing a jit function as a parameter to another jit function that accepts it as
a ``FunctionType`` has two new improvements.

First, the compiler can now inline a jit function that is passed as a non-local
variable (like a global variable) to another jit function. Previously, the
interpreter had to introspect the function address for first-class function
calls, which prevented inlining. With this improvement, the compiler can
statically determine the referenced jit function and link in the corresponding
LLVM module for optimization, bypassing the need for the GIL entirely.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And it's so much faster!


Second, jit functions used as first-class functions can now raise exceptions.
Before this improvement, they were subject to the same restrictions as
``@cfunc`` decorated functions, where any exceptions raised were ignored.
8 changes: 5 additions & 3 deletions numba/core/callconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,9 +876,11 @@ def call_function(self, builder, callee, resty, argtys, args,
LLVM style string: "noinline fast"
Equivalent iterable: ("noinline", "fast")
"""
# XXX better fix for callees that are not function values
# (pointers to function; thus have no `.args` attribute)
retty = self._get_return_argument(callee.function_type).pointee
retty = self.get_return_type(resty).pointee
actual_retty = self._get_return_argument(callee.function_type).pointee
if retty != actual_retty:
m = f"Function type returns {actual_retty} but resty={retty}"
raise ValueError(m)
stuartarchibald marked this conversation as resolved.
Show resolved Hide resolved

retvaltmp = cgutils.alloca_once(builder, retty)
# initialize return value to zeros
Expand Down
68 changes: 63 additions & 5 deletions numba/core/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,15 +1091,73 @@ def _lower_call_FunctionType(self, fnty, expr, signature):
raise UnsupportedError(
f'mismatch of function types:'
f' expected {fnty} but got {types.FunctionType(sig)}')
ftype = fnty.ftype
argvals = self.fold_call_args(
fnty, sig, expr.args, expr.vararg, expr.kws,
)
func_ptr = self.__get_function_pointer(ftype, expr.func.name, sig=sig)
res = self.builder.call(func_ptr, argvals, cconv=fnty.cconv)
return res
return self.__call_first_class_function_pointer(
fnty.ftype, expr.func.name, sig, argvals,
)

def __call_first_class_function_pointer(self, ftype, fname, sig, argvals):
"""
Calls a first-class function pointer.

This function is responsible for calling a first-class function pointer,
which can either be a JIT-compiled function or a Python function. It
determines if a JIT address is available, and if so, calls the function
using the JIT address. Otherwise, it calls the function using a function
pointer obtained from the `__get_first_class_function_pointer` method.

Args:
ftype: The type of the function.
fname: The name of the function.
sig: The signature of the function.
argvals: The argument values to pass to the function.

Returns:
The result of calling the function.
"""
context = self.context
builder = self.builder
# Determine if jit address is available
fstruct = self.loadvar(fname)
struct = cgutils.create_struct_proxy(self.typeof(fname))(
context, builder, value=fstruct
)
jit_addr = struct.jit_addr
jit_addr.name = f'jit_addr_of_{fname}'

ctx = context
res_slot = cgutils.alloca_once(builder,
ctx.get_value_type(sig.return_type))

if_jit_addr_is_null = builder.if_else(
cgutils.is_null(builder, jit_addr),
likely=False
)
with if_jit_addr_is_null as (then, orelse):
with then:
func_ptr = self.__get_first_class_function_pointer(
ftype, fname, sig)
res = builder.call(func_ptr, argvals)
builder.store(res, res_slot)

with orelse:
llty = ctx.call_conv.get_function_type(
sig.return_type,
sig.args
).as_pointer()
func_ptr = builder.bitcast(jit_addr, llty)
# call
status, res = ctx.call_conv.call_function(
builder, func_ptr, sig.return_type, sig.args, argvals
)
with cgutils.if_unlikely(builder, status.is_error):
context.call_conv.return_status_propagate(builder, status)
builder.store(res, res_slot)
return builder.load(res_slot)

def __get_function_pointer(self, ftype, fname, sig=None):
def __get_first_class_function_pointer(self, ftype, fname, sig):
from numba.experimental.function_type import lower_get_wrapper_address

llty = self.context.get_value_type(ftype)
Expand Down
139 changes: 103 additions & 36 deletions numba/experimental/function_type.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Provides Numba type, FunctionType, that makes functions as
instances of a first-class function type.
"""
from functools import partial

from numba.extending import typeof_impl
from numba.extending import models, register_model
Expand All @@ -9,7 +10,7 @@
from numba.core.ccallback import CFunc
from numba.core import cgutils
from llvmlite import ir
from numba.core import types
from numba.core import types, errors
from numba.core.types import (FunctionType, UndefinedFunctionType,
FunctionPrototype, WrapperAddressProtocol)
from numba.core.dispatcher import Dispatcher
Expand Down Expand Up @@ -52,11 +53,16 @@ class FunctionModel(models.StructModel):
"""
def __init__(self, dmm, fe_type):
members = [
# address of cfunc wrapper function:
('addr', types.voidptr),
# address of PyObject* referencing the Python function
# Address of cfunc wrapper function.
# This uses a C callconv and doesn't not support exceptions.
('c_addr', types.voidptr),
# Address of PyObject* referencing the Python function
# object:
('pyaddr', types.voidptr),
('py_addr', types.voidptr),
# Address of the underlying function object.
# Calling through this function pointer supports all features of
# regular numba function as it follows the same Numba callconv.
('jit_addr', types.voidptr),
]
super(FunctionModel, self).__init__(dmm, fe_type, members)

Expand All @@ -74,26 +80,26 @@ def lower_constant_function_type(context, builder, typ, pyval):
if isinstance(pyval, CFunc):
addr = pyval._wrapper_address
sfunc = cgutils.create_struct_proxy(typ)(context, builder)
sfunc.addr = context.add_dynamic_addr(builder, addr,
info=str(typ))
sfunc.pyaddr = context.add_dynamic_addr(builder, id(pyval),
info=type(pyval).__name__)
sfunc.c_addr = context.add_dynamic_addr(builder, addr,
info=str(typ))
sfunc.py_addr = context.add_dynamic_addr(builder, id(pyval),
info=type(pyval).__name__)
return sfunc._getvalue()

if isinstance(pyval, Dispatcher):
sfunc = cgutils.create_struct_proxy(typ)(context, builder)
sfunc.pyaddr = context.add_dynamic_addr(builder, id(pyval),
info=type(pyval).__name__)
sfunc.py_addr = context.add_dynamic_addr(builder, id(pyval),
info=type(pyval).__name__)
return sfunc._getvalue()

if isinstance(pyval, WrapperAddressProtocol):
addr = pyval.__wrapper_address__()
assert typ.check_signature(pyval.signature())
sfunc = cgutils.create_struct_proxy(typ)(context, builder)
sfunc.addr = context.add_dynamic_addr(builder, addr,
info=str(typ))
sfunc.pyaddr = context.add_dynamic_addr(builder, id(pyval),
info=type(pyval).__name__)
sfunc.c_addr = context.add_dynamic_addr(builder, addr,
info=str(typ))
sfunc.py_addr = context.add_dynamic_addr(builder, id(pyval),
info=type(pyval).__name__)
return sfunc._getvalue()

# TODO: implement support for pytypes.FunctionType, ctypes.CFUNCTYPE
Expand Down Expand Up @@ -166,9 +172,24 @@ def _get_wrapper_address(func, sig):
return addr


def lower_get_wrapper_address(context, builder, func, sig,
failure_mode='return_exc'):
"""Low-level call to _get_wrapper_address(func, sig).
def _get_jit_address(func, sig):
"""Similar to ``_get_wrapper_address()`` but get the `.jit_addr` instead.
"""
if isinstance(func, Dispatcher):
cres = func.get_compile_result(sig)
jit_name = cres.fndesc.llvm_func_name
addr = cres.library.get_pointer_to_function(jit_name)
else:
addr = 0
if not isinstance(addr, int):
raise TypeError(
f'jit address must be integer, got {type(addr)} instance')
stuartarchibald marked this conversation as resolved.
Show resolved Hide resolved
return addr


def _lower_get_address(context, builder, func, sig, failure_mode,
*, function_name):
"""Low-level call to <function_name>(func, sig).

When calling this function, GIL must be acquired.
"""
Expand All @@ -182,16 +203,15 @@ def lower_get_wrapper_address(context, builder, func, sig,

modname = context.insert_const_string(builder.module, __name__)
numba_mod = pyapi.import_module_noblock(modname)
numba_func = pyapi.object_getattr_string(
numba_mod, '_get_wrapper_address')
numba_func = pyapi.object_getattr_string(numba_mod, function_name)
pyapi.decref(numba_mod)
sig_obj = pyapi.unserialize(pyapi.serialize_object(sig))

addr = pyapi.call_function_objargs(numba_func, (func, sig_obj))

if failure_mode != 'ignore':
with builder.if_then(cgutils.is_null(builder, addr), likely=False):
# _get_wrapper_address has raised an exception, propagate it
# *function_name* has raised an exception, propagate it
# to the caller.
if failure_mode == 'return_exc':
context.call_conv.return_exc(builder)
Expand All @@ -200,10 +220,21 @@ def lower_get_wrapper_address(context, builder, func, sig,
else:
raise NotImplementedError(failure_mode)
# else the caller will handle addr == NULL

return addr # new reference or NULL


lower_get_wrapper_address = partial(
_lower_get_address,
function_name="_get_wrapper_address",
)


lower_get_jit_address = partial(
_lower_get_address,
function_name="_get_jit_address",
)


@unbox(FunctionType)
def unbox_function_type(typ, obj, c):
typ = typ.get_precise()
Expand All @@ -212,11 +243,16 @@ def unbox_function_type(typ, obj, c):

addr = lower_get_wrapper_address(
c.context, c.builder, obj, typ.signature, failure_mode='return_null')
sfunc.addr = c.pyapi.long_as_voidptr(addr)
sfunc.c_addr = c.pyapi.long_as_voidptr(addr)
c.pyapi.decref(addr)

llty = c.context.get_value_type(types.voidptr)
sfunc.pyaddr = c.builder.ptrtoint(obj, llty)
sfunc.py_addr = c.builder.ptrtoint(obj, llty)

addr = lower_get_jit_address(
c.context, c.builder, obj, typ.signature, failure_mode='return_null')
sfunc.jit_addr = c.pyapi.long_as_voidptr(addr)
c.pyapi.decref(addr)

return NativeValue(sfunc._getvalue())

Expand All @@ -227,7 +263,7 @@ def box_function_type(typ, val, c):

sfunc = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)
pyaddr_ptr = cgutils.alloca_once(c.builder, c.pyapi.pyobj)
raw_ptr = c.builder.inttoptr(sfunc.pyaddr, c.pyapi.pyobj)
raw_ptr = c.builder.inttoptr(sfunc.py_addr, c.pyapi.pyobj)
with c.builder.if_then(cgutils.is_null(c.builder, raw_ptr),
likely=False):
cstr = f"first-class function {typ} parent object not set"
Expand All @@ -249,17 +285,48 @@ def lower_cast_function_type_to_function_type(
def lower_cast_dispatcher_to_function_type(context, builder, fromty, toty, val):
toty = toty.get_precise()

pyapi = context.get_python_api(builder)
sig = toty.signature
dispatcher = fromty.dispatcher
llvoidptr = context.get_value_type(types.voidptr)
sfunc = cgutils.create_struct_proxy(toty)(context, builder)
# Always store the python function
sfunc.py_addr = builder.ptrtoint(val, llvoidptr)

# Attempt to compile the Dispatcher to the expected function type
try:
cres = dispatcher.get_compile_result(sig)
except errors.NumbaError:
cres = None

# If compilation is successful, we can by-pass using GIL to get the cfunc
if cres is not None:
# Declare cfunc in the current module
wrapper_name = cres.fndesc.llvm_cfunc_wrapper_name
llfnptr = context.get_value_type(toty.ftype)
llfnty = llfnptr.pointee
fn = cgutils.get_or_insert_function(
builder.module, llfnty, wrapper_name,
)
addr = builder.bitcast(fn, llvoidptr)
# Store the cfunc
sfunc.c_addr = addr
# Store the jit func
fn = context.declare_function(builder.module, cres.fndesc)
sfunc.jit_addr = builder.bitcast(fn, llvoidptr)
# Link-in the dispatcher library
context.active_code_library.add_linking_library(cres.library)

else:
# Use lower_get_wrapper_address() to get the cfunc
lower_get_wrapper_address
pyapi = context.get_python_api(builder)

gil_state = pyapi.gil_ensure()
addr = lower_get_wrapper_address(
context, builder, val, toty.signature,
failure_mode='return_exc')
sfunc.c_addr = pyapi.long_as_voidptr(addr)
stuartarchibald marked this conversation as resolved.
Show resolved Hide resolved
pyapi.decref(addr)
pyapi.gil_release(gil_state)

gil_state = pyapi.gil_ensure()
addr = lower_get_wrapper_address(
context, builder, val, toty.signature,
failure_mode='return_exc')
sfunc.addr = pyapi.long_as_voidptr(addr)
pyapi.decref(addr)
pyapi.gil_release(gil_state)

llty = context.get_value_type(types.voidptr)
sfunc.pyaddr = builder.ptrtoint(val, llty)
return sfunc._getvalue()
7 changes: 7 additions & 0 deletions numba/tests/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,13 @@ def redirect_c_stdout():
return redirect_fd(fd)


def redirect_c_stderr():
"""Redirect C stderr
"""
fd = sys.__stderr__.fileno()
return redirect_fd(fd)


def run_in_new_process_caching(func, cache_dir_prefix=__name__, verbose=True):
"""Spawn a new process to run `func` with a temporary cache directory.

Expand Down
Loading
Loading