Skip to content

Commit

Permalink
Merge pull request #3166 from sklam/enh/withobjmode
Browse files Browse the repository at this point in the history
Objmode with-block
  • Loading branch information
seibert committed Sep 10, 2018
2 parents d0e149b + eb8aa0c commit 35d9cdc
Show file tree
Hide file tree
Showing 20 changed files with 986 additions and 74 deletions.
1 change: 1 addition & 0 deletions docs/source/user/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ User Manual
pycc.rst
parallel.rst
stencil.rst
withobjmode.rst
performance-tips.rst
threading-layer.rst
troubleshoot.rst
Expand Down
34 changes: 34 additions & 0 deletions docs/source/user/withobjmode.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
============================================================
Callback into the Python Interpreter from within JIT'ed code
============================================================

There are rare but real cases when a nopython-mode function needs to callback
into the Python interpreter to invoke code that cannot be compiled by Numba.
Such cases include:

- logging progress for long running JIT'ed functions;
- use data structures that are not currently supported by Numba;
- debugging inside JIT'ed code using the Python debugger.

When Numba callbacks into the Python interpreter, the following has to happen:

- acquire the GIL;
- convert values in native representation back into Python objects;
- call-back into the Python interpreter;
- convert returned values from the Python-code into native representation;
- release the GIL.

These steps can be expensive. Users **should not** rely on the feature
described here on performance-critical paths.


.. _with_objmode:

The ``objmode`` context-manager
===============================

.. warning:: This feature can be easily mis-used. Users should first consider
alternative approaches to achieve their intended goal before using
this feature.

.. autofunction:: numba.objmode
2 changes: 2 additions & 0 deletions numba/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

# Initialize withcontexts
import numba.withcontexts
from numba.withcontexts import objmode_context as objmode

# Keep this for backward compatibility.
test = runtests.main
Expand All @@ -53,6 +54,7 @@
prange
stencil
vectorize
objmode
""".split() + types.__all__ + errors.__all__


Expand Down
11 changes: 7 additions & 4 deletions numba/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .tracing import event

from numba import (bytecode, interpreter, funcdesc, postproc,
typing, typeinfer, lowering, objmode, utils, config,
typing, typeinfer, lowering, pylowering, utils, config,
errors, types, ir, rewrites, transforms)
from numba.targets import cpu, callconv
from numba.annotations import type_annotations
Expand Down Expand Up @@ -428,7 +428,7 @@ def frontend_looplift(self):
lifted=tuple(loops), lifted_from=None)
return cres

def frontend_withlift(self):
def stage_frontend_withlift(self):
"""
Extract with-contexts
"""
Expand Down Expand Up @@ -466,7 +466,6 @@ def stage_nopython_frontend(self):
"""
Type inference and legalization
"""
self.frontend_withlift()
with self.fallback_context('Function "%s" failed type inference'
% (self.func_id.func_name,)):
# Type inference
Expand Down Expand Up @@ -773,11 +772,15 @@ def add_cleanup_stage(self, pm):
"""
pm.add_stage(self.stage_cleanup, "cleanup intermediate results")

def add_with_handling_stage(self, pm):
pm.add_stage(self.stage_frontend_withlift, "Handle with contexts")

def define_nopython_pipeline(self, pm, name='nopython'):
"""Add the nopython-mode pipeline to the pipeline manager
"""
pm.create_pipeline(name)
self.add_preprocessing_stage(pm)
self.add_with_handling_stage(pm)
self.add_pre_typing_stage(pm)
self.add_typing_stage(pm)
self.add_optimization_stage(pm)
Expand Down Expand Up @@ -1042,7 +1045,7 @@ def py_lowering_stage(targetctx, library, interp, flags):
fndesc = funcdesc.PythonFunctionDescriptor.from_object_mode_function(
interp
)
lower = objmode.PyLower(targetctx, library, fndesc, interp)
lower = pylowering.PyLower(targetctx, library, fndesc, interp)
lower.lower()
if not flags.no_cpython_wrapper:
lower.create_cpython_wrapper()
Expand Down
4 changes: 4 additions & 0 deletions numba/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,9 +686,13 @@ def op_WITH_CLEANUP(self, info, inst):
"""
Note: py2 only opcode
"""
# TOS is the return value of __exit__()
info.pop()
info.append(inst)

def op_WITH_CLEANUP_START(self, info, inst):
# TOS is the return value of __exit__()
info.pop()
info.append(inst)

def op_WITH_CLEANUP_FINISH(self, info, inst):
Expand Down
1 change: 1 addition & 0 deletions numba/datamodel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ def __init__(self, dmm, fe_type):
@register_default(types.Phantom)
@register_default(types.ContextManager)
@register_default(types.Dispatcher)
@register_default(types.ObjModeDispatcher)
@register_default(types.ExceptionClass)
@register_default(types.Dummy)
@register_default(types.ExceptionInstance)
Expand Down
53 changes: 53 additions & 0 deletions numba/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,59 @@ def get_call_template(self, args, kws):
return call_template, pysig, args, kws


class ObjModeLiftedWith(LiftedWith):
def __init__(self, *args, **kwargs):
self.output_types = kwargs.pop('output_types', None)
super(LiftedWith, self).__init__(*args, **kwargs)
if not self.flags.enable_pyobject:
raise ValueError("expecting `flags.enable_pyobject`")
if self.output_types is None:
raise TypeError('`output_types` must be provided')

@property
def _numba_type_(self):
return types.ObjModeDispatcher(self)

def get_call_template(self, args, kws):
"""
Get a typing.ConcreteTemplate for this dispatcher and the given
*args* and *kws* types. This enables the resolving of the return type.
A (template, pysig, args, kws) tuple is returned.
"""
assert not kws
self._legalize_arg_types(args)
# Coerce to object mode
args = [types.ffi_forced_object] * len(args)

if self._can_compile:
self.compile(tuple(args))

signatures = [typing.signature(self.output_types, *args)]
pysig = None
func_name = self.py_func.__name__
name = "CallTemplate({0})".format(func_name)
call_template = typing.make_concrete_template(
name, key=func_name, signatures=signatures)

return call_template, pysig, args, kws

def _legalize_arg_types(self, args):
for i, a in enumerate(args, start=1):
if isinstance(a, types.List):
msg = (
'Does not support list type inputs into '
'with-context for arg {}'
)
raise errors.TypingError(msg.format(i))
elif isinstance(a, types.Dispatcher):
msg = (
'Does not support function type inputs into '
'with-context for arg {}'
)
raise errors.TypingError(msg.format(i))


# Initialize typeof machinery
_dispatcher.typeof_init(
OmittedArg,
Expand Down
82 changes: 78 additions & 4 deletions numba/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,15 +659,77 @@ def lower_call(self, resty, expr):

if isinstance(expr.func, ir.Intrinsic):
fnty = expr.func.name
argvals = expr.func.args
else:
fnty = self.typeof(expr.func.name)
argvals = self.fold_call_args(fnty, signature,
expr.args, expr.vararg, expr.kws)

if isinstance(fnty, types.ExternalFunction):
if isinstance(fnty, types.ObjModeDispatcher):
self.init_pyapi()
# Acquire the GIL
gil_state = self.pyapi.gil_ensure()
# Fix types
argnames = [a.name for a in expr.args]
argtypes = [self.typeof(a) for a in argnames]
argvalues = [self.loadvar(a) for a in argnames]
for v, ty in zip(argvalues, argtypes):
# Because .from_native_value steal the reference
self.incref(ty, v)

argobjs = [self.pyapi.from_native_value(atyp, aval,
self.env_manager)
for atyp, aval in zip(argtypes, argvalues)]
# Make Call
entry_pt = fnty.dispatcher.compile(tuple(argtypes))
callee = self.context.add_dynamic_addr(self.builder,
id(entry_pt),
info="with_objectmode")
ret_obj = self.pyapi.call_function_objargs(callee, argobjs)
has_exception = cgutils.is_null(self.builder, ret_obj)
with self. builder.if_else(has_exception) as (then, orelse):
# Handles exception
# This branch must exit the function
with then:
# Clean arg
for obj in argobjs:
self.pyapi.decref(obj)

# Release the GIL
self.pyapi.gil_release(gil_state)

# Return and signal exception
self.call_conv.return_exc(self.builder)

# Handles normal return
with orelse:
# Fix output value
native = self.pyapi.to_native_value(
fnty.dispatcher.output_types,
ret_obj,
)
output = native.value

# Release objs
self.pyapi.decref(ret_obj)
for obj in argobjs:
self.pyapi.decref(obj)

# cleanup output
if callable(native.cleanup):
native.cleanup()

# Release the GIL
self.pyapi.gil_release(gil_state)

# Error during unboxing
with self.builder.if_then(native.is_error):
self.call_conv.return_exc(self.builder)

res = output

elif isinstance(fnty, types.ExternalFunction):
# Handle a named external function
self.debug_print("# external function")
argvals = self.fold_call_args(fnty, signature,
expr.args, expr.vararg, expr.kws)
fndesc = funcdesc.ExternalFunctionDescriptor(
fnty.symbol, fnty.sig.return_type, fnty.sig.args)
func = self.context.declare_external_function(self.builder.module,
Expand All @@ -678,11 +740,15 @@ def lower_call(self, resty, expr):
elif isinstance(fnty, types.NumbaFunction):
# Handle a compiled Numba function
self.debug_print("# calling numba function")
argvals = self.fold_call_args(fnty, signature,
expr.args, expr.vararg, expr.kws)
res = self.context.call_internal(self.builder, fnty.fndesc,
fnty.sig, argvals)

elif isinstance(fnty, types.ExternalFunctionPointer):
self.debug_print("# calling external function pointer")
argvals = self.fold_call_args(fnty, signature,
expr.args, expr.vararg, expr.kws)
# Handle a C function pointer
pointer = self.loadvar(expr.func.name)
# If the external function pointer uses libpython
Expand Down Expand Up @@ -721,6 +787,8 @@ def lower_call(self, resty, expr):

elif isinstance(fnty, types.RecursiveCall):
# Recursive call
argvals = self.fold_call_args(fnty, signature,
expr.args, expr.vararg, expr.kws)
qualprefix = fnty.overloads[signature.args]
mangler = self.context.mangler or default_mangler
mangled_name = mangler(qualprefix, signature.args)
Expand All @@ -736,6 +804,12 @@ def lower_call(self, resty, expr):
# Normal function resolution
self.debug_print("# calling normal function: {0}".format(fnty))
self.debug_print("# signature: {0}".format(signature))
if (isinstance(expr.func, ir.Intrinsic) or
isinstance(fnty, types.ObjModeDispatcher)):
argvals = expr.func.args
else:
argvals = self.fold_call_args(fnty, signature,
expr.args, expr.vararg, expr.kws)
impl = self.context.get_function(fnty, signature)
if signature.recvr:
# The "self" object is passed as the function object
Expand Down
File renamed without changes.
3 changes: 2 additions & 1 deletion numba/runtests.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def main(self, argv, kwds):
print('Flags', flags)
result = run_tests([prog] + flags + tests, **kwds)
# Save failed
self.save_failed_tests(result, all_tests)
if not self.last_failed:
self.save_failed_tests(result, all_tests)
return result.wasSuccessful()

def save_failed_tests(self, result, all_tests):
Expand Down
22 changes: 20 additions & 2 deletions numba/targets/boxing.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,26 @@ def unbox_array(typ, obj, c):
errcode = c.pyapi.nrt_adapt_ndarray_from_python(obj, ptr)
else:
errcode = c.pyapi.numba_array_adaptor(obj, ptr)
failed = cgutils.is_not_null(c.builder, errcode)

# TODO: here we have minimal typechecking by the itemsize.
# need to do better
try:
expected_itemsize = numpy_support.as_dtype(typ.dtype).itemsize
except NotImplementedError:
# Don't check types that can't be `as_dtype()`-ed
itemsize_mismatch = cgutils.false_bit
else:
expected_itemsize = nativeary.itemsize.type(expected_itemsize)
itemsize_mismatch = c.builder.icmp_unsigned(
'!=',
nativeary.itemsize,
expected_itemsize,
)

failed = c.builder.or_(
cgutils.is_not_null(c.builder, errcode),
itemsize_mismatch,
)
# Handle error
with c.builder.if_then(failed, likely=False):
c.pyapi.err_set_string("PyExc_TypeError",
Expand Down Expand Up @@ -1015,4 +1034,3 @@ def box_unsupported(typ, val, c):
c.pyapi.err_set_string("PyExc_TypeError", msg)
res = c.pyapi.get_null_object()
return res

2 changes: 1 addition & 1 deletion numba/tests/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def pyfunc(a, i):
types.int32))

cfunc = cres.entry_point
a = np.empty(2)
a = np.empty(2, dtype=np.int32)

self.assertEqual(cfunc(a, 0), pyfunc(a, 0))

Expand Down
2 changes: 1 addition & 1 deletion numba/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ def make_dest():

# Mismatching input size and slice length
with self.assertRaises(ValueError):
cfunc(np.zeros_like(arg), arg, 0, 0, 1)
cfunc(np.zeros_like(arg, dtype=np.int32), arg, 0, 0, 1)

def check_1d_slicing_set_sequence(self, flags, seqty, seq):
"""
Expand Down
2 changes: 1 addition & 1 deletion numba/tests/test_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def check_values(typ, values):
cr = compile_isolated(pyfunc, (arraytype,), flags=enable_pyobj_flags)
cfunc = cr.entry_point
with captured_stdout():
cfunc(np.arange(10))
cfunc(np.arange(10, dtype=np.int32))
self.assertEqual(sys.stdout.getvalue(),
'[0 1 2 3 4 5 6 7 8 9]\n')

Expand Down
2 changes: 1 addition & 1 deletion numba/tests/test_unpack_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_unpack_shape(self, flags=force_pyobj_flags):
layout='C')],
flags=flags)
cfunc = cr.entry_point
a = np.zeros(shape=(1, 2, 3))
a = np.zeros(shape=(1, 2, 3)).astype(np.int32)
self.assertPreciseEqual(cfunc(a), pyfunc(a))

@tag('important')
Expand Down

0 comments on commit 35d9cdc

Please sign in to comment.