Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Const type #3414

Merged
merged 34 commits into from Nov 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
cfbb1ea
Make explicit Literal classes for value type
sklam Oct 30, 2018
15bb60b
Adopt new Literal* types.
sklam Oct 30, 2018
3464b84
Rename RequireConstValue to RequireLiteralValue
sklam Oct 30, 2018
1c1cb34
Adjust to use Literal types
sklam Oct 31, 2018
969cdeb
Re-enable test
sklam Nov 6, 2018
c1d518f
Merge branch 'master' into enh/consttypeinfer
sklam Nov 6, 2018
9b40950
Fixes due to merge master
sklam Nov 6, 2018
819cb1c
Fix index unhashable error
sklam Nov 8, 2018
2b59a9c
Make LiteralInt a subclass of Integer
sklam Nov 8, 2018
18bef15
Make LiteralSlice a SliceType subclass
sklam Nov 8, 2018
ff5c865
Rename to literal_value.
sklam Nov 8, 2018
1701964
Cleanup literal type
sklam Nov 9, 2018
127cc48
Removing unneeded unliteral
sklam Nov 9, 2018
fb439eb
All pass
sklam Nov 9, 2018
dbb7b58
Clean up
sklam Nov 9, 2018
5242e78
Using typing context to determine the getitem signature properly
sklam Nov 9, 2018
b3c0ab8
Clean up
sklam Nov 9, 2018
2424c11
More clean up
sklam Nov 12, 2018
0477614
Move around Literal utils
sklam Nov 12, 2018
d0e6f4b
Merge branch 'master' into enh/consttypeinfer
sklam Nov 12, 2018
c115321
Fixes to add string literal support to the unicode support
sklam Nov 12, 2018
fd96472
Make non-allocating casting from literal to unicode
sklam Nov 13, 2018
42a07c8
Fix incorrect merge conflict resolution
sklam Nov 13, 2018
650ff71
Add more literal tests
sklam Nov 13, 2018
8cf41c1
Fix and test getattr on Literal types
sklam Nov 13, 2018
6a6ba06
Remove slow test marker
sklam Nov 13, 2018
3a1a786
Fix build
sklam Nov 13, 2018
a0c3835
Fixes according to comments
sklam Nov 14, 2018
7ac6803
More code comments
sklam Nov 14, 2018
4bc1125
Make typing template operating on numerical domains deterministic.
stuartarchibald Nov 16, 2018
ee5d1a8
Simple fixes from PR feedback.
stuartarchibald Nov 16, 2018
6b9ccc5
Add literal use to template rejection string context.
stuartarchibald Nov 16, 2018
7a98129
Add test for enumerate invalid typing
stuartarchibald Nov 16, 2018
68fa78e
Fix up failing test on python 2.7
stuartarchibald Nov 16, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions numba/_helpermod.c
Expand Up @@ -117,6 +117,9 @@ build_c_helpers_dict(void)
declmethod(rnd_init);
declmethod(poisson_ptrs);

/* Unicode string support */
declmethod(extract_unicode);

#define MATH_UNARY(F, R, A) declmethod(F);
#define MATH_BINARY(F, R, A, B) declmethod(F);
#include "mathnames.h"
Expand Down
5 changes: 3 additions & 2 deletions numba/_runtests.py
Expand Up @@ -68,12 +68,13 @@ def main(self, argv, kwds):
print("Running {} tests".format(len(tests)))
print('Flags', flags)
result = run_tests([prog] + flags + tests, **kwds)
# Save failed
if not self.last_failed:
# Update failed tests records
if not self.last_failed or len(tests) == len(failed_tests):
self.save_failed_tests(result, all_tests)
return result.wasSuccessful()

def save_failed_tests(self, result, all_tests):
print("Saving failed tests to {}".format(self.cache_filename))
cache = []
# Find failed tests
failed = set()
Expand Down
4 changes: 2 additions & 2 deletions numba/array_analysis.py
Expand Up @@ -110,7 +110,7 @@ def codegen(context, builder, sig, args):
assert(len(args) == 1) # it is a vararg tuple
tup = cgutils.unpack_tuple(builder, args[0])
tup_type = sig.args[0]
msg = sig.args[0][0].value
msg = sig.args[0][0].literal_value

def unpack_shapes(a, aty):
if isinstance(aty, types.ArrayCompatible):
Expand Down Expand Up @@ -1786,7 +1786,7 @@ def _make_assert_equiv(self, scope, loc, equiv_set, _args, names=None):

msg = "Sizes of {} do not match on {}".format(', '.join(arg_names), loc)
msg_val = ir.Const(msg, loc)
msg_typ = types.Const(msg)
msg_typ = types.StringLiteral(msg)
msg_var = ir.Var(scope, mk_unique_var("msg"), loc)
self.typemap[msg_var.name] = msg_typ
argtyps = tuple([msg_typ] + [self.typemap[x.name] for x in args])
Expand Down
5 changes: 3 additions & 2 deletions numba/cuda/printimpl.py
Expand Up @@ -26,6 +26,7 @@ def print_item(ty, context, builder, val):


@print_item.register(types.Integer)
@print_item.register(types.IntegerLiteral)
def int_print_impl(ty, context, builder, val):
if ty in types.unsigned_domain:
rawfmt = "%llu"
Expand All @@ -42,9 +43,9 @@ def real_print_impl(ty, context, builder, val):
lld = context.cast(builder, val, ty, types.float64)
return "%f", [lld]

@print_item.register(types.Const)
@print_item.register(types.StringLiteral)
def const_print_impl(ty, context, builder, sigval):
pyval = ty.value
pyval = ty.literal_value
assert isinstance(pyval, str) # Ensured by lowering
rawfmt = "%s"
val = context.insert_string_const_addrspace(builder, pyval)
Expand Down
3 changes: 2 additions & 1 deletion numba/datamodel/models.py
Expand Up @@ -281,7 +281,7 @@ def __init__(self, dmm, fe_type):
@register_default(types.PyObject)
@register_default(types.RawPointer)
@register_default(types.NoneType)
@register_default(types.Const)
@register_default(types.StringLiteral)
@register_default(types.EllipsisType)
@register_default(types.Function)
@register_default(types.Type)
Expand Down Expand Up @@ -328,6 +328,7 @@ def get_nrt_meminfo(self, builder, value):


@register_default(types.Integer)
@register_default(types.IntegerLiteral)
class IntegerModel(PrimitiveModel):
def __init__(self, dmm, fe_type):
be_type = ir.IntType(fe_type.bitwidth)
Expand Down
14 changes: 11 additions & 3 deletions numba/errors.py
Expand Up @@ -562,9 +562,17 @@ def __init__(self, exception):
self.old_exception = exception


class RequireConstValue(TypingError):
"""For signaling a function typing require constant value for some of
its arguments.
class RequireLiteralValue(TypingError):
"""
For signalling that a function's typing requires a constant value for
some of its arguments.
"""
pass


class LiteralTypingError(TypingError):
"""
Failure in typing a Literal type
"""
pass

Expand Down
4 changes: 3 additions & 1 deletion numba/ir.py
Expand Up @@ -653,9 +653,11 @@ def infer_constant(self):


class Const(object):
def __init__(self, value, loc):
def __init__(self, value, loc, use_literal_type=True):
self.value = value
self.loc = loc
# Note: need better way to tell if this is a literal or not.
self.use_literal_type = use_literal_type

def __repr__(self):
return 'const(%s, %s)' % (type(self.value).__name__, self.value)
Expand Down
40 changes: 30 additions & 10 deletions numba/lowering.py
Expand Up @@ -11,7 +11,8 @@

from . import (_dynfunc, cgutils, config, funcdesc, generators, ir, types,
typing, utils)
from .errors import LoweringError, new_error_context, TypingError
from .errors import (LoweringError, new_error_context, TypingError,
LiteralTypingError)
from .targets import removerefctpass
from .funcdesc import default_mangler
from . import debuginfo
Expand Down Expand Up @@ -547,18 +548,24 @@ def try_static_impl(tys, args):
except NotImplementedError:
return None

res = try_static_impl((types.Const(static_lhs), types.Const(static_rhs)),
(static_lhs, static_rhs))
res = try_static_impl(
(_lit_or_omitted(static_lhs), _lit_or_omitted(static_rhs)),
(static_lhs, static_rhs),
)
if res is not None:
return cast_result(res)

res = try_static_impl((types.Const(static_lhs), rty),
(static_lhs, rhs))
res = try_static_impl(
(_lit_or_omitted(static_lhs), rty),
(static_lhs, rhs),
)
if res is not None:
return cast_result(res)

res = try_static_impl((lty, types.Const(static_rhs)),
(lhs, static_rhs))
res = try_static_impl(
(lty, _lit_or_omitted(static_rhs)),
(lhs, static_rhs),
)
if res is not None:
return cast_result(res)

Expand Down Expand Up @@ -662,7 +669,7 @@ def lower_print(self, inst):
if i in inst.consts:
pyval = inst.consts[i]
if isinstance(pyval, str):
pos_tys[i] = types.Const(pyval)
pos_tys[i] = types.literal(pyval)

fixed_sig = typing.signature(sig.return_type, *pos_tys)
fixed_sig.pysig = sig.pysig
Expand Down Expand Up @@ -977,8 +984,11 @@ def lower_expr(self, resty, expr):
return res

elif expr.op == "static_getitem":
signature = typing.signature(resty, self.typeof(expr.value.name),
types.Const(expr.index))
signature = typing.signature(
resty,
self.typeof(expr.value.name),
_lit_or_omitted(expr.index),
)
try:
# Both get_function() and the returned implementation can
# raise NotImplementedError if the types aren't supported
Expand Down Expand Up @@ -1126,3 +1136,13 @@ def decref(self, typ, val):
return

self.context.nrt.decref(self.builder, typ, val)


def _lit_or_omitted(value):
"""Returns a Literal instance if the type of value is supported;
otherwise, return `Omitted(value)`.
"""
try:
return types.literal(value)
except LiteralTypingError:
return types.Omitted(value)
6 changes: 3 additions & 3 deletions numba/npyufunc/parfor.py
Expand Up @@ -291,7 +291,7 @@ def _lower_parfor_parallel(lowerer, parfor):

if config.DEBUG_ARRAY_OPT_RUNTIME:
res_print_str = "res_print"
strconsttyp = types.Const(res_print_str)
strconsttyp = types.StringLiteral(res_print_str)
lhs = ir.Var(scope, mk_unique_var("str_const"), loc)
assign_lhs = ir.Assign(value=ir.Const(value=res_print_str, loc=loc),
target=lhs, loc=loc)
Expand Down Expand Up @@ -330,7 +330,7 @@ def _lower_parfor_parallel(lowerer, parfor):

if config.DEBUG_ARRAY_OPT_RUNTIME:
res_print_str = "one_res_print"
strconsttyp = types.Const(res_print_str)
strconsttyp = types.StringLiteral(res_print_str)
lhs = ir.Var(scope, mk_unique_var("str_const"), loc)
assign_lhs = ir.Assign(value=ir.Const(value=res_print_str, loc=loc),
target=lhs, loc=loc)
Expand Down Expand Up @@ -898,7 +898,7 @@ def _create_gufunc_for_parfor_body(

# Make constant string
strval = "{} =".format(inst.target.name)
strconsttyp = types.Const(strval)
strconsttyp = types.StringLiteral(strval)

lhs = ir.Var(scope, mk_unique_var("str_const"), loc)
assign_lhs = ir.Assign(value=ir.Const(value=strval, loc=loc),
Expand Down
1 change: 1 addition & 0 deletions numba/numpy_support.py
Expand Up @@ -115,6 +115,7 @@ def as_dtype(nbtype):
Return a numpy dtype instance corresponding to the given Numba type.
NotImplementedError is if no correspondence is known.
"""
nbtype = types.unliteral(nbtype)
if isinstance(nbtype, (types.Complex, types.Integer, types.Float)):
return np.dtype(str(nbtype))
if nbtype is types.bool_:
Expand Down
10 changes: 7 additions & 3 deletions numba/parfor.py
Expand Up @@ -1432,8 +1432,7 @@ def _setitem_to_parfor(self, equiv_set, loc, target, index, value, shape=None):
getitem_call = ir.Expr.getitem(target, index, loc)
subarr_typ = typing.arraydecl.get_array_index_type( arr_typ, index_typ).result
self.typemap[subarr_var.name] = subarr_typ
self.calltypes[getitem_call] = signature(subarr_typ, arr_typ,
index_typ)
self.calltypes[getitem_call] = self._type_getitem((arr_typ, index_typ))
init_block.append(ir.Assign(getitem_call, subarr_var, loc))
target = subarr_var
else:
Expand Down Expand Up @@ -1503,6 +1502,10 @@ def _setitem_to_parfor(self, equiv_set, loc, target, index, value, shape=None):
parfor.dump()
return parfor

def _type_getitem(self, args):
fnty = operator.getitem
return self.typingctx.resolve_function_type(fnty, tuple(args), {})

def _is_supported_npycall(self, expr):
"""check if we support parfor translation for
this Numpy call.
Expand Down Expand Up @@ -3131,7 +3134,8 @@ def parfor_typeinfer(parfor, typeinferer):
# assigned to a tuple from individual indices
first_block = min(blocks.keys())
loc = blocks[first_block].loc
index_assigns = [ir.Assign(ir.Const(1, loc), v, loc) for v in index_vars]
# XXX
index_assigns = [ir.Assign(ir.Const(1, loc=loc, use_literal_type=False), v, loc) for v in index_vars]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The XXX comment is for the use_literal_type=False. I'll open an issue about it.

save_first_block_body = blocks[first_block].body
blocks[first_block].body = index_assigns + blocks[first_block].body
typeinferer.blocks = blocks
Expand Down
18 changes: 9 additions & 9 deletions numba/targets/arraymath.py
Expand Up @@ -25,7 +25,7 @@
from .linalg import ensure_blas

from numba.extending import intrinsic
from numba.errors import RequireConstValue, TypingError
from numba.errors import RequireLiteralValue, TypingError

def _check_blas():
# Checks if a BLAS is available so e.g. dot will work
Expand Down Expand Up @@ -86,10 +86,10 @@ def _gen_index_tuple(tyctx, shape_tuple, value, axis):
in the axis dimension and 'axis' is that dimension. For this to work,
axis has to be a const.
"""
if not isinstance(axis, types.Const):
raise RequireConstValue('axis argument must be a constant')
if not isinstance(axis, types.Literal):
raise RequireLiteralValue('axis argument must be a constant')
# Get the value of the axis constant.
axis_value = axis.value
axis_value = axis.literal_value
# The length of the indexing tuple to be output.
nd = len(shape_tuple)

Expand Down Expand Up @@ -167,9 +167,9 @@ def array_sum_impl(arr):
return impl_ret_borrowed(context, builder, sig.return_type, res)

@lower_builtin(np.sum, types.Array, types.intp)
@lower_builtin(np.sum, types.Array, types.Const)
@lower_builtin(np.sum, types.Array, types.IntegerLiteral)
@lower_builtin("array.sum", types.Array, types.intp)
@lower_builtin("array.sum", types.Array, types.Const)
@lower_builtin("array.sum", types.Array, types.IntegerLiteral)
def array_sum_axis(context, builder, sig, args):
"""
The third parameter to gen_index_tuple that generates the indexing
Expand All @@ -185,9 +185,9 @@ def array_sum_axis(context, builder, sig, args):
[ty_array, ty_axis] = sig.args
is_axis_const = False
const_axis_val = 0
if isinstance(ty_axis, types.Const):
if isinstance(ty_axis, types.Literal):
# this special-cases for constant axis
const_axis_val = ty_axis.value
const_axis_val = ty_axis.literal_value
# fix negative axis
if const_axis_val < 0:
const_axis_val = ty_array.ndim + const_axis_val
Expand Down Expand Up @@ -2021,7 +2021,7 @@ def searchsorted_inner(a, v):

@overload(np.searchsorted)
def searchsorted(a, v, side='left'):
side_val = getattr(side, 'value', side)
side_val = getattr(side, 'literal_value', side)
if side_val == 'left':
loop_impl = _searchsorted_left
elif side_val == 'right':
Expand Down
12 changes: 6 additions & 6 deletions numba/targets/arrayobj.py
Expand Up @@ -2182,7 +2182,7 @@ def array_record_getattr(context, builder, typ, value, attr):
res = rary._getvalue()
return impl_ret_borrowed(context, builder, resty, res)

@lower_builtin('static_getitem', types.Array, types.Const)
@lower_builtin('static_getitem', types.Array, types.StringLiteral)
def array_record_getitem(context, builder, sig, args):
index = args[1]
if not isinstance(index, str):
Expand Down Expand Up @@ -2251,15 +2251,15 @@ def record_setattr(context, builder, sig, args, attr):
context.pack_value(builder, elemty, val, dptr, align=align)


@lower_builtin('static_getitem', types.Record, types.Const)
@lower_builtin('static_getitem', types.Record, types.StringLiteral)
def record_getitem(context, builder, sig, args):
"""
Record.__getitem__ redirects to getattr()
"""
impl = context.get_getattr(sig.args[0], args[1])
return impl(context, builder, sig.args[0], args[0], args[1])

@lower_builtin('static_setitem', types.Record, types.Const, types.Any)
@lower_builtin('static_setitem', types.Record, types.StringLiteral, types.Any)
def record_setitem(context, builder, sig, args):
"""
Record.__setitem__ redirects to setattr()
Expand Down Expand Up @@ -4648,11 +4648,11 @@ def np_sort_impl(a):

return context.compile_internal(builder, np_sort_impl, sig, args)

@lower_builtin("array.argsort", types.Array, types.Const)
@lower_builtin(np.argsort, types.Array, types.Const)
@lower_builtin("array.argsort", types.Array, types.StringLiteral)
@lower_builtin(np.argsort, types.Array, types.StringLiteral)
def array_argsort(context, builder, sig, args):
arytype, kind = sig.args
sort_func = get_sort_func(kind=kind.value,
sort_func = get_sort_func(kind=kind.literal_value,
is_float=isinstance(arytype.dtype, types.Float),
is_argsort=True)

Expand Down
11 changes: 11 additions & 0 deletions numba/targets/base.py
Expand Up @@ -432,6 +432,16 @@ def insert_const_string(self, mod, string):
gv = self.insert_unique_const(mod, name, text)
return Constant.bitcast(gv, stringtype)

def insert_const_bytes(self, mod, bytes, name=None):
"""
Insert constant *byte* (a `bytes` object) into module *mod*.
"""
stringtype = GENERIC_POINTER
name = ".bytes.%s" % (name or hash(bytes))
text = cgutils.make_bytearray(bytes)
gv = self.insert_unique_const(mod, name, text)
return Constant.bitcast(gv, stringtype)

def insert_unique_const(self, mod, name, val):
"""
Insert a unique internal constant named *name*, with LLVM value
Expand Down Expand Up @@ -512,6 +522,7 @@ def get_function(self, fn, sig, _firstcall=True):
Return the implementation of function *fn* for signature *sig*.
The return value is a callable with the signature (builder, args).
"""
assert sig is not None
sig = sig.as_function()
if isinstance(fn, (types.Function, types.BoundFunction,
types.Dispatcher)):
Expand Down