Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

literal_unroll fails with jitclass #5321

Open
anelsene opened this issue Feb 27, 2020 · 3 comments
Open

literal_unroll fails with jitclass #5321

anelsene opened this issue Feb 27, 2020 · 3 comments
Labels

Comments

@anelsene
Copy link

When iterating over a tuple of indices and using those indices to get an item from an array stored in a jitclass the code fails if the loop uses literal_unroll.

from numba import njit, literal_unroll, jitclass, types
import numpy as np

@jitclass(spec=[('data', types.float64[:])])
class COOArray:
    def __init__(self, data):
        self.data = data


@njit
def foo_unrolled(arr):
    idxs = (2, 3)
    a = 0
    for idx in literal_unroll(idxs):
        a += arr.data[idx]
    return a

@njit
def foo(arr):
    idxs = (2, 3)
    a = 0
    for idx in idxs:
        a += arr.data[idx]
    return a

arr = COOArray(np.ones(5))
foo(arr)
#passes
foo_unrolled(arr)
#fails

Error message:
Failed in nopython mode pipeline (step: handles literal_unroll)\n"Failed in literal_unroll_subpipeline mode pipeline (step: performs mixed container unroll)\n'args'"

Full error message:
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-15-466268e2e215> in <module>
      1 arr = COOArray(np.ones(5))
----> 2 foo(arr)

~\miniconda3\envs\litt-conda3\lib\site-packages\numba\dispatcher.py in _compile_for_args(self, *args, **kws)
    418                     e.patch_message('\n'.join((str(e).rstrip(), help_msg)))
    419             # ignore the FULL_TRACEBACKS config, this needs reporting!
--> 420             raise e
    421 
    422     def inspect_llvm(self, signature=None):

~\miniconda3\envs\litt-conda3\lib\site-packages\numba\dispatcher.py in _compile_for_args(self, *args, **kws)
    351                 argtypes.append(self.typeof_pyval(a))
    352         try:
--> 353             return self.compile(tuple(argtypes))
    354         except errors.ForceLiteralArg as e:
    355             # Received request for compiler re-entry with the list of arguments

~\miniconda3\envs\litt-conda3\lib\site-packages\numba\compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
     30         def _acquire_compile_lock(*args, **kwargs):
     31             with self:
---> 32                 return func(*args, **kwargs)
     33         return _acquire_compile_lock
     34 

~\miniconda3\envs\litt-conda3\lib\site-packages\numba\dispatcher.py in compile(self, sig)
    766             self._cache_misses[sig] += 1
    767             try:
--> 768                 cres = self._compiler.compile(args, return_type)
    769             except errors.ForceLiteralArg as e:
    770                 def folded(args, kws):

~\miniconda3\envs\litt-conda3\lib\site-packages\numba\dispatcher.py in compile(self, args, return_type)
     75 
     76     def compile(self, args, return_type):
---> 77         status, retval = self._compile_cached(args, return_type)
     78         if status:
     79             return retval

~\miniconda3\envs\litt-conda3\lib\site-packages\numba\dispatcher.py in _compile_cached(self, args, return_type)
     89 
     90         try:
---> 91             retval = self._compile_core(args, return_type)
     92         except errors.TypingError as e:
     93             self._failed_cache[key] = e

~\miniconda3\envs\litt-conda3\lib\site-packages\numba\dispatcher.py in _compile_core(self, args, return_type)
    107                                       args=args, return_type=return_type,
    108                                       flags=flags, locals=self.locals,
--> 109                                       pipeline_class=self.pipeline_class)
    110         # Check typing error if object mode is used
    111         if cres.typing_error is not None and not flags.enable_pyobject:

~\miniconda3\envs\litt-conda3\lib\site-packages\numba\compiler.py in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
    549     pipeline = pipeline_class(typingctx, targetctx, library,
    550                               args, return_type, flags, locals)
--> 551     return pipeline.compile_extra(func)
    552 
    553 

~\miniconda3\envs\litt-conda3\lib\site-packages\numba\compiler.py in compile_extra(self, func)
    329         self.state.lifted = ()
    330         self.state.lifted_from = None
--> 331         return self._compile_bytecode()
    332 
    333     def compile_ir(self, func_ir, lifted=(), lifted_from=None):

~\miniconda3\envs\litt-conda3\lib\site-packages\numba\compiler.py in _compile_bytecode(self)
    391         """
    392         assert self.state.func_ir is None
--> 393         return self._compile_core()
    394 
    395     def _compile_ir(self):

~\miniconda3\envs\litt-conda3\lib\site-packages\numba\compiler.py in _compile_core(self)
    371                 self.state.status.fail_reason = e
    372                 if is_final_pipeline:
--> 373                     raise e
    374         else:
    375             raise CompilerError("All available pipelines exhausted")

~\miniconda3\envs\litt-conda3\lib\site-packages\numba\compiler.py in _compile_core(self)
    362             res = None
    363             try:
--> 364                 pm.run(self.state)
    365                 if self.state.cr is not None:
    366                     break

~\miniconda3\envs\litt-conda3\lib\site-packages\numba\compiler_machinery.py in run(self, state)
    345                     (self.pipeline_name, pass_desc)
    346                 patched_exception = self._patch_error(msg, e)
--> 347                 raise patched_exception
    348 
    349     def dependency_analysis(self):

~\miniconda3\envs\litt-conda3\lib\site-packages\numba\compiler_machinery.py in run(self, state)
    336                 pass_inst = _pass_registry.get(pss).pass_inst
    337                 if isinstance(pass_inst, CompilerPass):
--> 338                     self._runPass(idx, pass_inst, state)
    339                 else:
    340                     raise BaseException("Legacy pass in use")

~\miniconda3\envs\litt-conda3\lib\site-packages\numba\compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
     30         def _acquire_compile_lock(*args, **kwargs):
     31             with self:
---> 32                 return func(*args, **kwargs)
     33         return _acquire_compile_lock
     34 

~\miniconda3\envs\litt-conda3\lib\site-packages\numba\compiler_machinery.py in _runPass(self, index, pss, internal_state)
    300             mutated |= check(pss.run_initialization, internal_state)
    301         with SimpleTimer() as pass_time:
--> 302             mutated |= check(pss.run_pass, internal_state)
    303         with SimpleTimer() as finalize_time:
    304             mutated |= check(pss.run_finalizer, internal_state)

~\miniconda3\envs\litt-conda3\lib\site-packages\numba\compiler_machinery.py in check(func, compiler_state)
    273 
    274         def check(func, compiler_state):
--> 275             mangled = func(compiler_state)
    276             if mangled not in (True, False):
    277                 msg = ("CompilerPass implementations should return True/False. "

~\miniconda3\envs\litt-conda3\lib\site-packages\numba\untyped_passes.py in run_pass(self, state)
   1440         pm.add_pass(MixedContainerUnroller, "performs mixed container unroll")
   1441         pm.finalize()
-> 1442         pm.run(state)
   1443         return True
   1444 

~\miniconda3\envs\litt-conda3\lib\site-packages\numba\compiler_machinery.py in run(self, state)
    345                     (self.pipeline_name, pass_desc)
    346                 patched_exception = self._patch_error(msg, e)
--> 347                 raise patched_exception
    348 
    349     def dependency_analysis(self):

~\miniconda3\envs\litt-conda3\lib\site-packages\numba\compiler_machinery.py in run(self, state)
    336                 pass_inst = _pass_registry.get(pss).pass_inst
    337                 if isinstance(pass_inst, CompilerPass):
--> 338                     self._runPass(idx, pass_inst, state)
    339                 else:
    340                     raise BaseException("Legacy pass in use")

~\miniconda3\envs\litt-conda3\lib\site-packages\numba\compiler_lock.py in _acquire_compile_lock(*args, **kwargs)
     30         def _acquire_compile_lock(*args, **kwargs):
     31             with self:
---> 32                 return func(*args, **kwargs)
     33         return _acquire_compile_lock
     34 

~\miniconda3\envs\litt-conda3\lib\site-packages\numba\compiler_machinery.py in _runPass(self, index, pss, internal_state)
    300             mutated |= check(pss.run_initialization, internal_state)
    301         with SimpleTimer() as pass_time:
--> 302             mutated |= check(pss.run_pass, internal_state)
    303         with SimpleTimer() as finalize_time:
    304             mutated |= check(pss.run_finalizer, internal_state)

~\miniconda3\envs\litt-conda3\lib\site-packages\numba\compiler_machinery.py in check(func, compiler_state)
    273 
    274         def check(func, compiler_state):
--> 275             mangled = func(compiler_state)
    276             if mangled not in (True, False):
    277                 msg = ("CompilerPass implementations should return True/False. "

~\miniconda3\envs\litt-conda3\lib\site-packages\numba\untyped_passes.py in run_pass(self, state)
   1217         # keep running the transform loop until it reports no more changes
   1218         while(True):
-> 1219             stat = self.apply_transform(state)
   1220             mutated |= stat
   1221             if not stat:

~\miniconda3\envs\litt-conda3\lib\site-packages\numba\untyped_passes.py in apply_transform(self, state)
   1084         # 2. Do the unroll, get a loop and process it!
   1085         lbl, info = literal_unroll_info.popitem()
-> 1086         self.unroll_loop(state, info)
   1087 
   1088         # 3. Rebuild the state, the IR has taken a hammering

~\miniconda3\envs\litt-conda3\lib\site-packages\numba\untyped_passes.py in unroll_loop(self, state, loop_info)
   1122                         if stmt.value.value != getitem_target:
   1123                             dfn = func_ir.get_definition(stmt.value.value)
-> 1124                             args = getattr(dfn, 'args', False)
   1125                             if not args:
   1126                                 continue

~\miniconda3\envs\litt-conda3\lib\site-packages\numba\ir.py in __getattr__(self, name)
    380         if name.startswith('_'):
    381             return Inst.__getattr__(self, name)
--> 382         return self._kws[name]
    383 
    384     def __setattr__(self, name, value):

KeyError: 'Failed in nopython mode pipeline (step: handles literal_unroll)\n"Failed in literal_unroll_subpipeline mode pipeline (step: performs mixed container unroll)\\n\'args\'"'
@stuartarchibald
Copy link
Contributor

Thanks for the report, I can reproduce. Pretty sure that the cause is that jitclass isn't considered in the implementation and the reason the above is failing is because the getitem resolves to a getattr on the jitclass and not the data itself. This should be fixable.

@luk-f-a
Copy link
Contributor

luk-f-a commented Feb 27, 2020

@stuartarchibald , this is the bug that I mentioned yesterday.

@stuartarchibald
Copy link
Contributor

@luk-f-a great, thanks for confirming.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants