Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue with SSA not minimal #5686

Merged
merged 5 commits into from
May 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
98 changes: 39 additions & 59 deletions numba/core/ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,9 @@ def _fix_ssa_vars(blocks, varname, defmap, cfg, df_plus):
states['varname'] = varname
states['defmap'] = defmap
states['phimap'] = phimap = defaultdict(list)
states['cfg'] = cfg
states['df+'] = df_plus
states['cfg'] = cfg = compute_cfg_from_blocks(blocks)
states['phi_locations'] = _compute_phi_locations(cfg, defmap)
newblocks = _run_block_rewrite(blocks, states, _FixSSAVars())
# check for unneeded phi nodes
_remove_unneeded_phis(phimap)
# insert phi nodes
for label, philist in phimap.items():
curblk = newblocks[label]
Expand All @@ -92,36 +90,6 @@ def _fix_ssa_vars(blocks, varname, defmap, cfg, df_plus):
return newblocks


def _remove_unneeded_phis(phimap):
"""Remove unneeded PHIs from the phimap
"""
all_phis = []
for philist in phimap.values():
all_phis.extend(philist)
unneeded_phis = set()
# Find unneeded PHIs.
for phi in all_phis:
ivs = phi.value.incoming_values
# It's unneeded if the incomings are either undefined or
# the PHI node target is itself
if all(iv is ir.UNDEFINED or iv == phi.target for iv in ivs):
unneeded_phis.add(phi)
# Fix up references to unneeded PHIs
for phi in all_phis:
for unneed in unneeded_phis:
if unneed is not phi:
# If the unneeded PHI is in the current phi's incoming values
if unneed.target in phi.value.incoming_values:
# Replace the unneeded PHI with an UNDEFINED
idx = phi.value.incoming_values.index(unneed.target)
phi.value.incoming_values[idx] = ir.UNDEFINED
# Remove unneeded phis
for philist in phimap.values():
for unneeded in unneeded_phis:
if unneeded in philist:
philist.remove(unneeded)


def _iterated_domfronts(cfg):
"""Compute the iterated dominance frontiers (DF+ in literatures).

Expand All @@ -140,6 +108,19 @@ def _iterated_domfronts(cfg):
return domfronts


def _compute_phi_locations(cfg, defmap):
# See basic algorithm in Ch 4.1 in Inria SSA Book
# Compute DF+
iterated_df = _iterated_domfronts(cfg)
# Compute DF+(defs)
# DF of all DFs is the union of all DFs
phi_locations = set()
for deflabel, defstmts in defmap.items():
if defstmts:
phi_locations |= iterated_df[deflabel]
return phi_locations


def _fresh_vars(blocks, varname):
"""Rewrite to put fresh variable names
"""
Expand Down Expand Up @@ -390,32 +371,31 @@ def _find_def_from_top(self, states, label, loc):
cfg = states['cfg']
defmap = states['defmap']
phimap = states['phimap']
domfronts = states['df+']
for deflabel, defstmt in defmap.items():
df = domfronts[deflabel]
if label in df:
scope = states['scope']
loc = states['block'].loc
# fresh variable
freshvar = scope.redefine(states['varname'], loc=loc)
# insert phi
phinode = ir.Assign(
target=freshvar,
value=ir.Expr.phi(loc=loc),
loc=loc,
phi_locations = states['phi_locations']

if label in phi_locations:
scope = states['scope']
Copy link
Member Author

@sklam sklam May 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

code in this if block is unchanged, just unindented

loc = states['block'].loc
# fresh variable
freshvar = scope.redefine(states['varname'], loc=loc)
# insert phi
phinode = ir.Assign(
target=freshvar,
value=ir.Expr.phi(loc=loc),
loc=loc,
)
_logger.debug("insert phi node %s at %s", phinode, label)
defmap[label].insert(0, phinode)
phimap[label].append(phinode)
# Find incoming values for the Phi node
for pred, _ in cfg.predecessors(label):
incoming_def = self._find_def_from_bottom(
states, pred, loc=loc,
)
_logger.debug("insert phi node %s at %s", phinode, label)
defmap[label].insert(0, phinode)
phimap[label].append(phinode)
# Find incoming values for the Phi node
for pred, _ in cfg.predecessors(label):
incoming_def = self._find_def_from_bottom(
states, pred, loc=loc,
)
_logger.debug("incoming_def %s", incoming_def)
phinode.value.incoming_values.append(incoming_def.target)
phinode.value.incoming_blocks.append(pred)
return phinode
_logger.debug("incoming_def %s", incoming_def)
phinode.value.incoming_values.append(incoming_def.target)
phinode.value.incoming_blocks.append(pred)
return phinode
else:
idom = cfg.immediate_dominators()[label]
if idom == label:
Expand Down
49 changes: 49 additions & 0 deletions numba/tests/test_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,3 +444,52 @@ def foo(pred, stack):

np.testing.assert_array_equal(python, expect)
np.testing.assert_array_equal(nb, expect)

def test_issue5678_non_minimal_phi(self):
# There should be only one phi for variable "i"

from numba.core.compiler import CompilerBase, DefaultPassBuilder
from numba.core.untyped_passes import (
ReconstructSSA, FunctionPass, register_pass,
)

phi_counter = []

@register_pass(mutates_CFG=False, analysis_only=True)
class CheckSSAMinimal(FunctionPass):
# A custom pass to count the number of phis

_name = self.__class__.__qualname__ + ".CheckSSAMinimal"

def __init__(self):
super().__init__(self)

def run_pass(self, state):
ct = 0
for blk in state.func_ir.blocks.values():
ct += len(list(blk.find_exprs('phi')))
phi_counter.append(ct)
return True

class CustomPipeline(CompilerBase):
def define_pipelines(self):
pm = DefaultPassBuilder.define_nopython_pipeline(self.state)
pm.add_pass_after(CheckSSAMinimal, ReconstructSSA)
pm.finalize()
return [pm]

@njit(pipeline_class=CustomPipeline)
def while_for(n, max_iter=1):
a = np.empty((n,n))
i = 0
while i <= max_iter:
for j in range(len(a)):
for k in range(len(a)):
a[j,k] = j + k
i += 1
return a

# Runs fine?
self.assertPreciseEqual(while_for(10), while_for.py_func(10))
# One phi?
self.assertEqual(phi_counter, [1])