Skip to content

Commit

Permalink
Merge pull request #4070 from stuartarchibald/fix/raise_write_to_glob…
Browse files Browse the repository at this point in the history
…al_dict

Catch writes to global typed.Dict and raise.
  • Loading branch information
seibert committed May 13, 2019
2 parents 6eadd0a + 2e00c1f commit 92dbdb9
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
9 changes: 5 additions & 4 deletions numba/ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1903,10 +1903,11 @@ def raise_on_unsupported_feature(func_ir, typemap):
# checks for globals that are also reflected
if isinstance(stmt.value, ir.Global):
ty = typemap[stmt.target.name]
if getattr(ty, 'reflected', False):
msg = ("Writing to a %s defined in globals is not "
"supported as globals are considered compile-time "
"constants.")
msg = ("Writing to a %s defined in globals is not "
"supported as globals are considered compile-time "
"constants.")
if (getattr(ty, 'reflected', False) or
isinstance(ty, types.DictType)):
raise TypingError(msg % ty, loc=stmt.loc)

# There is more than one call to function gdb/gdb_init
Expand Down
28 changes: 20 additions & 8 deletions numba/tests/test_errorhandling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
"""
from __future__ import division

from numba import jit, njit
from numba import jit, njit, typed, int64
from numba import unittest_support as unittest
from numba import errors, utils
import numpy as np

# used in TestMiscErrorHandling::test_handling_of_write_to_global
# used in TestMiscErrorHandling::test_handling_of_write_to_*_global
_global_list = [1, 2, 3, 4]
_global_dict = typed.Dict.empty(int64, int64)

class TestErrorHandlingBeforeLowering(unittest.TestCase):

Expand Down Expand Up @@ -108,19 +109,30 @@ def f(a):
expected = 'File "unknown location", line 0:'
self.assertIn(expected, str(raises.exception))

def test_handling_of_write_to_global(self):
@njit
def foo():
_global_list[0] = 10

def check_write_to_globals(self, func):
with self.assertRaises(errors.TypingError) as raises:
foo()
func()

expected = ["Writing to a", "defined in globals is not supported"]
for ex in expected:
self.assertIn(ex, str(raises.exception))


def test_handling_of_write_to_reflected_global(self):
@njit
def foo():
_global_list[0] = 10

self.check_write_to_globals(foo)

def test_handling_of_write_to_typed_dict_global(self):
@njit
def foo():
_global_dict[0] = 10

self.check_write_to_globals(foo)


class TestConstantInferenceErrorHandling(unittest.TestCase):

def test_basic_error(self):
Expand Down

0 comments on commit 92dbdb9

Please sign in to comment.