Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for overload_method #3704

Merged
merged 8 commits into from
Jan 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 0 additions & 7 deletions numba/targets/callconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,6 @@ def decode_arguments(self, builder, argtypes, func):
arginfo = self._get_arg_packer(argtypes)
return arginfo.from_arguments(builder, raw_args)

def _fix_argtypes(self, argtypes):
"""
Fix argument types, removing any omitted arguments.
"""
return tuple(ty for ty in argtypes
if not isinstance(ty, types.Omitted))

def _get_arg_packer(self, argtypes):
"""
Get an argument packer for the given argument types.
Expand Down
39 changes: 38 additions & 1 deletion numba/tests/test_extending.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import numpy as np

from numba import unittest_support as unittest
from numba import jit, types, errors, typing, compiler
from numba import njit, jit, types, errors, typing, compiler
from numba.targets.registry import cpu_target
from numba.compiler import compile_isolated
from .support import (TestCase, captured_stdout, tag, temp_directory,
Expand Down Expand Up @@ -866,6 +866,43 @@ def impl(a, **kws):
self.assertIn("use of VAR_KEYWORD (e.g. **kwargs) is unsupported", msg)
self.assertIn("offending argument name is '**kws'", msg)

def test_overload_method_kwargs(self):
# Issue #3489
@overload_method(types.Array, 'foo')
def fooimpl(arr, a_kwarg=10):
def impl(arr, a_kwarg=10):
return a_kwarg
return impl

@njit
def bar(A):
return A.foo(), A.foo(20), A.foo(a_kwarg=30)

Z = np.arange(5)

self.assertEqual(bar(Z), (10, 20, 30))

def test_overload_method_literal_unpack(self):
# Issue #3683
@overload_method(types.Array, 'litfoo')
def litfoo(arr, val):
# Must be an integer
if isinstance(val, types.Integer):
# Must not be literal
if not isinstance(val, types.Literal):
def impl(arr, val):
return val
return impl

@njit
def bar(A):
return A.litfoo(0xcafe)

A = np.zeros(1)
bar(A)
self.assertEqual(bar(A), 0xcafe)


def _assert_cache_stats(cfunc, expect_hit, expect_misses):
hit = cfunc._cache_hits[cfunc.signatures[0]]
if hit != expect_hit:
Expand Down
9 changes: 8 additions & 1 deletion numba/typing/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ def as_method(self):
return self
sig = signature(self.return_type, *self.args[1:],
recvr=self.args[0])

# Adjust the python signature
params = list(self.pysig.parameters.values())[1:]
sig.pysig = utils.pySignature(
parameters=params,
return_annotation=self.pysig.return_annotation,
)
return sig

def as_function(self):
Expand Down Expand Up @@ -568,7 +575,7 @@ def _get_dispatcher(cls, context, typ, attr, sig_args, sig_kws):
Get the compiled dispatcher implementing the attribute for
the given formal signature.
"""
cache_key = context, typ, attr
cache_key = context, typ, attr, tuple(sig_args), tuple(sig_kws.items())
try:
disp = cls._impl_cache[cache_key]
except KeyError:
Expand Down
2 changes: 2 additions & 0 deletions numba/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,12 @@ def erase_traceback(exc_value):

try:
from inspect import signature as pysignature
from inspect import Signature as pySignature
from inspect import Parameter as pyParameter
except ImportError:
try:
from funcsigs import signature as pysignature
from funcsigs import Signature as pySignature
from funcsigs import Parameter as pyParameter
except ImportError:
raise ImportError("please install the 'funcsigs' package "
Expand Down