Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make IR comparable and legalize it. #4162

Merged
merged 6 commits into from
Aug 13, 2019
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion numba/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from numba.parfor import PreParforPass, ParforPass, Parfor, ParforDiagnostics
from numba.inline_closurecall import InlineClosureCallPass
from numba.errors import CompilerError
from numba.ir_utils import raise_on_unsupported_feature, warn_deprecated
from numba.ir_utils import (raise_on_unsupported_feature, warn_deprecated,
check_and_legalize_ir)
from numba.compiler_lock import global_compiler_lock
from numba.analysis import dead_branch_prune

Expand Down Expand Up @@ -766,6 +767,8 @@ def stage_compile_interp_mode(self):
def stage_ir_legalization(self):
raise_on_unsupported_feature(self.func_ir, self.typemap)
warn_deprecated(self.func_ir, self.typemap)
# NOTE: this function call must go last, it checks and fixes invalid IR!
check_and_legalize_ir(self.func_ir)

def stage_cleanup(self):
"""
Expand Down
10 changes: 5 additions & 5 deletions numba/inline_closurecall.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,8 +592,8 @@ def _inline_arraycall(func_ir, cfg, visited, loop, swapped, enable_prange=False,
if isinstance(func_def, ir.Expr) and func_def.op == 'getattr' \
and func_def.attr == 'append':
list_def = get_definition(func_ir, func_def.value)
debug_print("list_def = ", list_def, list_def == list_var_def)
if list_def == list_var_def:
debug_print("list_def = ", list_def, list_def is list_var_def)
if list_def is list_var_def:
# found matching append call
list_append_stmts.append((label, block, stmt))

Expand Down Expand Up @@ -648,7 +648,7 @@ def is_removed(val, removed):
# Skip list construction and skip terminator, add the rest to stmts
for i in range(len(loop_entry.body) - 1):
stmt = loop_entry.body[i]
if isinstance(stmt, ir.Assign) and (stmt.value == list_def or is_removed(stmt.value, removed)):
if isinstance(stmt, ir.Assign) and (stmt.value is list_def or is_removed(stmt.value, removed)):
removed.append(stmt.target)
else:
stmts.append(stmt)
Expand Down Expand Up @@ -770,7 +770,7 @@ def is_removed(val, removed):

# In append_block, change list_append into array assign
for i in range(len(append_block.body)):
if append_block.body[i] == append_stmt:
if append_block.body[i] is append_stmt:
debug_print("Replace append with SetItem")
append_block.body[i] = ir.SetItem(target=array_var, index=index_var,
value=append_stmt.value.args[0], loc=append_stmt.loc)
Expand Down Expand Up @@ -833,7 +833,7 @@ def fix_dependencies(expr, varlist):
inst = body[i]
if isinstance(inst, ir.Assign):
defined.add(inst.target.name)
if inst.value == expr:
if inst.value is expr:
new_varlist = []
for var in varlist:
# var must be defined before this inst, or live
Expand Down
175 changes: 158 additions & 17 deletions numba/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
import pprint
import re
import sys
import warnings
import operator
from types import FunctionType, BuiltinFunctionType
from functools import total_ordering

from numba import config, errors
from .utils import BINOPS_TO_OPERATORS, INPLACE_BINOPS_TO_OPERATORS, UNARY_BUITINS_TO_OPERATORS, OPERATORS_TO_BUILTINS
from .errors import (NotDefinedError, RedefinedError, VerificationError,
ConstantInferenceError)
from .six import StringIO

# terminal color markup
_termcolor = errors.termcolor()
Expand All @@ -40,6 +41,17 @@ def __init__(self, filename, line, col=None, maybe_decorator=False):
self.lines = None # the source lines from the linecache
self.maybe_decorator = maybe_decorator

def __eq__(self, other):
# equivalence is solely based on filename, line and col
if type(self) is not type(other): return False
if self.filename != other.filename: return False
if self.line != other.line: return False
if self.col != other.col: return False
return True

def __ne__(self, other):
return not self.__eq__(other)

@classmethod
def from_function_id(cls, func_id):
return cls(func_id.filename, func_id.firstlineno, maybe_decorator=True)
Expand Down Expand Up @@ -186,6 +198,54 @@ def with_lineno(self, line, col=None):
unknown_loc = Loc("unknown location", 0, 0)


@total_ordering
class SlotEqualityCheckMixin(object):
# some ir nodes are __dict__ free using __slots__ instead, this mixin
# should not trigger the unintended creation of __dict__.
__slots__ = tuple()

def __eq__(self, other):
if type(self) == type(other):
sklam marked this conversation as resolved.
Show resolved Hide resolved
for name in self.__slots__:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this will not support inheritance. An instance's __slots__ can only see the version defined in its type but not it's parent's.

if getattr(self, name) != getattr(other, name):
return False
else:
return True
return False

def __le__(self, other):
return str(self) < str(other)
sklam marked this conversation as resolved.
Show resolved Hide resolved

def __hash__(self):
return id(self)


@total_ordering
class EqualityCheckMixin(object):
""" Mixin for basic equality checking """

def __eq__(self, other):
if type(self) == type(other):
sklam marked this conversation as resolved.
Show resolved Hide resolved
def fixup(adict):
bad = ('loc', 'scope')
d = dict(adict)
for x in bad:
if x in d:
d.pop(x)
sklam marked this conversation as resolved.
Show resolved Hide resolved
return d
d1 = fixup(self.__dict__)
d2 = fixup(other.__dict__)
if d1 == d2:
return True
return False

def __le__(self, other):
return str(self) < str(other)

def __hash__(self):
return id(self)


class VarMap(object):
def __init__(self):
self._con = {}
Expand Down Expand Up @@ -217,14 +277,23 @@ def __hash__(self):
def __iter__(self):
return self._con.iterkeys()

def __eq__(self, other):
if type(self) == type(other):
sklam marked this conversation as resolved.
Show resolved Hide resolved
# check keys only, else __eq__ ref cycles, scope -> varmap -> var
return self._con.keys() == other._con.keys()
return False

def __ne__(self, other):
return not self.__eq__(other)


class AbstractRHS(object):
"""Abstract base class for anything that can be the RHS of an assignment.
This class **does not** define any methods.
"""


class Inst(AbstractRHS):
class Inst(AbstractRHS, EqualityCheckMixin):
"""
Base class for all IR instructions.
"""
Expand Down Expand Up @@ -750,7 +819,7 @@ def __init__(self, contextmanager, begin, end, loc):
contextmanager : IR value
begin, end : int
The beginning and the ending offset of the with-body.
loc : int
loc : ir.Loc instance
Source location
"""
assert isinstance(contextmanager, Var)
Expand All @@ -767,7 +836,7 @@ def list_vars(self):
return [self.contextmanager]


class Arg(AbstractRHS):
class Arg(AbstractRHS, EqualityCheckMixin):
sklam marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, name, index, loc):
assert isinstance(name, str)
assert isinstance(index, int)
Expand All @@ -783,7 +852,7 @@ def infer_constant(self):
raise ConstantInferenceError('%s' % self, loc=self.loc)


class Const(AbstractRHS):
class Const(AbstractRHS, EqualityCheckMixin):
def __init__(self, value, loc, use_literal_type=True):
assert isinstance(loc, Loc)
self.value = value
Expand All @@ -797,7 +866,8 @@ def __repr__(self):
def infer_constant(self):
return self.value

class Global(AbstractRHS):

class Global(AbstractRHS, EqualityCheckMixin):
def __init__(self, name, value, loc):
assert isinstance(loc, Loc)
self.name = name
Expand All @@ -816,7 +886,7 @@ def __deepcopy__(self, memo):
return Global(self.name, self.value, copy.deepcopy(self.loc))


class FreeVar(AbstractRHS):
class FreeVar(AbstractRHS, EqualityCheckMixin):
"""
A freevar, as loaded by LOAD_DECREF.
(i.e. a variable defined in an enclosing non-global scope)
Expand All @@ -841,7 +911,7 @@ def infer_constant(self):
return self.value


class Var(AbstractRHS):
class Var(AbstractRHS, EqualityCheckMixin):
"""
Attributes
-----------
Expand Down Expand Up @@ -873,7 +943,7 @@ def is_temp(self):
return self.name.startswith("$")


class Intrinsic(object):
class Intrinsic(EqualityCheckMixin):
"""
A low-level "intrinsic" function. Suitable as the callable of a "call"
expression.
Expand All @@ -883,10 +953,10 @@ class Intrinsic(object):
The *type* is the equivalent Numba signature of calling the intrinsic.
"""

def __init__(self, name, type, args):
def __init__(self, name, type, args, loc=None):
self.name = name
self.type = type
self.loc = None
self.loc = loc
self.args = args

def __repr__(self):
Expand All @@ -896,7 +966,7 @@ def __str__(self):
return self.name


class Scope(object):
class Scope(EqualityCheckMixin):
"""
Attributes
-----------
Expand Down Expand Up @@ -952,7 +1022,6 @@ def get_or_define(self, name, loc):
if name in self.redefined:
name = "%s.%d" % (name, self.redefined[name])

v = Var(scope=self, name=name, loc=loc)
if name not in self.localvars:
return self.define(name, loc)
else:
Expand Down Expand Up @@ -990,7 +1059,7 @@ def __repr__(self):
self.loc)


class Block(object):
class Block(EqualityCheckMixin):
"""A code block

"""
Expand Down Expand Up @@ -1094,7 +1163,7 @@ def __repr__(self):
return "<ir.Block at %s>" % (self.loc,)


class Loop(object):
class Loop(SlotEqualityCheckMixin):
"""Describes a loop-block
"""
__slots__ = "entry", "exit"
Expand All @@ -1108,7 +1177,7 @@ def __repr__(self):
return "Loop(entry=%s, exit=%s)" % args


class With(object):
class With(SlotEqualityCheckMixin):
"""Describes a with-block
"""
__slots__ = "entry", "exit"
Expand Down Expand Up @@ -1137,6 +1206,77 @@ def __init__(self, blocks, is_generator, func_id, loc,

self._reset_analysis_variables()

def equal_ir(self, other):
""" Checks that IR (and solely IR) is equal """
if type(self) is type(other):
return self.blocks == other.blocks
return False

def diff_str(self, other):
"""
Compute a human readable difference in the IR, returns a formatted
string ready for printing.
"""
msg = []
for label, block in self.blocks.items():
other_blk = other.blocks.get(label, None)
if other_blk is not None:
if block != other_blk:
msg.append(("Block %s differs" % label).center(80, '-'))
# see if the instructions are just a permutation
block_del = [x for x in block.body if isinstance(x, Del)]
oth_del = [x for x in other_blk.body if isinstance(x, Del)]
if block_del != oth_del:
# this is a common issue, dels are all present, but
# order shuffled.
if sorted(block_del) == sorted(oth_del):
msg.append(("Block %s contains the same dels but "
"their order is different") % label)
if len(block.body) > len(other_blk.body):
msg.append("This block contains more statements")
elif len(block.body) < len(other_blk.body):
msg.append("Other block contains more statements")

# find the indexes where they don't match
tmp = []
for idx, stmts in enumerate(zip(block.body,
other_blk.body)):
b_s, o_s = stmts
if b_s != o_s:
tmp.append(idx)

def get_pad(ablock, l):
pointer = '-> '
sp = len(pointer) * ' '
pad = []
nstmt = len(ablock)
for i in range(nstmt):
if i in tmp:
item = pointer
elif i >= l:
item = pointer
else:
item = sp
pad.append(item)
return pad

min_stmt_len = min(len(block.body), len(other_blk.body))

with StringIO() as buf:
it = [("self", block), ("other", other_blk)]
for name, _block in it:
buf.truncate(0)
_block.dump(file=buf)
stmts = buf.getvalue().splitlines()
pad = get_pad(_block.body, min_stmt_len)
title = ("%s: block %s" % (name, label))
msg.append(title.center(80, '-'))
msg.extend(["{0}{1}".format(a, b) for a, b in
zip(pad, stmts)])
if msg == []:
msg.append("IR is considered equivalent.")
return '\n'.join(msg)

def _reset_analysis_variables(self):
from . import consts

Expand Down Expand Up @@ -1243,7 +1383,7 @@ def dump_generator_info(self, file=None):


# A stub for undefined global reference
class UndefinedType(object):
class UndefinedType(EqualityCheckMixin):

_singleton = None

Expand All @@ -1259,4 +1399,5 @@ def __new__(cls):
def __repr__(self):
return "Undefined"


UNDEFINED = UndefinedType()