Skip to content

Commit

Permalink
Merge pull request #4474 from ehsantn/fix_parfor_setitem_rm_dead
Browse files Browse the repository at this point in the history
Fix liveness for remove dead of parfors (and other IR extensions)
  • Loading branch information
seibert committed Aug 29, 2019
2 parents 196fb2c + 57aea3b commit fee9b90
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 1 deletion.
3 changes: 2 additions & 1 deletion numba/ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,8 @@ def remove_dead_block(block, lives, call_table, arg_aliases, alias_map,
# let external calls handle stmt if type matches
if type(stmt) in remove_dead_extensions:
f = remove_dead_extensions[type(stmt)]
stmt = f(stmt, lives, arg_aliases, alias_map, func_ir, typemap)
stmt = f(stmt, lives_n_aliases, arg_aliases, alias_map, func_ir,
typemap)
if stmt is None:
removed = True
continue
Expand Down
54 changes: 54 additions & 0 deletions numba/tests/test_remove_dead.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,5 +213,59 @@ def func(A, i):
# recover global state
numba.ir_utils.alias_func_extensions = old_ext_handlers

@skip_parfors_unsupported
def test_alias_parfor_extension(self):
"""Make sure aliases are considered in remove dead extension for
parfors.
"""
def func():
n = 11
numba.parfor.init_prange()
A = np.empty(n)
B = A # create alias to A
for i in numba.prange(n):
A[i] = i

return B

class TestPipeline(numba.compiler.Pipeline):
"""Test pipeline that just converts prange() to parfor and calls
remove_dead(). Copy propagation can replace B in the example code
which this pipeline avoids.
"""
def define_pipelines(self, pm):
pm.create_pipeline("test parfor aliasing")
self.add_preprocessing_stage(pm)
self.add_pre_typing_stage(pm)
self.add_typing_stage(pm)
pm.add_stage(
self.stage_limited_parfor, "just prange and rm dead")
self.add_lowering_stage(pm)
self.add_cleanup_stage(pm)

def stage_limited_parfor(self):
parfor_pass = numba.parfor.ParforPass(
self.func_ir,
self.type_annotation.typemap,
self.type_annotation.calltypes,
self.return_type,
self.typingctx,
self.flags.auto_parallel,
self.flags,
self.parfor_diagnostics
)
remove_dels(self.func_ir.blocks)
parfor_pass.array_analysis.run(self.func_ir.blocks)
parfor_pass._convert_loop(self.func_ir.blocks)
remove_dead(self.func_ir.blocks, self.func_ir.arg_names,
self.func_ir, self.type_annotation.typemap)
numba.parfor.get_parfor_params(self.func_ir.blocks,
parfor_pass.options.fusion, parfor_pass.nested_fusion_info)

test_res = numba.jit(pipeline_class=TestPipeline)(func)()
py_res = func()
np.testing.assert_array_equal(test_res, py_res)


if __name__ == "__main__":
unittest.main()

0 comments on commit fee9b90

Please sign in to comment.