Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fused types: specialize each base type only once, also for compound type args #284

Merged
merged 1 commit into from
Jun 22, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 28 additions & 7 deletions Cython/Compiler/FusedNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def copy_def(self, env):
"""
fused_compound_types = PyrexTypes.unique(
[arg.type for arg in self.node.args if arg.type.is_fused])
permutations = PyrexTypes.get_all_specialized_permutations(fused_compound_types)
fused_types = self._get_fused_base_types(fused_compound_types)
permutations = PyrexTypes.get_all_specialized_permutations(fused_types)

self.fused_compound_types = fused_compound_types

Expand Down Expand Up @@ -179,6 +180,17 @@ def copy_cdef(self, env):
else:
self.py_func = orig_py_func

def _get_fused_base_types(self, fused_compound_types):
"""
Get a list of unique basic fused types, from a list of
(possibly) compound fused types.
"""
base_types = []
seen = set()
for fused_type in fused_compound_types:
fused_type.get_fused_types(result=base_types, seen=seen)
return base_types

def _specialize_function_args(self, args, fused_to_specific):
for arg in args:
if arg.type.is_fused:
Expand All @@ -203,9 +215,10 @@ def create_new_local_scope(self, node, env, f2s):
node.has_fused_arguments = False
self.nodes.append(node)

def specialize_copied_def(self, node, cname, py_entry, f2s, fused_types):
def specialize_copied_def(self, node, cname, py_entry, f2s, fused_compound_types):
"""Specialize the copy of a DefNode given the copied node,
the specialization cname and the original DefNode entry"""
fused_types = self._get_fused_base_types(fused_compound_types)
type_strings = [
PyrexTypes.specialization_signature_string(fused_type, f2s)
for fused_type in fused_types
Expand Down Expand Up @@ -519,13 +532,13 @@ def make_fused_cpdef(self, orig_py_func, env, is_def):
"""
from Cython.Compiler import TreeFragment, Code, MemoryView, UtilityCode

# { (arg_pos, FusedType) : specialized_type }
seen_fused_types = set()
fused_types = self._get_fused_base_types([
arg.type for arg in self.node.args if arg.type.is_fused])

context = {
'memviewslice_cname': MemoryView.memviewslice_cname,
'func_args': self.node.args,
'n_fused': len([arg for arg in self.node.args]),
'n_fused': len(fused_types),
'name': orig_py_func.entry.name,
}

Expand Down Expand Up @@ -557,9 +570,17 @@ def __pyx_fused_cpdef(signatures, args, kwargs, defaults):
fused_index = 0
default_idx = 0
all_buffer_types = set()
seen_fused_types = set()
for i, arg in enumerate(self.node.args):
if arg.type.is_fused and arg.type not in seen_fused_types:
seen_fused_types.add(arg.type)
if arg.type.is_fused:
arg_fused_types = arg.type.get_fused_types()
if len(arg_fused_types) > 1:
raise NotImplementedError("Determination of more than one fused base "
"type per argument is not implemented.")
fused_type = arg_fused_types[0]

if arg.type.is_fused and fused_type not in seen_fused_types:
seen_fused_types.add(fused_type)

context.update(
arg_tuple_idx=i,
Expand Down
44 changes: 43 additions & 1 deletion tests/run/fused_types.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ ctypedef fused_type1 *composed_t
other_t = cython.fused_type(int, double)
ctypedef double *p_double
ctypedef int *p_int

fused_type3 = cython.fused_type(int, double)

def test_pure():
"""
Expand Down Expand Up @@ -268,6 +268,12 @@ def get_array(itemsize, format):
result[6] = 6.0
return result

def get_intc_array():
result = array((10,), sizeof(int), 'i')
result[5] = 5.0
result[6] = 6.0
return result

def test_fused_memslice_dtype(cython.floating[:] array):
"""
Note: the np.ndarray dtype test is in numpy_test
Expand All @@ -285,6 +291,42 @@ def test_fused_memslice_dtype(cython.floating[:] array):
print cython.typeof(array), cython.typeof(otherarray), \
array[5], otherarray[6]

def test_fused_memslice_dtype_repeated(cython.floating[:] array1, cython.floating[:] array2):
"""
Note: the np.ndarray dtype test is in numpy_test

>>> import cython
>>> sorted(test_fused_memslice_dtype_repeated.__signatures__)
['double', 'float']

>>> test_fused_memslice_dtype_repeated(get_array(8, 'd'), get_array(8, 'd'))
double[:] double[:]
>>> test_fused_memslice_dtype_repeated(get_array(4, 'f'), get_array(4, 'f'))
float[:] float[:]
>>> test_fused_memslice_dtype_repeated(get_array(8, 'd'), get_array(4, 'f'))
Traceback (most recent call last):
ValueError: Buffer dtype mismatch, expected 'double' but got 'float'
"""
print cython.typeof(array1), cython.typeof(array2)

def test_fused_memslice_dtype_repeated_2(cython.floating[:] array1, cython.floating[:] array2,
fused_type3[:] array3):
"""
Note: the np.ndarray dtype test is in numpy_test

>>> import cython
>>> sorted(test_fused_memslice_dtype_repeated_2.__signatures__)
['double|double', 'double|int', 'float|double', 'float|int']

>>> test_fused_memslice_dtype_repeated_2(get_array(8, 'd'), get_array(8, 'd'), get_array(8, 'd'))
double[:] double[:] double[:]
>>> test_fused_memslice_dtype_repeated_2(get_array(8, 'd'), get_array(8, 'd'), get_intc_array())
double[:] double[:] int[:]
>>> test_fused_memslice_dtype_repeated_2(get_array(4, 'f'), get_array(4, 'f'), get_intc_array())
float[:] float[:] int[:]
"""
print cython.typeof(array1), cython.typeof(array2), cython.typeof(array3)

def test_cython_numeric(cython.numeric arg):
"""
Test to see whether complex numbers have their utility code declared
Expand Down