Skip to content

Commit

Permalink
each traces_times value has to be a Quantity now, fixed time array in…
Browse files Browse the repository at this point in the history
… eFEL calc + docs
  • Loading branch information
alTeska committed Sep 13, 2019
1 parent ba7c3c7 commit 2c91e89
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 13 deletions.
16 changes: 11 additions & 5 deletions brian2modelfitting/modelfitting/metric.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import abc
import efel
from itertools import repeat
from brian2 import Hz, second, Quantity
from brian2.units.fundamentalunits import check_units
from brian2 import Hz, second, Quantity, ms, us
from brian2.units.fundamentalunits import check_units, in_unit
from numpy import (array, sum, square, reshape, abs, amin, digitize,
rint, arange, atleast_2d, NaN, float64, split, shape,)

Expand Down Expand Up @@ -73,7 +73,7 @@ def get_gamma_factor(model, data, delta, time, dt):
def calc_eFEL(traces, inp_times, feat_list, dt):
out_traces = []
for i, trace in enumerate(traces):
time = arange(0, len(trace)*dt, dt)
time = arange(0, len(trace))*dt/ms
temp_trace = {}
temp_trace['T'] = time
temp_trace['V'] = trace
Expand All @@ -82,7 +82,6 @@ def calc_eFEL(traces, inp_times, feat_list, dt):
out_traces.append(temp_trace)

results = efel.getFeatureValues(out_traces, feat_list)

return results


Expand Down Expand Up @@ -257,6 +256,13 @@ def feat_to_err(self, d1, d2):
return err

def get_features(self, traces, output, n_traces, dt):
if self.traces_times[0][0] is Quantity:
for n, trace in enumerate(self.traces_times):
t_start, t_end = trace[0], trace[1]
t_start = t_start / ms
t_end = t_end / ms
self.traces_times[n] = [t_start, t_end]

n_times = shape(self.traces_times)[0]

if (n_times != (n_traces)):
Expand All @@ -279,7 +285,7 @@ def get_features(self, traces, output, n_traces, dt):
for ii in arange(sl):
temp_trace = temp_traces[ii]
temp_feat = calc_eFEL(temp_trace, self.traces_times,
self.feat_list)
self.feat_list, dt)
self.check_values(temp_feat)
features.append(temp_feat)

Expand Down
6 changes: 3 additions & 3 deletions docs_sphinx/metric/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,22 +87,22 @@ To get all of the eFEL features you can run the following code:

To define the :py:class:`~brian2modelfitting.modelfitting.metric.FeatureMetric`, user has to define following input parameters:

- ``traces_times`` - list of times indicating start and end of input current, has to be specified for each of input traces
- ``traces_times`` - list of times indicating start and end of input current, has to be specified for each of input traces, each value has to be a :py:class:`~brian2.units.fundamentalunits.Quantity`
- ``feat_list`` - list of strings with names of features to be used
- ``combine`` - function to be used to compare features between output and simulated traces, (for `combine=None`, subtracts the features)

Example code usage:

.. code:: python
traces_times = [[50, 100], [50, 100], [50, 100], [50, 100]]
traces_times = [[50*ms, 100*ms], [50*ms, 100*ms], [50*ms, 100*ms], [50, 100*ms]]
feat_list = ['voltage_base', 'time_to_first_spike', 'Spikecount']
metric = FeatureMetric(traces_times, feat_list, combine=None)
.. note::

If times of stimulation are same for all of the traces, user can specify a single list that will be replicated for
``eFEL`` library: ``traces_times = [[50, 100]]``.
``eFEL`` library: ``traces_times = [[50*ms, 100*ms]]``.



Expand Down
13 changes: 8 additions & 5 deletions examples/hh_nevergrad_eFEL.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
## Optimization and Metric Choice
n_opt = NevergradOptimizer()

# traces_times = [[50, 100], [50, 100], [50, 100], [50, 100]]
traces_times = [[50, 100]]
# traces_times = [[5*ms, 10*ms]]
traces_times = [[0.005*second, 0.010*second]]
feat_list = ['voltage_base', 'time_to_first_spike', 'Spikecount', ]
metric = FeatureMetric(traces_times, feat_list)

Expand All @@ -54,7 +54,7 @@
param_init={'v': -65*mV},
method='exponential_euler',)

res, error = fitter.fit(n_rounds=5,
res, error = fitter.fit(n_rounds=1,
optimizer=n_opt, metric=metric,
callback='progressbar',
gl = [1e-09 *siemens, 1e-07 *siemens],
Expand All @@ -70,10 +70,13 @@
## Visualization of the results
start_scope()
fits = fitter.generate_traces(params=None, param_init={'v': -65*mV})
trace = out_traces[0]
time = arange(0, len(trace)*dt/ms, dt/ms)
# print(time[-1])

fig, ax = plt.subplots(ncols=4, figsize=(20,5))
ax[0].plot(out_traces[0].transpose())
ax[0].plot(fits[0].transpose()/mV)
ax[0].plot(time, out_traces[0].transpose())
ax[0].plot(time,fits[0].transpose()/mV)

ax[1].plot(out_traces[1].transpose())
ax[1].plot(fits[1].transpose()/mV)
Expand Down

0 comments on commit 2c91e89

Please sign in to comment.