Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer and propagate constant information earlier in the pipeline. #8021

Open
2 tasks done
stuartarchibald opened this issue May 3, 2022 · 0 comments
Open
2 tasks done

Comments

@stuartarchibald
Copy link
Contributor

Reporting a bug

  • I have tried using the latest released version of Numba (most recent is
    visible in the change log (https://github.com/numba/numba/blob/main/CHANGE_LOG).
  • I have included a self contained code sample to reproduce the problem.
    i.e. it's possible to run as 'python bug.py'.

I'm wondering if code like this could potentially be handled directly by Numba. The following fails at present as it can't unify agg in the agg.shape call. If the following sequence of passes were run:

  • PartialTypeInference (aggs_and_cols[0] inferred as 3d)
  • RewriteSemanticConstants (aggs_and_cols[0].ndim baked into IR as const(3)? maybe needs some work on the pass)
  • DeadBranchPrune (Remove branching as 3 == 3 is always True)

it might allow elimination of the control flow and permit compilation.

from numba import njit
import numpy as np

@njit
def foo(*aggs_and_cols):
    i = 0
    agg = aggs_and_cols[0] # This agg is 3D
    if aggs_and_cols[0].ndim == 3: # This type could be inferred, semantic rewritten and branch pruned.
        cat_index = aggs_and_cols[1][i]
        agg = agg[:, :, cat_index] # This agg is 2D

    xmax = agg.shape[1] - 1 # What's this agg? 2D or 3D?

args = (np.zeros((3, 4, 5)), (0, 0))

foo(*args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant