Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Better fused buffer runtime dispatch + dispatch restructuring + PyxCo…

…deWriter
  • Loading branch information...
commit 39c966d2e2a7ec860f1c57e509763e85bc6506c4 1 parent a3230e4
@markflorisson markflorisson authored
View
60 Cython/Compiler/Code.py
@@ -14,6 +14,7 @@
import sys
from string import Template
import operator
+import textwrap
import Naming
import Options
@@ -376,7 +377,7 @@ def put_code(self, output):
self.cleanup(writer, output.module_pos)
-def sub_tempita(s, context, file, name):
+def sub_tempita(s, context, file=None, name=None):
"Run tempita on string s with given context."
if not s:
return None
@@ -1940,6 +1941,63 @@ def indent(self):
def dedent(self):
self.level -= 1
+class PyxCodeWriter(object):
+ """
+ Can be used for writing out some Cython code.
+ """
+
+ def __init__(self, buffer=None, indent_level=0, context=None):
+ self.buffer = buffer or StringIOTree()
+ self.level = indent_level
+ self.context = context
+ self.encoding = 'ascii'
+
+ def indent(self, levels=1):
+ self.level += levels
+
+ def dedent(self, levels=1):
+ self.level -= levels
+
+ def indenter(self, line):
+ """
+ with pyx_code.indenter("for i in range(10):"):
+ pyx_code.putln("print i")
+ """
+ self.putln(line)
+ return self
+
+ def getvalue(self):
+ return unicode(self.buffer.getvalue(), self.encoding)
+
+ def putln(self, line, context=None):
+ context = context or self.context
+ if context:
+ line = sub_tempita(line, context)
+ self._putln(line)
+
+ def _putln(self, line):
+ self.buffer.write("%s%s\n" % (self.level * " ", line))
+
+ def put_chunk(self, chunk, context=None):
+ context = context or self.context
+ if context:
+ chunk = sub_tempita(chunk, context)
+
+ chunk = textwrap.dedent(chunk)
+ for line in chunk.splitlines():
+ self._putln(line)
+
+ def insertion_point(self):
+ return PyxCodeWriter(self.buffer.insertion_point(), self.level,
+ self.context)
+
+ def named_insertion_point(self, name):
+ setattr(self, name, self.insertion_point())
+
+ __enter__ = indent
+
+ def __exit__(self, exc_value, exc_type, exc_tb):
+ self.dedent()
class ClosureTempAllocator(object):
def __init__(self, klass):
View
546 Cython/Compiler/Nodes.py
@@ -2277,24 +2277,11 @@ def __init__(self, node, env):
assert n.type.op_arg_struct
node.entry.fused_cfunction = self
-
- if self.py_func:
- self.py_func.entry.fused_cfunction = self
- for node in self.nodes:
- if is_def:
- node.fused_py_func = self.py_func
- else:
- node.py_func.fused_py_func = self.py_func
- node.entry.as_variable = self.py_func.entry
# Copy the nodes as AnalyseDeclarationsTransform will prepend
# self.py_func to self.stats, as we only want specialized
# CFuncDefNodes in self.nodes
self.stats = self.nodes[:]
- if self.py_func:
- self.synthesize_defnodes()
- self.stats.append(self.__signatures__)
-
def copy_def(self, env):
"""
Create a copy of the original def or lambda function for specialized
@@ -2326,6 +2313,7 @@ def copy_def(self, env):
if not self.replace_fused_typechecks(copied_node):
break
+ self.orig_py_func = self.node
self.py_func = self.make_fused_cpdef(self.node, env, is_def=True)
def copy_cdef(self, env):
@@ -2342,7 +2330,7 @@ def copy_cdef(self, env):
env.cfunc_entries.remove(self.node.entry)
# Prevent copying of the python function
- orig_py_func = self.node.py_func
+ self.orig_py_func = orig_py_func = self.node.py_func
self.node.py_func = None
if orig_py_func:
env.pyfunc_entries.remove(orig_py_func.entry)
@@ -2459,179 +2447,437 @@ def replace_fused_typechecks(self, copied_node):
return True
- def make_fused_cpdef(self, orig_py_func, env, is_def):
+ def _fused_instance_checks(self, normal_types, pyx_code, env):
"""
- This creates the function that is indexable from Python and does
- runtime dispatch based on the argument types. The function gets the
- arg tuple and kwargs dict (or None) as arugments from the Binding
- Fused Function's tp_call.
+ Genereate Cython code for instance checks, matching an object to
+ specialized types.
"""
- from Cython.Compiler import TreeFragment
- from Cython.Compiler import ParseTreeTransforms
+ if_ = 'if'
+ for specialized_type in normal_types:
+ # all_numeric = all_numeric and specialized_type.is_numeric
+ py_type_name = specialized_type.py_type_name()
+
+ # in the case of long, unicode or bytes we need to instance
+ # check for long_, unicode_, bytes_ (long = long is no longer
+ # valid code with control flow analysis)
+ specialized_check_name = py_type_name
+ if py_type_name in ('long', 'unicode', 'bytes'):
+ specialized_check_name += '_'
+
+ specialized_type_name = specialized_type.specialization_string
+ pyx_code.context.update(locals())
+ pyx_code.put_chunk(
+ u"""
+ {{if_}} isinstance(arg, {{specialized_check_name}}):
+ dest_sig[{{dest_sig_idx}}] = '{{specialized_type_name}}'
+ """)
+ if_ = 'elif'
+
+ if not normal_types:
+ # we need an 'if' to match the following 'else'
+ pyx_code.putln("if 0: pass")
+
+ def _dtype_name(self, dtype):
+ if dtype.is_typedef:
+ return '___pyx_%s' % dtype
+ return str(dtype).replace(' ', '_')
+
+ def _dtype_type(self, dtype):
+ if dtype.is_typedef:
+ return self._dtype_name(dtype)
+ return str(dtype)
+
+ def _sizeof_dtype(self, dtype):
+ if dtype.is_pyobject:
+ return 'sizeof(void *)'
+ else:
+ return "sizeof(%s)" % self._dtype_type(dtype)
- # { (arg_pos, FusedType) : specialized_type }
- seen_fused_types = set()
+ def _buffer_check_numpy_dtype_setup_cases(self, pyx_code):
+ "Setup some common cases to match dtypes against specializations"
+ with pyx_code.indenter("if dtype.kind in ('i', 'u'):"):
+ pyx_code.putln("pass")
+ pyx_code.named_insertion_point("dtype_int")
- # list of statements that do the instance checks
- body_stmts = []
+ with pyx_code.indenter("elif dtype.kind == 'f':"):
+ pyx_code.putln("pass")
+ pyx_code.named_insertion_point("dtype_float")
- args = self.node.args
- for i, arg in enumerate(args):
- arg_type = arg.type
- if arg_type.is_fused and arg_type not in seen_fused_types:
- seen_fused_types.add(arg_type)
+ with pyx_code.indenter("elif dtype.kind == 'c':"):
+ pyx_code.putln("pass")
+ pyx_code.named_insertion_point("dtype_complex")
- specialized_types = PyrexTypes.get_specialized_types(arg_type)
- # Prefer long over int, etc
- # specialized_types.sort()
+ with pyx_code.indenter("elif dtype.kind == 'O':"):
+ pyx_code.putln("pass")
+ pyx_code.named_insertion_point("dtype_object")
- seen_py_type_names = set()
- first_check = True
+ match = "dest_sig[{{dest_sig_idx}}] = '{{specialized_type_name}}'"
+ no_match = "dest_sig[{{dest_sig_idx}}] = None"
+ def _buffer_check_numpy_dtype(self, pyx_code, specialized_buffer_types):
+ """
+ Match a numpy dtype object to the individual specializations.
+ """
+ self._buffer_check_numpy_dtype_setup_cases(pyx_code)
+
+ for specialized_type in specialized_buffer_types:
+ dtype = specialized_type.dtype
+ pyx_code.context.update(
+ itemsize_match=self._sizeof_dtype(dtype) + " == itemsize",
+ signed_match="not (%s_is_signed ^ dtype_signed)" % self._dtype_name(dtype),
+ dtype=dtype,
+ specialized_type_name=specialized_type.specialization_string)
+
+ dtypes = [
+ (dtype.is_int, pyx_code.dtype_int),
+ (dtype.is_float, pyx_code.dtype_float),
+ (dtype.is_complex, pyx_code.dtype_complex)
+ ]
+
+ for dtype_category, codewriter in dtypes:
+ if dtype_category:
+ cond = '{{itemsize_match}}'
+ if dtype.is_int:
+ cond += ' and {{signed_match}}'
+
+ with codewriter.indenter("if %s:" % cond):
+ # codewriter.putln("print 'buffer match found based on numpy dtype'")
+ codewriter.putln(self.match)
+ codewriter.putln("break")
+
+ def _buffer_parse_format_string_check(self, pyx_code, decl_code,
+ specialized_type, env):
+ """
+ For each specialized type, try to coerce the object to a memoryview
+ slice of that type. This means obtaining a buffer and parsing the
+ format string.
+ TODO: separate buffer acquisition from format parsing
+ """
+ dtype = specialized_type.dtype
+ if specialized_type.is_buffer:
+ axes = [('direct', 'strided')] * specialized_type.ndim
+ else:
+ axes = specialized_type.axes
+
+ memslice_type = PyrexTypes.MemoryViewSliceType(dtype, axes)
+ memslice_type.create_from_py_utility_code(env)
+ pyx_code.context.update(
+ coerce_from_py_func=memslice_type.from_py_function,
+ dtype=dtype)
+ decl_code.putln(
+ "{{memviewslice_cname}} {{coerce_from_py_func}}(object)")
+
+ pyx_code.context.update(
+ specialized_type_name=specialized_type.specialization_string,
+ sizeof_dtype=self._sizeof_dtype(dtype))
+
+ pyx_code.put_chunk(
+ u"""
+ # try {{dtype}}
+ if itemsize == -1 or itemsize == {{sizeof_dtype}}:
+ memslice = {{coerce_from_py_func}}(arg)
+ if memslice.memview:
+ __PYX_XDEC_MEMVIEW(&memslice, 1)
+ # print "found a match for the buffer through format parsing"
+ %s
+ break
+ else:
+ PyErr_Clear()
+ """ % self.match)
- body_stmts.append(u"""
- if nargs >= %(nextidx)d or '%(argname)s' in kwargs:
- if nargs >= %(nextidx)d:
- arg = args[%(idx)d]
+ def _buffer_checks(self, buffer_types, pyx_code, decl_code, env):
+ """
+ Generate Cython code to match objects to buffer specializations.
+ First try to get a numpy dtype object and match it against the individual
+ specializations. If that fails, try naively to coerce the object
+ to each specialization, which obtains the buffer each time and tries
+ to match the format string.
+ """
+ from Cython.Compiler import ExprNodes
+ if buffer_types:
+ with pyx_code.indenter(u"else:"):
+ # The first thing to find a match in this loop breaks out of the loop
+ with pyx_code.indenter(u"while 1:"):
+ pyx_code.put_chunk(
+ u"""
+ if numpy is not None:
+ if isinstance(arg, numpy.ndarray):
+ dtype = arg.dtype
+ elif (__pyx_memoryview_check(arg) and
+ isinstance(arg.object, numpy.ndarray)):
+ dtype = arg.object.dtype
+ else:
+ dtype = None
+
+ itemsize = -1
+ if dtype is not None:
+ itemsize = dtype.itemsize
+ kind = ord(dtype.kind)
+ dtype_signed = kind == ord('u')
+ """)
+ pyx_code.indent(2)
+ pyx_code.named_insertion_point("numpy_dtype_checks")
+ self._buffer_check_numpy_dtype(pyx_code, buffer_types)
+ pyx_code.dedent(2)
+
+ for specialized_type in buffer_types:
+ self._buffer_parse_format_string_check(
+ pyx_code, decl_code, specialized_type, env)
+
+ pyx_code.putln(self.no_match)
+ pyx_code.putln("break")
else:
- arg = kwargs['%(argname)s']
-""" % {'idx': i, 'nextidx': i + 1, 'argname': arg.name})
+ pyx_code.putln("else: %s" % self.no_match)
- all_numeric = True
- for specialized_type in specialized_types:
- py_type_name = specialized_type.py_type_name()
+ def _buffer_declarations(self, pyx_code, decl_code, all_buffer_types):
+ """
+ If we have any buffer specializations, write out some variable
+ declarations and imports.
+ """
+ decl_code.put_chunk(
+ u"""
+ ctypedef struct {{memviewslice_cname}}:
+ void *memview
+
+ void __PYX_XDEC_MEMVIEW({{memviewslice_cname}} *, int have_gil)
+ bint __pyx_memoryview_check(object)
+ """)
+
+ pyx_code.local_variable_declarations.put_chunk(
+ u"""
+ cdef {{memviewslice_cname}} memslice
+ cdef Py_ssize_t itemsize
+ cdef bint dtype_signed
+ cdef char kind
+
+ itemsize = -1
+ """)
+
+ pyx_code.imports.put_chunk(
+ u"""
+ try:
+ import numpy
+ except ImportError:
+ numpy = None
+ """)
+
+ seen_int_dtypes = set()
+ for buffer_type in all_buffer_types:
+ dtype = buffer_type.dtype
+ if dtype.is_typedef:
+ #decl_code.putln("ctypedef %s %s" % (dtype.resolve(),
+ # self._dtype_name(dtype)))
+ decl_code.putln('ctypedef %s %s "%s"' % (dtype.resolve(),
+ self._dtype_name(dtype),
+ dtype.declaration_code("")))
+
+ if buffer_type.dtype.is_int:
+ if str(dtype) not in seen_int_dtypes:
+ seen_int_dtypes.add(str(dtype))
+ pyx_code.context.update(dtype_name=self._dtype_name(dtype),
+ dtype_type=self._dtype_type(dtype))
+ pyx_code.local_variable_declarations.put_chunk(
+ u"""
+ cdef bint {{dtype_name}}_is_signed
+ {{dtype_name}}_is_signed = <{{dtype_type}}> -1 < 0
+ """)
+
+ def _split_fused_types(self, arg):
+ """
+ Specialize fused types and split into normal types and buffer types.
+ """
+ specialized_types = PyrexTypes.get_specialized_types(arg.type)
+ # Prefer long over int, etc
+ # specialized_types.sort()
+ seen_py_type_names = set()
+ normal_types, buffer_types = [], []
+ for specialized_type in specialized_types:
+ py_type_name = specialized_type.py_type_name()
+ if py_type_name:
+ if py_type_name in seen_py_type_names:
+ continue
+ seen_py_type_names.add(py_type_name)
+ normal_types.append(specialized_type)
+ elif specialized_type.is_buffer or specialized_type.is_memoryviewslice:
+ buffer_types.append(specialized_type)
+
+ return normal_types, buffer_types
+
+ def _unpack_argument(self, pyx_code):
+ pyx_code.put_chunk(
+ u"""
+ # PROCESSING ARGUMENT {{arg_tuple_idx}}
+ if {{arg_tuple_idx}} < len(args):
+ arg = args[{{arg_tuple_idx}}]
+ elif '{{arg.name}}' in kwargs:
+ arg = kwargs['{{arg.name}}']
+ else:
+ {{if arg.default:}}
+ arg = defaults[{{default_idx}}]
+ {{else}}
+ raise TypeError("Expected at least %d arguments" % len(args))
+ {{endif}}
+ """)
- if not py_type_name or py_type_name in seen_py_type_names:
- continue
+ def make_fused_cpdef(self, orig_py_func, env, is_def):
+ """
+ This creates the function that is indexable from Python and does
+ runtime dispatch based on the argument types. The function gets the
+ arg tuple and kwargs dict (or None) and the defaults tuple
+ as arguments from the Binding Fused Function's tp_call.
+ """
+ from Cython.Compiler import TreeFragment, Code, MemoryView, UtilityCode
- seen_py_type_names.add(py_type_name)
+ # { (arg_pos, FusedType) : specialized_type }
+ seen_fused_types = set()
- all_numeric = all_numeric and specialized_type.is_numeric
+ context = {
+ 'memviewslice_cname': MemoryView.memviewslice_cname,
+ 'func_args': self.node.args,
+ 'n_fused': len([arg for arg in self.node.args]),
+ 'name': orig_py_func.entry.name,
+ }
- if first_check:
- if_ = 'if'
- first_check = False
+ pyx_code = Code.PyxCodeWriter(context=context)
+ decl_code = Code.PyxCodeWriter(context=context)
+ decl_code.put_chunk(
+ u"""
+ cdef extern from *:
+ void PyErr_Clear()
+ """)
+ decl_code.indent()
+
+ pyx_code.put_chunk(
+ u"""
+ def __pyx_fused_cpdef(signatures, args, kwargs, defaults):
+ import sys
+ if sys.version_info >= (3, 0):
+ long_ = int
+ unicode_ = str
+ bytes_ = bytes
else:
- if_ = 'elif'
-
- # in the case of long, unicode or bytes we need to instance
- # check for long_, unicode_, bytes_ (long = long is no longer
- # valid code with control flow analysis)
- instance_check_py_type_name = py_type_name
- if py_type_name in ('long', 'unicode', 'bytes'):
- instance_check_py_type_name += '_'
-
- tup = (if_, instance_check_py_type_name,
- len(seen_fused_types) - 1,
- specialized_type.typeof_name())
- body_stmts.append(
- " %s isinstance(arg, %s): "
- "dest_sig[%d] = '%s'" % tup)
-
- if arg.default and all_numeric:
- arg.default.analyse_types(env)
+ long_ = long
+ unicode_ = unicode
+ bytes_ = str
- ts = specialized_types
- if arg.default.type.is_complex:
- typelist = [t for t in ts if t.is_complex]
- elif arg.default.type.is_float:
- typelist = [t for t in ts if t.is_float]
- else:
- typelist = [t for t in ts if t.is_int]
+ dest_sig = [None] * {{n_fused}}
- if typelist:
- body_stmts.append(u"""\
- else:
- dest_sig[%d] = '%s'
-""" % (i, typelist[0].typeof_name()))
+ if kwargs is None:
+ kwargs = {}
- fmt_dict = {
- 'body': '\n'.join(body_stmts),
- 'nargs': len(args),
- 'name': orig_py_func.entry.name,
- }
+ cdef Py_ssize_t i
- fragment_code = u"""
-def __pyx_fused_cpdef(signatures, args, kwargs):
- #if len(args) < %(nargs)d:
- # raise TypeError("Invalid number of arguments, expected %(nargs)d, "
- # "got %%d" %% len(args))
- cdef int nargs
- nargs = len(args)
-
- import sys
- if sys.version_info >= (3, 0):
- long_ = int
- unicode_ = str
- bytes_ = bytes
- else:
- long_ = long
- unicode_ = unicode
- bytes_ = str
+ # instance check body
+ """)
+ pyx_code.indent() # indent following code to function body
+ pyx_code.named_insertion_point("imports")
+ pyx_code.named_insertion_point("local_variable_declarations")
- dest_sig = [None] * %(nargs)d
+ fused_index = 0
+ default_idx = 0
+ all_buffer_types = set()
+ for i, arg in enumerate(self.node.args):
+ if arg.type.is_fused and arg.type not in seen_fused_types:
+ seen_fused_types.add(arg.type)
- if kwargs is None:
- kwargs = {}
+ context.update(
+ arg_tuple_idx=i,
+ arg=arg,
+ dest_sig_idx=fused_index,
+ default_idx=default_idx,
+ )
- # instance check body
-%(body)s
- candidates = []
- for sig in signatures:
- match_found = [x for x in dest_sig if x]
- for src_type, dst_type in zip(sig.strip('()').split(', '), dest_sig):
- if dst_type is not None and match_found:
- match_found = src_type == dst_type
+ normal_types, buffer_types = self._split_fused_types(arg)
+ self._unpack_argument(pyx_code)
+ self._fused_instance_checks(normal_types, pyx_code, env)
+ self._buffer_checks(buffer_types, pyx_code, decl_code, env)
+ fused_index += 1
- if match_found:
- candidates.append(sig)
+ all_buffer_types.update(buffer_types)
- if not candidates:
- raise TypeError("No matching signature found")
- elif len(candidates) > 1:
- raise TypeError("Function call with ambiguous argument types")
- else:
- return signatures[candidates[0]]
-""" % fmt_dict
+ if arg.default:
+ default_idx += 1
+
+ if all_buffer_types:
+ self._buffer_declarations(pyx_code, decl_code, all_buffer_types)
+
+ pyx_code.put_chunk(
+ u"""
+ candidates = []
+ for sig in signatures:
+ match_found = True
+ for src_type, dst_type in zip(sig.strip('()').split(', '), dest_sig):
+ if dst_type is not None and match_found:
+ match_found = src_type == dst_type
+
+ if match_found:
+ candidates.append(sig)
+
+ if not candidates:
+ raise TypeError("No matching signature found")
+ elif len(candidates) > 1:
+ raise TypeError("Function call with ambiguous argument types")
+ else:
+ return signatures[candidates[0]]
+ """)
+
+ fragment_code = pyx_code.getvalue()
+ # print decl_code.getvalue()
+ # print fragment_code
+ fragment = TreeFragment.TreeFragment(fragment_code.decode('ascii'),
+ level='module')
+ ast = TreeFragment.SetPosTransform(self.node.pos)(fragment.root)
+ UtilityCode.declare_declarations_in_scope(decl_code.getvalue(), env)
+ ast.scope = env
+ ast.analyse_declarations(env)
+ py_func = ast.stats[-1] # the DefNode
+ self.fragment_scope = ast.scope
+
+ if isinstance(self.node, DefNode):
+ py_func.specialized_cpdefs = self.nodes[:]
+ else:
+ py_func.specialized_cpdefs = [n.py_func for n in self.nodes]
- fragment = TreeFragment.TreeFragment(fragment_code, level='module')
+ return py_func
- # analyse the declarations of our fragment ...
- py_func, = fragment.substitute(pos=self.node.pos).stats
- # Analyse the function object ...
- py_func.analyse_declarations(env)
- # ... and its body
- py_func.scope = env
+ def update_fused_defnode_entry(self, env):
+ import ExprNodes
- # Will be analysed later by underlying AnalyseDeclarationsTransform
- #ParseTreeTransforms.AnalyseDeclarationsTransform(None)(py_func)
+ copy_attributes = (
+ 'name', 'pos', 'cname', 'func_cname', 'pyfunc_cname',
+ 'pymethdef_cname', 'doc', 'doc_cname', 'is_member',
+ 'scope'
+ )
- e, orig_e = py_func.entry, orig_py_func.entry
+ entry = self.py_func.entry
- # Update the new entry ...
- py_func.name = e.name = orig_e.name
- e.cname, e.func_cname = orig_e.cname, orig_e.func_cname
- e.pymethdef_cname = orig_e.pymethdef_cname
- e.doc, e.doc_cname = orig_e.doc, orig_e.doc_cname
- # e.signature = TypeSlots.binaryfunc
+ for attr in copy_attributes:
+ setattr(entry, attr,
+ getattr(self.orig_py_func.entry, attr))
- py_func.doc = orig_py_func.doc
+ self.py_func.name = self.orig_py_func.name
+ self.py_func.doc = self.orig_py_func.doc
- # ... and the symbol table
env.entries.pop('__pyx_fused_cpdef', None)
- if is_def:
- env.entries[e.name] = e
+ if isinstance(self.node, DefNode):
+ env.entries[entry.name] = entry
else:
- env.entries[e.name].as_variable = e
+ env.entries[entry.name].as_variable = entry
- env.pyfunc_entries.append(e)
+ env.pyfunc_entries.append(entry)
- if is_def:
- py_func.specialized_cpdefs = self.nodes[:]
- else:
- py_func.specialized_cpdefs = [n.py_func for n in self.nodes]
+ self.py_func.entry.fused_cfunction = self
+ for node in self.nodes:
+ if isinstance(self.node, DefNode):
+ node.fused_py_func = self.py_func
+ else:
+ node.py_func.fused_py_func = self.py_func
+ node.entry.as_variable = entry
- return py_func
+ self.synthesize_defnodes()
+ self.stats.append(self.__signatures__)
+
+ env.use_utility_code(ExprNodes.import_utility_code)
def analyse_expressions(self, env):
"""
View
1  Cython/Compiler/ParseTreeTransforms.py
@@ -1498,6 +1498,7 @@ def visit_FuncDefNode(self, node):
# Create PyCFunction nodes for each specialization
node.stats.insert(0, node.py_func)
node.py_func = self.visit(node.py_func)
+ node.update_fused_defnode_entry(env)
pycfunc = ExprNodes.PyCFunctionNode.from_defnode(node.py_func,
True)
pycfunc = ExprNodes.ProxyNode(pycfunc.coerce_to_temp(env))
View
6 Cython/Compiler/Symtab.py
@@ -362,11 +362,11 @@ def builtin_scope(self):
# Return the module-level scope containing this scope.
return self.outer_scope.builtin_scope()
- def declare(self, name, cname, type, pos, visibility, shadow = 0):
+ def declare(self, name, cname, type, pos, visibility, shadow = 0, is_type = 0):
# Create new entry, and add to dictionary if
# name is not None. Reports a warning if already
# declared.
- if type.is_buffer and not isinstance(self, LocalScope):
+ if type.is_buffer and not isinstance(self, LocalScope) and not is_type:
error(pos, 'Buffer types only allowed as function local variables')
if not self.in_cinclude and cname and re.match("^_[_A-Z]+$", cname):
# See http://www.gnu.org/software/libc/manual/html_node/Reserved-Names.html#Reserved-Names
@@ -417,7 +417,7 @@ def declare_type(self, name, type, pos,
# Add an entry for a type definition.
if not cname:
cname = name
- entry = self.declare(name, cname, type, pos, visibility, shadow)
+ entry = self.declare(name, cname, type, pos, visibility, shadow, True)
entry.is_type = 1
entry.api = api
if defining:
View
10 Cython/Compiler/TreeFragment.py
@@ -231,6 +231,12 @@ def substitute(self, nodes={}, temps=[], pos = None):
substitutions = nodes,
temps = self.temps + temps, pos = pos)
+class SetPosTransform(VisitorTransform):
+ def __init__(self, pos):
+ super(SetPosTransform, self).__init__()
+ self.pos = pos
-
-
+ def visit_Node(self, node):
+ node.pos = self.pos
+ self.visitchildren(node)
+ return node
View
8 Cython/Compiler/UtilityCode.py
@@ -167,3 +167,11 @@ def declare_in_scope(self, dest_scope, used=False, cython_scope=None,
dep.declare_in_scope(dest_scope)
return original_scope
+
+def declare_declarations_in_scope(declaration_string, env, private_type=True,
+ *args, **kwargs):
+ """
+ Declare some declarations given as Cython code in declaration_string
+ in scope env.
+ """
+ CythonUtilityCode(declaration_string, *args, **kwargs).declare_in_scope(env)
View
7 Cython/Utility/CythonFunction.c
@@ -740,8 +740,6 @@ __pyx_FusedFunction_callfunction(PyObject *func, PyObject *args, PyObject *kw)
int static_specialized = (cyfunc->flags & __Pyx_CYFUNCTION_STATICMETHOD &&
!((__pyx_FusedFunctionObject *) func)->__signatures__);
- //PyObject_Print(args, stdout, Py_PRINT_RAW);
-
if (cyfunc->flags & __Pyx_CYFUNCTION_CCLASS && !static_specialized) {
Py_ssize_t argc;
PyObject *new_args;
@@ -827,8 +825,9 @@ __pyx_FusedFunction_call(PyObject *func, PyObject *args, PyObject *kw)
}
if (binding_func->__signatures__) {
- PyObject *tup = PyTuple_Pack(3, binding_func->__signatures__, args,
- kw == NULL ? Py_None : kw);
+ PyObject *tup = PyTuple_Pack(4, binding_func->__signatures__, args,
+ kw == NULL ? Py_None : kw,
+ binding_func->func.defaults_tuple);
if (!tup)
goto __pyx_err;
View
246 tests/run/numpy_test.pyx
@@ -4,6 +4,8 @@
cimport numpy as np
cimport cython
+from libc.stdlib cimport malloc
+
def little_endian():
cdef int endian_detector = 1
return (<char*>&endian_detector)[0] != 0
@@ -503,19 +505,28 @@ def test_point_record():
test[i].y = -i
print repr(test).replace('<', '!').replace('>', '!')
-def test_fused_ndarray_dtype(np.ndarray[cython.floating, ndim=1] a):
+# Test fused np.ndarray dtypes and runtime dispatch
+def test_fused_ndarray_floating_dtype(np.ndarray[cython.floating, ndim=1] a):
"""
>>> import cython
- >>> sorted(test_fused_ndarray_dtype.__signatures__)
+ >>> sorted(test_fused_ndarray_floating_dtype.__signatures__)
['double', 'float']
- >>> test_fused_ndarray_dtype[cython.double](np.arange(10, dtype=np.float64))
+
+
+ >>> test_fused_ndarray_floating_dtype[cython.double](np.arange(10, dtype=np.float64))
ndarray[double,ndim=1] ndarray[double,ndim=1] 5.0 6.0
- >>> test_fused_ndarray_dtype[cython.float](np.arange(10, dtype=np.float32))
+ >>> test_fused_ndarray_floating_dtype(np.arange(10, dtype=np.float64))
+ ndarray[double,ndim=1] ndarray[double,ndim=1] 5.0 6.0
+
+ >>> test_fused_ndarray_floating_dtype[cython.float](np.arange(10, dtype=np.float32))
+ ndarray[float,ndim=1] ndarray[float,ndim=1] 5.0 6.0
+ >>> test_fused_ndarray_floating_dtype(np.arange(10, dtype=np.float32))
ndarray[float,ndim=1] ndarray[float,ndim=1] 5.0 6.0
"""
cdef np.ndarray[cython.floating, ndim=1] b = a
print cython.typeof(a), cython.typeof(b), a[5], b[6]
+
double_array = np.linspace(0, 1, 100)
int32_array = np.arange(100, dtype=np.int32)
@@ -568,4 +579,231 @@ def test_fused_cpdef_buffers():
cdef np.ndarray[np.int32_t] typed_array = int32_array
_fused_cpdef_buffers(typed_array)
+def test_fused_ndarray_integral_dtype(np.ndarray[cython.integral, ndim=1] a):
+ """
+ >>> import cython
+ >>> sorted(test_fused_ndarray_integral_dtype.__signatures__)
+ ['int', 'long', 'short']
+
+ >>> test_fused_ndarray_integral_dtype[cython.int](np.arange(10, dtype=np.dtype('i')))
+ ndarray[int,ndim=1] ndarray[int,ndim=1] 5 6
+ >>> test_fused_ndarray_integral_dtype(np.arange(10, dtype=np.dtype('i')))
+ ndarray[int,ndim=1] ndarray[int,ndim=1] 5 6
+
+ >>> test_fused_ndarray_integral_dtype[cython.long](np.arange(10, dtype=np.long))
+ ndarray[long,ndim=1] ndarray[long,ndim=1] 5 6
+ >>> test_fused_ndarray_integral_dtype(np.arange(10, dtype=np.long))
+ ndarray[long,ndim=1] ndarray[long,ndim=1] 5 6
+ """
+ cdef np.ndarray[cython.integral, ndim=1] b = a
+ print cython.typeof(a), cython.typeof(b), a[5], b[6]
+
+cdef fused fused_dtype:
+ float complex
+ double complex
+ object
+
+def test_fused_ndarray_other_dtypes(np.ndarray[fused_dtype, ndim=1] a):
+ """
+ >>> import cython
+ >>> sorted(test_fused_ndarray_other_dtypes.__signatures__)
+ ['double complex', 'float complex', 'object']
+ >>> test_fused_ndarray_other_dtypes(np.arange(10, dtype=np.complex64))
+ ndarray[float complex,ndim=1] ndarray[float complex,ndim=1] (5+0j) (6+0j)
+ >>> test_fused_ndarray_other_dtypes(np.arange(10, dtype=np.complex128))
+ ndarray[double complex,ndim=1] ndarray[double complex,ndim=1] (5+0j) (6+0j)
+ >>> test_fused_ndarray_other_dtypes(np.arange(10, dtype=np.object))
+ ndarray[Python object,ndim=1] ndarray[Python object,ndim=1] 5 6
+ """
+ cdef np.ndarray[fused_dtype, ndim=1] b = a
+ print cython.typeof(a), cython.typeof(b), a[5], b[6]
+
+
+# Test fusing the array types together and runtime dispatch
+cdef struct Foo:
+ int a
+ float b
+
+cdef fused fused_FooArray:
+ np.ndarray[Foo, ndim=1]
+
+cdef fused fused_ndarray:
+ np.ndarray[float, ndim=1]
+ np.ndarray[double, ndim=1]
+ np.ndarray[Foo, ndim=1]
+
+def get_Foo_array():
+ cdef Foo[:] result = <Foo[:10]> malloc(sizeof(Foo) * 10)
+ result[5].b = 9.0
+ return np.asarray(result)
+
+def test_fused_ndarray(fused_ndarray a):
+ """
+ >>> import cython
+ >>> sorted(test_fused_ndarray.__signatures__)
+ ['ndarray[Foo,ndim=1]', 'ndarray[double,ndim=1]', 'ndarray[float,ndim=1]']
+
+ >>> test_fused_ndarray(get_Foo_array())
+ ndarray[Foo,ndim=1] ndarray[Foo,ndim=1]
+ 9.0
+ >>> test_fused_ndarray(np.arange(10, dtype=np.float64))
+ ndarray[double,ndim=1] ndarray[double,ndim=1]
+ 5.0
+ >>> test_fused_ndarray(np.arange(10, dtype=np.float32))
+ ndarray[float,ndim=1] ndarray[float,ndim=1]
+ 5.0
+ """
+ cdef fused_ndarray b = a
+ print cython.typeof(a), cython.typeof(b)
+
+ if fused_ndarray in fused_FooArray:
+ print b[5].b
+ else:
+ print b[5]
+
+cpdef test_fused_cpdef_ndarray(fused_ndarray a):
+ """
+ >>> import cython
+ >>> sorted(test_fused_cpdef_ndarray.__signatures__)
+ ['ndarray[Foo,ndim=1]', 'ndarray[double,ndim=1]', 'ndarray[float,ndim=1]']
+
+ >>> test_fused_cpdef_ndarray(get_Foo_array())
+ ndarray[Foo,ndim=1] ndarray[Foo,ndim=1]
+ 9.0
+ >>> test_fused_cpdef_ndarray(np.arange(10, dtype=np.float64))
+ ndarray[double,ndim=1] ndarray[double,ndim=1]
+ 5.0
+ >>> test_fused_cpdef_ndarray(np.arange(10, dtype=np.float32))
+ ndarray[float,ndim=1] ndarray[float,ndim=1]
+ 5.0
+ """
+ cdef fused_ndarray b = a
+ print cython.typeof(a), cython.typeof(b)
+
+ if fused_ndarray in fused_FooArray:
+ print b[5].b
+ else:
+ print b[5]
+
+def test_fused_cpdef_ndarray_cdef_call():
+ """
+ >>> test_fused_cpdef_ndarray_cdef_call()
+ ndarray[Foo,ndim=1] ndarray[Foo,ndim=1]
+ 9.0
+ """
+ cdef np.ndarray[Foo, ndim=1] foo_array = get_Foo_array()
+ test_fused_cpdef_ndarray(foo_array)
+
+cdef fused int_type:
+ np.int32_t
+ np.int64_t
+
+float64_array = np.arange(10, dtype=np.float64)
+float32_array = np.arange(10, dtype=np.float32)
+int32_array = np.arange(10, dtype=np.int32)
+int64_array = np.arange(10, dtype=np.int64)
+
+def test_dispatch_non_clashing_declarations_repeating_types(np.ndarray[cython.floating] a1,
+ np.ndarray[int_type] a2,
+ np.ndarray[cython.floating] a3,
+ np.ndarray[int_type] a4):
+ """
+ >>> test_dispatch_non_clashing_declarations_repeating_types(float64_array, int32_array, float64_array, int32_array)
+ 1.0 2 3.0 4
+ >>> test_dispatch_non_clashing_declarations_repeating_types(float64_array, int64_array, float64_array, int64_array)
+ 1.0 2 3.0 4
+ >>> test_dispatch_non_clashing_declarations_repeating_types(float64_array, int32_array, float64_array, int64_array)
+ Traceback (most recent call last):
+ ...
+ TypeError: No matching signature found
+ """
+ print a1[1], a2[2], a3[3], a4[4]
+
+ctypedef np.int32_t typedeffed_type
+
+cdef fused typedeffed_fused_type:
+ typedeffed_type
+ int
+ long
+
+def test_dispatch_typedef(np.ndarray[typedeffed_fused_type] a):
+ """
+ >>> test_dispatch_typedef(int32_array)
+ 5
+ """
+ print a[5]
+
+
+cdef extern from "types.h":
+ ctypedef unsigned char actually_long_t
+
+cdef fused confusing_fused_typedef:
+ actually_long_t
+ unsigned char
+ signed char
+
+def test_dispatch_external_typedef(np.ndarray[confusing_fused_typedef] a):
+ """
+ >>> test_dispatch_external_typedef(np.arange(10, dtype=np.long))
+ 5
+ """
+ print a[5]
+
+# test fused memoryview slices
+cdef fused memslice_fused_dtype:
+ float
+ double
+ int
+ long
+ float complex
+ double complex
+ object
+
+def test_fused_memslice_other_dtypes(memslice_fused_dtype[:] a):
+ """
+ >>> import cython
+ >>> sorted(test_fused_memslice_other_dtypes.__signatures__)
+ ['double', 'double complex', 'float', 'float complex', 'int', 'long', 'object']
+ >>> test_fused_memslice_other_dtypes(np.arange(10, dtype=np.complex64))
+ float complex[:] float complex[:] (5+0j) (6+0j)
+ >>> test_fused_memslice_other_dtypes(np.arange(10, dtype=np.complex128))
+ double complex[:] double complex[:] (5+0j) (6+0j)
+ >>> test_fused_memslice_other_dtypes(np.arange(10, dtype=np.float32))
+ float[:] float[:] 5.0 6.0
+ >>> test_fused_memslice_other_dtypes(np.arange(10, dtype=np.dtype('i')))
+ int[:] int[:] 5 6
+ >>> test_fused_memslice_other_dtypes(np.arange(10, dtype=np.object))
+ object[:] object[:] 5 6
+ """
+ cdef memslice_fused_dtype[:] b = a
+ print cython.typeof(a), cython.typeof(b), a[5], b[6]
+
+cdef fused memslice_fused:
+ float[:]
+ double[:]
+ int[:]
+ long[:]
+ float complex[:]
+ double complex[:]
+ object[:]
+
+def test_fused_memslice_fused(memslice_fused a):
+ """
+ >>> import cython
+ >>> sorted(test_fused_memslice_fused.__signatures__)
+ ['double complex[:]', 'double[:]', 'float complex[:]', 'float[:]', 'int[:]', 'long[:]', 'object[:]']
+ >>> test_fused_memslice_fused(np.arange(10, dtype=np.complex64))
+ float complex[:] float complex[:] (5+0j) (6+0j)
+ >>> test_fused_memslice_fused(np.arange(10, dtype=np.complex128))
+ double complex[:] double complex[:] (5+0j) (6+0j)
+ >>> test_fused_memslice_fused(np.arange(10, dtype=np.float32))
+ float[:] float[:] 5.0 6.0
+ >>> test_fused_memslice_fused(np.arange(10, dtype=np.dtype('i')))
+ int[:] int[:] 5 6
+ >>> test_fused_memslice_fused(np.arange(10, dtype=np.object))
+ object[:] object[:] 5 6
+ """
+ cdef memslice_fused b = a
+ print cython.typeof(a), cython.typeof(b), a[5], b[6]
+
include "numpy_common.pxi"
Please sign in to comment.
Something went wrong with that request. Please try again.