Skip to content

Commit

Permalink
Merge 6d924b2 into 9d6bd58
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Mar 17, 2020
2 parents 9d6bd58 + 6d924b2 commit 55ff8e8
Show file tree
Hide file tree
Showing 4 changed files with 301 additions and 70 deletions.
198 changes: 151 additions & 47 deletions brian2modelfitting/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from brian2.devices.cpp_standalone.device import CPPStandaloneDevice
from brian2.core.functions import Function
from .simulator import RuntimeSimulator, CPPStandaloneSimulator
from .metric import Metric, SpikeMetric, TraceMetric
from .metric import Metric, SpikeMetric, TraceMetric, MSEMetric
from .optimizer import Optimizer
from .utils import callback_setup, make_dic

Expand Down Expand Up @@ -265,7 +265,8 @@ class Fitter(metaclass=abc.ABCMeta):
Dictionary of variables to be initialized with respective values
"""
def __init__(self, dt, model, input, output, input_var, output_var,
n_samples, threshold, reset, refractory, method, param_init):
n_samples, threshold, reset, refractory, method, param_init,
use_units=True):
"""Initialize the fitter."""

if dt is None:
Expand Down Expand Up @@ -295,8 +296,14 @@ def __init__(self, dt, model, input, output, input_var, output_var,
self.output = Quantity(output)
self.output_ = array(output)
self.output_var = output_var
if output_var == 'spikes':
self.output_dim = DIMENSIONLESS
else:
self.output_dim = model[output_var].dim
self.model = model

self.use_units = use_units

input_dim = get_dimensions(input)
input_dim = '1' if input_dim is DIMENSIONLESS else repr(input_dim)
input_eqs = "{} = input_var(t, i % n_traces) : {}".format(input_var,
Expand All @@ -307,7 +314,8 @@ def __init__(self, dt, model, input, output, input_var, output_var,
self.input_traces = input_traces

# initialization of attributes used later
self.best_params = None
self._best_params = None
self._best_error = None
self.optimizer = None
self.metric = None
if not param_init:
Expand Down Expand Up @@ -470,11 +478,10 @@ def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
online_error: bool, optional
Whether to calculate the squared error between target trace and
simulated trace online. Defaults to ``False``.
level : `int`, optional
How much farther to go down in the stack to find the namespace.
**params
bounds for each parameter
level : `int`, optional
How much farther to go down in the stack to find the namespace.
Returns
-------
best_results : dict
Expand Down Expand Up @@ -517,24 +524,62 @@ def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
level=level+1)

# Run Optimization Loop
error = None
for index in range(n_rounds):
best_params, parameters, errors = self.optimization_iter(optimizer,
metric)

self._best_error = nanmin(self.optimizer.errors)
# create output variables
self.best_params = make_dic(self.parameter_names, best_params)
error = nanmin(self.optimizer.errors)
param_dicts = [{p: v for p, v in zip(self.parameter_names,
self._best_params = make_dic(self.parameter_names, best_params)
if self.use_units:
if self.output_var == 'spikes':
output_dim = DIMENSIONLESS
else:
output_dim = self.output_dim
# Correct the units for the normalization factor
error_dim = self.metric.get_normalized_dimensions(output_dim)
best_error = Quantity(float(self.best_error), dim=error_dim)
errors = Quantity(errors, dim=error_dim)
param_dicts = [{p: Quantity(v, dim=self.model[p].dim)
for p, v in zip(self.parameter_names,
one_param_set)}
for one_param_set in parameters]

if callback(param_dicts, errors, self.best_params, error, index) is True:
for one_param_set in parameters]
else:
param_dicts = [{p: v for p, v in zip(self.parameter_names,
one_param_set)}
for one_param_set in parameters]
best_error = self.best_error

if callback(param_dicts,
errors,
self.best_params,
best_error,
index) is True:
break

return self.best_params, error
return self.best_params, self.best_error

@property
def best_params(self):
if self._best_params is None:
return None
if self.use_units:
params_with_units = {p: Quantity(v, dim=self.model[p].dim)
for p, v in self._best_params.items()}
return params_with_units
else:
return self._best_params

@property
def best_error(self):
if self._best_error is None:
return None
if self.use_units:
error_dim = self.metric.get_dimensions(self.output_dim)
return Quantity(self._best_error, dim=error_dim)
else:
return self._best_error

def results(self, format='list'):
def results(self, format='list', use_units=None):
"""
Returns all of the gathered results (parameters and errors).
In one of the 3 formats: 'dataframe', 'list', 'dict'.
Expand All @@ -544,6 +589,10 @@ def results(self, format='list'):
format: str
The desired output format. Currently supported: ``dataframe``,
``list``, or ``dict``.
use_units: bool, optional
Whether to use units in the results. If not specified, defaults to
`.Tracefitter.use_units`, i.e. the value that was specified when
the `.Tracefitter` object was created (``True`` by default).
Returns
-------
Expand All @@ -552,40 +601,57 @@ def results(self, format='list'):
'list': list of dictionaries
'dict': dictionary of lists
"""
if use_units is None:
use_units = self.use_units
names = list(self.parameter_names)
names.append('errors')

params = array(self.optimizer.tested_parameters)
params = params.reshape(-1, params.shape[-1])

errors = array([array(self.optimizer.errors).flatten()])
data = concatenate((params, errors.transpose()), axis=1)
if use_units:
error_dim = self.metric.get_dimensions(self.output_dim)
errors = Quantity(array(self.optimizer.errors).flatten(),
dim=error_dim)
else:
errors = array(array(self.optimizer.errors).flatten())

dim = self.model.dimensions

if format == 'list':
res_list = []
for j in arange(0, len(params)):
temp_data = data[j]
temp_data = params[j]
res_dict = dict()

for i, n in enumerate(names[:-1]):
res_dict[n] = Quantity(temp_data[i], dim=dim[n])
res_dict[names[-1]] = temp_data[-1]
for i, n in enumerate(names):
if use_units:
res_dict[n] = Quantity(temp_data[i], dim=dim[n])
else:
res_dict[n] = float(temp_data[i])
res_dict['error'] = errors[j]
res_list.append(res_dict)

return res_list

elif format == 'dict':
res_dict = dict()
for i, n in enumerate(names[:-1]):
res_dict[n] = Quantity(data[:, i], dim=dim[n])
for i, n in enumerate(names):
if use_units:
res_dict[n] = Quantity(params[:, i], dim=dim[n])
else:
res_dict[n] = array(params[:, i])

res_dict[names[-1]] = data[:, -1]
res_dict['error'] = errors
return res_dict

elif format == 'dataframe':
from pandas import DataFrame
return DataFrame(data=data, columns=names)
if use_units:
logger.warn('Results in dataframes do not support units. '
'Specify "use_units=False" to avoid this warning.',
name_suffix='dataframe_units')
data = concatenate((params, array(errors)[None, :].transpose()), axis=1)
return DataFrame(data=data, columns=names + ['error'])

def generate(self, params=None, output_var=None, param_init=None, level=0):
"""
Expand Down Expand Up @@ -640,14 +706,34 @@ def generate(self, params=None, output_var=None, param_init=None, level=0):


class TraceFitter(Fitter):
"""Input and output have to have the same dimensions."""
"""
A `Fitter` for fitting recorded traces (e.g. of the membrane potential).
Parameters
----------
model
input_var
input
output_var
output
dt
n_samples
method
reset
refractory
threshold
param_init
use_units: bool, optional
Whether to use units in all user-facing interfaces, e.g. in the callback
arguments or in the returned parameter dictionary and errors. Defaults
to ``True``.
"""
def __init__(self, model, input_var, input, output_var, output, dt,
n_samples=30, method=None, reset=None, refractory=False,
threshold=None, param_init=None):
"""Initialize the fitter."""
threshold=None, param_init=None, use_units=True):
super().__init__(dt, model, input, output, input_var, output_var,
n_samples, threshold, reset, refractory, method,
param_init)
param_init, use_units=use_units)
# We store the bounds set in TraceFitter.fit, so that Tracefitter.refine
# can reuse them
self.bounds = None
Expand Down Expand Up @@ -677,10 +763,11 @@ def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
raise TypeError("You can only use TraceMetric child metric with "
"TraceFitter")
self.bounds = dict(params)
self.best_params, error = super().fit(optimizer, metric, n_rounds,
callback, restart, level=level+1,
**params)
return self.best_params, error
best_params, error = super().fit(optimizer, metric, n_rounds,
callback, restart,
level=level+1,
**params)
return best_params, error

def generate_traces(self, params=None, param_init=None, level=0):
"""Generates traces for best fit of parameters and all inputs"""
Expand Down Expand Up @@ -829,13 +916,22 @@ def _calc_gradient(params):
errors = []
def _callback_wrapper(params, iter, resid, *args, **kwds):
error = mean(resid**2)
params = {p: float(val) for p, val in params.items()}
tested_parameters.append(params)
errors.append(error)
if self.use_units:
error_dim = self.output_dim**2 * get_dimensions(normalization)**2
all_errors = Quantity(errors, dim=error_dim)
params = {p: Quantity(val, dim=self.model[p].dim)
for p, val in params.items()}
else:
all_errors = array(errors)
params = {p: float(val) for p, val in params.items()}
tested_parameters.append(params)

best_idx = argmin(errors)
best_error = errors[best_idx]
best_error = all_errors[best_idx]
best_params = tested_parameters[best_idx]
return callback_func(params, array(errors),

return callback_func(params, all_errors,
best_params, best_error, iter)

assert 'Dfun' not in kwds
Expand All @@ -858,19 +954,26 @@ def _callback_wrapper(params, iter, resid, *args, **kwds):
iter_cb=iter_cb,
**kwds)

return {p: float(val) for p, val in result.params.items()}, result
if self.use_units:
param_dict = {p: Quantity(float(val), dim=self.model[p].dim)
for p, val in result.params.items()}
else:
param_dict = {p: float(val)
for p, val in result.params.items()}

return param_dict, result


class SpikeFitter(Fitter):
def __init__(self, model, input, output, dt, reset, threshold,
input_var='I', refractory=False, n_samples=30,
method=None, param_init=None):
method=None, param_init=None, use_units=True):
"""Initialize the fitter."""
if method is None:
method = 'exponential_euler'
super().__init__(dt, model, input, output, input_var, 'v',
n_samples, threshold, reset, refractory, method,
param_init)
param_init, use_units=use_units)
self.output_var = 'spikes'

if param_init:
Expand All @@ -897,10 +1000,10 @@ def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
if not isinstance(metric, SpikeMetric):
raise TypeError("You can only use SpikeMetric child metric with "
"SpikeFitter")
self.best_params, error = super().fit(optimizer, metric, n_rounds,
callback, restart, level=level+1,
**params)
return self.best_params, error
best_params, error = super().fit(optimizer, metric, n_rounds,
callback, restart, level=level+1,
**params)
return best_params, error

def generate_spikes(self, params=None, param_init=None, level=0):
"""Generates traces for best fit of parameters and all inputs"""
Expand Down Expand Up @@ -943,8 +1046,9 @@ def __init__(self, model, input_var, input, output_var, output, dt,

self.simulator = None

def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
def fit(self, optimizer, n_rounds=1, callback='text',
restart=False, level=0, **params):
metric = MSEMetric() # not used, but makes error dimensions correct
return super(OnlineTraceFitter, self).fit(optimizer, metric=metric,
n_rounds=n_rounds,
callback=callback,
Expand Down
Loading

0 comments on commit 55ff8e8

Please sign in to comment.