Skip to content

Commit

Permalink
Cleaned up CPPStandaloneDevice.build and partial support for group
Browse files Browse the repository at this point in the history
variable initialisations
  • Loading branch information
thesamovar committed Sep 24, 2013
1 parent 9e72b25 commit 8a81ce8
Show file tree
Hide file tree
Showing 12 changed files with 233 additions and 64 deletions.
2 changes: 1 addition & 1 deletion brian2/devices/cpp_standalone/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from brian2.core.preferences import brian_prefs
from .codeobject import CPPStandaloneCodeObject
from .device import cpp_standalone_device, build
from .device import cpp_standalone_device, build, Network, run, reinit, stop
8 changes: 8 additions & 0 deletions brian2/devices/cpp_standalone/brianlib/common_math.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef _BRIAN_COMMON_MATH_H
#define _BRIAN_COMMON_MATH_H

#include<limits>

#define inf (std::numeric_limits<double>::infinity())

#endif
3 changes: 2 additions & 1 deletion brian2/devices/cpp_standalone/codeobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from brian2.codegen.codeobject import CodeObject
from brian2.codegen.templates import Templater
from brian2.codegen.languages.cpp_lang import CPPLanguage
from brian2.devices.device import get_device

__all__ = ['CPPStandaloneCodeObject']

Expand All @@ -26,4 +27,4 @@ def variables_to_namespace(self):
self.namespace[varname] = var.get_value()

def run(self):
raise RuntimeError("Cannot run in C++ standalone mode")
get_device().main_queue.append(('run_code_object', (self,)))
166 changes: 120 additions & 46 deletions brian2/devices/cpp_standalone/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,24 @@
from collections import defaultdict

from brian2.core.clocks import defaultclock
from brian2.core.network import Network as OrigNetwork
from brian2.core.namespace import get_local_namespace
from brian2.devices.device import Device, all_devices
from brian2.core.preferences import brian_prefs
from brian2.core.variables import *
from brian2.utils.filetools import copy_directory
from brian2.utils.stringtools import word_substitute
from brian2.codegen.languages.cpp_lang import c_data_type
from brian2.codegen.codeobject import CodeObjectUpdater
from brian2.units.fundamentalunits import (Quantity, Unit, is_scalar_type,
fail_for_dimension_mismatch,
have_same_dimensions,
)
from brian2.units import second

from .codeobject import CPPStandaloneCodeObject

__all__ = ['build']
__all__ = ['build', 'Network', 'run', 'reinit', 'stop']

def freeze(code, ns):
# this is a bit of a hack, it should be passed to the template somehow
Expand All @@ -23,16 +30,47 @@ def freeze(code, ns):
code = word_substitute(code, {k: repr(v)})
return code

class StandaloneVariableView():
class StandaloneVariableView(object):
'''
Will store information about how the variable was set in the original
`ArrayVariable` object.
'''
def __init__(self, variable):
def __init__(self, name, variable, group, template,
unit=None, level=0):
self.name = name
self.variable = variable
self.group = group
self.template = template
self.unit = unit
self.level = level

def __setitem__(self, key, value):
self.variable.assignments.append((key, value))
# def __setitem__(self, key, value):
# variable = self.variable
# self.variable.assignments.append((key, value))
def __setitem__(self, i, value):
variable = self.variable
if variable.scalar:
if not (i == slice(None) or i == 0 or (hasattr(i, '__len__') and len(i) == 0)):
raise IndexError('Variable is a scalar variable.')
indices = np.array([0])
else:
indices = self.group.indices[self.group.variable_indices[self.name]][i]
try:
iter(value)
sequence = True
except TypeError:
sequence = False
if not isinstance(value, basestring) and not sequence:
if not self.unit is None:
fail_for_dimension_mismatch(value, self.unit)
value = repr(value)
if isinstance(value, basestring):
check_units = self.unit is not None
self.group._set_with_code(variable, indices, value,
template=self.template,
check_units=check_units, level=self.level + 1)
else:
raise NotImplementedError("Setting variables with sequences not supported on devices.")

def __getitem__(self, item):
raise NotImplementedError()
Expand Down Expand Up @@ -60,11 +98,23 @@ def set_value(self, value, index=None):
index = slice(None)
self.assignments.append((index, value))

# def get_addressable_value(self, group, level=0):
# return StandaloneVariableView(self)
#
# def get_addressable_value_with_unit(self, group, level=0):
# return StandaloneVariableView(self)

def get_addressable_value(self, group, level=0):
return StandaloneVariableView(self)
template = getattr(group, '_set_with_code_template',
'group_variable_set')
return StandaloneVariableView(self.name, self, group, template=template,
unit=None, level=level)

def get_addressable_value_with_unit(self, group, level=0):
return StandaloneVariableView(self)
template = getattr(group, '_set_with_code_template',
'group_variable_set')
return StandaloneVariableView(self.name, self, group, template=template,
unit=self.unit, level=level)


class StandaloneDynamicArrayVariable(StandaloneArrayVariable):
Expand All @@ -80,6 +130,7 @@ def __init__(self):
self.array_specs = []
self.dynamic_array_specs = []
self.code_objects = {}
self.main_queue = []

def array(self, owner, name, size, unit, dtype=None, constant=False,
is_bool=False):
Expand Down Expand Up @@ -126,22 +177,7 @@ def code_object(self, owner, name, abstract_code, namespace, variables, template
self.code_objects[codeobj.name] = codeobj
return codeobj

def build(self, net):
# Extract all the CodeObjects
# Note that since we ran the Network object, these CodeObjects will be sorted into the right
# running order, assuming that there is only one clock
updaters = []
for obj in net.objects:
for updater in obj.updaters:
updaters.append(updater)

# Extract the arrays information
vars = {}
for obj in net.objects:
if hasattr(obj, 'variables'):
for k, v in obj.variables.iteritems():
vars[(obj, k)] = v

def build(self):
if not os.path.exists('output'):
os.mkdir('output')

Expand All @@ -151,6 +187,39 @@ def build(self, net):
open('output/arrays.cpp', 'w').write(arr_tmp.cpp_file)
open('output/arrays.h', 'w').write(arr_tmp.h_file)

main_lines = []
for func, args in self.main_queue:
if func=='run_code_object':
codeobj, = args
main_lines.append('_run_%s(t);' % codeobj.name)
elif func=='run_network':
net, duration, namespace = args
net._prepare_for_device(namespace)
# Extract all the CodeObjects
# Note that since we ran the Network object, these CodeObjects will be sorted into the right
# running order, assuming that there is only one clock
updaters = []
for obj in net.objects:
for updater in obj.updaters:
updaters.append(updater)

# Generate the updaters
run_lines = []
for updater in updaters:
cls = updater.__class__
if cls is CodeObjectUpdater:
codeobj = updater.owner
run_lines.append('_run_%s(t);' % codeobj.name)
else:
raise NotImplementedError("C++ standalone device has not implemented "+cls.__name__)

# Generate the main lines
num_steps = int(duration/defaultclock.dt)
netcode = CPPStandaloneCodeObject.templater.network(None, run_lines=run_lines, num_steps=num_steps)
main_lines.extend(netcode.split('\n'))
else:
raise NotImplementedError("Unknown main queue function type "+func)

# Generate data for non-constant values
code_object_defs = defaultdict(list)
for codeobj in self.code_objects.itervalues():
Expand All @@ -169,33 +238,24 @@ def build(self, net):
'&(_dynamic{arrayname}[0]);').format(c_type=c_type,
arrayname=v.arrayname)
code_object_defs[codeobj.name].append(code)

# Generate the updaters
run_lines = []
for updater in updaters:
cls = updater.__class__
if cls is CodeObjectUpdater:
codeobj = updater.owner
ns = codeobj.namespace
# TODO: fix these freeze/CONSTANTS hacks somehow - they work but not elegant.
code = freeze(codeobj.code.cpp_file, ns)
code = code.replace('%CONSTANTS%', '\n'.join(code_object_defs[codeobj.name]))
code = '#include "arrays.h"\n'+code

open('output/'+codeobj.name+'.cpp', 'w').write(code)
open('output/'+codeobj.name+'.h', 'w').write(codeobj.code.h_file)

run_lines.append('_run_%s(t);' % codeobj.name)
else:
raise NotImplementedError("C++ standalone device has not implemented "+cls.__name__)

# Generate the code objects
for codeobj in self.code_objects.itervalues():
ns = codeobj.namespace
# TODO: fix these freeze/CONSTANTS hacks somehow - they work but not elegant.
code = freeze(codeobj.code.cpp_file, ns)
code = code.replace('%CONSTANTS%', '\n'.join(code_object_defs[codeobj.name]))
code = '#include "arrays.h"\n'+code

open('output/'+codeobj.name+'.cpp', 'w').write(code)
open('output/'+codeobj.name+'.h', 'w').write(codeobj.code.h_file)

# The code_objects are passed in the right order to run them because they were
# sorted by the Network object. To support multiple clocks we'll need to be
# smarter about that.
main_tmp = CPPStandaloneCodeObject.templater.main(None,
run_lines=run_lines,
main_lines=main_lines,
code_objects=self.code_objects.values(),
num_steps=1000,
dt=float(defaultclock.dt),
)
open('output/main.cpp', 'w').write(main_tmp)
Expand All @@ -210,6 +270,20 @@ def build(self, net):

all_devices['cpp_standalone'] = cpp_standalone_device

def build(net):
cpp_standalone_device.build(net)
build = cpp_standalone_device.build


class Network(OrigNetwork):
def run(self, duration, report=None, report_period=60*second,
namespace=None, level=0):
if namespace is None:
namespace = get_local_namespace(1 + level)
cpp_standalone_device.main_queue.append(('run_network', (self, duration, namespace)))

def _prepare_for_device(self, namespace):
OrigNetwork.run(self, 0*second, namespace=namespace)

def run(*args, **kwds):
raise NotImplementedError("Magic networks not implemented for C++ standalone")
stop = run
reinit = run
59 changes: 59 additions & 0 deletions brian2/devices/cpp_standalone/templates/group_variable_set.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
////////////////////////////////////////////////////////////////////////////
//// MAIN CODE /////////////////////////////////////////////////////////////

{% macro cpp_file() %}

#include "{{codeobj_name}}.h"
#include<math.h>
#include "brianlib/common_math.h"
#include<stdint.h>

////// SUPPORT CODE ///////
namespace {
{% for line in support_code_lines %}
{{line}}
{% endfor %}
}

////// HASH DEFINES ///////
{% for line in hashdefine_lines %}
{{line}}
{% endfor %}

void _run_{{codeobj_name}}(double t)
{
///// CONSTANTS ///////////
%CONSTANTS%
///// POINTERS ////////////
{% for line in pointers_lines %}
{{line}}
{% endfor %}

//// MAIN CODE ////////////
// TODO: this hack only works when writing G.V = str, not e.g. G.v[str] = str.
const int _num_group_idx = _num_idx;
for(int _idx_group_idx=0; _idx_group_idx<_num_group_idx; _idx_group_idx++)
{
//const int _idx = _group_idx[_idx_group_idx];
const int _idx = _idx_group_idx;
const int _vectorisation_idx = _idx;
{% for line in code_lines %}
{{line}}
{% endfor %}
}
}
{% endmacro %}

////////////////////////////////////////////////////////////////////////////
//// HEADER FILE ///////////////////////////////////////////////////////////

{% macro h_file() %}
#ifndef _INCLUDED_{{codeobj_name}}
#define _INCLUDED_{{codeobj_name}}

#include "arrays.h"

void _run_{{codeobj_name}}(double t);

#endif
{% endmacro %}
15 changes: 8 additions & 7 deletions brian2/devices/cpp_standalone/templates/main.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "arrays.h"
#include<ctime>

{% for codeobj in code_objects %}
#include "{{codeobj.name}}.h"
Expand All @@ -9,16 +10,16 @@ using namespace std;

int main(void)
{
clock_t start = clock();
_init_arrays();
const double dt = {{dt}};
for(int i=0; i<{{num_steps}}; i++)
{
double t = i*dt;
{% for run_line in run_lines %}
{{run_line}}
{% endfor %}
}
double t = 0.0;
{% for main_line in main_lines %}
{{ main_line }}
{% endfor %}
cout << "Num spikes: " << _dynamic_array_spikemonitor__i.size() << endl;
double duration = (clock()-start)/(double)CLOCKS_PER_SEC;
cout << "Time: " << duration << endl;
_dealloc_arrays();
return 0;
}
7 changes: 7 additions & 0 deletions brian2/devices/cpp_standalone/templates/network.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
for(int i=0; i<{{num_steps}}; i++)
{
t = i*dt;
{% for run_line in run_lines %}
{{run_line}}
{% endfor %}
}
1 change: 1 addition & 0 deletions brian2/devices/cpp_standalone/templates/reset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "{{codeobj_name}}.h"
#include<math.h>
#include<stdint.h>
#include "brianlib/common_math.h"

////// SUPPORT CODE ///////
namespace {
Expand Down
1 change: 1 addition & 0 deletions brian2/devices/cpp_standalone/templates/spikemonitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "{{codeobj_name}}.h"
#include<math.h>
#include<stdint.h>
#include "brianlib/common_math.h"

////// SUPPORT CODE ///////
namespace {
Expand Down
1 change: 1 addition & 0 deletions brian2/devices/cpp_standalone/templates/stateupdate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "{{codeobj_name}}.h"
#include<math.h>
#include<stdint.h>
#include "brianlib/common_math.h"

////// SUPPORT CODE ///////
namespace {
Expand Down
Loading

0 comments on commit 8a81ce8

Please sign in to comment.