Skip to content

Commit

Permalink
Merge pull request #284 from pv/fused-fixes
Browse files Browse the repository at this point in the history
fused types: specialize each base type only once, also for compound type args
  • Loading branch information
scoder committed Jun 22, 2014
2 parents bc04ebe + 95d76de commit 4228dfd
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 8 deletions.
35 changes: 28 additions & 7 deletions Cython/Compiler/FusedNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def copy_def(self, env):
"""
fused_compound_types = PyrexTypes.unique(
[arg.type for arg in self.node.args if arg.type.is_fused])
permutations = PyrexTypes.get_all_specialized_permutations(fused_compound_types)
fused_types = self._get_fused_base_types(fused_compound_types)
permutations = PyrexTypes.get_all_specialized_permutations(fused_types)

self.fused_compound_types = fused_compound_types

Expand Down Expand Up @@ -183,6 +184,17 @@ def copy_cdef(self, env):
else:
self.py_func = orig_py_func

def _get_fused_base_types(self, fused_compound_types):
"""
Get a list of unique basic fused types, from a list of
(possibly) compound fused types.
"""
base_types = []
seen = set()
for fused_type in fused_compound_types:
fused_type.get_fused_types(result=base_types, seen=seen)
return base_types

def _specialize_function_args(self, args, fused_to_specific):
for arg in args:
if arg.type.is_fused:
Expand All @@ -207,9 +219,10 @@ def create_new_local_scope(self, node, env, f2s):
node.has_fused_arguments = False
self.nodes.append(node)

def specialize_copied_def(self, node, cname, py_entry, f2s, fused_types):
def specialize_copied_def(self, node, cname, py_entry, f2s, fused_compound_types):
"""Specialize the copy of a DefNode given the copied node,
the specialization cname and the original DefNode entry"""
fused_types = self._get_fused_base_types(fused_compound_types)
type_strings = [
PyrexTypes.specialization_signature_string(fused_type, f2s)
for fused_type in fused_types
Expand Down Expand Up @@ -522,13 +535,13 @@ def make_fused_cpdef(self, orig_py_func, env, is_def):
"""
from . import TreeFragment, Code, UtilityCode

# { (arg_pos, FusedType) : specialized_type }
seen_fused_types = set()
fused_types = self._get_fused_base_types([
arg.type for arg in self.node.args if arg.type.is_fused])

context = {
'memviewslice_cname': MemoryView.memviewslice_cname,
'func_args': self.node.args,
'n_fused': len([arg for arg in self.node.args]),
'n_fused': len(fused_types),
'name': orig_py_func.entry.name,
}

Expand Down Expand Up @@ -560,9 +573,17 @@ def __pyx_fused_cpdef(signatures, args, kwargs, defaults):
fused_index = 0
default_idx = 0
all_buffer_types = set()
seen_fused_types = set()
for i, arg in enumerate(self.node.args):
if arg.type.is_fused and arg.type not in seen_fused_types:
seen_fused_types.add(arg.type)
if arg.type.is_fused:
arg_fused_types = arg.type.get_fused_types()
if len(arg_fused_types) > 1:
raise NotImplementedError("Determination of more than one fused base "
"type per argument is not implemented.")
fused_type = arg_fused_types[0]

if arg.type.is_fused and fused_type not in seen_fused_types:
seen_fused_types.add(fused_type)

context.update(
arg_tuple_idx=i,
Expand Down
44 changes: 43 additions & 1 deletion tests/run/fused_types.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ ctypedef fused_type1 *composed_t
other_t = cython.fused_type(int, double)
ctypedef double *p_double
ctypedef int *p_int

fused_type3 = cython.fused_type(int, double)

def test_pure():
"""
Expand Down Expand Up @@ -268,6 +268,12 @@ def get_array(itemsize, format):
result[6] = 6.0
return result

def get_intc_array():
result = array((10,), sizeof(int), 'i')
result[5] = 5.0
result[6] = 6.0
return result

def test_fused_memslice_dtype(cython.floating[:] array):
"""
Note: the np.ndarray dtype test is in numpy_test
Expand All @@ -285,6 +291,42 @@ def test_fused_memslice_dtype(cython.floating[:] array):
print cython.typeof(array), cython.typeof(otherarray), \
array[5], otherarray[6]

def test_fused_memslice_dtype_repeated(cython.floating[:] array1, cython.floating[:] array2):
"""
Note: the np.ndarray dtype test is in numpy_test
>>> import cython
>>> sorted(test_fused_memslice_dtype_repeated.__signatures__)
['double', 'float']
>>> test_fused_memslice_dtype_repeated(get_array(8, 'd'), get_array(8, 'd'))
double[:] double[:]
>>> test_fused_memslice_dtype_repeated(get_array(4, 'f'), get_array(4, 'f'))
float[:] float[:]
>>> test_fused_memslice_dtype_repeated(get_array(8, 'd'), get_array(4, 'f'))
Traceback (most recent call last):
ValueError: Buffer dtype mismatch, expected 'double' but got 'float'
"""
print cython.typeof(array1), cython.typeof(array2)

def test_fused_memslice_dtype_repeated_2(cython.floating[:] array1, cython.floating[:] array2,
fused_type3[:] array3):
"""
Note: the np.ndarray dtype test is in numpy_test
>>> import cython
>>> sorted(test_fused_memslice_dtype_repeated_2.__signatures__)
['double|double', 'double|int', 'float|double', 'float|int']
>>> test_fused_memslice_dtype_repeated_2(get_array(8, 'd'), get_array(8, 'd'), get_array(8, 'd'))
double[:] double[:] double[:]
>>> test_fused_memslice_dtype_repeated_2(get_array(8, 'd'), get_array(8, 'd'), get_intc_array())
double[:] double[:] int[:]
>>> test_fused_memslice_dtype_repeated_2(get_array(4, 'f'), get_array(4, 'f'), get_intc_array())
float[:] float[:] int[:]
"""
print cython.typeof(array1), cython.typeof(array2), cython.typeof(array3)

def test_cython_numeric(cython.numeric arg):
"""
Test to see whether complex numbers have their utility code declared
Expand Down

0 comments on commit 4228dfd

Please sign in to comment.