Skip to content

Commit

Permalink
Merge db6d345 into f026464
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed May 24, 2019
2 parents f026464 + db6d345 commit 2f0b2f9
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 173 deletions.
11 changes: 7 additions & 4 deletions brian2/codegen/generators/cython_generator.py
Expand Up @@ -237,12 +237,15 @@ def _add_user_function(self, varname, var):
"global _namespace_num{var_name}",
"cdef _numpy.ndarray[{cpp_dtype}, ndim=1, mode='c'] _buf_{var_name} = _namespace['{var_name}']",
"_namespace{var_name} = <{cpp_dtype} *> _buf_{var_name}.data",
"_namespace_num{var_name} = len(_namespace['{var_name}'])"
"_namespace_num{var_name} = _buf_{var_name}.shape[0]"
]
support_code.append(
support_code.extend([
"cdef {cpp_dtype} *_namespace{var_name}".format(
cpp_dtype=get_cpp_dtype(ns_value.dtype),
var_name=ns_key))
var_name=ns_key),
"cdef int _namespace_num{var_name}".format(
var_name=ns_key)
])

else: # e.g. a function
newlines = [
Expand Down Expand Up @@ -327,7 +330,7 @@ def determine_keywords(self):
"cdef {cpp_dtype} * {array_name} = <{cpp_dtype} *> _buf_{array_name}.data"]

if not var.scalar:
newlines += ["cdef int _num{array_name} = len(_namespace['{array_name}'])"]
newlines += ["cdef int _num{array_name} = _buf_{array_name}.shape[0]"]

if var.scalar and var.constant:
newlines += ['cdef {cpp_dtype} {varname} = _namespace["{varname}"]']
Expand Down
12 changes: 10 additions & 2 deletions brian2/codegen/runtime/cython_rt/cython_rt.py
Expand Up @@ -110,6 +110,9 @@ def __init__(self, owner, code, variables, variable_indices,
self.libraries = (list(prefs['codegen.cpp.libraries']) +
compiler_kwds.get('libraries', []))
self.sources = compiler_kwds.get('sources', [])
self.compiled_code = None
self.build_process = None
self.module_name = None

@classmethod
def is_available(cls):
Expand All @@ -135,7 +138,7 @@ def main():
return False

def compile(self):
self.compiled_code = cython_extension_manager.create_extension(
self.module_name, self.build_process = cython_extension_manager.create_extension(
self.code,
define_macros=self.define_macros,
libraries=self.libraries,
Expand All @@ -146,9 +149,14 @@ def compile(self):
compiler=self.compiler,
owner_name=self.owner.name+'_'+self.template_name,
sources=self.sources
)
)

def run(self):
if self.compiled_code is None:
if self.build_process is not None:
self.build_process.join()
self.build_process = None
self.compiled_code = cython_extension_manager.get_module(self.module_name)
return self.compiled_code.main(self.namespace)

# the following are copied from WeaveCodeObject
Expand Down

0 comments on commit 2f0b2f9

Please sign in to comment.