Skip to content

Commit

Permalink
Merge pull request #5630 from sklam/fix/iss5570
Browse files Browse the repository at this point in the history
Fix #5570. Incorrect race variable detection due to SSA naming.
  • Loading branch information
sklam committed Apr 30, 2020
2 parents 684d420 + 765298a commit 47cdbe5
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
31 changes: 30 additions & 1 deletion numba/parfors/parfor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3170,14 +3170,43 @@ def get_parfor_params(blocks, options_fusion, fusion_info):
dummy_block.body = block.body[:i]
before_defs = compute_use_defs({0: dummy_block}).defmap[0]
pre_defs |= before_defs
parfor.params = get_parfor_params_inner(parfor, pre_defs, options_fusion, fusion_info) | parfor.races
params = get_parfor_params_inner(
parfor, pre_defs, options_fusion, fusion_info,
)
parfor.params, parfor.races = _combine_params_races_for_ssa_names(
block.scope, params, parfor.races,
)
parfor_ids.add(parfor.id)
parfors.append(parfor)

pre_defs |= all_defs[label]
return parfor_ids, parfors


def _combine_params_races_for_ssa_names(scope, params, races):
"""Returns `(params|races1, races1)`, where `races1` contains all variables
in `races` are NOT referring to the same unversioned (SSA) variables in
`params`.
"""
def unversion(k):
try:
return scope.get_exact(k).unversioned_name
except ir.NotDefinedError:
# XXX: it's a bug that something references an undefined name
return k

races1 = set(races)
unver_params = list(map(unversion, params))

for rv in races:
if any(unversion(rv) == pv for pv in unver_params):
races1.discard(rv)
else:
break

return params | races1, races1


def get_parfor_params_inner(parfor, pre_defs, options_fusion, fusion_info):
blocks = wrap_parfor_blocks(parfor)
cfg = compute_cfg_from_blocks(blocks)
Expand Down
20 changes: 20 additions & 0 deletions numba/tests/test_parfors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3428,6 +3428,26 @@ def parallel_test(size, arr):

self.check(parallel_test, size, arr)

@skip_parfors_unsupported
def test_issue5570_ssa_races(self):
@njit(parallel=True)
def foo(src, method, out):
for i in prange(1):
for j in range(1):
out[i, j] = 1
if method:
out += 1
return out

src = np.zeros((5,5))
method = 57
out = np.zeros((2, 2))

self.assertPreciseEqual(
foo(src, method, out),
foo.py_func(src, method, out)
)


@skip_parfors_unsupported
class TestParforsDiagnostics(TestParforsBase):
Expand Down

0 comments on commit 47cdbe5

Please sign in to comment.