Skip to content

Commit

Permalink
Merge pull request #428 from brian-team/function_improvements
Browse files Browse the repository at this point in the history
Function improvements + PoissonInput
  • Loading branch information
thesamovar committed Mar 25, 2015
2 parents 53c2894 + 6baca99 commit 7830100
Show file tree
Hide file tree
Showing 19 changed files with 971 additions and 230 deletions.
92 changes: 64 additions & 28 deletions brian2/codegen/generators/cpp_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,56 @@ def denormals_to_zero_code(self):
else:
return ''

def _add_user_function(self, varname, variable):
impl = variable.implementations[self.codeobj_class]
support_code = []
hash_defines = []
pointers = []
user_functions = [(varname, variable)]
funccode = impl.get_code(self.owner)
if isinstance(funccode, basestring):
funccode = {'support_code': funccode}
if funccode is not None:
# To make namespace variables available to functions, we
# create global variables and assign to them in the main
# code
func_namespace = impl.get_namespace(self.owner) or {}
for ns_key, ns_value in func_namespace.iteritems():
if hasattr(ns_value, 'dtype'):
if ns_value.shape == ():
raise NotImplementedError((
'Directly replace scalar values in the function '
'instead of providing them via the namespace'))
type_str = c_data_type(ns_value.dtype) + '*'
else: # e.g. a function
type_str = 'py::object'
support_code.append('static {0} _namespace{1};'.format(type_str,
ns_key))
pointers.append('_namespace{0} = {1};'.format(ns_key, ns_key))
support_code.append(deindent(funccode.get('support_code', '')))
hash_defines.append(deindent(funccode.get('hashdefine_code', '')))

dep_hash_defines = []
dep_pointers = []
dep_support_code = []
if impl.dependencies is not None:
for dep_name, dep in impl.dependencies.iteritems():
self.variables[dep_name] = dep
hd, ps, sc, uf = self._add_user_function(dep_name, dep)
dep_hash_defines.extend(hd)
dep_pointers.extend(ps)
dep_support_code.extend(sc)
user_functions.extend(uf)

return (dep_hash_defines + hash_defines,
dep_pointers + pointers,
dep_support_code + support_code,
user_functions)

def determine_keywords(self):
# set up the restricted pointers, these are used so that the compiler
# knows there is no aliasing in the pointers, for optimisation
lines = []
pointers = []
# It is possible that several different variable names refer to the
# same array. E.g. in gapjunction code, v_pre and v_post refer to the
# same array if a group is connected to itself
Expand All @@ -256,35 +302,25 @@ def determine_keywords(self):
continue
if getattr(var, 'dimensions', 1) > 1:
continue # multidimensional (dynamic) arrays have to be treated differently
line = self.c_data_type(var.dtype) + ' * ' + self.restrict + pointer_name + ' = ' + array_name + ';'
lines.append(line)
line = '{0}* {1} {2} = {3};'.format(self.c_data_type(var.dtype),
self.restrict,
pointer_name,
array_name)
pointers.append(line)
handled_pointers.add(pointer_name)

pointers = '\n'.join(lines)

# set up the functions
user_functions = []
support_code = ''
hash_defines = ''
support_code = []
hash_defines = []
for varname, variable in self.variables.items():
if isinstance(variable, Function):
user_functions.append((varname, variable))
funccode = variable.implementations[self.codeobj_class].get_code(self.owner)
if isinstance(funccode, basestring):
funccode = {'support_code': funccode}
if funccode is not None:
support_code += '\n' + deindent(funccode.get('support_code', ''))
hash_defines += '\n' + deindent(funccode.get('hashdefine_code', ''))
# add the Python function with a leading '_python', if it
# exists. This allows the function to make use of the Python
# function via weave if necessary (e.g. in the case of randn)
if not variable.pyfunc is None:
pyfunc_name = '_python_' + varname
if pyfunc_name in self.variables:
logger.warn(('Namespace already contains function %s, '
'not replacing it') % pyfunc_name)
else:
self.variables[pyfunc_name] = variable.pyfunc
hd, ps, sc, uf = self._add_user_function(varname, variable)
user_functions.extend(uf)
support_code.extend(sc)
pointers.extend(ps)
hash_defines.extend(hd)


# delete the user-defined functions from the namespace and add the
# function namespaces (if any)
Expand All @@ -294,10 +330,10 @@ def determine_keywords(self):
if func_namespace is not None:
self.variables.update(func_namespace)

keywords = {'pointers_lines': stripped_deindented_lines(pointers),
'support_code_lines': stripped_deindented_lines(support_code),
'hashdefine_lines': stripped_deindented_lines(hash_defines),
'denormals_code_lines': stripped_deindented_lines(self.denormals_to_zero_code()),
keywords = {'pointers_lines': stripped_deindented_lines('\n'.join(pointers)),
'support_code_lines': stripped_deindented_lines('\n'.join(support_code)),
'hashdefine_lines': stripped_deindented_lines('\n'.join(hash_defines)),
'denormals_code_lines': stripped_deindented_lines('\n'.join(self.denormals_to_zero_code())),
}
keywords.update(template_kwds)
return keywords
Expand Down
119 changes: 94 additions & 25 deletions brian2/codegen/generators/cython_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,88 @@ def translate_one_statement_sequence(self, statements):

return lines

def _add_user_function(self, varname, var):
user_functions = []
load_namespace = []
support_code = []
impl = var.implementations[self.codeobj_class]
func_code= impl.get_code(self.owner)
# Implementation can be None if the function is already
# available in Cython (possibly under a different name)
if func_code is not None:
if isinstance(func_code, basestring):
# Function is provided as Cython code
# To make namespace variables available to functions, we
# create global variables and assign to them in the main
# code
user_functions.append((varname, var))
func_namespace = impl.get_namespace(self.owner) or {}
for ns_key, ns_value in func_namespace.iteritems():
load_namespace.append(
'# namespace for function %s' % varname)
if hasattr(ns_value, 'dtype'):
if ns_value.shape == ():
raise NotImplementedError((
'Directly replace scalar values in the function '
'instead of providing them via the namespace'))
newlines = [
"global _namespace{var_name}",
"global _namespace_num{var_name}",
"cdef _numpy.ndarray[{cpp_dtype}, ndim=1, mode='c'] _buf_{var_name} = _namespace['{var_name}'].view(dtype=_numpy.{numpy_dtype})",
"_namespace{var_name} = <{cpp_dtype} *> _buf_{var_name}.data",
"_namespace_num{var_name} = len(_namespace['{var_name}'])"
]
support_code.append(
"cdef {cpp_dtype} *_namespace{var_name}".format(
cpp_dtype=get_cpp_dtype(ns_value.dtype),
var_name=ns_key))

else: # e.g. a function
newlines = [
"_namespace{var_name} = namespace['{var_name}']"
]
for line in newlines:
load_namespace.append(
line.format(cpp_dtype=get_cpp_dtype(ns_value.dtype),
numpy_dtype=get_numpy_dtype(
ns_value.dtype),
var_name=ns_key))
support_code.append(deindent(func_code))
elif callable(func_code):
self.variables[varname] = func_code
line = '{0}} = _namespace["{1}}"]'.format(varname, varname)
load_namespace.append(line)
else:
raise TypeError(('Provided function implementation '
'for function %s is neither a string '
'nor callable (is type %s instead)') % (
varname,
type(func_code)))

dep_support_code = []
dep_load_namespace = []
dep_user_functions = []
if impl.dependencies is not None:
for dep_name, dep in impl.dependencies.iteritems():
self.variables[dep_name] = dep
sc, ln, uf = self._add_user_function(dep_name, dep)
dep_support_code.extend(sc)
dep_load_namespace.extend(ln)
dep_user_functions.extend(uf)

return (support_code + dep_support_code,
dep_load_namespace + load_namespace,
dep_user_functions + user_functions)

def determine_keywords(self):
from brian2.devices.device import get_device
device = get_device()
# load variables from namespace
load_namespace = []
support_code = []
handled_pointers = set()
for varname, var in self.variables.iteritems():
user_functions = []
for varname, var in self.variables.items():
if isinstance(var, AuxiliaryVariable):
line = "cdef {dtype} {varname}".format(
dtype=get_cpp_dtype(var.dtype),
Expand Down Expand Up @@ -142,8 +216,8 @@ def determine_keywords(self):
load_namespace.append(line)
elif isinstance(var, Variable):
if var.dynamic:
load_namespace.append('%s = _namespace["%s"]' % (self.get_array_name(var, False),
self.get_array_name(var, False)))
load_namespace.append('{0} = _namespace["{1}"]'.format(self.get_array_name(var, False),
self.get_array_name(var, False)))

# This is the "true" array name, not the restricted pointer.
array_name = device.get_array_name(var)
Expand Down Expand Up @@ -172,32 +246,27 @@ def determine_keywords(self):
handled_pointers.add(pointer_name)

elif isinstance(var, Function):
func_impl = var.implementations[self.codeobj_class].get_code(self.owner)
# Implementation can be None if the function is already
# available in Cython (possibly under a different name)
if func_impl is not None:
if isinstance(func_impl, basestring):
# Function is provided as Cython code
support_code.append(deindent(func_impl))
elif callable(func_impl):
self.variables[varname] = func_impl
line = '%s = _namespace["%s"]' % (varname, varname)
load_namespace.append(line)
else:
raise TypeError(('Provided function implementation '
'for function %s is neither a string '
'nor callable') % varname)
sc, ln, uf = self._add_user_function(varname, var)
support_code.extend(sc)
load_namespace.extend(ln)
user_functions.extend(uf)
else:
# fallback to Python object
print var
for k, v in var.__dict__.iteritems():
print ' ', k, v
load_namespace.append('%s = _namespace["%s"]' % (varname, varname))
load_namespace.append('{0} = _namespace["{1}"]'.format(varname, varname))

load_namespace = '\n'.join(load_namespace)
support_code = '\n'.join(support_code)
# delete the user-defined functions from the namespace and add the
# function namespaces (if any)
for funcname, func in user_functions:
del self.variables[funcname]
func_namespace = func.implementations[self.codeobj_class].get_namespace(self.owner)
if func_namespace is not None:
self.variables.update(func_namespace)

return {'load_namespace': load_namespace, 'support_code': support_code}
return {'load_namespace': '\n'.join(load_namespace),
'support_code': '\n'.join(support_code)}

###############################################################################
# Implement functions
Expand All @@ -220,7 +289,7 @@ def determine_keywords(self):
cdef int _rand_buffer_size = 1024
cdef double[:] _rand_buf = _numpy.zeros(_rand_buffer_size, dtype=_numpy.float64)
cdef int _cur_rand_buf = 0
cdef double rand(int _idx):
cdef double _rand(int _idx):
global _cur_rand_buf
global _rand_buf
if _cur_rand_buf==0:
Expand All @@ -234,11 +303,11 @@ def determine_keywords(self):

DEFAULT_FUNCTIONS['rand'].implementations.add_implementation(CythonCodeGenerator,
code=rand_code,
name='rand')
name='_rand')

DEFAULT_FUNCTIONS['randn'].implementations.add_implementation(CythonCodeGenerator,
code=randn_code,
name='randn')
name='_randn')

int_code = '''
ctypedef fused _to_int:
Expand Down
23 changes: 6 additions & 17 deletions brian2/codegen/generators/numpy_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,24 +88,13 @@ def translate_one_statement_sequence(self, statements):
for varname in write:
var = variables[varname]
index_var = variable_indices[varname]
# check if all operations were inplace and we're operating on the
# whole vector, if so we don't need to write the array back
if not index_var in self.iterate_all:
all_inplace = False
line = self.get_array_name(var)
if index_var in self.iterate_all:
line = line + '[:]'
else:
all_inplace = True
for stmt in statements:
if stmt.var == varname and not stmt.inplace:
all_inplace = False
break
if not all_inplace:
line = self.get_array_name(var)
if index_var in self.iterate_all:
line = line + '[:]'
else:
line = line + '[' + index_var + ']'
line = line + ' = ' + varname
lines.append(line)
line = line + '[' + index_var + ']'
line = line + ' = ' + varname
lines.append(line)
# if index_var in iterate_all:
# line = '{array_name}[:] = {varname}'
# else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ cdef void _flush_buffer(buf, dynarr, int N):
# add to buffer
if _cond:
if _p!=1.0:
if rand(_vectorisation_idx)>=_p:
if _rand(_vectorisation_idx)>=_p:
continue
for _repetition in range(_n):
{{N_outgoing}}[_pre_idx] += 1
Expand Down

0 comments on commit 7830100

Please sign in to comment.