Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SSA pass causes invalid compilation (array write elided). #5623

Closed
stuartarchibald opened this issue Apr 24, 2020 · 6 comments · Fixed by #5627
Closed

SSA pass causes invalid compilation (array write elided). #5623

stuartarchibald opened this issue Apr 24, 2020 · 6 comments · Fixed by #5627
Labels
bug SSA Problem due to SSA (or lack of)

Comments

@stuartarchibald
Copy link
Contributor

From pymatting/pymatting#15, whilst long, this compiles incorrectly if the SSA passes are enabled. Problem is marked in the source by PROBLEM IS HERE comment.

import numpy as np
import llvmlite
import numba
import platform
from numba import njit

print("System:", platform.uname())
print("llvmlite version:", llvmlite.__version__)
print("numba version:", numba.__version__)
print("numpy version:", np.__version__)

# If this should crash, try without this line
@njit("i8(i8[:], i8[:], i8[:], i8[:], i8[:], f4[:, :, :], f4[:], f4[:, :], i8[:], i8)",
      boundscheck=True, debug=True)
def _make_tree(
    i0_inds,
    i1_inds,
    less_inds,
    more_inds,
    split_dims,
    bounds,
    split_values,
    points,
    indices,
    min_leaf_size,
):
    dimension = points.shape[1]
    # Expect log2(len(points) / min_leaf_size) depth, 1000 should be plenty
    stack = np.ones(1000, np.int64) * -99999999999999
    stack_size = 0
    n_nodes = 0
    # min_leaf_size <= leaf_node_size < max_leaf_size
    max_leaf_size = 2 * min_leaf_size

    # Push i0, i1, i_node
    stack[stack_size] = 0
    stack_size += 1
    stack[stack_size] = points.shape[0]
    stack_size += 1
    stack[stack_size] = n_nodes
    n_nodes += 1
    stack_size += 1

    # While there are more tree nodes to process recursively
    while stack_size > 0:
        print("start", stack[:20])
        stack_size -= 1
        i_node = stack[stack_size]
        stack_size -= 1
        i1 = stack[stack_size]
        stack_size -= 1
        i0 = stack[stack_size]

        lo = bounds[i_node, 0]
        hi = bounds[i_node, 1]

        for d in range(dimension):
            lo[d] = points[i0, d]
            hi[d] = points[i0, d]

        # Find lower and upper bounds of points for each dimension
        for i in range(i0 + 1, i1):
            for d in range(dimension):
                lo[d] = min(lo[d], points[i, d])
                hi[d] = max(hi[d], points[i, d])

        # Done if node is small
        if i1 - i0 <= max_leaf_size:
            i0_inds[i_node] = i0
            i1_inds[i_node] = i1
            less_inds[i_node] = -1
            more_inds[i_node] = -1
            split_dims[i_node] = -1
            split_values[i_node] = np.float32(0.0)
        else:
            # Split on largest dimension
            lengths = hi - lo
            split_dim = np.argmax(lengths)
            split_value = lo[split_dim] + np.float32(0.5) * lengths[split_dim]

            # Partition i0:i1 range into points where points[i, split_dim] < split_value
            i = i0
            j = i1 - 1

            while i < j:
                while i < j and points[i, split_dim] < split_value:
                    i += 1
                while i < j and points[j, split_dim] >= split_value:
                    j -= 1

                # Swap points
                if i < j:
                    for d in range(dimension):
                        temp = points[i, d]
                        points[i, d] = points[j, d]
                        points[j, d] = temp

                    temp_i_node = indices[i]
                    indices[i] = indices[j]
                    indices[j] = temp_i_node

            if points[i, split_dim] < split_value:
                i += 1

            i_split = i

            # Now it holds that:
            # for i in range(i0, i_split): assert(points[i, split_dim] < split_value)
            # for i in range(i_split, i1): assert(points[i, split_dim] >= split_value)

            # Ensure that each node has at least min_leaf_size children
            i_split = max(i0 + min_leaf_size, min(i1 - min_leaf_size, i_split))
            print(i_split)

            less = n_nodes
            n_nodes += 1
            more = n_nodes
            n_nodes += 1

            # push i0, i_split, less
            print("i0", i0, "to", stack_size)
            stack[stack_size] = i0
            stack_size += 1
            print("i_split", i_split, "to", stack_size)
            stack[stack_size] = i_split
            stack_size += 1
            print("less", less, "to", stack_size)
            stack[stack_size] = less
            stack_size += 1

            # push i_split, i1, more
            print("i_split", i_split, "to", stack_size)
            stack[stack_size] = i_split # <----- PROBLEM IS HERE, THIS DOESN'T GET WRITTEN
            #stack[stack_size] = np.sqrt(i_split) # DO THIS AND IT WORKS WITH SSA!
            stack_size += 1
            print("i1", i1, "to", stack_size)
            stack[stack_size] = i1
            stack_size += 1
            print("more", more, "to", stack_size)
            stack[stack_size] = more
            stack_size += 1

            i0_inds[i_node] = i0
            i1_inds[i_node] = i1
            less_inds[i_node] = less
            more_inds[i_node] = more
            split_dims[i_node] = split_dim
            split_values[i_node] = split_value
            print("end", stack_size, "stack", stack[:20])
    return n_nodes

def main():
    k = 20
    n_data = 100_000
    n_query = n_data
    dimension = 5
    np.random.seed(0)
    data_points = np.random.rand(n_data, dimension).astype(np.float32)
    min_leaf_size = 8

    n_data, dimension = data_points.shape

    max_nodes = 2 * ((n_data + min_leaf_size - 1) // min_leaf_size)

    i0_inds = np.empty(max_nodes, np.int64)
    i1_inds = np.empty(max_nodes, np.int64)
    less_inds = np.empty(max_nodes, np.int64)
    more_inds = np.empty(max_nodes, np.int64)
    split_dims = np.empty(max_nodes, np.int64)
    bounds = np.empty((max_nodes, 2, dimension), np.float32)
    split_values = np.empty(max_nodes, np.float32)
    shuffled_data_points = data_points.copy()
    shuffled_indices = np.arange(n_data).astype(np.int64)

    _make_tree(
        i0_inds,
        i1_inds,
        less_inds,
        more_inds,
        split_dims,
        bounds,
        split_values,
        shuffled_data_points,
        shuffled_indices,
        min_leaf_size,
    )

main()
@stuartarchibald stuartarchibald added bug SSA Problem due to SSA (or lack of) labels Apr 24, 2020
@stuartarchibald stuartarchibald added this to the 0.49.1 milestone Apr 24, 2020
@stuartarchibald
Copy link
Contributor Author

CC @99991
CC @sklam

@stuartarchibald
Copy link
Contributor Author

This is the block from i_split = i with the print statements elided. Problem statement noted by PROBLEM IS HERE (it's off the RHS of the text box window size).

label 668:
    i_split = i.10                           ['i.10', 'i_split']
    $672load_global.1 = global(max: <built-in function max>) ['$672load_global.1']
    $678binary_add.4 = i0 + min_leaf_size    ['$678binary_add.4', 'i0', 'min_leaf_size']
    $680load_global.5 = global(min: <built-in function min>) ['$680load_global.5']
    $686binary_subtract.8 = i1 - min_leaf_size ['$686binary_subtract.8', 'i1', 'min_leaf_size']
    $690call_function.10 = call $680load_global.5($686binary_subtract.8, i_split, func=$680load_global.5, args=[Var($686binary_subtract.8, pymatting1.py:112), Var(i_split, pymatting1.py:105)], kws=(), vararg=None) ['$680load_global.5', '$686binary_subtract.8', '$690call_function.10', 'i_split']
    $692call_function.11 = call $672load_global.1($678binary_add.4, $690call_function.10, func=$672load_global.1, args=[Var($678binary_add.4, pymatting1.py:112), Var($690call_function.10, pymatting1.py:112)], kws=(), vararg=None) ['$672load_global.1', '$678binary_add.4', '$690call_function.10', '$692call_function.11']
    i_split.1 = $692call_function.11         ['$692call_function.11', 'i_split.1']
    $696load_global.12 = global(print: <built-in function print>) ['$696load_global.12']
    print(i_split.1)                         ['i_split.1']
    $700call_function.14 = const(NoneType, None) ['$700call_function.14']
    less = n_nodes.1.3                       ['less', 'n_nodes.1.3']
    $const710.17 = const(int, 1)             ['$const710.17']
    $712inplace_add.18 = inplace_binop(fn=<built-in function iadd>, immutable_fn=<built-in function add>, lhs=n_nodes.1.3, rhs=$const710.17, static_lhs=Undefined, static_rhs=Undefined) ['$712inplace_add.18', '$const710.17', 'n_nodes.1.3']
    n_nodes.1.1 = $712inplace_add.18         ['$712inplace_add.18', 'n_nodes.1.1']
    more = n_nodes.1.1                       ['more', 'n_nodes.1.1']
    $const722.21 = const(int, 1)             ['$const722.21']
    $724inplace_add.22 = inplace_binop(fn=<built-in function iadd>, immutable_fn=<built-in function add>, lhs=n_nodes.1.1, rhs=$const722.21, static_lhs=Undefined, static_rhs=Undefined) ['$724inplace_add.22', '$const722.21', 'n_nodes.1.1']
    n_nodes.1.2 = $724inplace_add.22         ['$724inplace_add.22', 'n_nodes.1.2']
    n_nodes.1.8 = n_nodes.1.2                ['n_nodes.1.2', 'n_nodes.1.8']
    stack[stack_size.3.15] = i0              ['i0', 'stack', 'stack_size.3.15']
    $const738.27 = const(int, 1)             ['$const738.27']
    $740inplace_add.28 = inplace_binop(fn=<built-in function iadd>, immutable_fn=<built-in function add>, lhs=stack_size.3.15, rhs=$const738.27, static_lhs=Undefined, static_rhs=Undefined) ['$740inplace_add.28', '$const738.27', 'stack_size.3.15']
    stack_size.3.4 = $740inplace_add.28      ['$740inplace_add.28', 'stack_size.3.4']
    stack[stack_size.3.4] = i_split.1        ['i_split.1', 'stack', 'stack_size.3.4']
    $const754.33 = const(int, 1)             ['$const754.33']
    $756inplace_add.34 = inplace_binop(fn=<built-in function iadd>, immutable_fn=<built-in function add>, lhs=stack_size.3.4, rhs=$const754.33, static_lhs=Undefined, static_rhs=Undefined) ['$756inplace_add.34', '$const754.33', 'stack_size.3.4']
    stack_size.3.5 = $756inplace_add.34      ['$756inplace_add.34', 'stack_size.3.5']
    stack[stack_size.3.5] = less             ['less', 'stack', 'stack_size.3.5']
    $const770.39 = const(int, 1)             ['$const770.39']
    $772inplace_add.40 = inplace_binop(fn=<built-in function iadd>, immutable_fn=<built-in function add>, lhs=stack_size.3.5, rhs=$const770.39, static_lhs=Undefined, static_rhs=Undefined) ['$772inplace_add.40', '$const770.39', 'stack_size.3.5']
    stack_size.3.6 = $772inplace_add.40      ['$772inplace_add.40', 'stack_size.3.6']
    stack[stack_size.3.4] = i_split.1        ['i_split.1', 'stack', 'stack_size.3.4']  # <----- PROBLEM IS HERE, WHY IT IS REUSING stack_size3.4 and not stack_size3.6? 
    $const786.45 = const(int, 1)             ['$const786.45']
    $788inplace_add.46 = inplace_binop(fn=<built-in function iadd>, immutable_fn=<built-in function add>, lhs=stack_size.3.6, rhs=$const786.45, static_lhs=Undefined, static_rhs=Undefined) ['$788inplace_add.46', '$const786.45', 'stack_size.3.6']
    stack_size.3.7 = $788inplace_add.46      ['$788inplace_add.46', 'stack_size.3.7']
    stack[stack_size.3.7] = i1               ['i1', 'stack', 'stack_size.3.7']
    $const802.51 = const(int, 1)             ['$const802.51']
    $804inplace_add.52 = inplace_binop(fn=<built-in function iadd>, immutable_fn=<built-in function add>, lhs=stack_size.3.7, rhs=$const802.51, static_lhs=Undefined, static_rhs=Undefined) ['$804inplace_add.52', '$const802.51', 'stack_size.3.7']
    stack_size.3.8 = $804inplace_add.52      ['$804inplace_add.52', 'stack_size.3.8']
    stack[stack_size.3.8] = more             ['more', 'stack', 'stack_size.3.8']
    $const818.57 = const(int, 1)             ['$const818.57']
    $820inplace_add.58 = inplace_binop(fn=<built-in function iadd>, immutable_fn=<built-in function add>, lhs=stack_size.3.8, rhs=$const818.57, static_lhs=Undefined, static_rhs=Undefined) ['$820inplace_add.58', '$const818.57', 'stack_size.3.8']
    stack_size.3.9 = $820inplace_add.58      ['$820inplace_add.58', 'stack_size.3.9']
    stack_size.3.11 = stack_size.3.9         ['stack_size.3.11', 'stack_size.3.9']
    i0_inds[i_node] = i0                     ['i0', 'i0_inds', 'i_node']
    i1_inds[i_node] = i1                     ['i1', 'i1_inds', 'i_node']
    less_inds[i_node] = less                 ['i_node', 'less', 'less_inds']
    more_inds[i_node] = more                 ['i_node', 'more', 'more_inds']
    split_dims[i_node] = split_dim           ['i_node', 'split_dim', 'split_dims']
    split_values[i_node] = split_value       ['i_node', 'split_value', 'split_values']
    jump 875                                 []

@stuartarchibald
Copy link
Contributor Author

With thanks to @99991 for posting https://gitter.im/numba/numba?at=5ea2c3601e3d5e20633f3c51, a smaller reproducer based on similar:

import numpy as np
from numba import njit

def foo(pred, stack):
    i = 0
    c = 1

    if pred is True:
        stack[i] = c
        i += 1
        stack[i] = c
        i += 1

stack = np.array([0, 666])
foo(True, stack)
print("Python:", stack)

stack = np.array([0, 666])
njit(boundscheck=True, debug=True)(foo)(True, stack)
print("Numba:", stack)

IR post nopython_backend (with all optimisation/feature passes disabled).

-------------------------------------__main__.foo: nopython: AFTER nopython_backend-------------------------------------
label 0:
    pred = arg(0, name=pred)                 ['pred']
    stack = arg(1, name=stack)               ['stack']
    $const2.0 = const(int, 0)                ['$const2.0']
    i = $const2.0                            ['$const2.0', 'i']
    del $const2.0                            []
    $const6.1 = const(int, 1)                ['$const6.1']
    c = $const6.1                            ['$const6.1', 'c']
    del $const6.1                            []
    $const12.3 = const(bool, True)           ['$const12.3']
    $14compare_op.4 = pred is $const12.3     ['$14compare_op.4', '$const12.3', 'pred']
    del pred                                 []
    del $const12.3                           []
    branch $14compare_op.4, 18, 50           ['$14compare_op.4']
label 18:
    del $14compare_op.4                      []
    stack[i] = c                             ['c', 'i', 'stack']
    $const28.4 = const(int, 1)               ['$const28.4']
    $30inplace_add.5 = inplace_binop(fn=<built-in function iadd>, immutable_fn=<built-in function add>, lhs=i, rhs=$const28.4, static_lhs=Undefined, static_rhs=Undefined) ['$30inplace_add.5', '$const28.4', 'i']
    del $const28.4                           []
    i.1 = $30inplace_add.5                   ['$30inplace_add.5', 'i.1']
    del $30inplace_add.5                     []
    stack[i] = c                             ['c', 'i', 'stack']
    del stack                                []
    del i                                    []
    del c                                    []
    $const44.10 = const(int, 1)              ['$const44.10']
    $46inplace_add.11 = inplace_binop(fn=<built-in function iadd>, immutable_fn=<built-in function add>, lhs=i.1, rhs=$const44.10, static_lhs=Undefined, static_rhs=Undefined) ['$46inplace_add.11', '$const44.10', 'i.1']
    del i.1                                  []
    del $const44.10                          []
    i.2 = $46inplace_add.11                  ['$46inplace_add.11', 'i.2']
    del i.2                                  []
    del $46inplace_add.11                    []
    jump 50                                  []
label 50:
    del stack                                []
    del i                                    []
    del c                                    []
    del $14compare_op.4                      []
    $const50.0 = const(NoneType, None)       ['$const50.0']
    $52return_value.1 = cast(value=$const50.0) ['$52return_value.1', '$const50.0']
    del $const50.0                           []
    return $52return_value.1                 ['$52return_value.1']

note that an i.1 is created but not used in the subsequent index for stack[i] = c.

@stuartarchibald
Copy link
Contributor Author

stuartarchibald commented Apr 24, 2020

I think the problem is this:

import numpy as np
from numba import njit
from numba.core import compiler
import logging

def foo(pred, stack):
    i = 0
    c = 1

    if pred is True:
        stack[i] = c
        i += 1
        stack[i] = c
        i += 1

stack = np.array([0, 666])
foo(True, stack)
print("Python:", stack)

stack = np.array([0, 666])
njit(boundscheck=True, debug=True)(foo)(True, stack)
print("Numba:", stack)


ir = compiler.run_frontend(foo)
blk2 = [_ for _ in ir.blocks.keys()][1]
block = ir.blocks[blk2]
block.dump()
s1 = block.body[0]
s2 = block.body[4]
print("s1 %s" % s1)
print("s2 %s" % s2)
print(s1 == s2)

gives this:

Python: [1 1]
Numba: [1 666]
    stack[i] = c                             ['c', 'i', 'stack']
    $const28.4 = const(int, 1)               ['$const28.4']
    $30inplace_add.5 = inplace_binop(fn=<built-in function iadd>, immutable_fn=<built-in function add>, lhs=i, rhs=$const28.4, static_lhs=Undefined, static_rhs=Undefined) ['$30inplace_add.5', '$const28.4', 'i']
    i = $30inplace_add.5                     ['$30inplace_add.5', 'i']
    stack[i] = c                             ['c', 'i', 'stack']
    $const44.10 = const(int, 1)              ['$const44.10']
    $46inplace_add.11 = inplace_binop(fn=<built-in function iadd>, immutable_fn=<built-in function add>, lhs=i, rhs=$const44.10, static_lhs=Undefined, static_rhs=Undefined) ['$46inplace_add.11', '$const44.10', 'i']
    i = $46inplace_add.11                    ['$46inplace_add.11', 'i']
    jump 50                                  []
s1 stack[i] = c
s2 stack[i] = c
True

knock on effect is that here:

numba/numba/core/ssa.py

Lines 432 to 441 in 1dabbf6

def _stmt_index(self, defstmt, block, stop=-1):
"""Find the postitional index of the statement at ``block``.
Assumptions:
- no two statements can point to the same object.
"""
try:
return block.body.index(defstmt, 0, stop)
except ValueError:
return len(block.body)

when SSA is looking for whether there's a local redefinition that needs inserting in here:

numba/numba/core/ssa.py

Lines 355 to 365 in 1dabbf6

cur_pos = self._stmt_index(stmt, block)
for defstmt in reversed(local_defs):
# Phi nodes have no index
def_pos = self._stmt_index(defstmt, block, stop=cur_pos)
if def_pos < cur_pos:
selected_def = defstmt
break
# Maybe it's a PHI
elif defstmt in local_phis:
selected_def = local_phis[-1]
break

the cur_pos position of the statement stmt for stack[i] = const aliases such that 0 is returned (the position of the first). This leads to the updated versions of i not being used and results in nonsense.

@stuartarchibald
Copy link
Contributor Author

And the above equality is because IR nodes are compared with their locations stripped.

numba/numba/core/ir.py

Lines 234 to 246 in 1dabbf6

def __eq__(self, other):
if type(self) is type(other):
def fixup(adict):
bad = ('loc', 'scope')
d = dict(adict)
for x in bad:
d.pop(x, None)
return d
d1 = fixup(self.__dict__)
d2 = fixup(other.__dict__)
if d1 == d2:
return True
return False

@stuartarchibald
Copy link
Contributor Author

Something like this might work:

diff --git a/numba/core/ssa.py b/numba/core/ssa.py
index d38918f8c..375dbb7d7 100644
--- a/numba/core/ssa.py
+++ b/numba/core/ssa.py
@@ -435,8 +435,9 @@ class _FixSSAVars(_BaseHandler):
         Assumptions:
         - no two statements can point to the same object.
         """
+        stmtids = [id(x) for x in block.body]
         try:
-            return block.body.index(defstmt, 0, stop)
+            return stmtids.index(id(defstmt), 0, stop)
         except ValueError:
             return len(block.body)```

stuartarchibald added a commit to stuartarchibald/numba that referenced this issue Apr 24, 2020
sklam added a commit that referenced this issue Apr 27, 2020
Fixes #5623, SSA local def scan based on invalid equality assumption.
sklam added a commit that referenced this issue Apr 30, 2020
Fixes #5623, SSA local def scan based on invalid equality assumption.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug SSA Problem due to SSA (or lack of)
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant