Skip to content

Commit

Permalink
Support refine for multiple variables
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Oct 12, 2020
1 parent 733878b commit aaefa48
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 32 deletions.
59 changes: 38 additions & 21 deletions brian2modelfitting/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def __init__(self, dt, model, input, output, input_var, output_var,
self._best_error = None
self.optimizer = None
self.metric = None
self.metric_weights = None
if not param_init:
param_init = {}
for param, val in param_init.items():
Expand Down Expand Up @@ -389,11 +390,9 @@ def setup_simulator(self, network_name, n_neurons, output_var, param_init,

record_vars = [v for v in output_var if v != 'spikes']
if calc_gradient:
if not len(output_var) == 1:
raise AssertionError('Cannot calculate gradient with multiple '
'output variables.')
record_vars.extend([f'S_{output_var[0]}_{p}'
for p in self.parameter_names])
record_vars.extend([f'S_{out_var}_{p}'
for p in self.parameter_names
for out_var in self.output_var])
if len(record_vars):
network.add(StateMonitor(neurons, record_vars, record=True,
name='statemonitor', dt=self.dt))
Expand Down Expand Up @@ -1028,8 +1027,6 @@ def refine(self, params=None, t_start=None, t_weights=None, normalization=None,
import lmfit
except ImportError:
raise ImportError('Refinement needs the "lmfit" package.')
if len(self.output_var) > 1:
raise NotImplementedError('refine currently requires a single output variable.')
if params is None:
if self.best_params is None:
raise TypeError('You need to either specify parameters or run '
Expand Down Expand Up @@ -1091,23 +1088,40 @@ def _calc_error(params):
self.simulator.run(self.duration, param_dic,
self.parameter_names, iteration=iteration,
name='refine')
trace = getattr(self.simulator.statemonitor, self.output_var[0]+'_')
if t_weights is None:
residual = trace[:, t_start_steps:] - self.output_[0][:, t_start_steps:]
one_residual = []
if self.metric_weights is not None:
metric_weights = self.metric_weights
else:
residual = (trace - self.output_[0]) * sqrt(t_weights)
return residual.flatten() * normalization
metric_weights = ones(len(self.output_var))
for out_var, out, metric_weight in zip(self.output_var,
self.output_,
metric_weights):
trace = getattr(self.simulator.statemonitor, out_var+'_')
if t_weights is None:
residual = trace[:, t_start_steps:] - out[:, t_start_steps:]
else:
residual = (trace - out) * sqrt(t_weights)
one_residual.append(sqrt(metric_weight)*residual)
return array(one_residual).flatten() * normalization

def _calc_gradient(params):
residuals = []
if self.metric_weights is not None:
metric_weights = self.metric_weights
else:
metric_weights = ones(len(self.output_var))
for name in self.parameter_names:
trace = getattr(self.simulator.statemonitor,
f'S_{self.output_var[0]}_{name}_')
if t_weights is None:
residual = trace[:, t_start_steps:]
else:
residual = trace * sqrt(t_weights)
residuals.append(residual.flatten() * normalization)
one_residual = []
for out_var, metric_weight in zip(self.output_var,
metric_weights):
trace = getattr(self.simulator.statemonitor,
f'S_{out_var}_{name}_')
if t_weights is None:
residual = trace[:, t_start_steps:]
else:
residual = trace * sqrt(t_weights)
one_residual.append(sqrt(metric_weight)*residual)
residuals.append(array(one_residual).flatten() * normalization)
gradient = array(residuals)
return gradient.T

Expand All @@ -1117,8 +1131,11 @@ def _callback_wrapper(params, iter, resid, *args, **kwds):
error = mean(resid**2)
errors.append(error)
if self.use_units:
error_dim = self.output_dim[0]**2 * get_dimensions(normalization)**2
all_errors = Quantity(errors, dim=error_dim)
if len(self.output_var) == 1:
error_dim = self.output_dim[0]**2 * get_dimensions(normalization)**2
all_errors = Quantity(errors, dim=error_dim)
else:
all_errors = array(errors)
params = {p: Quantity(val, dim=self.model[p].dim)
for p, val in params.items()}
else:
Expand Down
24 changes: 13 additions & 11 deletions examples/multiobjective.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
import numpy as np
from brian2 import *
from brian2modelfitting import *

# set_device('cpp_standalone') # recommend for speed
dt = 0.01*ms
defaultclock.dt = dt
prefs.codegen.target = 'numpy'

# Generate ground truth data
area = 20000*umetre**2
Expand Down Expand Up @@ -50,7 +49,7 @@
fitter = TraceFitter(model=eqs, input_var='I', output_var=['v', 'm'],
input=inp_ar.T, output=[ground_truth_v,
ground_truth_m],
dt=dt, n_samples=30, param_init={'v': 'VT'},
dt=dt, n_samples=60, param_init={'v': 'El'},
method='exponential_euler')

res, error = fitter.fit(n_rounds=20,
Expand All @@ -63,18 +62,21 @@
g_kd=[6e-07*siemens, 6e-05*siemens],
Cm=[0.1*ufarad*cm**-2 * area, 2*ufarad*cm**-2 * area])

## Show results
all_output = fitter.results(format='dataframe')
print(all_output)
refined_params, _ = fitter.refine(calc_gradient=True)

## Visualization of the results
start_scope()
fits = fitter.generate_traces(params=None, param_init={'v': -65*mV})
refined_fits = fitter.generate_traces(params=refined_params, param_init={'v': -65*mV})

fig, ax = plt.subplots(2, ncols=5, figsize=(20, 5), sharex=True, sharey='row')
for idx in range(5):
ax[0][idx].plot(ground_truth_v[idx]/mV)
ax[0][idx].plot(fits['v'][idx].transpose()/mV)
ax[1][idx].plot(ground_truth_m[idx])
ax[1][idx].plot(fits['m'][idx].transpose())
ax[0][idx].plot(ground_truth_v[idx]/mV, 'k:', alpha=0.75,
label='ground truth')
ax[0][idx].plot(fits['v'][idx].transpose()/mV, alpha=0.75, label='fit')
ax[0][idx].plot(refined_fits['v'][idx].transpose() / mV, alpha=0.75,
label='refined')
ax[1][idx].plot(ground_truth_m[idx], 'k:', alpha=0.75)
ax[1][idx].plot(fits['m'][idx].transpose(), alpha=0.75)
ax[1][idx].plot(refined_fits['m'][idx].transpose(), alpha=0.75)
ax[0][0].legend(loc='best')
plt.show()

0 comments on commit aaefa48

Please sign in to comment.