Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse code

Better fused buffer runtime dispatch + dispatch restructuring + PyxCo…

…deWriter
  • Loading branch information...
commit 39c966d2e2a7ec860f1c57e509763e85bc6506c4 1 parent a3230e4
Mark Florisson markflorisson88 authored
60 Cython/Compiler/Code.py
@@ -14,6 +14,7 @@
14 14 import sys
15 15 from string import Template
16 16 import operator
  17 +import textwrap
17 18
18 19 import Naming
19 20 import Options
@@ -376,7 +377,7 @@ def put_code(self, output):
376 377 self.cleanup(writer, output.module_pos)
377 378
378 379
379   -def sub_tempita(s, context, file, name):
  380 +def sub_tempita(s, context, file=None, name=None):
380 381 "Run tempita on string s with given context."
381 382 if not s:
382 383 return None
@@ -1940,6 +1941,63 @@ def indent(self):
1940 1941 def dedent(self):
1941 1942 self.level -= 1
1942 1943
  1944 +class PyxCodeWriter(object):
  1945 + """
  1946 + Can be used for writing out some Cython code.
  1947 + """
  1948 +
  1949 + def __init__(self, buffer=None, indent_level=0, context=None):
  1950 + self.buffer = buffer or StringIOTree()
  1951 + self.level = indent_level
  1952 + self.context = context
  1953 + self.encoding = 'ascii'
  1954 +
  1955 + def indent(self, levels=1):
  1956 + self.level += levels
  1957 +
  1958 + def dedent(self, levels=1):
  1959 + self.level -= levels
  1960 +
  1961 + def indenter(self, line):
  1962 + """
  1963 + with pyx_code.indenter("for i in range(10):"):
  1964 + pyx_code.putln("print i")
  1965 + """
  1966 + self.putln(line)
  1967 + return self
  1968 +
  1969 + def getvalue(self):
  1970 + return unicode(self.buffer.getvalue(), self.encoding)
  1971 +
  1972 + def putln(self, line, context=None):
  1973 + context = context or self.context
  1974 + if context:
  1975 + line = sub_tempita(line, context)
  1976 + self._putln(line)
  1977 +
  1978 + def _putln(self, line):
  1979 + self.buffer.write("%s%s\n" % (self.level * " ", line))
  1980 +
  1981 + def put_chunk(self, chunk, context=None):
  1982 + context = context or self.context
  1983 + if context:
  1984 + chunk = sub_tempita(chunk, context)
  1985 +
  1986 + chunk = textwrap.dedent(chunk)
  1987 + for line in chunk.splitlines():
  1988 + self._putln(line)
  1989 +
  1990 + def insertion_point(self):
  1991 + return PyxCodeWriter(self.buffer.insertion_point(), self.level,
  1992 + self.context)
  1993 +
  1994 + def named_insertion_point(self, name):
  1995 + setattr(self, name, self.insertion_point())
  1996 +
  1997 + __enter__ = indent
  1998 +
  1999 + def __exit__(self, exc_value, exc_type, exc_tb):
  2000 + self.dedent()
1943 2001
1944 2002 class ClosureTempAllocator(object):
1945 2003 def __init__(self, klass):
546 Cython/Compiler/Nodes.py
@@ -2277,24 +2277,11 @@ def __init__(self, node, env):
2277 2277 assert n.type.op_arg_struct
2278 2278
2279 2279 node.entry.fused_cfunction = self
2280   -
2281   - if self.py_func:
2282   - self.py_func.entry.fused_cfunction = self
2283   - for node in self.nodes:
2284   - if is_def:
2285   - node.fused_py_func = self.py_func
2286   - else:
2287   - node.py_func.fused_py_func = self.py_func
2288   - node.entry.as_variable = self.py_func.entry
2289 2280 # Copy the nodes as AnalyseDeclarationsTransform will prepend
2290 2281 # self.py_func to self.stats, as we only want specialized
2291 2282 # CFuncDefNodes in self.nodes
2292 2283 self.stats = self.nodes[:]
2293 2284
2294   - if self.py_func:
2295   - self.synthesize_defnodes()
2296   - self.stats.append(self.__signatures__)
2297   -
2298 2285 def copy_def(self, env):
2299 2286 """
2300 2287 Create a copy of the original def or lambda function for specialized
@@ -2326,6 +2313,7 @@ def copy_def(self, env):
2326 2313 if not self.replace_fused_typechecks(copied_node):
2327 2314 break
2328 2315
  2316 + self.orig_py_func = self.node
2329 2317 self.py_func = self.make_fused_cpdef(self.node, env, is_def=True)
2330 2318
2331 2319 def copy_cdef(self, env):
@@ -2342,7 +2330,7 @@ def copy_cdef(self, env):
2342 2330 env.cfunc_entries.remove(self.node.entry)
2343 2331
2344 2332 # Prevent copying of the python function
2345   - orig_py_func = self.node.py_func
  2333 + self.orig_py_func = orig_py_func = self.node.py_func
2346 2334 self.node.py_func = None
2347 2335 if orig_py_func:
2348 2336 env.pyfunc_entries.remove(orig_py_func.entry)
@@ -2459,179 +2447,437 @@ def replace_fused_typechecks(self, copied_node):
2459 2447
2460 2448 return True
2461 2449
2462   - def make_fused_cpdef(self, orig_py_func, env, is_def):
  2450 + def _fused_instance_checks(self, normal_types, pyx_code, env):
2463 2451 """
2464   - This creates the function that is indexable from Python and does
2465   - runtime dispatch based on the argument types. The function gets the
2466   - arg tuple and kwargs dict (or None) as arugments from the Binding
2467   - Fused Function's tp_call.
  2452 + Genereate Cython code for instance checks, matching an object to
  2453 + specialized types.
2468 2454 """
2469   - from Cython.Compiler import TreeFragment
2470   - from Cython.Compiler import ParseTreeTransforms
  2455 + if_ = 'if'
  2456 + for specialized_type in normal_types:
  2457 + # all_numeric = all_numeric and specialized_type.is_numeric
  2458 + py_type_name = specialized_type.py_type_name()
  2459 +
  2460 + # in the case of long, unicode or bytes we need to instance
  2461 + # check for long_, unicode_, bytes_ (long = long is no longer
  2462 + # valid code with control flow analysis)
  2463 + specialized_check_name = py_type_name
  2464 + if py_type_name in ('long', 'unicode', 'bytes'):
  2465 + specialized_check_name += '_'
  2466 +
  2467 + specialized_type_name = specialized_type.specialization_string
  2468 + pyx_code.context.update(locals())
  2469 + pyx_code.put_chunk(
  2470 + u"""
  2471 + {{if_}} isinstance(arg, {{specialized_check_name}}):
  2472 + dest_sig[{{dest_sig_idx}}] = '{{specialized_type_name}}'
  2473 + """)
  2474 + if_ = 'elif'
  2475 +
  2476 + if not normal_types:
  2477 + # we need an 'if' to match the following 'else'
  2478 + pyx_code.putln("if 0: pass")
  2479 +
  2480 + def _dtype_name(self, dtype):
  2481 + if dtype.is_typedef:
  2482 + return '___pyx_%s' % dtype
  2483 + return str(dtype).replace(' ', '_')
  2484 +
  2485 + def _dtype_type(self, dtype):
  2486 + if dtype.is_typedef:
  2487 + return self._dtype_name(dtype)
  2488 + return str(dtype)
  2489 +
  2490 + def _sizeof_dtype(self, dtype):
  2491 + if dtype.is_pyobject:
  2492 + return 'sizeof(void *)'
  2493 + else:
  2494 + return "sizeof(%s)" % self._dtype_type(dtype)
2471 2495
2472   - # { (arg_pos, FusedType) : specialized_type }
2473   - seen_fused_types = set()
  2496 + def _buffer_check_numpy_dtype_setup_cases(self, pyx_code):
  2497 + "Setup some common cases to match dtypes against specializations"
  2498 + with pyx_code.indenter("if dtype.kind in ('i', 'u'):"):
  2499 + pyx_code.putln("pass")
  2500 + pyx_code.named_insertion_point("dtype_int")
2474 2501
2475   - # list of statements that do the instance checks
2476   - body_stmts = []
  2502 + with pyx_code.indenter("elif dtype.kind == 'f':"):
  2503 + pyx_code.putln("pass")
  2504 + pyx_code.named_insertion_point("dtype_float")
2477 2505
2478   - args = self.node.args
2479   - for i, arg in enumerate(args):
2480   - arg_type = arg.type
2481   - if arg_type.is_fused and arg_type not in seen_fused_types:
2482   - seen_fused_types.add(arg_type)
  2506 + with pyx_code.indenter("elif dtype.kind == 'c':"):
  2507 + pyx_code.putln("pass")
  2508 + pyx_code.named_insertion_point("dtype_complex")
2483 2509
2484   - specialized_types = PyrexTypes.get_specialized_types(arg_type)
2485   - # Prefer long over int, etc
2486   - # specialized_types.sort()
  2510 + with pyx_code.indenter("elif dtype.kind == 'O':"):
  2511 + pyx_code.putln("pass")
  2512 + pyx_code.named_insertion_point("dtype_object")
2487 2513
2488   - seen_py_type_names = set()
2489   - first_check = True
  2514 + match = "dest_sig[{{dest_sig_idx}}] = '{{specialized_type_name}}'"
  2515 + no_match = "dest_sig[{{dest_sig_idx}}] = None"
  2516 + def _buffer_check_numpy_dtype(self, pyx_code, specialized_buffer_types):
  2517 + """
  2518 + Match a numpy dtype object to the individual specializations.
  2519 + """
  2520 + self._buffer_check_numpy_dtype_setup_cases(pyx_code)
  2521 +
  2522 + for specialized_type in specialized_buffer_types:
  2523 + dtype = specialized_type.dtype
  2524 + pyx_code.context.update(
  2525 + itemsize_match=self._sizeof_dtype(dtype) + " == itemsize",
  2526 + signed_match="not (%s_is_signed ^ dtype_signed)" % self._dtype_name(dtype),
  2527 + dtype=dtype,
  2528 + specialized_type_name=specialized_type.specialization_string)
  2529 +
  2530 + dtypes = [
  2531 + (dtype.is_int, pyx_code.dtype_int),
  2532 + (dtype.is_float, pyx_code.dtype_float),
  2533 + (dtype.is_complex, pyx_code.dtype_complex)
  2534 + ]
  2535 +
  2536 + for dtype_category, codewriter in dtypes:
  2537 + if dtype_category:
  2538 + cond = '{{itemsize_match}}'
  2539 + if dtype.is_int:
  2540 + cond += ' and {{signed_match}}'
  2541 +
  2542 + with codewriter.indenter("if %s:" % cond):
  2543 + # codewriter.putln("print 'buffer match found based on numpy dtype'")
  2544 + codewriter.putln(self.match)
  2545 + codewriter.putln("break")
  2546 +
  2547 + def _buffer_parse_format_string_check(self, pyx_code, decl_code,
  2548 + specialized_type, env):
  2549 + """
  2550 + For each specialized type, try to coerce the object to a memoryview
  2551 + slice of that type. This means obtaining a buffer and parsing the
  2552 + format string.
  2553 + TODO: separate buffer acquisition from format parsing
  2554 + """
  2555 + dtype = specialized_type.dtype
  2556 + if specialized_type.is_buffer:
  2557 + axes = [('direct', 'strided')] * specialized_type.ndim
  2558 + else:
  2559 + axes = specialized_type.axes
  2560 +
  2561 + memslice_type = PyrexTypes.MemoryViewSliceType(dtype, axes)
  2562 + memslice_type.create_from_py_utility_code(env)
  2563 + pyx_code.context.update(
  2564 + coerce_from_py_func=memslice_type.from_py_function,
  2565 + dtype=dtype)
  2566 + decl_code.putln(
  2567 + "{{memviewslice_cname}} {{coerce_from_py_func}}(object)")
  2568 +
  2569 + pyx_code.context.update(
  2570 + specialized_type_name=specialized_type.specialization_string,
  2571 + sizeof_dtype=self._sizeof_dtype(dtype))
  2572 +
  2573 + pyx_code.put_chunk(
  2574 + u"""
  2575 + # try {{dtype}}
  2576 + if itemsize == -1 or itemsize == {{sizeof_dtype}}:
  2577 + memslice = {{coerce_from_py_func}}(arg)
  2578 + if memslice.memview:
  2579 + __PYX_XDEC_MEMVIEW(&memslice, 1)
  2580 + # print "found a match for the buffer through format parsing"
  2581 + %s
  2582 + break
  2583 + else:
  2584 + PyErr_Clear()
  2585 + """ % self.match)
2490 2586
2491   - body_stmts.append(u"""
2492   - if nargs >= %(nextidx)d or '%(argname)s' in kwargs:
2493   - if nargs >= %(nextidx)d:
2494   - arg = args[%(idx)d]
  2587 + def _buffer_checks(self, buffer_types, pyx_code, decl_code, env):
  2588 + """
  2589 + Generate Cython code to match objects to buffer specializations.
  2590 + First try to get a numpy dtype object and match it against the individual
  2591 + specializations. If that fails, try naively to coerce the object
  2592 + to each specialization, which obtains the buffer each time and tries
  2593 + to match the format string.
  2594 + """
  2595 + from Cython.Compiler import ExprNodes
  2596 + if buffer_types:
  2597 + with pyx_code.indenter(u"else:"):
  2598 + # The first thing to find a match in this loop breaks out of the loop
  2599 + with pyx_code.indenter(u"while 1:"):
  2600 + pyx_code.put_chunk(
  2601 + u"""
  2602 + if numpy is not None:
  2603 + if isinstance(arg, numpy.ndarray):
  2604 + dtype = arg.dtype
  2605 + elif (__pyx_memoryview_check(arg) and
  2606 + isinstance(arg.object, numpy.ndarray)):
  2607 + dtype = arg.object.dtype
  2608 + else:
  2609 + dtype = None
  2610 +
  2611 + itemsize = -1
  2612 + if dtype is not None:
  2613 + itemsize = dtype.itemsize
  2614 + kind = ord(dtype.kind)
  2615 + dtype_signed = kind == ord('u')
  2616 + """)
  2617 + pyx_code.indent(2)
  2618 + pyx_code.named_insertion_point("numpy_dtype_checks")
  2619 + self._buffer_check_numpy_dtype(pyx_code, buffer_types)
  2620 + pyx_code.dedent(2)
  2621 +
  2622 + for specialized_type in buffer_types:
  2623 + self._buffer_parse_format_string_check(
  2624 + pyx_code, decl_code, specialized_type, env)
  2625 +
  2626 + pyx_code.putln(self.no_match)
  2627 + pyx_code.putln("break")
2495 2628 else:
2496   - arg = kwargs['%(argname)s']
2497   -""" % {'idx': i, 'nextidx': i + 1, 'argname': arg.name})
  2629 + pyx_code.putln("else: %s" % self.no_match)
2498 2630
2499   - all_numeric = True
2500   - for specialized_type in specialized_types:
2501   - py_type_name = specialized_type.py_type_name()
  2631 + def _buffer_declarations(self, pyx_code, decl_code, all_buffer_types):
  2632 + """
  2633 + If we have any buffer specializations, write out some variable
  2634 + declarations and imports.
  2635 + """
  2636 + decl_code.put_chunk(
  2637 + u"""
  2638 + ctypedef struct {{memviewslice_cname}}:
  2639 + void *memview
  2640 +
  2641 + void __PYX_XDEC_MEMVIEW({{memviewslice_cname}} *, int have_gil)
  2642 + bint __pyx_memoryview_check(object)
  2643 + """)
  2644 +
  2645 + pyx_code.local_variable_declarations.put_chunk(
  2646 + u"""
  2647 + cdef {{memviewslice_cname}} memslice
  2648 + cdef Py_ssize_t itemsize
  2649 + cdef bint dtype_signed
  2650 + cdef char kind
  2651 +
  2652 + itemsize = -1
  2653 + """)
  2654 +
  2655 + pyx_code.imports.put_chunk(
  2656 + u"""
  2657 + try:
  2658 + import numpy
  2659 + except ImportError:
  2660 + numpy = None
  2661 + """)
  2662 +
  2663 + seen_int_dtypes = set()
  2664 + for buffer_type in all_buffer_types:
  2665 + dtype = buffer_type.dtype
  2666 + if dtype.is_typedef:
  2667 + #decl_code.putln("ctypedef %s %s" % (dtype.resolve(),
  2668 + # self._dtype_name(dtype)))
  2669 + decl_code.putln('ctypedef %s %s "%s"' % (dtype.resolve(),
  2670 + self._dtype_name(dtype),
  2671 + dtype.declaration_code("")))
  2672 +
  2673 + if buffer_type.dtype.is_int:
  2674 + if str(dtype) not in seen_int_dtypes:
  2675 + seen_int_dtypes.add(str(dtype))
  2676 + pyx_code.context.update(dtype_name=self._dtype_name(dtype),
  2677 + dtype_type=self._dtype_type(dtype))
  2678 + pyx_code.local_variable_declarations.put_chunk(
  2679 + u"""
  2680 + cdef bint {{dtype_name}}_is_signed
  2681 + {{dtype_name}}_is_signed = <{{dtype_type}}> -1 < 0
  2682 + """)
  2683 +
  2684 + def _split_fused_types(self, arg):
  2685 + """
  2686 + Specialize fused types and split into normal types and buffer types.
  2687 + """
  2688 + specialized_types = PyrexTypes.get_specialized_types(arg.type)
  2689 + # Prefer long over int, etc
  2690 + # specialized_types.sort()
  2691 + seen_py_type_names = set()
  2692 + normal_types, buffer_types = [], []
  2693 + for specialized_type in specialized_types:
  2694 + py_type_name = specialized_type.py_type_name()
  2695 + if py_type_name:
  2696 + if py_type_name in seen_py_type_names:
  2697 + continue
  2698 + seen_py_type_names.add(py_type_name)
  2699 + normal_types.append(specialized_type)
  2700 + elif specialized_type.is_buffer or specialized_type.is_memoryviewslice:
  2701 + buffer_types.append(specialized_type)
  2702 +
  2703 + return normal_types, buffer_types
  2704 +
  2705 + def _unpack_argument(self, pyx_code):
  2706 + pyx_code.put_chunk(
  2707 + u"""
  2708 + # PROCESSING ARGUMENT {{arg_tuple_idx}}
  2709 + if {{arg_tuple_idx}} < len(args):
  2710 + arg = args[{{arg_tuple_idx}}]
  2711 + elif '{{arg.name}}' in kwargs:
  2712 + arg = kwargs['{{arg.name}}']
  2713 + else:
  2714 + {{if arg.default:}}
  2715 + arg = defaults[{{default_idx}}]
  2716 + {{else}}
  2717 + raise TypeError("Expected at least %d arguments" % len(args))
  2718 + {{endif}}
  2719 + """)
2502 2720
2503   - if not py_type_name or py_type_name in seen_py_type_names:
2504   - continue
  2721 + def make_fused_cpdef(self, orig_py_func, env, is_def):
  2722 + """
  2723 + This creates the function that is indexable from Python and does
  2724 + runtime dispatch based on the argument types. The function gets the
  2725 + arg tuple and kwargs dict (or None) and the defaults tuple
  2726 + as arguments from the Binding Fused Function's tp_call.
  2727 + """
  2728 + from Cython.Compiler import TreeFragment, Code, MemoryView, UtilityCode
2505 2729
2506   - seen_py_type_names.add(py_type_name)
  2730 + # { (arg_pos, FusedType) : specialized_type }
  2731 + seen_fused_types = set()
2507 2732
2508   - all_numeric = all_numeric and specialized_type.is_numeric
  2733 + context = {
  2734 + 'memviewslice_cname': MemoryView.memviewslice_cname,
  2735 + 'func_args': self.node.args,
  2736 + 'n_fused': len([arg for arg in self.node.args]),
  2737 + 'name': orig_py_func.entry.name,
  2738 + }
2509 2739
2510   - if first_check:
2511   - if_ = 'if'
2512   - first_check = False
  2740 + pyx_code = Code.PyxCodeWriter(context=context)
  2741 + decl_code = Code.PyxCodeWriter(context=context)
  2742 + decl_code.put_chunk(
  2743 + u"""
  2744 + cdef extern from *:
  2745 + void PyErr_Clear()
  2746 + """)
  2747 + decl_code.indent()
  2748 +
  2749 + pyx_code.put_chunk(
  2750 + u"""
  2751 + def __pyx_fused_cpdef(signatures, args, kwargs, defaults):
  2752 + import sys
  2753 + if sys.version_info >= (3, 0):
  2754 + long_ = int
  2755 + unicode_ = str
  2756 + bytes_ = bytes
2513 2757 else:
2514   - if_ = 'elif'
2515   -
2516   - # in the case of long, unicode or bytes we need to instance
2517   - # check for long_, unicode_, bytes_ (long = long is no longer
2518   - # valid code with control flow analysis)
2519   - instance_check_py_type_name = py_type_name
2520   - if py_type_name in ('long', 'unicode', 'bytes'):
2521   - instance_check_py_type_name += '_'
2522   -
2523   - tup = (if_, instance_check_py_type_name,
2524   - len(seen_fused_types) - 1,
2525   - specialized_type.typeof_name())
2526   - body_stmts.append(
2527   - " %s isinstance(arg, %s): "
2528   - "dest_sig[%d] = '%s'" % tup)
2529   -
2530   - if arg.default and all_numeric:
2531   - arg.default.analyse_types(env)
  2758 + long_ = long
  2759 + unicode_ = unicode
  2760 + bytes_ = str
2532 2761
2533   - ts = specialized_types
2534   - if arg.default.type.is_complex:
2535   - typelist = [t for t in ts if t.is_complex]
2536   - elif arg.default.type.is_float:
2537   - typelist = [t for t in ts if t.is_float]
2538   - else:
2539   - typelist = [t for t in ts if t.is_int]
  2762 + dest_sig = [None] * {{n_fused}}
2540 2763
2541   - if typelist:
2542   - body_stmts.append(u"""\
2543   - else:
2544   - dest_sig[%d] = '%s'
2545   -""" % (i, typelist[0].typeof_name()))
  2764 + if kwargs is None:
  2765 + kwargs = {}
2546 2766
2547   - fmt_dict = {
2548   - 'body': '\n'.join(body_stmts),
2549   - 'nargs': len(args),
2550   - 'name': orig_py_func.entry.name,
2551   - }
  2767 + cdef Py_ssize_t i
2552 2768
2553   - fragment_code = u"""
2554   -def __pyx_fused_cpdef(signatures, args, kwargs):
2555   - #if len(args) < %(nargs)d:
2556   - # raise TypeError("Invalid number of arguments, expected %(nargs)d, "
2557   - # "got %%d" %% len(args))
2558   - cdef int nargs
2559   - nargs = len(args)
2560   -
2561   - import sys
2562   - if sys.version_info >= (3, 0):
2563   - long_ = int
2564   - unicode_ = str
2565   - bytes_ = bytes
2566   - else:
2567   - long_ = long
2568   - unicode_ = unicode
2569   - bytes_ = str
  2769 + # instance check body
  2770 + """)
  2771 + pyx_code.indent() # indent following code to function body
  2772 + pyx_code.named_insertion_point("imports")
  2773 + pyx_code.named_insertion_point("local_variable_declarations")
2570 2774
2571   - dest_sig = [None] * %(nargs)d
  2775 + fused_index = 0
  2776 + default_idx = 0
  2777 + all_buffer_types = set()
  2778 + for i, arg in enumerate(self.node.args):
  2779 + if arg.type.is_fused and arg.type not in seen_fused_types:
  2780 + seen_fused_types.add(arg.type)
2572 2781
2573   - if kwargs is None:
2574   - kwargs = {}
  2782 + context.update(
  2783 + arg_tuple_idx=i,
  2784 + arg=arg,
  2785 + dest_sig_idx=fused_index,
  2786 + default_idx=default_idx,
  2787 + )
2575 2788
2576   - # instance check body
2577   -%(body)s
2578 2789
2579   - candidates = []
2580   - for sig in signatures:
2581   - match_found = [x for x in dest_sig if x]
2582   - for src_type, dst_type in zip(sig.strip('()').split(', '), dest_sig):
2583   - if dst_type is not None and match_found:
2584   - match_found = src_type == dst_type
  2790 + normal_types, buffer_types = self._split_fused_types(arg)
  2791 + self._unpack_argument(pyx_code)
  2792 + self._fused_instance_checks(normal_types, pyx_code, env)
  2793 + self._buffer_checks(buffer_types, pyx_code, decl_code, env)
  2794 + fused_index += 1
2585 2795
2586   - if match_found:
2587   - candidates.append(sig)
  2796 + all_buffer_types.update(buffer_types)
2588 2797
2589   - if not candidates:
2590   - raise TypeError("No matching signature found")
2591   - elif len(candidates) > 1:
2592   - raise TypeError("Function call with ambiguous argument types")
2593   - else:
2594   - return signatures[candidates[0]]
2595   -""" % fmt_dict
  2798 + if arg.default:
  2799 + default_idx += 1
  2800 +
  2801 + if all_buffer_types:
  2802 + self._buffer_declarations(pyx_code, decl_code, all_buffer_types)
  2803 +
  2804 + pyx_code.put_chunk(
  2805 + u"""
  2806 + candidates = []
  2807 + for sig in signatures:
  2808 + match_found = True
  2809 + for src_type, dst_type in zip(sig.strip('()').split(', '), dest_sig):
  2810 + if dst_type is not None and match_found:
  2811 + match_found = src_type == dst_type
  2812 +
  2813 + if match_found:
  2814 + candidates.append(sig)
  2815 +
  2816 + if not candidates:
  2817 + raise TypeError("No matching signature found")
  2818 + elif len(candidates) > 1:
  2819 + raise TypeError("Function call with ambiguous argument types")
  2820 + else:
  2821 + return signatures[candidates[0]]
  2822 + """)
  2823 +
  2824 + fragment_code = pyx_code.getvalue()
  2825 + # print decl_code.getvalue()
  2826 + # print fragment_code
  2827 + fragment = TreeFragment.TreeFragment(fragment_code.decode('ascii'),
  2828 + level='module')
  2829 + ast = TreeFragment.SetPosTransform(self.node.pos)(fragment.root)
  2830 + UtilityCode.declare_declarations_in_scope(decl_code.getvalue(), env)
  2831 + ast.scope = env
  2832 + ast.analyse_declarations(env)
  2833 + py_func = ast.stats[-1] # the DefNode
  2834 + self.fragment_scope = ast.scope
  2835 +
  2836 + if isinstance(self.node, DefNode):
  2837 + py_func.specialized_cpdefs = self.nodes[:]
  2838 + else:
  2839 + py_func.specialized_cpdefs = [n.py_func for n in self.nodes]
2596 2840
2597   - fragment = TreeFragment.TreeFragment(fragment_code, level='module')
  2841 + return py_func
2598 2842
2599   - # analyse the declarations of our fragment ...
2600   - py_func, = fragment.substitute(pos=self.node.pos).stats
2601   - # Analyse the function object ...
2602   - py_func.analyse_declarations(env)
2603   - # ... and its body
2604   - py_func.scope = env
  2843 + def update_fused_defnode_entry(self, env):
  2844 + import ExprNodes
2605 2845
2606   - # Will be analysed later by underlying AnalyseDeclarationsTransform
2607   - #ParseTreeTransforms.AnalyseDeclarationsTransform(None)(py_func)
  2846 + copy_attributes = (
  2847 + 'name', 'pos', 'cname', 'func_cname', 'pyfunc_cname',
  2848 + 'pymethdef_cname', 'doc', 'doc_cname', 'is_member',
  2849 + 'scope'
  2850 + )
2608 2851
2609   - e, orig_e = py_func.entry, orig_py_func.entry
  2852 + entry = self.py_func.entry
2610 2853
2611   - # Update the new entry ...
2612   - py_func.name = e.name = orig_e.name
2613   - e.cname, e.func_cname = orig_e.cname, orig_e.func_cname
2614   - e.pymethdef_cname = orig_e.pymethdef_cname
2615   - e.doc, e.doc_cname = orig_e.doc, orig_e.doc_cname
2616   - # e.signature = TypeSlots.binaryfunc
  2854 + for attr in copy_attributes:
  2855 + setattr(entry, attr,
  2856 + getattr(self.orig_py_func.entry, attr))
2617 2857
2618   - py_func.doc = orig_py_func.doc
  2858 + self.py_func.name = self.orig_py_func.name
  2859 + self.py_func.doc = self.orig_py_func.doc
2619 2860
2620   - # ... and the symbol table
2621 2861 env.entries.pop('__pyx_fused_cpdef', None)
2622   - if is_def:
2623   - env.entries[e.name] = e
  2862 + if isinstance(self.node, DefNode):
  2863 + env.entries[entry.name] = entry
2624 2864 else:
2625   - env.entries[e.name].as_variable = e
  2865 + env.entries[entry.name].as_variable = entry
2626 2866
2627   - env.pyfunc_entries.append(e)
  2867 + env.pyfunc_entries.append(entry)
2628 2868
2629   - if is_def:
2630   - py_func.specialized_cpdefs = self.nodes[:]
2631   - else:
2632   - py_func.specialized_cpdefs = [n.py_func for n in self.nodes]
  2869 + self.py_func.entry.fused_cfunction = self
  2870 + for node in self.nodes:
  2871 + if isinstance(self.node, DefNode):
  2872 + node.fused_py_func = self.py_func
  2873 + else:
  2874 + node.py_func.fused_py_func = self.py_func
  2875 + node.entry.as_variable = entry
2633 2876
2634   - return py_func
  2877 + self.synthesize_defnodes()
  2878 + self.stats.append(self.__signatures__)
  2879 +
  2880 + env.use_utility_code(ExprNodes.import_utility_code)
2635 2881
2636 2882 def analyse_expressions(self, env):
2637 2883 """
1  Cython/Compiler/ParseTreeTransforms.py
@@ -1498,6 +1498,7 @@ def visit_FuncDefNode(self, node):
1498 1498 # Create PyCFunction nodes for each specialization
1499 1499 node.stats.insert(0, node.py_func)
1500 1500 node.py_func = self.visit(node.py_func)
  1501 + node.update_fused_defnode_entry(env)
1501 1502 pycfunc = ExprNodes.PyCFunctionNode.from_defnode(node.py_func,
1502 1503 True)
1503 1504 pycfunc = ExprNodes.ProxyNode(pycfunc.coerce_to_temp(env))
6 Cython/Compiler/Symtab.py
@@ -362,11 +362,11 @@ def builtin_scope(self):
362 362 # Return the module-level scope containing this scope.
363 363 return self.outer_scope.builtin_scope()
364 364
365   - def declare(self, name, cname, type, pos, visibility, shadow = 0):
  365 + def declare(self, name, cname, type, pos, visibility, shadow = 0, is_type = 0):
366 366 # Create new entry, and add to dictionary if
367 367 # name is not None. Reports a warning if already
368 368 # declared.
369   - if type.is_buffer and not isinstance(self, LocalScope):
  369 + if type.is_buffer and not isinstance(self, LocalScope) and not is_type:
370 370 error(pos, 'Buffer types only allowed as function local variables')
371 371 if not self.in_cinclude and cname and re.match("^_[_A-Z]+$", cname):
372 372 # See http://www.gnu.org/software/libc/manual/html_node/Reserved-Names.html#Reserved-Names
@@ -417,7 +417,7 @@ def declare_type(self, name, type, pos,
417 417 # Add an entry for a type definition.
418 418 if not cname:
419 419 cname = name
420   - entry = self.declare(name, cname, type, pos, visibility, shadow)
  420 + entry = self.declare(name, cname, type, pos, visibility, shadow, True)
421 421 entry.is_type = 1
422 422 entry.api = api
423 423 if defining:
10 Cython/Compiler/TreeFragment.py
@@ -231,6 +231,12 @@ def substitute(self, nodes={}, temps=[], pos = None):
231 231 substitutions = nodes,
232 232 temps = self.temps + temps, pos = pos)
233 233
  234 +class SetPosTransform(VisitorTransform):
  235 + def __init__(self, pos):
  236 + super(SetPosTransform, self).__init__()
  237 + self.pos = pos
234 238
235   -
236   -
  239 + def visit_Node(self, node):
  240 + node.pos = self.pos
  241 + self.visitchildren(node)
  242 + return node
8 Cython/Compiler/UtilityCode.py
@@ -167,3 +167,11 @@ def declare_in_scope(self, dest_scope, used=False, cython_scope=None,
167 167 dep.declare_in_scope(dest_scope)
168 168
169 169 return original_scope
  170 +
  171 +def declare_declarations_in_scope(declaration_string, env, private_type=True,
  172 + *args, **kwargs):
  173 + """
  174 + Declare some declarations given as Cython code in declaration_string
  175 + in scope env.
  176 + """
  177 + CythonUtilityCode(declaration_string, *args, **kwargs).declare_in_scope(env)
7 Cython/Utility/CythonFunction.c
@@ -740,8 +740,6 @@ __pyx_FusedFunction_callfunction(PyObject *func, PyObject *args, PyObject *kw)
740 740 int static_specialized = (cyfunc->flags & __Pyx_CYFUNCTION_STATICMETHOD &&
741 741 !((__pyx_FusedFunctionObject *) func)->__signatures__);
742 742
743   - //PyObject_Print(args, stdout, Py_PRINT_RAW);
744   -
745 743 if (cyfunc->flags & __Pyx_CYFUNCTION_CCLASS && !static_specialized) {
746 744 Py_ssize_t argc;
747 745 PyObject *new_args;
@@ -827,8 +825,9 @@ __pyx_FusedFunction_call(PyObject *func, PyObject *args, PyObject *kw)
827 825 }
828 826
829 827 if (binding_func->__signatures__) {
830   - PyObject *tup = PyTuple_Pack(3, binding_func->__signatures__, args,
831   - kw == NULL ? Py_None : kw);
  828 + PyObject *tup = PyTuple_Pack(4, binding_func->__signatures__, args,
  829 + kw == NULL ? Py_None : kw,
  830 + binding_func->func.defaults_tuple);
832 831 if (!tup)
833 832 goto __pyx_err;
834 833
246 tests/run/numpy_test.pyx
@@ -4,6 +4,8 @@
4 4 cimport numpy as np
5 5 cimport cython
6 6
  7 +from libc.stdlib cimport malloc
  8 +
7 9 def little_endian():
8 10 cdef int endian_detector = 1
9 11 return (<char*>&endian_detector)[0] != 0
@@ -503,19 +505,28 @@ def test_point_record():
503 505 test[i].y = -i
504 506 print repr(test).replace('<', '!').replace('>', '!')
505 507
506   -def test_fused_ndarray_dtype(np.ndarray[cython.floating, ndim=1] a):
  508 +# Test fused np.ndarray dtypes and runtime dispatch
  509 +def test_fused_ndarray_floating_dtype(np.ndarray[cython.floating, ndim=1] a):
507 510 """
508 511 >>> import cython
509   - >>> sorted(test_fused_ndarray_dtype.__signatures__)
  512 + >>> sorted(test_fused_ndarray_floating_dtype.__signatures__)
510 513 ['double', 'float']
511   - >>> test_fused_ndarray_dtype[cython.double](np.arange(10, dtype=np.float64))
  514 +
  515 +
  516 + >>> test_fused_ndarray_floating_dtype[cython.double](np.arange(10, dtype=np.float64))
512 517 ndarray[double,ndim=1] ndarray[double,ndim=1] 5.0 6.0
513   - >>> test_fused_ndarray_dtype[cython.float](np.arange(10, dtype=np.float32))
  518 + >>> test_fused_ndarray_floating_dtype(np.arange(10, dtype=np.float64))
  519 + ndarray[double,ndim=1] ndarray[double,ndim=1] 5.0 6.0
  520 +
  521 + >>> test_fused_ndarray_floating_dtype[cython.float](np.arange(10, dtype=np.float32))
  522 + ndarray[float,ndim=1] ndarray[float,ndim=1] 5.0 6.0
  523 + >>> test_fused_ndarray_floating_dtype(np.arange(10, dtype=np.float32))
514 524 ndarray[float,ndim=1] ndarray[float,ndim=1] 5.0 6.0
515 525 """
516 526 cdef np.ndarray[cython.floating, ndim=1] b = a
517 527 print cython.typeof(a), cython.typeof(b), a[5], b[6]
518 528
  529 +
519 530 double_array = np.linspace(0, 1, 100)
520 531 int32_array = np.arange(100, dtype=np.int32)
521 532
@@ -568,4 +579,231 @@ def test_fused_cpdef_buffers():
568 579 cdef np.ndarray[np.int32_t] typed_array = int32_array
569 580 _fused_cpdef_buffers(typed_array)
570 581
  582 +def test_fused_ndarray_integral_dtype(np.ndarray[cython.integral, ndim=1] a):
  583 + """
  584 + >>> import cython
  585 + >>> sorted(test_fused_ndarray_integral_dtype.__signatures__)
  586 + ['int', 'long', 'short']
  587 +
  588 + >>> test_fused_ndarray_integral_dtype[cython.int](np.arange(10, dtype=np.dtype('i')))
  589 + ndarray[int,ndim=1] ndarray[int,ndim=1] 5 6
  590 + >>> test_fused_ndarray_integral_dtype(np.arange(10, dtype=np.dtype('i')))
  591 + ndarray[int,ndim=1] ndarray[int,ndim=1] 5 6
  592 +
  593 + >>> test_fused_ndarray_integral_dtype[cython.long](np.arange(10, dtype=np.long))
  594 + ndarray[long,ndim=1] ndarray[long,ndim=1] 5 6
  595 + >>> test_fused_ndarray_integral_dtype(np.arange(10, dtype=np.long))
  596 + ndarray[long,ndim=1] ndarray[long,ndim=1] 5 6
  597 + """
  598 + cdef np.ndarray[cython.integral, ndim=1] b = a
  599 + print cython.typeof(a), cython.typeof(b), a[5], b[6]
  600 +
  601 +cdef fused fused_dtype:
  602 + float complex
  603 + double complex
  604 + object
  605 +
  606 +def test_fused_ndarray_other_dtypes(np.ndarray[fused_dtype, ndim=1] a):
  607 + """
  608 + >>> import cython
  609 + >>> sorted(test_fused_ndarray_other_dtypes.__signatures__)
  610 + ['double complex', 'float complex', 'object']
  611 + >>> test_fused_ndarray_other_dtypes(np.arange(10, dtype=np.complex64))
  612 + ndarray[float complex,ndim=1] ndarray[float complex,ndim=1] (5+0j) (6+0j)
  613 + >>> test_fused_ndarray_other_dtypes(np.arange(10, dtype=np.complex128))
  614 + ndarray[double complex,ndim=1] ndarray[double complex,ndim=1] (5+0j) (6+0j)
  615 + >>> test_fused_ndarray_other_dtypes(np.arange(10, dtype=np.object))
  616 + ndarray[Python object,ndim=1] ndarray[Python object,ndim=1] 5 6
  617 + """
  618 + cdef np.ndarray[fused_dtype, ndim=1] b = a
  619 + print cython.typeof(a), cython.typeof(b), a[5], b[6]
  620 +
  621 +
  622 +# Test fusing the array types together and runtime dispatch
  623 +cdef struct Foo:
  624 + int a
  625 + float b
  626 +
  627 +cdef fused fused_FooArray:
  628 + np.ndarray[Foo, ndim=1]
  629 +
  630 +cdef fused fused_ndarray:
  631 + np.ndarray[float, ndim=1]
  632 + np.ndarray[double, ndim=1]
  633 + np.ndarray[Foo, ndim=1]
  634 +
  635 +def get_Foo_array():
  636 + cdef Foo[:] result = <Foo[:10]> malloc(sizeof(Foo) * 10)
  637 + result[5].b = 9.0
  638 + return np.asarray(result)
  639 +
  640 +def test_fused_ndarray(fused_ndarray a):
  641 + """
  642 + >>> import cython
  643 + >>> sorted(test_fused_ndarray.__signatures__)
  644 + ['ndarray[Foo,ndim=1]', 'ndarray[double,ndim=1]', 'ndarray[float,ndim=1]']
  645 +
  646 + >>> test_fused_ndarray(get_Foo_array())
  647 + ndarray[Foo,ndim=1] ndarray[Foo,ndim=1]
  648 + 9.0
  649 + >>> test_fused_ndarray(np.arange(10, dtype=np.float64))
  650 + ndarray[double,ndim=1] ndarray[double,ndim=1]
  651 + 5.0
  652 + >>> test_fused_ndarray(np.arange(10, dtype=np.float32))
  653 + ndarray[float,ndim=1] ndarray[float,ndim=1]
  654 + 5.0
  655 + """
  656 + cdef fused_ndarray b = a
  657 + print cython.typeof(a), cython.typeof(b)
  658 +
  659 + if fused_ndarray in fused_FooArray:
  660 + print b[5].b
  661 + else:
  662 + print b[5]
  663 +
  664 +cpdef test_fused_cpdef_ndarray(fused_ndarray a):
  665 + """
  666 + >>> import cython
  667 + >>> sorted(test_fused_cpdef_ndarray.__signatures__)
  668 + ['ndarray[Foo,ndim=1]', 'ndarray[double,ndim=1]', 'ndarray[float,ndim=1]']
  669 +
  670 + >>> test_fused_cpdef_ndarray(get_Foo_array())
  671 + ndarray[Foo,ndim=1] ndarray[Foo,ndim=1]
  672 + 9.0
  673 + >>> test_fused_cpdef_ndarray(np.arange(10, dtype=np.float64))
  674 + ndarray[double,ndim=1] ndarray[double,ndim=1]
  675 + 5.0
  676 + >>> test_fused_cpdef_ndarray(np.arange(10, dtype=np.float32))
  677 + ndarray[float,ndim=1] ndarray[float,ndim=1]
  678 + 5.0
  679 + """
  680 + cdef fused_ndarray b = a
  681 + print cython.typeof(a), cython.typeof(b)
  682 +
  683 + if fused_ndarray in fused_FooArray:
  684 + print b[5].b
  685 + else:
  686 + print b[5]
  687 +
  688 +def test_fused_cpdef_ndarray_cdef_call():
  689 + """
  690 + >>> test_fused_cpdef_ndarray_cdef_call()
  691 + ndarray[Foo,ndim=1] ndarray[Foo,ndim=1]
  692 + 9.0
  693 + """
  694 + cdef np.ndarray[Foo, ndim=1] foo_array = get_Foo_array()
  695 + test_fused_cpdef_ndarray(foo_array)
  696 +
  697 +cdef fused int_type:
  698 + np.int32_t
  699 + np.int64_t
  700 +
  701 +float64_array = np.arange(10, dtype=np.float64)
  702 +float32_array = np.arange(10, dtype=np.float32)
  703 +int32_array = np.arange(10, dtype=np.int32)
  704 +int64_array = np.arange(10, dtype=np.int64)
  705 +
  706 +def test_dispatch_non_clashing_declarations_repeating_types(np.ndarray[cython.floating] a1,
  707 + np.ndarray[int_type] a2,
  708 + np.ndarray[cython.floating] a3,
  709 + np.ndarray[int_type] a4):
  710 + """
  711 + >>> test_dispatch_non_clashing_declarations_repeating_types(float64_array, int32_array, float64_array, int32_array)
  712 + 1.0 2 3.0 4
  713 + >>> test_dispatch_non_clashing_declarations_repeating_types(float64_array, int64_array, float64_array, int64_array)
  714 + 1.0 2 3.0 4
  715 + >>> test_dispatch_non_clashing_declarations_repeating_types(float64_array, int32_array, float64_array, int64_array)
  716 + Traceback (most recent call last):
  717 + ...
  718 + TypeError: No matching signature found
  719 + """
  720 + print a1[1], a2[2], a3[3], a4[4]
  721 +
  722 +ctypedef np.int32_t typedeffed_type
  723 +
  724 +cdef fused typedeffed_fused_type:
  725 + typedeffed_type
  726 + int
  727 + long
  728 +
  729 +def test_dispatch_typedef(np.ndarray[typedeffed_fused_type] a):
  730 + """
  731 + >>> test_dispatch_typedef(int32_array)
  732 + 5
  733 + """
  734 + print a[5]
  735 +
  736 +
  737 +cdef extern from "types.h":
  738 + ctypedef unsigned char actually_long_t
  739 +
  740 +cdef fused confusing_fused_typedef:
  741 + actually_long_t
  742 + unsigned char
  743 + signed char
  744 +
  745 +def test_dispatch_external_typedef(np.ndarray[confusing_fused_typedef] a):
  746 + """
  747 + >>> test_dispatch_external_typedef(np.arange(10, dtype=np.long))
  748 + 5
  749 + """
  750 + print a[5]
  751 +
  752 +# test fused memoryview slices
  753 +cdef fused memslice_fused_dtype:
  754 + float
  755 + double
  756 + int
  757 + long
  758 + float complex
  759 + double complex
  760 + object
  761 +
  762 +def test_fused_memslice_other_dtypes(memslice_fused_dtype[:] a):
  763 + """
  764 + >>> import cython
  765 + >>> sorted(test_fused_memslice_other_dtypes.__signatures__)
  766 + ['double', 'double complex', 'float', 'float complex', 'int', 'long', 'object']
  767 + >>> test_fused_memslice_other_dtypes(np.arange(10, dtype=np.complex64))
  768 + float complex[:] float complex[:] (5+0j) (6+0j)
  769 + >>> test_fused_memslice_other_dtypes(np.arange(10, dtype=np.complex128))
  770 + double complex[:] double complex[:] (5+0j) (6+0j)
  771 + >>> test_fused_memslice_other_dtypes(np.arange(10, dtype=np.float32))
  772 + float[:] float[:] 5.0 6.0
  773 + >>> test_fused_memslice_other_dtypes(np.arange(10, dtype=np.dtype('i')))
  774 + int[:] int[:] 5 6
  775 + >>> test_fused_memslice_other_dtypes(np.arange(10, dtype=np.object))
  776 + object[:] object[:] 5 6
  777 + """
  778 + cdef memslice_fused_dtype[:] b = a
  779 + print cython.typeof(a), cython.typeof(b), a[5], b[6]
  780 +
  781 +cdef fused memslice_fused:
  782 + float[:]
  783 + double[:]
  784 + int[:]
  785 + long[:]
  786 + float complex[:]
  787 + double complex[:]
  788 + object[:]
  789 +
  790 +def test_fused_memslice_fused(memslice_fused a):
  791 + """
  792 + >>> import cython
  793 + >>> sorted(test_fused_memslice_fused.__signatures__)
  794 + ['double complex[:]', 'double[:]', 'float complex[:]', 'float[:]', 'int[:]', 'long[:]', 'object[:]']
  795 + >>> test_fused_memslice_fused(np.arange(10, dtype=np.complex64))
  796 + float complex[:] float complex[:] (5+0j) (6+0j)
  797 + >>> test_fused_memslice_fused(np.arange(10, dtype=np.complex128))
  798 + double complex[:] double complex[:] (5+0j) (6+0j)
  799 + >>> test_fused_memslice_fused(np.arange(10, dtype=np.float32))
  800 + float[:] float[:] 5.0 6.0
  801 + >>> test_fused_memslice_fused(np.arange(10, dtype=np.dtype('i')))
  802 + int[:] int[:] 5 6
  803 + >>> test_fused_memslice_fused(np.arange(10, dtype=np.object))
  804 + object[:] object[:] 5 6
  805 + """
  806 + cdef memslice_fused b = a
  807 + print cython.typeof(a), cython.typeof(b), a[5], b[6]
  808 +
571 809 include "numpy_common.pxi"

0 comments on commit 39c966d

Please sign in to comment.
Something went wrong with that request. Please try again.