Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better handling of prange with multiple reductions on the same variable. #4935

Merged
merged 5 commits into from
Dec 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
37 changes: 32 additions & 5 deletions numba/npyufunc/parfor.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ def _lower_parfor_parallel(lowerer, parfor):
name = parfor_redvars[i]
redarr = redarrs[name]
redvar_typ = lowerer.fndesc.typemap[name]
if config.DEBUG_ARRAY_OPT:
print("post-gufunc reduction:", name, redarr, redvar_typ)

if config.DEBUG_ARRAY_OPT_RUNTIME:
res_print_str = "res_print"
Expand Down Expand Up @@ -337,22 +339,23 @@ def _lower_parfor_parallel(lowerer, parfor):
lowerer.lower_inst(init_assign)

if config.DEBUG_ARRAY_OPT_RUNTIME:
res_print_str = "one_res_print"
res_print_str = "res_print1 for thread " + str(j) + ":"
strconsttyp = types.StringLiteral(res_print_str)
lhs = ir.Var(scope, mk_unique_var("str_const"), loc)
assign_lhs = ir.Assign(value=ir.Const(value=res_print_str, loc=loc),
target=lhs, loc=loc)
typemap[lhs.name] = strconsttyp
lowerer.lower_inst(assign_lhs)

res_print = ir.Print(args=[lhs, index_var, oneelem, init_var],
res_print = ir.Print(args=[lhs, index_var, oneelem, init_var, ir.Var(scope, name, loc)],
vararg=None, loc=loc)
lowerer.fndesc.calltypes[res_print] = signature(types.none,
typemap[lhs.name],
typemap[index_var.name],
typemap[oneelem.name],
typemap[init_var.name])
print("res_print", res_print)
typemap[init_var.name],
typemap[name])
print("res_print1", res_print)
lowerer.lower_inst(res_print)

# generate code for combining reduction variable with thread output
Expand All @@ -368,7 +371,8 @@ def _lower_parfor_parallel(lowerer, parfor):
rhs = inst.value
# We probably need to generalize this since it only does substitutions in
# inplace_binops.
if isinstance(rhs, ir.Expr) and rhs.op == 'inplace_binop' and rhs.rhs.name == init_var.name:
if (isinstance(rhs, ir.Expr) and rhs.op == 'inplace_binop' and
rhs.rhs.name == init_var.name):
if config.DEBUG_ARRAY_OPT:
print("Adding call to reduction", rhs)
if rhs.fn == operator.isub:
Expand All @@ -391,6 +395,29 @@ def _lower_parfor_parallel(lowerer, parfor):
# Add calltype back in for the expr with updated signature.
lowerer.fndesc.calltypes[rhs] = ct
lowerer.lower_inst(inst)
if isinstance(inst, ir.Assign) and name == inst.target.name:
break

if config.DEBUG_ARRAY_OPT_RUNTIME:
res_print_str = "res_print2 for thread " + str(j) + ":"
strconsttyp = types.StringLiteral(res_print_str)
lhs = ir.Var(scope, mk_unique_var("str_const"), loc)
assign_lhs = ir.Assign(value=ir.Const(value=res_print_str, loc=loc),
target=lhs, loc=loc)
typemap[lhs.name] = strconsttyp
lowerer.lower_inst(assign_lhs)

res_print = ir.Print(args=[lhs, index_var, oneelem, init_var, ir.Var(scope, name, loc)],
vararg=None, loc=loc)
lowerer.fndesc.calltypes[res_print] = signature(types.none,
typemap[lhs.name],
typemap[index_var.name],
typemap[oneelem.name],
typemap[init_var.name],
typemap[name])
print("res_print2", res_print)
lowerer.lower_inst(res_print)


# Cleanup reduction variable
for v in redarrs.values():
Expand Down
24 changes: 23 additions & 1 deletion numba/parfor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2966,7 +2966,7 @@ def get_parfor_reductions(func_ir, parfor, parfor_params, calltypes, reductions=
for stmt in reversed(parfor.loop_body[label].body):
if (isinstance(stmt, ir.Assign)
and (stmt.target.name in parfor_params
or stmt.target.name in var_to_param)):
or stmt.target.name in var_to_param)):
lhs = stmt.target.name
rhs = stmt.value
cur_param = lhs if lhs in parfor_params else var_to_param[lhs]
Expand All @@ -2987,6 +2987,7 @@ def get_parfor_reductions(func_ir, parfor, parfor_params, calltypes, reductions=
# recursive parfors can have reductions like test_prange8
get_parfor_reductions(func_ir, stmt, parfor_params, calltypes,
reductions, reduce_varnames, param_uses, param_nodes, var_to_param)

for param, used_vars in param_uses.items():
# a parameter is a reduction variable if its value is used to update it
# check reduce_varnames since recursive parfors might have processed
Expand All @@ -2995,15 +2996,36 @@ def get_parfor_reductions(func_ir, parfor, parfor_params, calltypes, reductions=
reduce_varnames.append(param)
param_nodes[param].reverse()
reduce_nodes = get_reduce_nodes(param, param_nodes[param], func_ir)
check_conflicting_reduction_operators(param, reduce_nodes)
gri_out = guard(get_reduction_init, reduce_nodes)
if gri_out is not None:
init_val, redop = gri_out
else:
init_val = None
redop = None
reductions[param] = (init_val, reduce_nodes, redop)

return reduce_varnames, reductions

def check_conflicting_reduction_operators(param, nodes):
"""In prange, a user could theoretically specify conflicting
reduction operators. For example, in one spot it is += and
another spot *=. Here, we raise an exception if multiple
different reduction operators are used in one prange.
"""
first_red_func = None
for node in nodes:
if (isinstance(node, ir.Assign) and
isinstance(node.value, ir.Expr) and
node.value.op=='inplace_binop'):
if first_red_func is None:
first_red_func = node.value.fn
else:
if first_red_func != node.value.fn:
msg = ("Reduction variable %s has multiple conflicting "
"reduction operators." % param)
raise errors.UnsupportedError(msg, node.loc)

def get_reduction_init(nodes):
"""
Get initial value for known reductions.
Expand Down
28 changes: 28 additions & 0 deletions numba/tests/test_parfors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2068,6 +2068,34 @@ def test_impl(A):
self.prange_tester(test_impl, A, scheduler_type='unsigned',
check_fastmath=True, check_fastmath_result=True)

@skip_unsupported
def test_prange_two_instances_same_reduction_var(self):
# issue4922 - multiple uses of same reduction variable
def test_impl(n):
c = 0
for i in range(n):
c += 1
if i > 10:
c += 1
return c
self.prange_tester(test_impl, 9)

@skip_unsupported
def test_prange_conflicting_reduction_ops(self):
def test_impl(n):
c = 0
for i in range(n):
c += 1
if i > 10:
c *= 1
return c

with self.assertRaises(errors.UnsupportedError) as raises:
self.prange_tester(test_impl, 9)
msg = ('Reduction variable c has multiple conflicting reduction '
'operators.')
self.assertIn(msg, str(raises.exception))

# @skip_unsupported
@test_disabled
def test_check_error_model(self):
Expand Down