Skip to content

Commit

Permalink
WIP: refactor simulation, make refine+generate work with standalone
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Mar 9, 2020
1 parent 561019b commit 648735f
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 107 deletions.
183 changes: 78 additions & 105 deletions brian2modelfitting/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

logger = get_logger(__name__)


def get_param_dic(params, param_names, n_traces, n_samples):
"""Transform parameters into a dictionary of appropiate size"""
params = array(params)
Expand Down Expand Up @@ -77,7 +78,10 @@ def setup_fit():
'CPPStandaloneDevice': CPPStandaloneSimulator(),
'RuntimeDevice': RuntimeSimulator()
}

if isinstance(get_device(), CPPStandaloneDevice):
if device.has_been_run is True:
get_device().reinit()
get_device().activate()
return simulators[get_device().__class__.__name__]


Expand Down Expand Up @@ -259,10 +263,6 @@ def __init__(self, dt, model, input, output, input_var, output_var,
n_samples, threshold, reset, refractory, method, param_init):
"""Initialize the fitter."""

if isinstance(get_device(), CPPStandaloneDevice):
if device.has_been_run is True:
raise Exception("To run another fitter in standalone mode you "
"need to create new script")
if dt is None:
raise ValueError("dt-sampling frequency of the input must be set")

Expand All @@ -271,10 +271,9 @@ def __init__(self, dt, model, input, output, input_var, output_var,
if input_var not in model.identifiers:
raise NameError("%s is not an identifier in the model" % input_var)

defaultclock.dt = dt
self.dt = dt

self.simulator = setup_fit()
self.simulator = None

self.parameter_names = model.parameter_names
self.n_traces, n_steps = input.shape
Expand Down Expand Up @@ -314,6 +313,35 @@ def __init__(self, dt, model, input, output, input_var, output_var,
"parameter in the model" % param)
self.param_init = param_init

def setup_simulator(self, network_name, n_neurons, output_var, param_init,
calc_gradient=False, optimize=True, level=1):
simulator = setup_fit()

namespace = get_full_namespace({'input_var': self.input_traces,
'n_traces': self.n_traces},
level=level+1)
if network_name != 'generate':
namespace['output_var'] = TimedArray(self.output.transpose(),
dt=self.dt)
neurons = self.setup_neuron_group(n_neurons, namespace,
calc_gradient=calc_gradient,
optimize=optimize)

if output_var == 'spikes':
monitor = SpikeMonitor(neurons, name='monitor')
else:
record_vars = [output_var]
if calc_gradient:
record_vars.extend([f'S_{output_var}_{p}'
for p in self.parameter_names])
monitor = StateMonitor(neurons, record_vars, record=True,
name='monitor', dt=self.dt)

network = Network(neurons, monitor)

simulator.initialize(network, param_init, name=network_name)
return simulator

def setup_neuron_group(self, n_neurons, namespace, calc_gradient=False,
optimize=True, name='neurons'):
"""
Expand All @@ -336,7 +364,7 @@ def setup_neuron_group(self, n_neurons, namespace, calc_gradient=False,
neurons = NeuronGroup(n_neurons, self.model, method=self.method,
threshold=self.threshold, reset=self.reset,
refractory=self.refractory, name=name,
namespace=namespace)
namespace=namespace, dt=self.dt)
if calc_gradient:
sensitivity_eqs = get_sensitivity_equations(neurons,
parameters=self.parameter_names,
Expand All @@ -346,7 +374,7 @@ def setup_neuron_group(self, n_neurons, namespace, calc_gradient=False,
method=self.method,
threshold=self.threshold, reset=self.reset,
refractory=self.refractory, name=name,
namespace=namespace)
namespace=namespace, dt=self.dt)
return neurons

@abc.abstractmethod
Expand Down Expand Up @@ -382,6 +410,7 @@ def optimization_iter(self, optimizer, metric):
d_param = get_param_dic(parameters, self.parameter_names,
self.n_traces, self.n_samples)
self.simulator.run(self.duration, d_param, self.parameter_names)

errors = self.calc_errors(metric)

optimizer.tell(parameters, errors)
Expand Down Expand Up @@ -449,6 +478,15 @@ def fit(self, optimizer, metric=None, n_rounds=1, callback='text',

callback = callback_setup(callback, n_rounds)

# Check whether we can reuse the current simulator or whether we have
# to create a new one (only relevant for standalone, but does not hurt
# for runtime)
if self.simulator is None or self.simulator.current_net != 'fit':
self.simulator = self.setup_simulator('fit', self.n_neurons,
output_var=self.output_var,
param_init=self.param_init,
level=2)

# Run Optimization Loop
error = None
for index in range(n_rounds):
Expand Down Expand Up @@ -538,52 +576,28 @@ def generate(self, params=None, output_var=None, param_init=None, level=0):
"""
if params is None:
params = self.best_params

needs_device_reset = False
if isinstance(get_device(), CPPStandaloneDevice):
set_device('runtime')
simulator = RuntimeSimulator()
needs_device_reset = True
else:
simulator = self.simulator

defaultclock.dt = self.dt
Ntraces, Nsteps = self.input.shape

# Setup NeuronGroup
namespace = get_full_namespace({'input_var': self.input_traces,
'n_traces': Ntraces,
'output_var': output_var},
level=level+1)

self.neurons = self.setup_neuron_group(Ntraces, namespace,
name='neurons')

neurons = self.setup_neuron_group(Ntraces, namespace, name='neurons')
neurons.namespace['input_var'] = self.input_traces
neurons.namespace['n_traces'] = Ntraces
neurons.namespace['output_var'] = output_var
if output_var == 'spikes':
monitor = SpikeMonitor(neurons, record=True, name='monitor')
else:
monitor = StateMonitor(neurons, output_var, record=True, name='monitor')
network = Network(neurons, monitor)

if param_init:
simulator.initialize(network, param_init, name='generate')
if param_init is None:
param_init = self.param_init
else:
simulator.initialize(network, self.param_init, name='generate')

simulator.run(self.duration, params, self.parameter_names, name='generate')
param_init = dict(self.param_init)
self.param_init.update(param_init)
if output_var is None:
output_var = self.output_var

self.simulator = self.setup_simulator('generate', self.n_traces,
output_var=output_var,
param_init=param_init,
level=level+1)
param_dic = get_param_dic([params[p] for p in self.parameter_names],
self.parameter_names, self.n_traces, 1)
self.simulator.run(self.duration, param_dic, self.parameter_names,
name='generate')

if output_var == 'spikes':
fits = get_spikes(simulator.monitor,
fits = get_spikes(self.simulator.monitor,
1, self.n_traces)[0] # a single "sample"
else:
fits = getattr(simulator.monitor, output_var)

if needs_device_reset:
reset_device()
fits = getattr(self.simulator.monitor, output_var)

return fits

Expand All @@ -592,35 +606,20 @@ class TraceFitter(Fitter):
"""Input nad output have to have the same dimensions."""
def __init__(self, model, input_var, input, output_var, output, dt,
n_samples=30, method=None, reset=None, refractory=False,
threshold=None, level=0, param_init=None):
threshold=None, param_init=None):
"""Initialize the fitter."""
super().__init__(dt, model, input, output, input_var, output_var,
n_samples, threshold, reset, refractory, method,
param_init)
# We store the bounds set in TraceFitter.fit, so that Tracefitter.refine
# can
# can reuse them
self.bounds = None

if output_var not in self.model.names:
raise NameError("%s is not a model variable" % output_var)
if output.shape != input.shape:
raise ValueError("Input and output must have the same size")

output_traces = TimedArray(output.transpose(), dt=dt)

# Setup NeuronGroup
namespace = get_full_namespace({'input_var': self.input_traces,
'n_traces': self.n_traces,
'output_var': output_traces},
level=level+1)
neurons = self.setup_neuron_group(self.n_neurons, namespace)

monitor = StateMonitor(neurons, output_var, record=True,
name='monitor')
network = Network(neurons, monitor)

self.simulator.initialize(network, self.param_init)

def calc_errors(self, metric):
"""
Returns errors after simulation with StateMonitor.
Expand Down Expand Up @@ -760,51 +759,28 @@ def refine(self, params=None, t_start=None, normalization=None,
parameters.add(param_name, value=array(params[param_name]),
min=array(min_bound), max=array(max_bound))

needs_device_reset = False
if isinstance(get_device(), CPPStandaloneDevice):
set_device('runtime')
simulator = RuntimeSimulator()
needs_device_reset = True
else:
simulator = self.simulator

namespace = get_full_namespace({'input_var': self.input_traces,
'n_traces': self.n_traces,
'output_var': self.output_var},
level=level+1)
neurons = self.setup_neuron_group(self.n_traces, namespace,
calc_gradient=calc_gradient,
optimize=optimize,
name='neurons')
monitored_variables = [self.output_var]
param_init = dict(self.param_init)
if calc_gradient:
monitored_variables += [f'S_{self.output_var}_{p}'
for p in self.parameter_names]
param_init.update(get_sensitivity_init(neurons,
self.parameter_names,
param_init))
monitor = StateMonitor(neurons, monitored_variables, record=True,
name='monitor')
network = Network(neurons, monitor)

simulator.initialize(network, param_init, name='refine')
self.simulator = self.setup_simulator('refine', self.n_traces,
output_var=self.output_var,
param_init=self.param_init,
calc_gradient=calc_gradient,
optimize=optimize,
level=level+1)

t_start_steps = int(round(t_start / self.dt))

def _calc_error(params):
simulator.run(self.duration, {p: float(val)
for p, val in params.items()},
self.parameter_names, name='refine')
trace = getattr(simulator.networks['refine']['monitor'],
self.output_var+'_')
param_dic = get_param_dic([params[p] for p in self.parameter_names],
self.parameter_names, self.n_traces, 1)
self.simulator.run(self.duration, param_dic,
self.parameter_names, name='refine')
trace = getattr(self.simulator.monitor, self.output_var+'_')
residual = trace[:, t_start_steps:] - self.output[:, t_start_steps:]
return residual.flatten() * normalization

def _calc_gradient(params):
residuals = []
for name in self.parameter_names:
trace = getattr(simulator.networks['refine']['monitor'],
trace = getattr(self.simulator.monitor,
f'S_{self.output_var}_{name}_')
residual = trace[:, t_start_steps:]
residuals.append(residual.flatten() * normalization)
Expand Down Expand Up @@ -843,9 +819,6 @@ def _callback_wrapper(params, iter, resid, *args, **kwds):
iter_cb=iter_cb,
**kwds)

if needs_device_reset:
reset_device()

return {p: float(val) for p, val in result.params.items()}, result


Expand Down
3 changes: 1 addition & 2 deletions brian2modelfitting/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def __init__(self):
super(CPPStandaloneSimulator, self).__init__()
self.params_init = None


def run(self, duration, params, params_names, name='fit'):
"""
Simulation has to be run in two stages in order to initialize the
Expand All @@ -142,7 +141,7 @@ def run(self, duration, params, params_names, name='fit'):
for k, v in self.var_init.items():
self.neurons.__setattr__(k, v)

network.run(duration, namespace={}, report='text')
network.run(duration, namespace={})
else:
set_states(self.params_init, params)
run_again()

0 comments on commit 648735f

Please sign in to comment.