Skip to content

Commit

Permalink
Test for temp assignment removal in Interpreter
Browse files Browse the repository at this point in the history
  • Loading branch information
ehsantn committed Feb 1, 2021
1 parent b763436 commit ca58e53
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions numba/tests/test_copy_propagate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
# SPDX-License-Identifier: BSD-2-Clause
#

from numba import njit
from numba.core import types, typing, ir, config, compiler, cpu
from numba.core.registry import cpu_target
from numba.core.annotations import type_annotations
from numba.core.ir_utils import (copy_propagate, apply_copy_propagate,
get_name_var_table)
from numba.core.typed_passes import type_inference_stage
from numba.tests.test_ir_inlining import InlineTestPipeline
import unittest


def test_will_propagate(b, z, w):
x = 3
x1 = x
Expand All @@ -21,6 +24,7 @@ def test_will_propagate(b, z, w):
a = 2 * x1
return a < b


def test_wont_propagate(b, z, w):
x = 3
if b > 0:
Expand All @@ -31,15 +35,18 @@ def test_wont_propagate(b, z, w):
a = 2 * x
return a < b


def null_func(a,b,c,d):
False


def inListVar(list_var, var):
for i in list_var:
if i.name == var:
return True
return False


def findAssign(func_ir, var):
for label, block in func_ir.blocks.items():
for i, inst in enumerate(block.body):
Expand All @@ -50,6 +57,7 @@ def findAssign(func_ir, var):

return False


class TestCopyPropagate(unittest.TestCase):
def test1(self):
typingctx = typing.Context()
Expand Down Expand Up @@ -97,6 +105,30 @@ def test2(self):

self.assertTrue(findAssign(test_ir, "x"))

def test_input_ir_extra_copies(self):
"""make sure Interpreter._remove_unused_temporaries() has removed extra copies
in the IR in simple cases so copy propagation is faster
"""
def test_impl(a):
b = a + 3
return b

j_func = njit(pipeline_class=InlineTestPipeline)(test_impl)
self.assertEqual(test_impl(5), j_func(5))

# make sure b is the target of the expression assignment, not a temporary
fir = j_func.overloads[j_func.signatures[0]].metadata['preserved_ir']
self.assertTrue(len(fir.blocks) == 1)
block = next(iter(fir.blocks.values()))
b_found = False
for stmt in block.body:
if isinstance(stmt, ir.Assign) and stmt.target.name == "b":
b_found = True
self.assertTrue(isinstance(stmt.value, ir.Expr)
and stmt.value.op == "binop" and stmt.value.lhs.name == "a")

self.assertTrue(b_found)


if __name__ == "__main__":
unittest.main()

0 comments on commit ca58e53

Please sign in to comment.