Skip to content

Commit

Permalink
Test hack to fix old-style implementation's exception propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
sklam committed Dec 11, 2019
1 parent 38d1a7a commit f127335
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 3 deletions.
36 changes: 35 additions & 1 deletion numba/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from llvmlite.llvmpy.core import Constant, Type, Builder

from . import (_dynfunc, cgutils, config, funcdesc, generators, ir, types,
typing, utils)
typing, utils, ir_utils)
from .errors import (LoweringError, new_error_context, TypingError,
LiteralTypingError)
from .targets import removerefctpass
Expand Down Expand Up @@ -159,6 +159,11 @@ def pre_block(self, block):
Called before lowering a block.
"""

def post_block(self, block):
"""
Called after lowering a block.
"""

def return_exception(self, exc_class, exc_args=None, loc=None):
"""Propagate exception to the caller.
"""
Expand Down Expand Up @@ -269,6 +274,7 @@ def lower_block(self, block):
with new_error_context('lowering "{inst}" at {loc}', inst=inst,
loc=self.loc, errcls_=defaulterrcls):
self.lower_inst(inst)
self.post_block(block)

def create_cpython_wrapper(self, release_gil=False):
"""
Expand Down Expand Up @@ -305,6 +311,34 @@ def debug_print(self, msg):
class Lower(BaseLower):
GeneratorLower = generators.GeneratorLower

def pre_block(self, block):
from numba.unsafe import eh

super(Lower, self).pre_block(block)

# Detect if we are in a TRY block by looking for a call to
# `eh.exception_check`.
for call in block.find_exprs(op='call'):
defn = ir_utils.guard(
ir_utils.get_definition, self.func_ir, call.func,
)
if defn is not None and isinstance(defn, ir.Global):
if defn.value is eh.exception_check:
if isinstance(block.terminator, ir.Branch):
targetblk = self.blkmap[block.terminator.truebr]
# NOTE: This hacks in an attribute for call_conv to
# pick up. This hack is no longer needed when
# all old-style implementation are gone.
self.builder._in_try_block = {'target': targetblk}
break

def post_block(self, block):
# Cleaup
try:
del self.builder._in_try_block
except AttributeError:
pass

def lower_inst(self, inst):
# Set debug location for all subsequent LL instructions
self.debuginfo.mark_location(self.builder, self.loc)
Expand Down
13 changes: 12 additions & 1 deletion numba/targets/callconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,17 @@ def set_static_user_exc(self, builder, exc, exc_args=None, loc=None,

def return_user_exc(self, builder, exc, exc_args=None, loc=None,
func_name=None):
try_info = getattr(builder, '_in_try_block', False)
self.set_static_user_exc(builder, exc, exc_args=exc_args,
loc=loc, func_name=func_name)
trystatus = self.check_try_status(builder)
self._return_errcode_raw(builder, RETCODE_USEREXC)
if try_info:
# This is a hack for old-style impl.
# We will branch directly to the exception handler.
builder.branch(try_info['target'])
else:
# Return from the current function
self._return_errcode_raw(builder, RETCODE_USEREXC)

def _get_try_state(self, builder):
try:
Expand Down Expand Up @@ -428,6 +435,10 @@ def unset_try_status(self, builder):
null = cgutils.get_null_value(excinfoptr.type.pointee)
builder.store(null, excinfoptr)

def get_try_block_status(self, builder):
trystatus = self.check_try_status(builder)
return trystatus.in_try

def set_exception(self, builder, exc):
self.set_static_user_exc(builder, exc=Exception)

Expand Down
16 changes: 15 additions & 1 deletion numba/tests/test_try_except.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,6 @@ def udt(x):
str(raises.exception)
)

# runtime error no active exception to reraise
def test_try_except_reraise(self):
@njit
def udt():
Expand Down Expand Up @@ -495,6 +494,21 @@ def udt():
str(raises.exception),
)

def test_division_operator(self):
# This test that old-style implementation propagate exception
# to the exception handler properly.
@njit
def udt(y):
try:
1 / y
except Exception:
return 0xdead
else:
return 1 / y

self.assertEqual(udt(0), 0xdead)
self.assertEqual(udt(2), 0.5)


@skip_tryexcept_unsupported
class TestTryExceptNested(TestCase):
Expand Down

0 comments on commit f127335

Please sign in to comment.