Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Constant-folding pass needed to permit more "static"* expression rewrites. #2518

Open
bmerry opened this issue Aug 23, 2017 · 9 comments
Open

Comments

@bmerry
Copy link
Contributor

bmerry commented Aug 23, 2017

I'm trying to write some generic code which I want to compile with numba several times, providing different values for a constant each time (constant for each compilation, similar to an int template parameter in C++. In particular, numba only allows constant slices for tuples, and I'd like to use a different slice in each compilation.

I've tried several approaches so far.

  1. Encode an integer as a size-1 array with that many dimensions, so that the signature is different, triggering a new JIT compilation. Here's a contrived example:
#!/usr/bin/env python
import numba
import numpy as np

@numba.njit
def f(a, b):
    new_shape = a.shape[:b.ndim] + a.shape[b.ndim + 1:]
    return np.ones(new_shape)

# Should return array with shape (5, 4, 2)
print f(np.zeros((5, 4, 3, 2)), np.zeros((1, 1))).shape

With numba 0.34, numpy 1.13.1, Python 2.7.12, Ubuntu 16.04 I get this error:

Failed at nopython (nopython frontend)
Invalid usage of getitem with parameters ((int64 x 3), slice<a:b>)
 * parameterized
File "axis.py", line 8
[1] During: typing of intrinsic-call at axis.py (8)
  1. Compiling a closure, where the constant comes from the outer scope e.g.
#!/usr/bin/env python
import numba
import numpy as np

def make_f(d):
    @numba.njit
    def f(a):
        new_shape = a.shape[:d] + a.shape[d + 1:]
        return np.ones(new_shape)
    return f

print make_f(2)(np.zeros((5, 4, 3, 2))).shape

This also fails:

Failed at nopython (nopython frontend)
Invalid usage of getitem with parameters ((int64 x 4), slice<a:b>)
 * parameterized
File "axis2.py", line 8
[1] During: typing of intrinsic-call at axis2.py (8)
  1. Combining the above two approaches using generated_jit, where the constant n is encoded as an n-dimensional array and the dispatcher function extracts it and uses it in the returned closure. Also fails.
@bmerry
Copy link
Contributor Author

bmerry commented Aug 24, 2017

For anyone finding this bug from an internet search: this article may prove helpful for writing a workaround.

@rk-roman
Copy link

Having the same issue (python 2.7, numba==0.37.0):

import numpy as np
import numba as nb

@nb.njit()
def fixed():
    return np.zeros((3,), dtype=np.float64)

@nb.njit()
def dynamic():
    a = np.array([1,2,3])
    return np.zeros(a, dtype=np.float64)

print fixed()

Gives:

[0. 0. 0.]

while

print dynamic()

Fails with:

Failed at nopython (nopython frontend)
Invalid usage of Function(<built-in function zeros>) with parameters (array(int64, 1d, C), dtype=class(float64))
 * parameterized

any thoughts?

@stuartarchibald
Copy link
Contributor

@bmerry Thanks for the report. This can be done with @generate_jit to create a specialization at compile time:

from numba import generated_jit
import numpy as np

@generated_jit
def f(a, b):
    A = b.ndim
    B = A + 1
    def specialize(a, b):
        new_shape = a.shape[:A] + a.shape[B:]
        return np.ones(new_shape)
    return specialize

# Should return array with shape (5, 4, 2)
print(f(np.zeros((5, 4, 3, 2)), np.zeros((1, 1))).shape)

@stuartarchibald
Copy link
Contributor

@rk-roman I'm not sure what the intention is in your example? It seems like you want to dynamically create an array based on the values of another array, but that array is also static? Is this what you really want to do or in reality is the a array e.g. an argument to the function?

@bmerry
Copy link
Contributor Author

bmerry commented Jan 2, 2020

Thanks, I don't think I'd tried lifting the expressions from specialize out to the containing function. It would still be more user-friendly if captured variables could be used in constant expressions e.g. if one could modify your example to:

from numba import generated_jit
import numpy as np

@generated_jit
def f(a, b):
    A = b.ndim
    def specialize(a, b):
        new_shape = a.shape[:A] + a.shape[A + 1:]
        return np.ones(new_shape)
    return specialize

# Should return array with shape (5, 4, 2)
print(f(np.zeros((5, 4, 3, 2)), np.zeros((1, 1))).shape)

(i.e. replace B by A + 1 - currently fails). If that's not easy to support, it would be nice to have the limitations mentioned in the documentation of @generated_jit.

@stuartarchibald
Copy link
Contributor

@bmerry no problem. FWIW right now there's a pass that propagates constants from the typing domain like ndarray.ndim into the IR which helps a bit, but I've just noticed there's a bug in the pass ordering logic which means this doesn't always work when it should:

from numba import generated_jit
import numpy as np

@generated_jit(nopython=True)
def f(a, b):
    def specialize(a, b):
        return a.shape[:b.ndim] # <-- b.ndim get's rewritten to "2" but it's too late as the generic const-expr rewrite has already happened
    return specialize

print(f(np.zeros((5, 4, 3, 2)), np.zeros((1, 1))))

More generally I think what you are asking for is compile time constant folding, something which is not yet implemented. I think a prerequisite for this is constant propagation and a prerequisite for that is having Numba's IR in SSA form, SSA is in progress!

RE documentation, pull requests are welcomed.

Thanks!

@bmerry
Copy link
Contributor Author

bmerry commented Jan 2, 2020

@bmerry no problem. FWIW right now there's a pass that propagates constants from the typing domain like ndarray.ndim into the IR which helps a bit, but I've just noticed there's a bug in the pass ordering logic which means this doesn't always work when it should:

Ah, that makes more sense now. I was rather perplexed that 1+1 worked but A+1 did not since they're both expressions formed from constants. I'll edit the bug description to reflect the ordering bug.

Is the ordering bug also responsible for the example in my initial bug report not working?

More generally I think what you are asking for is compile time constant folding, something which is not yet implemented.

I was only really expecting constants like ndim to be usable to the same extent as constant constants like 1. I wasn't expecting constant-ness to be tracked through assignments, although obviously more features are always nice to have.

RE documentation, pull requests are welcomed.

In this case I don't have a solid understanding of what does and doesn't work, which makes it tricky to document. Also, if this is a quick bug to fix then there probably isn't too much point documenting the current behaviour.

@bmerry bmerry changed the title Allow more things as "constants" in tuple slicing Constant-folding pass misses constants like ndim from the type system Jan 2, 2020
@stuartarchibald stuartarchibald changed the title Constant-folding pass misses constants like ndim from the type system Constant-folding pass needed to permit more "static"* expression rewrites. Jan 2, 2020
@stuartarchibald
Copy link
Contributor

@bmerry no problem. FWIW right now there's a pass that propagates constants from the typing domain like ndarray.ndim into the IR which helps a bit, but I've just noticed there's a bug in the pass ordering logic which means this doesn't always work when it should:

Ah, that makes more sense now. I was rather perplexed that 1+1 worked but A+1 did not since they're both expressions formed from constants. I'll edit the bug description to reflect the ordering bug.

Is the ordering bug also responsible for the example in my initial bug report not working?

Yes and no. Yes in that a.shape[:b.ndim] should have worked but no in that a.shape[b.ndim + 1:] needs constant folding ahead of the const-expr rewrite happening, this (constant folding) is something that's not yet implemented.

More generally I think what you are asking for is compile time constant folding, something which is not yet implemented.

I was only really expecting constants like ndim to be usable to the same extent as constant constants like 1. I wasn't expecting constant-ness to be tracked through assignments, although obviously more features are always nice to have.

Constant-ness is tracked through assignment, and there's some degree of propagation via type inference. For example:

from numba import njit
import numpy as np

@njit
def foo(x):
    a = x.ndim # const
    b = a + 1 # not const
    c = a # tracked + propagated as it's a literal
    return a + b + c

foo(np.zeros((4,3,2,1)))
foo.inspect_types()

gives:

# File: issue2518_b.py
# --- LINE 4 --- 

@njit

# --- LINE 5 --- 

def foo(x):

    # --- LINE 6 --- 
    # label 0
    #   x = arg(0, name=x)  :: array(float64, 4d, C)
    #   del x
    #   $4load_attr.1 = const(int, 4)  :: Literal[int](4)
    #   a = $4load_attr.1  :: Literal[int](4)
    #   del $4load_attr.1

    a = x.ndim # const

    # --- LINE 7 --- 
    #   $const10.3 = const(int, 1)  :: Literal[int](1)
    #   $12binary_add.4 = a + $const10.3  :: int64
    #   del $const10.3
    #   b = $12binary_add.4  :: int64
    #   del $12binary_add.4

    b = a + 1 # not const

    # --- LINE 8 --- 
    #   c = a  :: Literal[int](4)

    c = a # tracked + propagated as it's a literal

    # --- LINE 9 --- 
    #   $24binary_add.8 = a + b  :: int64
    #   del b
    #   del a
    #   $28binary_add.10 = $24binary_add.8 + c  :: int64
    #   del c
    #   del $24binary_add.8
    #   $30return_value.11 = cast(value=$28binary_add.10)  :: int64
    #   del $28binary_add.10
    #   return $30return_value.11

    return a + b + c

RE documentation, pull requests are welcomed.

In this case I don't have a solid understanding of what does and doesn't work, which makes it tricky to document. Also, if this is a quick bug to fix then there probably isn't too much point documenting the current behaviour.

No problem. I think the use of ndim-as-a-const in a getitem is a question of swapping two lines, but there might be some knock-on effect I'm not anticipating, I've opened this #5015 to track and will try and fix it for 0.48. The constant folding is a load harder but I'll have a think about if there's something simple that could be done as a start.

@stuartarchibald stuartarchibald added this to the Numba 0.48 RC milestone Jan 2, 2020
@stuartarchibald
Copy link
Contributor

@bmerry #5024 it an attempt at making the a.shape[:b.ndim] case work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants