Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid temp variable assignments #6575

Merged
merged 25 commits into from
Feb 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ac44f59
Avoid temp variable assignments
ehsantn Dec 13, 2020
ee30e18
Fix definitions after removing temp vars
ehsantn Dec 14, 2020
80b6aaa
Fix typeof_global() to support user provided type for same variable
ehsantn Dec 14, 2020
9b1ed7d
Update parfor reduction handling to not assuming extra temp assignment
ehsantn Dec 14, 2020
c53de53
Merge branch 'master' into ehsan/avoid_tmp_vars
ehsantn Jan 14, 2021
77b66d7
Handle replaced temp vars in later assignments
ehsantn Jan 14, 2021
74ddbc8
Handle chained unpack
ehsantn Jan 14, 2021
01ecb02
Avoid creating new inplace binop cases
ehsantn Jan 15, 2021
008efe8
avoid replacing lhs of inplace_binop
ehsantn Jan 15, 2021
2b5597c
avoid removing extra assign in replace_returns since needed for reduc…
ehsantn Jan 15, 2021
e5a9772
create tmp var for parfor lowering array value replacement
ehsantn Jan 15, 2021
10700a8
proper check for call/binop reductions
ehsantn Jan 15, 2021
e8b6382
handle both cases are reduction detection
ehsantn Jan 15, 2021
63bc44d
handle SetItem/SetAttr case in chained assignment
ehsantn Jan 15, 2021
b958d8b
fix copy propagation test
ehsantn Jan 15, 2021
a182952
fix dead code elimination tests
ehsantn Jan 15, 2021
d3db219
fix ir_utils test function match
ehsantn Jan 15, 2021
92f0071
Merge branch 'master' into ehsan/avoid_tmp_vars
ehsantn Jan 27, 2021
e5d537e
Update numba/core/typeinfer.py
ehsantn Feb 1, 2021
2bd6ce4
remove unrelated file
ehsantn Feb 1, 2021
34016d4
add suggested change in array expr
ehsantn Feb 1, 2021
b763436
Fix typo in type check
ehsantn Feb 1, 2021
ca58e53
Test for temp assignment removal in Interpreter
ehsantn Feb 1, 2021
f9f9716
fix flake8
ehsantn Feb 1, 2021
4aec70c
Tests for chained assignment corner cases
ehsantn Feb 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
63 changes: 59 additions & 4 deletions numba/core/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,14 +540,69 @@ def _remove_unused_temporaries(self):
current block.
"""
new_body = []
replaced_var = {}
for inst in self.current_block.body:
if (isinstance(inst, ir.Assign)
and inst.target.is_temp
and inst.target.name in self.assigner.unused_dests):
continue
# the same temporary is assigned to multiple variables in cases
# like a = b[i] = 1, so need to handle replaced temporaries in
# later setitem/setattr nodes
if (isinstance(inst, (ir.SetItem, ir.SetAttr))
and inst.value.name in replaced_var):
inst.value = replaced_var[inst.value.name]
stuartarchibald marked this conversation as resolved.
Show resolved Hide resolved
elif isinstance(inst, ir.Assign):
if (inst.target.is_temp
and inst.target.name in self.assigner.unused_dests):
continue
# the same temporary is assigned to multiple variables in cases
# like a = b = 1, so need to handle replaced temporaries in
# later assignments
if (isinstance(inst.value, ir.Var)
and inst.value.name in replaced_var):
inst.value = replaced_var[inst.value.name]
new_body.append(inst)
continue
# chained unpack cases may reuse temporary
# e.g. a = (b, c) = (x, y)
if (isinstance(inst.value, ir.Expr)
and inst.value.op == "exhaust_iter"
and inst.value.value.name in replaced_var):
inst.value.value = replaced_var[inst.value.value.name]
new_body.append(inst)
continue
# eliminate temporary variables that are assigned to user
# variables right after creation. E.g.:
# $1 = f(); a = $1 -> a = f()
# the temporary variable is not reused elsewhere since CPython
# bytecode is stack-based and this pattern corresponds to a pop
if (isinstance(inst.value, ir.Var) and inst.value.is_temp
and new_body and isinstance(new_body[-1], ir.Assign)):
prev_assign = new_body[-1]
# _var_used_in_binop check makes sure we don't create a new
# inplace binop operation which can fail
# (see TestFunctionType.test_in_iter_func_call)
if (prev_assign.target.name == inst.value.name
and not self._var_used_in_binop(
inst.target.name, prev_assign.value)):
replaced_var[inst.value.name] = inst.target
prev_assign.target = inst.target
# replace temp var definition in target with proper defs
self.definitions[inst.target.name].remove(inst.value)
self.definitions[inst.target.name].extend(
self.definitions.pop(inst.value.name)
)
continue

new_body.append(inst)

self.current_block.body = new_body

def _var_used_in_binop(self, varname, expr):
"""return True if 'expr' is a binary expression and 'varname' is used
in it as an argument
"""
return (isinstance(expr, ir.Expr)
and expr.op in ("binop", "inplace_binop")
and (varname == expr.lhs.name or varname == expr.rhs.name))

def _insert_outgoing_phis(self):
"""
Add assignments to forward requested outgoing values
Expand Down
26 changes: 14 additions & 12 deletions numba/core/ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1739,23 +1739,25 @@ def replace_arg_nodes(block, args):
stmt.value = args[idx]
return


def replace_returns(blocks, target, return_label):
"""
Return return statement by assigning directly to target, and a jump.
"""
for block in blocks.values():
casts = []
for i, stmt in enumerate(block.body):
if isinstance(stmt, ir.Return):
assert(i + 1 == len(block.body))
block.body[i] = ir.Assign(stmt.value, target, stmt.loc)
block.body.append(ir.Jump(return_label, stmt.loc))
# remove cast of the returned value
for cast in casts:
if cast.target.name == stmt.value.name:
cast.value = cast.value.value
elif isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Expr) and stmt.value.op == 'cast':
casts.append(stmt)
# some blocks may be empty during transformations
if not block.body:
continue
stmt = block.terminator
if isinstance(stmt, ir.Return):
block.body.pop() # remove return
cast_stmt = block.body.pop()
assert (isinstance(cast_stmt, ir.Assign)
and isinstance(cast_stmt.value, ir.Expr)
and cast_stmt.value.op == 'cast'), "invalid return cast"
block.body.append(ir.Assign(cast_stmt.value.value, target, stmt.loc))
block.body.append(ir.Jump(return_label, stmt.loc))


def gen_np_call(func_as_str, func, lhs, args, typingctx, typemap, calltypes):
scope = args[0].scope
Expand Down
11 changes: 10 additions & 1 deletion numba/core/typeinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1601,7 +1601,16 @@ def typeof_global(self, inst, target, gvar):
# Setting literal_value for globals because they are handled
# like const value in numba
lit = types.maybe_literal(gvar.value)
self.lock_type(target.name, lit or typ, loc=inst.loc)
# The user may have provided the type for this variable already.
# In this case, call add_type() to make sure the value type is
# consistent. See numba.tests.test_array_reductions
# TestArrayReductions.test_array_cumsum for examples.
# Variable type locked by using the locals dict.
tv = self.typevars[target.name]
if tv.locked:
tv.add_type(lit or typ, loc=inst.loc)
else:
self.lock_type(target.name, lit or typ, loc=inst.loc)
self.assumed_immutables.add(inst)

def typeof_expr(self, inst, target, expr):
Expand Down
2 changes: 1 addition & 1 deletion numba/np/ufunc/array_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _handle_matches(self):
self.array_assigns[instr.target.name] = new_instr
for operand in self._get_operands(expr):
operand_name = operand.name
if operand_name in self.array_assigns:
if operand.is_temp and operand_name in self.array_assigns:
child_assign = self.array_assigns[operand_name]
child_expr = child_assign.value
child_operands = child_expr.list_vars()
Expand Down
28 changes: 18 additions & 10 deletions numba/parfors/parfor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3522,13 +3522,17 @@ def check_conflicting_reduction_operators(param, nodes):
def get_reduction_init(nodes):
"""
Get initial value for known reductions.
Currently, only += and *= are supported. We assume the inplace_binop node
is followed by an assignment.
Currently, only += and *= are supported.
"""
require(len(nodes) >=2)
require(isinstance(nodes[-1].value, ir.Var))
require(nodes[-2].target.name == nodes[-1].value.name)
acc_expr = nodes[-2].value
require(len(nodes) >=1)
# there could be an extra assignment after the reduce node
# See: test_reduction_var_reuse
if isinstance(nodes[-1].value, ir.Var):
require(len(nodes) >=2)
require(nodes[-2].target.name == nodes[-1].value.name)
acc_expr = nodes[-2].value
else:
acc_expr = nodes[-1].value
require(isinstance(acc_expr, ir.Expr) and acc_expr.op=='inplace_binop')
if acc_expr.fn == operator.iadd or acc_expr.fn == operator.isub:
return 0, acc_expr.fn
Expand Down Expand Up @@ -3573,9 +3577,13 @@ def lookup(var, varonly=True):
if isinstance(rhs, ir.Expr):
in_vars = set(lookup(v, True).name for v in rhs.list_vars())
if name in in_vars:
next_node = nodes[i+1]
target_name = next_node.target.unversioned_name
if not (isinstance(next_node, ir.Assign) and target_name == unversioned_name):
# reductions like sum have an assignment afterwards
# e.g. $2 = a + $1; a = $2
# reductions that are functions calls like max() don't have an
# extra assignment afterwards
if (not (i+1 < len(nodes) and isinstance(nodes[i+1], ir.Assign)
and nodes[i+1].target.unversioned_name == unversioned_name)
and lhs.unversioned_name != unversioned_name):
raise ValueError(
f"Use of reduction variable {unversioned_name!r} other "
"than in a supported reduction function is not "
Expand All @@ -3593,7 +3601,7 @@ def lookup(var, varonly=True):
replace_dict[non_red_args[0]] = ir.Var(lhs.scope, name+"#init", lhs.loc)
replace_vars_inner(rhs, replace_dict)
reduce_nodes = nodes[i:]
break;
break
return reduce_nodes

def get_expr_args(expr):
Expand Down
5 changes: 4 additions & 1 deletion numba/parfors/parfor_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,7 +1385,10 @@ def replace_var_with_array_in_block(vars, block, typemap, calltypes):
const_assign = ir.Assign(const_node, const_var, inst.loc)
new_block.append(const_assign)

setitem_node = ir.SetItem(inst.target, const_var, inst.value, inst.loc)
val_var = ir.Var(inst.target.scope, mk_unique_var("$val"), inst.loc)
typemap[val_var.name] = typemap[inst.target.name]
new_block.append(ir.Assign(inst.value, val_var, inst.loc))
setitem_node = ir.SetItem(inst.target, const_var, val_var, inst.loc)
calltypes[setitem_node] = signature(
types.none, types.npytypes.Array(typemap[inst.target.name], 1, "C"), types.intp, typemap[inst.target.name])
new_block.append(setitem_node)
Expand Down
71 changes: 63 additions & 8 deletions numba/tests/test_copy_propagate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,29 @@
# SPDX-License-Identifier: BSD-2-Clause
#

from numba import jit, njit
from numba.core import types, typing, ir, config, compiler, cpu
from numba.core.registry import cpu_target
from numba.core.annotations import type_annotations
from numba.core.ir_utils import (copy_propagate, apply_copy_propagate,
get_name_var_table)
from numba.core.typed_passes import type_inference_stage
from numba.tests.test_ir_inlining import InlineTestPipeline
import numpy as np
import unittest


def test_will_propagate(b, z, w):
x = 3
x1 = x
if b > 0:
y = z + w
else:
y = 0
a = 2 * x
a = 2 * x1
return a < b


def test_wont_propagate(b, z, w):
x = 3
if b > 0:
Expand All @@ -30,15 +36,18 @@ def test_wont_propagate(b, z, w):
a = 2 * x
return a < b


def null_func(a,b,c,d):
False


def inListVar(list_var, var):
for i in list_var:
if i.name == var:
return True
return False


def findAssign(func_ir, var):
for label, block in func_ir.blocks.items():
for i, inst in enumerate(block.body):
Expand All @@ -49,20 +58,17 @@ def findAssign(func_ir, var):

return False


class TestCopyPropagate(unittest.TestCase):
def test1(self):
typingctx = typing.Context()
targetctx = cpu.CPUContext(typingctx)
test_ir = compiler.run_frontend(test_will_propagate)
#print("Num blocks = ", len(test_ir.blocks))
#print(test_ir.dump())
with cpu_target.nested_context(typingctx, targetctx):
typingctx.refresh()
targetctx.refresh()
args = (types.int64, types.int64, types.int64)
typemap, return_type, calltypes, _ = type_inference_stage(typingctx, test_ir, args, None)
#print("typemap = ", typemap)
#print("return_type = ", return_type)
type_annotation = type_annotations.TypeAnnotation(
func_ir=test_ir,
typemap=typemap,
Expand All @@ -75,14 +81,12 @@ def test1(self):
in_cps, out_cps = copy_propagate(test_ir.blocks, typemap)
apply_copy_propagate(test_ir.blocks, in_cps, get_name_var_table(test_ir.blocks), typemap, calltypes)

self.assertFalse(findAssign(test_ir, "x"))
self.assertFalse(findAssign(test_ir, "x1"))

def test2(self):
typingctx = typing.Context()
targetctx = cpu.CPUContext(typingctx)
test_ir = compiler.run_frontend(test_wont_propagate)
#print("Num blocks = ", len(test_ir.blocks))
#print(test_ir.dump())
with cpu_target.nested_context(typingctx, targetctx):
typingctx.refresh()
targetctx.refresh()
Expand All @@ -102,5 +106,56 @@ def test2(self):

self.assertTrue(findAssign(test_ir, "x"))

def test_input_ir_extra_copies(self):
"""make sure Interpreter._remove_unused_temporaries() has removed extra copies
in the IR in simple cases so copy propagation is faster
"""
def test_impl(a):
b = a + 3
return b

j_func = njit(pipeline_class=InlineTestPipeline)(test_impl)
self.assertEqual(test_impl(5), j_func(5))

# make sure b is the target of the expression assignment, not a temporary
fir = j_func.overloads[j_func.signatures[0]].metadata['preserved_ir']
self.assertTrue(len(fir.blocks) == 1)
block = next(iter(fir.blocks.values()))
b_found = False
for stmt in block.body:
if isinstance(stmt, ir.Assign) and stmt.target.name == "b":
b_found = True
self.assertTrue(isinstance(stmt.value, ir.Expr)
and stmt.value.op == "binop" and stmt.value.lhs.name == "a")

self.assertTrue(b_found)

def test_input_ir_copy_remove_transform(self):
"""make sure Interpreter._remove_unused_temporaries() does not generate
invalid code for rare chained assignment cases
"""
# regular chained assignment
def impl1(a):
b = c = a + 1
return (b, c)

# chained assignment with setitem
def impl2(A, i, a):
b = A[i] = a + 1
return b, A[i] + 2

# chained assignment with setattr
def impl3(A, a):
b = A.a = a + 1
return b, A.a + 2

class C:
pass

self.assertEqual(impl1(5), njit(impl1)(5))
self.assertEqual(impl2(np.ones(3), 0, 5), njit(impl2)(np.ones(3), 0, 5))
self.assertEqual(impl3(C(), 5), jit(forceobj=True)(impl3)(C(), 5))


if __name__ == "__main__":
unittest.main()
5 changes: 5 additions & 0 deletions numba/tests/test_function_type.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import unittest
import types as pytypes
from numba import jit, njit, cfunc, types, int64, float64, float32, errors
from numba import literal_unroll
Expand Down Expand Up @@ -1252,3 +1253,7 @@ def bar(fcs, ffs):
got = bar(tup, tup_bar)
expected = foo1(a) + foo2(a) + bar1(a) + bar2(a)
self.assertEqual(got, expected)


if __name__ == '__main__':
unittest.main()
10 changes: 3 additions & 7 deletions numba/tests/test_ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_func():
typing_res = type_inference_stage(
typingctx, test_ir, (), None)
matched_call = ir_utils.find_callname(
test_ir, test_ir.blocks[0].body[8].value, typing_res.typemap)
test_ir, test_ir.blocks[0].body[7].value, typing_res.typemap)
self.assertTrue(isinstance(matched_call, tuple) and
len(matched_call) == 2 and
matched_call[0] == 'append')
Expand Down Expand Up @@ -89,7 +89,7 @@ def check_initial_ir(the_ir):
# an assign of above into to variable `dead`
# a const int above 0xdeaddead
# an assign of said int to variable `deaddead`
# this is 4 things to remove
# this is 2 statements to remove

self.assertEqual(len(the_ir.blocks), 1)
block = the_ir.blocks[0]
Expand All @@ -99,18 +99,14 @@ def check_initial_ir(the_ir):
if 'dead' in getattr(x.target, 'name', ''):
deads.append(x)

expect_removed = []
self.assertEqual(len(deads), 2)
expect_removed.extend(deads)
for d in deads:
# check the ir.Const is the definition and the value is expected
const_val = the_ir.get_definition(d.value)
self.assertTrue(int('0x%s' % d.target.name, 16),
const_val.value)
expect_removed.append(const_val)

self.assertEqual(len(expect_removed), 4)
return expect_removed
return deads

def check_dce_ir(the_ir):
self.assertEqual(len(the_ir.blocks), 1)
Expand Down