New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SSA pass causes invalid compilation (array write elided). #5623
Comments
This is the block from
|
With thanks to @99991 for posting https://gitter.im/numba/numba?at=5ea2c3601e3d5e20633f3c51, a smaller reproducer based on similar: import numpy as np
from numba import njit
def foo(pred, stack):
i = 0
c = 1
if pred is True:
stack[i] = c
i += 1
stack[i] = c
i += 1
stack = np.array([0, 666])
foo(True, stack)
print("Python:", stack)
stack = np.array([0, 666])
njit(boundscheck=True, debug=True)(foo)(True, stack)
print("Numba:", stack) IR post
note that an |
I think the problem is this: import numpy as np
from numba import njit
from numba.core import compiler
import logging
def foo(pred, stack):
i = 0
c = 1
if pred is True:
stack[i] = c
i += 1
stack[i] = c
i += 1
stack = np.array([0, 666])
foo(True, stack)
print("Python:", stack)
stack = np.array([0, 666])
njit(boundscheck=True, debug=True)(foo)(True, stack)
print("Numba:", stack)
ir = compiler.run_frontend(foo)
blk2 = [_ for _ in ir.blocks.keys()][1]
block = ir.blocks[blk2]
block.dump()
s1 = block.body[0]
s2 = block.body[4]
print("s1 %s" % s1)
print("s2 %s" % s2)
print(s1 == s2) gives this:
knock on effect is that here: Lines 432 to 441 in 1dabbf6
when SSA is looking for whether there's a local redefinition that needs inserting in here: Lines 355 to 365 in 1dabbf6
the cur_pos position of the statement stmt for stack[i] = const aliases such that 0 is returned (the position of the first). This leads to the updated versions of i not being used and results in nonsense.
|
And the above equality is because IR nodes are compared with their locations stripped. Lines 234 to 246 in 1dabbf6
|
Something like this might work: diff --git a/numba/core/ssa.py b/numba/core/ssa.py
index d38918f8c..375dbb7d7 100644
--- a/numba/core/ssa.py
+++ b/numba/core/ssa.py
@@ -435,8 +435,9 @@ class _FixSSAVars(_BaseHandler):
Assumptions:
- no two statements can point to the same object.
"""
+ stmtids = [id(x) for x in block.body]
try:
- return block.body.index(defstmt, 0, stop)
+ return stmtids.index(id(defstmt), 0, stop)
except ValueError:
return len(block.body)``` |
…tion. As title with test. Fixes numba#5623
Fixes #5623, SSA local def scan based on invalid equality assumption.
Fixes #5623, SSA local def scan based on invalid equality assumption.
From pymatting/pymatting#15, whilst long, this compiles incorrectly if the SSA passes are enabled. Problem is marked in the source by
PROBLEM IS HERE
comment.The text was updated successfully, but these errors were encountered: