Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
63 changed files
with
2,706 additions
and
401 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ syntax: glob | |
Cython/Compiler/Lexicon.pickle | ||
BUILD/ | ||
build/ | ||
dist/ | ||
.coverage | ||
*~ | ||
*.orig | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
print "Warning: Using prototype cython.inline code..." | ||
|
||
import tempfile | ||
import sys, os, re, inspect | ||
|
||
try: | ||
import hashlib | ||
except ImportError: | ||
import md5 as hashlib | ||
|
||
from distutils.dist import Distribution | ||
from Cython.Distutils.extension import Extension | ||
from Cython.Distutils import build_ext | ||
|
||
from Cython.Compiler.Main import Context, CompilationOptions, default_options | ||
|
||
from Cython.Compiler.ParseTreeTransforms import CythonTransform, SkipDeclarations, AnalyseDeclarationsTransform | ||
from Cython.Compiler.TreeFragment import parse_from_strings | ||
|
||
_code_cache = {} | ||
|
||
|
||
class AllSymbols(CythonTransform, SkipDeclarations): | ||
def __init__(self): | ||
CythonTransform.__init__(self, None) | ||
self.names = set() | ||
def visit_NameNode(self, node): | ||
self.names.add(node.name) | ||
|
||
def unbound_symbols(code, context=None): | ||
if context is None: | ||
context = Context([], default_options) | ||
from Cython.Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform | ||
if isinstance(code, str): | ||
code = code.decode('ascii') | ||
tree = parse_from_strings('(tree fragment)', code) | ||
for phase in context.create_pipeline(pxd=False): | ||
if phase is None: | ||
continue | ||
tree = phase(tree) | ||
if isinstance(phase, AnalyseDeclarationsTransform): | ||
break | ||
symbol_collector = AllSymbols() | ||
symbol_collector(tree) | ||
unbound = [] | ||
import __builtin__ | ||
for name in symbol_collector.names: | ||
if not tree.scope.lookup(name) and not hasattr(__builtin__, name): | ||
unbound.append(name) | ||
return unbound | ||
|
||
|
||
def get_type(arg, context=None): | ||
py_type = type(arg) | ||
if py_type in [list, tuple, dict, str]: | ||
return py_type.__name__ | ||
elif py_type is float: | ||
return 'double' | ||
elif py_type is bool: | ||
return 'bint' | ||
elif py_type is int: | ||
return 'long' | ||
elif 'numpy' in sys.modules and isinstance(arg, sys.modules['numpy'].ndarray): | ||
return 'numpy.ndarray[numpy.%s_t, ndim=%s]' % (arg.dtype.name, arg.ndim) | ||
else: | ||
for base_type in py_type.mro(): | ||
if base_type.__module__ == '__builtin__': | ||
return 'object' | ||
module = context.find_module(base_type.__module__, need_pxd=False) | ||
if module: | ||
entry = module.lookup(base_type.__name__) | ||
if entry.is_type: | ||
return '%s.%s' % (base_type.__module__, base_type.__name__) | ||
return 'object' | ||
|
||
# TODO: use locals/globals for unbound variables | ||
def cython_inline(code, | ||
types='aggressive', | ||
lib_dir=os.path.expanduser('~/.cython/inline'), | ||
include_dirs=['.'], | ||
locals=None, | ||
globals=None, | ||
**kwds): | ||
code = strip_common_indent(code) | ||
ctx = Context(include_dirs, default_options) | ||
if locals is None: | ||
locals = inspect.currentframe().f_back.f_back.f_locals | ||
if globals is None: | ||
globals = inspect.currentframe().f_back.f_back.f_globals | ||
try: | ||
for symbol in unbound_symbols(code): | ||
if symbol in kwds: | ||
continue | ||
elif symbol in locals: | ||
kwds[symbol] = locals[symbol] | ||
elif symbol in globals: | ||
kwds[symbol] = globals[symbol] | ||
else: | ||
print "Couldn't find ", symbol | ||
except AssertionError: | ||
# Parsing from strings not fully supported (e.g. cimports). | ||
print "Could not parse code as a string (to extract unbound symbols)." | ||
arg_names = kwds.keys() | ||
arg_names.sort() | ||
arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names]) | ||
key = code, arg_sigs | ||
module = _code_cache.get(key) | ||
if not module: | ||
cimports = [] | ||
qualified = re.compile(r'([.\w]+)[.]') | ||
for type, _ in arg_sigs: | ||
m = qualified.match(type) | ||
if m: | ||
cimports.append('\ncimport %s' % m.groups()[0]) | ||
module_body, func_body = extract_func_code(code) | ||
params = ', '.join(['%s %s' % a for a in arg_sigs]) | ||
module_code = """ | ||
%(cimports)s | ||
%(module_body)s | ||
def __invoke(%(params)s): | ||
%(func_body)s | ||
""" % {'cimports': '\n'.join(cimports), 'module_body': module_body, 'params': params, 'func_body': func_body } | ||
# print module_code | ||
_, pyx_file = tempfile.mkstemp('.pyx') | ||
open(pyx_file, 'w').write(module_code) | ||
module = "_" + hashlib.md5(code + str(arg_sigs)).hexdigest() | ||
extension = Extension( | ||
name = module, | ||
sources = [pyx_file], | ||
pyrex_include_dirs = include_dirs) | ||
build_extension = build_ext(Distribution()) | ||
build_extension.finalize_options() | ||
build_extension.extensions = [extension] | ||
build_extension.build_temp = os.path.dirname(pyx_file) | ||
if lib_dir not in sys.path: | ||
sys.path.append(lib_dir) | ||
build_extension.build_lib = lib_dir | ||
build_extension.run() | ||
_code_cache[key] = module | ||
arg_list = [kwds[arg] for arg in arg_names] | ||
return __import__(module).__invoke(*arg_list) | ||
|
||
non_space = re.compile('[^ ]') | ||
def strip_common_indent(code): | ||
min_indent = None | ||
lines = code.split('\n') | ||
for line in lines: | ||
match = non_space.search(line) | ||
if not match: | ||
continue # blank | ||
indent = match.start() | ||
if line[indent] == '#': | ||
continue # comment | ||
elif min_indent is None or min_indent > indent: | ||
min_indent = indent | ||
for ix, line in enumerate(lines): | ||
match = non_space.search(line) | ||
if not match or line[indent] == '#': | ||
continue | ||
else: | ||
lines[ix] = line[min_indent:] | ||
return '\n'.join(lines) | ||
|
||
module_statement = re.compile(r'^((cdef +(extern|class))|cimport|(from .+ cimport)|(from .+ import +[*]))') | ||
def extract_func_code(code): | ||
module = [] | ||
function = [] | ||
# TODO: string literals, backslash | ||
current = function | ||
code = code.replace('\t', ' ') | ||
lines = code.split('\n') | ||
for line in lines: | ||
if not line.startswith(' '): | ||
if module_statement.match(line): | ||
current = module | ||
else: | ||
current = function | ||
current.append(line) | ||
return '\n'.join(module), ' ' + '\n '.join(function) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.