diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index f3133c828c8..0a3d4ed3723 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -3322,7 +3322,6 @@ def analyse_as_pyobject(self, env, is_slice, getting, setting): if base_type in (list_type, tuple_type, dict_type): # do the None check explicitly (not in a helper) to allow optimising it away self.base = self.base.as_none_safe_node("'NoneType' object is not subscriptable") - self.wrap_in_nonecheck_node(env, getting) return self @@ -3366,7 +3365,7 @@ def analyse_as_c_function(self, env): if base_type.is_fused: self.parse_indexed_fused_cdef(env) else: - self.type_indices = self.parse_index_as_types(env) + self.type_indices = self.parse_index_as_template_parameters(env) self.index = None # FIXME: use a dedicated Node class instead of generic IndexNode if base_type.templates is None: error(self.pos, "Can only parameterize template functions.") @@ -3439,18 +3438,25 @@ def wrap_in_nonecheck_node(self, env, getting): return self.base = self.base.as_none_safe_node("'NoneType' object is not subscriptable") - def parse_index_as_types(self, env, required=True): + def parse_index_as_template_parameters(self, env, required=True): if isinstance(self.index, TupleNode): indices = self.index.args else: indices = [self.index] type_indices = [] for index in indices: - type_indices.append(index.analyse_as_type(env)) - if type_indices[-1] is None: - if required: - error(index.pos, "not parsable as a type") + type_index = index.analyse_as_type(env) + if type_index is None and self.base.type.templates: + # Handle the case that this is a template specialization + # that uses a non-type value. + if index.type is None: + index.analyse_types(env) + if index.type.is_int or index.type.is_enum or index.type.is_ptr or isinstance(index.type, PyrexTypes.CFuncType): + type_index = index + if required and type_index is None: + error(index.pos, "not parsable as a type") return None + type_indices.append(type_index) return type_indices def parse_indexed_fused_cdef(self, env): @@ -3474,7 +3480,7 @@ def parse_indexed_fused_cdef(self, env): elif isinstance(self.index, TupleNode): for arg in self.index.args: positions.append(arg.pos) - specific_types = self.parse_index_as_types(env, required=False) + specific_types = self.parse_index_as_template_parameters(env, required=False) if specific_types is None: self.index = self.index.analyse_types(env) @@ -3557,9 +3563,19 @@ def calculate_result_code(self): else: assert False, "unexpected base type in indexing: %s" % self.base.type elif self.base.type.is_cfunction: + template_indices = [] + for param in self.type_indices: + if isinstance(param, PyrexTypes.CType): + template_indices.append(param.empty_declaration_code()) + elif param.type.is_int and not param.type.is_const: + template_indices.append(param.value_as_c_integer_string()) + elif param.type.is_int or param.type.is_enum or param.type.is_ptr or isinstance(param.type, PyrexTypes.CFuncType): + template_indices.append(param.result()) + else: + error(self.pos, "Invalid or unsupported template parameter.") return "%s<%s>" % ( self.base.result(), - ",".join([param.empty_declaration_code() for param in self.type_indices])) + ",".join(template_indices)) elif self.base.type.is_ctuple: index = self.index.constant_result if index < 0: diff --git a/tests/run/cpp_non_type_template_parameters.srctree b/tests/run/cpp_non_type_template_parameters.srctree new file mode 100644 index 00000000000..258d6380a29 --- /dev/null +++ b/tests/run/cpp_non_type_template_parameters.srctree @@ -0,0 +1,47 @@ +# mode: run +# tag: cpp + +""" +PYTHON setup.py build_ext --inplace +PYTHON -c "from assignment_overload import test; test()" +""" + +######## setup.py ######## + +from distutils.core import setup +from Cython.Build import cythonize +setup(ext_modules=cythonize("*.pyx", language='c++')) + + +######## non_type_templates.hpp ######## + +#pragma once + +template +int myfunc1(int a){ + return a + N; +} + +template +int myfunc2(int a){ + return p(a); +} + +const int c = 2; + +######## non_type_params.pyx ######## + +cdef extern from "non_type_templates.hpp" nogil: + int myfunc1[T](int) + int myfunc2[T](int) + const int c + +cdef int p(int a) nogil: + return a + 2 + +def test(): + cdef int a = 0 + assert myfunc1[2](a) == 2 + assert myfunc1[c](a) == 2 + assert myfunc2[&p](a) == 2 + assert myfunc2[p](a) == 2