Skip to content

Commit

Permalink
Fix broken OnlineTraceFitter (regression)
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Mar 13, 2020
1 parent cbe7e9a commit 9d6bd58
Showing 1 changed file with 31 additions and 7 deletions.
38 changes: 31 additions & 7 deletions brian2modelfitting/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ def __init__(self, dt, model, input, output, input_var, output_var,
self.refractory = refractory

self.input = input
self.output = array(output)
self.output = Quantity(output)
self.output_ = array(output)
self.output_var = output_var
self.model = model

Expand All @@ -319,18 +320,22 @@ def __init__(self, dt, model, input, output, input_var, output_var,
self.param_init = param_init

def setup_simulator(self, network_name, n_neurons, output_var, param_init,
calc_gradient=False, optimize=True, level=1):
calc_gradient=False, optimize=True, online_error=False,
level=1):
simulator = setup_fit()

namespace = get_full_namespace({'input_var': self.input_traces,
'n_traces': self.n_traces},
level=level+1)
if hasattr(self, 't_start'): # OnlineTraceFitter
namespace['t_start'] = self.t_start
if network_name != 'generate':
namespace['output_var'] = TimedArray(self.output.transpose(),
dt=self.dt)
neurons = self.setup_neuron_group(n_neurons, namespace,
calc_gradient=calc_gradient,
optimize=optimize)
optimize=optimize,
online_error=online_error)

if output_var == 'spikes':
monitor = SpikeMonitor(neurons, name='monitor')
Expand All @@ -352,7 +357,7 @@ def setup_simulator(self, network_name, n_neurons, output_var, param_init,
return simulator

def setup_neuron_group(self, n_neurons, namespace, calc_gradient=False,
optimize=True, name='neurons'):
optimize=True, online_error=False, name='neurons'):
"""
Setup neuron group, initialize required number of neurons, create
namespace and initialize the parameters.
Expand Down Expand Up @@ -388,6 +393,11 @@ def setup_neuron_group(self, n_neurons, namespace, calc_gradient=False,
threshold=self.threshold, reset=self.reset,
refractory=self.refractory, name=name,
namespace=namespace, dt=self.dt, **kwds)
if online_error:
neurons.run_regularly('total_error += (' + self.output_var +
'-output_var(t,i % n_traces))**2 * int(t>=t_start)',
when='end')

return neurons

@abc.abstractmethod
Expand Down Expand Up @@ -433,7 +443,7 @@ def optimization_iter(self, optimizer, metric):
return results, parameters, errors

def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
restart=False, level=0, **params):
restart=False, online_error=False, level=0, **params):
"""
Run the optimization algorithm for given amount of rounds with given
number of samples drawn. Return best set of parameters and
Expand All @@ -457,6 +467,9 @@ def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
restart: bool
Flag that reinitializes the Fitter to reset the optimization.
With restart True user is allowed to change optimizer/metric.
online_error: bool, optional
Whether to calculate the squared error between target trace and
simulated trace online. Defaults to ``False``.
**params
bounds for each parameter
level : `int`, optional
Expand Down Expand Up @@ -499,6 +512,7 @@ def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
if self.simulator is None or self.simulator.current_net != 'fit':
self.simulator = self.setup_simulator('fit', self.n_neurons,
output_var=self.output_var,
online_error=online_error,
param_init=self.param_init,
level=level+1)

Expand Down Expand Up @@ -654,7 +668,7 @@ def calc_errors(self, metric):
traces = reshape(traces, (traces.shape[0]//self.n_traces,
self.n_traces,
-1))
errors = metric.calc(traces, self.output, self.dt)
errors = metric.calc(traces, self.output_, self.dt)
return errors

def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
Expand Down Expand Up @@ -798,7 +812,7 @@ def _calc_error(params):
self.simulator.run(self.duration, param_dic,
self.parameter_names, name='refine')
trace = getattr(self.simulator.monitor, self.output_var+'_')
residual = trace[:, t_start_steps:] - self.output[:, t_start_steps:]
residual = trace[:, t_start_steps:] - self.output_[:, t_start_steps:]
return residual.flatten() * normalization

def _calc_gradient(params):
Expand Down Expand Up @@ -929,6 +943,16 @@ def __init__(self, model, input_var, input, output_var, output, dt,

self.simulator = None

def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
restart=False, level=0, **params):
return super(OnlineTraceFitter, self).fit(optimizer, metric=metric,
n_rounds=n_rounds,
callback=callback,
restart=restart,
online_error=True,
level=level+1,
**params)

def calc_errors(self, metric=None):
"""Calculates error in online fashion.To be used inside optim_iter."""
errors = self.simulator.neurons.total_error/int((self.duration-self.t_start)/defaultclock.dt)
Expand Down

0 comments on commit 9d6bd58

Please sign in to comment.