-
Notifications
You must be signed in to change notification settings - Fork 6
/
fitter.py
1091 lines (951 loc) · 45.7 KB
/
fitter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import abc
import numbers
import sympy
from numpy import ones, array, arange, concatenate, mean, argmin, nanmin, reshape, zeros
from brian2.parsing.sympytools import sympy_to_str, str_to_sympy
from brian2.units.fundamentalunits import DIMENSIONLESS, get_dimensions, fail_for_dimension_mismatch
from brian2.utils.stringtools import get_identifiers
from brian2 import (NeuronGroup, defaultclock, get_device, Network,
StateMonitor, SpikeMonitor, second, get_local_namespace,
Quantity, get_logger)
from brian2.input import TimedArray
from brian2.equations.equations import Equations, SUBEXPRESSION
from brian2.devices import set_device, reset_device, device
from brian2.devices.cpp_standalone.device import CPPStandaloneDevice
from brian2.core.functions import Function
from .simulator import RuntimeSimulator, CPPStandaloneSimulator
from .metric import Metric, SpikeMetric, TraceMetric, MSEMetric
from .optimizer import Optimizer
from .utils import callback_setup, make_dic
logger = get_logger(__name__)
def get_param_dic(params, param_names, n_traces, n_samples):
"""Transform parameters into a dictionary of appropiate size"""
params = array(params)
d = dict()
for name, value in zip(param_names, params.T):
d[name] = (ones((n_traces, n_samples)) * value).T.flatten()
return d
def get_spikes(monitor, n_samples, n_traces):
"""
Get spikes from spike monitor change format from dict to a list,
remove units.
"""
spike_trains = monitor.spike_trains()
assert len(spike_trains) == n_samples*n_traces
spikes = []
i = -1
for sample in range(n_samples):
sample_spikes = []
for trace in range(n_traces):
i += 1
sample_spikes.append(array(spike_trains[i], copy=False))
spikes.append(sample_spikes)
return spikes
def get_full_namespace(additional_namespace, level=0):
# Get the local namespace with all the values that could be relevant
# in principle -- by filtering things out, we avoid circular loops
namespace = {key: value
for key, value in get_local_namespace(level=level + 1).items()
if isinstance(value, (Quantity, numbers.Number, Function))}
namespace.update(additional_namespace)
return namespace
def setup_fit():
"""
Function sets up simulator in one of the two available modes: runtime
or standalone. The `.Simulator` that will be used depends on the currently
set `.Device`. In the case of `.CPPStandaloneDevice`, the device will also
be reset if it has already run a simulation.
Returns
-------
simulator : `.Simulator`
"""
simulators = {
'CPPStandaloneDevice': CPPStandaloneSimulator(),
'RuntimeDevice': RuntimeSimulator()
}
if isinstance(get_device(), CPPStandaloneDevice):
if device.has_been_run is True:
get_device().reinit()
get_device().activate()
return simulators[get_device().__class__.__name__]
def get_sensitivity_equations(group, parameters, namespace=None, level=1,
optimize=True):
"""
Get equations for sensitivity variables.
Parameters
----------
group : `NeuronGroup`
The group of neurons that will be simulated.
parameters : list of str
Names of the parameters that are fit.
namespace : dict, optional
The namespace to use.
level : `int`, optional
How much farther to go down in the stack to find the namespace.
optimize : bool, optional
Whether to remove sensitivity variables from the equations that do
not evolve if initialized to zero (e.g. ``dS_x_y/dt = -S_x_y/tau``
would be removed). This avoids unnecessary computation but will fail
in the rare case that such a sensitivity variable needs to be
initialized to a non-zero value. Defaults to ``True``.
Returns
-------
sensitivity_eqs : `Equations`
The equations for the sensitivity variables.
"""
if namespace is None:
namespace = get_local_namespace(level)
namespace.update(group.namespace)
eqs = group.equations
diff_eqs = eqs.get_substituted_expressions(group.variables)
diff_eq_names = [name for name, _ in diff_eqs]
system = sympy.Matrix([str_to_sympy(diff_eq[1].code)
for diff_eq in diff_eqs])
J = system.jacobian([str_to_sympy(d) for d in diff_eq_names])
sensitivity = []
sensitivity_names = []
for parameter in parameters:
F = system.jacobian([str_to_sympy(parameter)])
names = [str_to_sympy(f'S_{diff_eq_name}_{parameter}')
for diff_eq_name in diff_eq_names]
sensitivity.append(J * sympy.Matrix(names) + F)
sensitivity_names.append(names)
new_eqs = []
for names, sensitivity_eqs, param in zip(sensitivity_names, sensitivity, parameters):
for name, eq, orig_var in zip(names, sensitivity_eqs, diff_eq_names):
if param in namespace:
unit = eqs[orig_var].dim / namespace[param].dim
elif param in group.variables:
unit = eqs[orig_var].dim / group.variables[param].dim
else:
raise AssertionError(f'Parameter {param} neither in namespace nor variables')
unit = repr(unit) if not unit.is_dimensionless else '1'
if optimize:
# Check if the equation stays at zero if initialized at zero
zeroed = eq.subs(name, sympy.S.Zero)
if zeroed == sympy.S.Zero:
# No need to include equation as differential equation
if unit == '1':
new_eqs.append(f'{sympy_to_str(name)} = 0 : {unit}')
else:
new_eqs.append(f'{sympy_to_str(name)} = 0*{unit} : {unit}')
continue
rhs = sympy_to_str(eq)
if rhs == '0': # avoid unit mismatch
rhs = f'0*{unit}/second'
new_eqs.append('d{lhs}/dt = {rhs} : {unit}'.format(lhs=sympy_to_str(name),
rhs=rhs,
unit=unit))
new_eqs = Equations('\n'.join(new_eqs))
return new_eqs
def get_sensitivity_init(group, parameters, param_init):
"""
Calculate the initial values for the sensitivity parameters (necessary if
initial values are functions of parameters).
Parameters
----------
group : `NeuronGroup`
The group of neurons that will be simulated.
parameters : list of str
Names of the parameters that are fit.
param_init : dict
The dictionary with expressions to initialize the model variables.
Returns
-------
sensitivity_init : dict
Dictionary of expressions to initialize the sensitivity
parameters.
"""
sensitivity_dict = {}
for var_name, expr in param_init.items():
if not isinstance(expr, str):
continue
identifiers = get_identifiers(expr)
for identifier in identifiers:
if (identifier in group.variables
and getattr(group.variables[identifier],
'type', None) == SUBEXPRESSION):
raise NotImplementedError('Initializations that refer to a '
'subexpression are currently not '
'supported')
sympy_expr = str_to_sympy(expr)
for parameter in parameters:
diffed = sympy_expr.diff(str_to_sympy(parameter))
if diffed != sympy.S.Zero:
if getattr(group.variables[parameter],
'type', None) == SUBEXPRESSION:
raise NotImplementedError('Sensitivity '
f'S_{var_name}_{parameter} '
'is initialized to a non-zero '
'value, but it has been '
'removed from the equations. '
'Set optimize=False to avoid '
'this.')
init_expr = sympy_to_str(diffed)
sensitivity_dict[f'S_{var_name}_{parameter}'] = init_expr
return sensitivity_dict
class Fitter(metaclass=abc.ABCMeta):
"""
Base Fitter class for model fitting applications.
Creates an interface for model fitting of traces with parameters draw by
gradient-free algorithms (through ask/tell interfaces).
Initiates n_neurons = num input traces * num samples, to which drawn
parameters get assigned and evaluates them in parallel.
Parameters
----------
dt : `~brian2.units.fundamentalunits.Quantity`
The size of the time step.
model : `~brian2.equations.equations.Equations` or str
The equations describing the model.
input : `~numpy.ndarray` or `~brian2.units.fundamentalunits.Quantity`
A 2D array of shape ``(n_traces, time steps)`` given the input that will
be fed into the model.
output : `~brian2.units.fundamentalunits.Quantity` or list
Recorded output of the model that the model should reproduce. Should
be a 2D array of the same shape as the input when fitting traces with
`TraceFitter`, a list of spike times when fitting spike trains with
`SpikeFitter`.
input_var : str
The name of the input variable in the model. Note that this variable
should be *used* in the model (e.g. a variable ``I`` that is added as
a current in the membrane potential equation), but not *defined*.
output_var : str
The name of the output variable in the model. Only needed when fitting
traces with `.TraceFitter`.
n_samples: int
Number of parameter samples to be optimized over in a single iteration.
threshold: `str`, optional
The condition which produces spikes. Should be a boolean expression as
a string.
reset: `str`, optional
The (possibly multi-line) string with the code to execute on reset.
refractory: `str` or `~brian2.units.fundamentalunits.Quantity`, optional
Either the length of the refractory period (e.g. 2*ms), a string
expression that evaluates to the length of the refractory period after
each spike (e.g. '(1 + rand())*ms'), or a string expression evaluating
to a boolean value, given the condition under which the neuron stays
refractory after a spike (e.g. 'v > -20*mV')
method: `str`, optional
Integration method
param_init: `dict`, optional
Dictionary of variables to be initialized with respective values
"""
def __init__(self, dt, model, input, output, input_var, output_var,
n_samples, threshold, reset, refractory, method, param_init,
use_units=True):
"""Initialize the fitter."""
if dt is None:
raise ValueError("dt-sampling frequency of the input must be set")
if isinstance(model, str):
model = Equations(model)
if input_var not in model.identifiers:
raise NameError("%s is not an identifier in the model" % input_var)
self.dt = dt
self.simulator = None
self.parameter_names = model.parameter_names
self.n_traces, n_steps = input.shape
self.duration = n_steps * dt
self.n_neurons = self.n_traces * n_samples
self.n_samples = n_samples
self.method = method
self.threshold = threshold
self.reset = reset
self.refractory = refractory
self.input = input
self.output_var = output_var
if output_var == 'spikes':
self.output_dim = DIMENSIONLESS
else:
self.output_dim = model[output_var].dim
fail_for_dimension_mismatch(output, self.output_dim,
'The provided target values '
'("output") need to have the same '
'units as the variable '
'{}'.format(output_var))
self.model = model
self.use_units = use_units
input_dim = get_dimensions(input)
input_dim = '1' if input_dim is DIMENSIONLESS else repr(input_dim)
input_eqs = "{} = input_var(t, i % n_traces) : {}".format(input_var,
input_dim)
self.model += input_eqs
if output_var != 'spikes':
# For approaches that couple the system to the target values,
# provide a convenient variable
output_expr = 'output_var(t, i % n_traces)'
output_dim = ('1' if self.output_dim is DIMENSIONLESS
else repr(self.output_dim))
output_eqs = "{}_target = {} : {}".format(output_var,
output_expr,
output_dim)
self.model += output_eqs
input_traces = TimedArray(input.transpose(), dt=dt)
self.input_traces = input_traces
# initialization of attributes used later
self._best_params = None
self._best_error = None
self.optimizer = None
self.metric = None
if not param_init:
param_init = {}
for param, val in param_init.items():
if not (param in self.model.diff_eq_names or
param in self.model.parameter_names):
raise ValueError("%s is not a model variable or a "
"parameter in the model" % param)
self.param_init = param_init
def setup_simulator(self, network_name, n_neurons, output_var, param_init,
calc_gradient=False, optimize=True, online_error=False,
level=1):
simulator = setup_fit()
namespace = get_full_namespace({'input_var': self.input_traces,
'n_traces': self.n_traces},
level=level+1)
if hasattr(self, 't_start'): # OnlineTraceFitter
namespace['t_start'] = self.t_start
if self.output_var != 'spikes':
namespace['output_var'] = TimedArray(self.output.transpose(),
dt=self.dt)
neurons = self.setup_neuron_group(n_neurons, namespace,
calc_gradient=calc_gradient,
optimize=optimize,
online_error=online_error)
if output_var == 'spikes':
monitor = SpikeMonitor(neurons, name='monitor')
else:
record_vars = [output_var]
if calc_gradient:
record_vars.extend([f'S_{output_var}_{p}'
for p in self.parameter_names])
monitor = StateMonitor(neurons, record_vars, record=True,
name='monitor', dt=self.dt)
network = Network(neurons, monitor)
if calc_gradient:
param_init = dict(param_init)
param_init.update(get_sensitivity_init(neurons, self.parameter_names,
param_init))
simulator.initialize(network, param_init, name=network_name)
return simulator
def setup_neuron_group(self, n_neurons, namespace, calc_gradient=False,
optimize=True, online_error=False, name='neurons'):
"""
Setup neuron group, initialize required number of neurons, create
namespace and initialize the parameters.
Parameters
----------
n_neurons: int
number of required neurons
**namespace
arguments to be added to NeuronGroup namespace
Returns
-------
neurons : ~brian2.groups.neurongroup.NeuronGroup
group of neurons
"""
# We only want to specify the method argument if it is not None –
# otherwise it should use NeuronGroup's default value
kwds = {}
if self.method is not None:
kwds['method'] = self.method
neurons = NeuronGroup(n_neurons, self.model,
threshold=self.threshold, reset=self.reset,
refractory=self.refractory, name=name,
namespace=namespace, dt=self.dt, **kwds)
if calc_gradient:
sensitivity_eqs = get_sensitivity_equations(neurons,
parameters=self.parameter_names,
optimize=optimize,
namespace=namespace)
neurons = NeuronGroup(n_neurons, self.model + sensitivity_eqs,
threshold=self.threshold, reset=self.reset,
refractory=self.refractory, name=name,
namespace=namespace, dt=self.dt, **kwds)
if online_error:
neurons.run_regularly('total_error += ({} - {}_target)**2 * '
'int(t>=t_start)'.format(self.output_var,
self.output_var),
when='end')
return neurons
@abc.abstractmethod
def calc_errors(self, metric):
"""
Abstract method required in all Fitter classes, used for
calculating errors
Parameters
----------
metric: `~.Metric` children
Child of Metric class, specifies optimization metric
"""
pass
def optimization_iter(self, optimizer, metric):
"""
Function performs all operations required for one iteration of
optimization. Drawing parameters, setting them to simulator and
calulating the error.
Returns
-------
results : list
recommended parameters
parameters: list of list
drawn parameters
errors: list
calculated errors
"""
parameters = optimizer.ask(n_samples=self.n_samples)
d_param = get_param_dic(parameters, self.parameter_names,
self.n_traces, self.n_samples)
self.simulator.run(self.duration, d_param, self.parameter_names)
errors = self.calc_errors(metric)
optimizer.tell(parameters, errors)
results = optimizer.recommend()
return results, parameters, errors
def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
restart=False, online_error=False, level=0, **params):
"""
Run the optimization algorithm for given amount of rounds with given
number of samples drawn. Return best set of parameters and
corresponding error.
Parameters
----------
optimizer: `~.Optimizer` children
Child of Optimizer class, specific for each library.
metric: `~.Metric` children
Child of Metric class, specifies optimization metric
n_rounds: int
Number of rounds to optimize over (feedback provided over each
round).
callback: `str` or `~typing.Callable`
Either the name of a provided callback function (``text`` or
``progressbar``), or a custom feedback function
``func(parameters, errors, best_parameters, best_error, index)``.
If this function returns ``True`` the fitting execution is
interrupted.
restart: bool
Flag that reinitializes the Fitter to reset the optimization.
With restart True user is allowed to change optimizer/metric.
online_error: bool, optional
Whether to calculate the squared error between target trace and
simulated trace online. Defaults to ``False``.
level : `int`, optional
How much farther to go down in the stack to find the namespace.
**params
bounds for each parameter
Returns
-------
best_results : dict
dictionary with best parameter set
error: float
error value for best parameter set
"""
if not (isinstance(metric, Metric) or metric is None):
raise TypeError("metric has to be a child of class Metric or None "
"for OnlineTraceFitter")
if not (isinstance(optimizer, Optimizer)) or optimizer is None:
raise TypeError("metric has to be a child of class Optimizer")
if self.metric is not None and restart is False:
if metric is not self.metric:
raise Exception("You can not change the metric between fits")
if self.optimizer is not None and restart is False:
if optimizer is not self.optimizer:
raise Exception("You can not change the optimizer between fits")
if self.optimizer is None or restart is True:
optimizer.initialize(self.parameter_names, popsize=self.n_samples,
**params)
self.optimizer = optimizer
self.metric = metric
callback = callback_setup(callback, n_rounds)
# Check whether we can reuse the current simulator or whether we have
# to create a new one (only relevant for standalone, but does not hurt
# for runtime)
if self.simulator is None or self.simulator.current_net != 'fit':
self.simulator = self.setup_simulator('fit', self.n_neurons,
output_var=self.output_var,
online_error=online_error,
param_init=self.param_init,
level=level+1)
# Run Optimization Loop
for index in range(n_rounds):
best_params, parameters, errors = self.optimization_iter(optimizer,
metric)
self._best_error = nanmin(self.optimizer.errors)
# create output variables
self._best_params = make_dic(self.parameter_names, best_params)
if self.use_units:
if self.output_var == 'spikes':
output_dim = DIMENSIONLESS
else:
output_dim = self.output_dim
# Correct the units for the normalization factor
error_dim = self.metric.get_normalized_dimensions(output_dim)
best_error = Quantity(float(self.best_error), dim=error_dim)
errors = Quantity(errors, dim=error_dim)
param_dicts = [{p: Quantity(v, dim=self.model[p].dim)
for p, v in zip(self.parameter_names,
one_param_set)}
for one_param_set in parameters]
else:
param_dicts = [{p: v for p, v in zip(self.parameter_names,
one_param_set)}
for one_param_set in parameters]
best_error = self.best_error
if callback(param_dicts,
errors,
self.best_params,
best_error,
index) is True:
break
return self.best_params, self.best_error
@property
def best_params(self):
if self._best_params is None:
return None
if self.use_units:
params_with_units = {p: Quantity(v, dim=self.model[p].dim)
for p, v in self._best_params.items()}
return params_with_units
else:
return self._best_params
@property
def best_error(self):
if self._best_error is None:
return None
if self.use_units:
error_dim = self.metric.get_dimensions(self.output_dim)
return Quantity(self._best_error, dim=error_dim)
else:
return self._best_error
def results(self, format='list', use_units=None):
"""
Returns all of the gathered results (parameters and errors).
In one of the 3 formats: 'dataframe', 'list', 'dict'.
Parameters
----------
format: str
The desired output format. Currently supported: ``dataframe``,
``list``, or ``dict``.
use_units: bool, optional
Whether to use units in the results. If not specified, defaults to
`.Tracefitter.use_units`, i.e. the value that was specified when
the `.Tracefitter` object was created (``True`` by default).
Returns
-------
object
'dataframe': returns pandas `~pandas.DataFrame` without units
'list': list of dictionaries
'dict': dictionary of lists
"""
if use_units is None:
use_units = self.use_units
names = list(self.parameter_names)
params = array(self.optimizer.tested_parameters)
params = params.reshape(-1, params.shape[-1])
if use_units:
error_dim = self.metric.get_dimensions(self.output_dim)
errors = Quantity(array(self.optimizer.errors).flatten(),
dim=error_dim)
else:
errors = array(array(self.optimizer.errors).flatten())
dim = self.model.dimensions
if format == 'list':
res_list = []
for j in arange(0, len(params)):
temp_data = params[j]
res_dict = dict()
for i, n in enumerate(names):
if use_units:
res_dict[n] = Quantity(temp_data[i], dim=dim[n])
else:
res_dict[n] = float(temp_data[i])
res_dict['error'] = errors[j]
res_list.append(res_dict)
return res_list
elif format == 'dict':
res_dict = dict()
for i, n in enumerate(names):
if use_units:
res_dict[n] = Quantity(params[:, i], dim=dim[n])
else:
res_dict[n] = array(params[:, i])
res_dict['error'] = errors
return res_dict
elif format == 'dataframe':
from pandas import DataFrame
if use_units:
logger.warn('Results in dataframes do not support units. '
'Specify "use_units=False" to avoid this warning.',
name_suffix='dataframe_units')
data = concatenate((params, array(errors)[None, :].transpose()), axis=1)
return DataFrame(data=data, columns=names + ['error'])
def generate(self, params=None, output_var=None, param_init=None, level=0):
"""
Generates traces for best fit of parameters and all inputs.
If provided with other parameters provides those.
Parameters
----------
params: dict
Dictionary of parameters to generate fits for.
output_var: str
Name of the output variable to be monitored, or the special name
``spikes`` to record spikes.
param_init: dict
Dictionary of initial values for the model.
level : `int`, optional
How much farther to go down in the stack to find the namespace.
Returns
-------
fit
Either a 2D `.Quantity` with the recorded output variable over time,
with shape <number of input traces> × <number of time steps>, or
a list of spike times for each input trace.
"""
if params is None:
params = self.best_params
if param_init is None:
param_init = self.param_init
else:
param_init = dict(self.param_init)
self.param_init.update(param_init)
if output_var is None:
output_var = self.output_var
self.simulator = self.setup_simulator('generate', self.n_traces,
output_var=output_var,
param_init=param_init,
level=level+1)
param_dic = get_param_dic([params[p] for p in self.parameter_names],
self.parameter_names, self.n_traces, 1)
self.simulator.run(self.duration, param_dic, self.parameter_names,
name='generate')
if output_var == 'spikes':
fits = get_spikes(self.simulator.monitor,
1, self.n_traces)[0] # a single "sample"
else:
fits = getattr(self.simulator.monitor, output_var)[:]
return fits
class TraceFitter(Fitter):
"""
A `Fitter` for fitting recorded traces (e.g. of the membrane potential).
Parameters
----------
model
input_var
input
output_var
output
dt
n_samples
method
reset
refractory
threshold
param_init
use_units: bool, optional
Whether to use units in all user-facing interfaces, e.g. in the callback
arguments or in the returned parameter dictionary and errors. Defaults
to ``True``.
"""
def __init__(self, model, input_var, input, output_var, output, dt,
n_samples=30, method=None, reset=None, refractory=False,
threshold=None, param_init=None, use_units=True):
super().__init__(dt, model, input, output, input_var, output_var,
n_samples, threshold, reset, refractory, method,
param_init, use_units=use_units)
self.output = Quantity(output)
self.output_ = array(output)
# We store the bounds set in TraceFitter.fit, so that Tracefitter.refine
# can reuse them
self.bounds = None
if output_var not in self.model.names:
raise NameError("%s is not a model variable" % output_var)
if output.shape != input.shape:
raise ValueError("Input and output must have the same size")
def calc_errors(self, metric):
"""
Returns errors after simulation with StateMonitor.
To be used inside `optim_iter`.
"""
traces = getattr(self.simulator.networks['fit']['monitor'],
self.output_var+'_')
# Reshape traces for easier calculation of error
traces = reshape(traces, (traces.shape[0]//self.n_traces,
self.n_traces,
-1))
errors = metric.calc(traces, self.output_, self.dt)
return errors
def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
restart=False, level=0, **params):
if not isinstance(metric, TraceMetric):
raise TypeError("You can only use TraceMetric child metric with "
"TraceFitter")
self.bounds = dict(params)
best_params, error = super().fit(optimizer, metric, n_rounds,
callback, restart,
level=level+1,
**params)
return best_params, error
def generate_traces(self, params=None, param_init=None, level=0):
"""Generates traces for best fit of parameters and all inputs"""
fits = self.generate(params=params, output_var=self.output_var,
param_init=param_init, level=level+1)
return fits
def refine(self, params=None, t_start=None, normalization=None,
callback='text', calc_gradient=False, optimize=True,
level=0, **kwds):
"""
Refine the fitting results with a sequentially operating minimization
algorithm. Uses the `lmfit <https://lmfit.github.io/lmfit-py/>`_
package which itself makes use of
`scipy.optimize <https://docs.scipy.org/doc/scipy/reference/optimize.html>`_.
Has to be called after `~.TraceFitter.fit`, but a call with
``n_rounds=0`` is enough.
Parameters
----------
params : dict, optional
A dictionary with the parameters to use as a starting point for the
refinement. If not given, the best parameters found so far by
`~.TraceFitter.fit` will be used.
t_start : `~brian2.units.fundamentalunits.Quantity`, optional
Initial simulation/model time that should be ignored for the error
calculation. If not set, will reuse the `t_start` value from the
previously used metric.
normalization : float, optional
A normalization term that will be used rescale results before
handing them to the optimization algorithm. Can be useful if the
algorithm makes assumptions about the scale of errors, e.g. if the
size of steps in the parameter space depends on the absolute value
of the error. The difference between simulated and target traces
will be divided by this value. If not set, will reuse the
`normalization` value from the previously used metric.
callback: `str` or `~typing.Callable`
Either the name of a provided callback function (``text`` or
``progressbar``), or a custom feedback function
``func(parameters, errors, best_parameters, best_error, index)``.
If this function returns ``True`` the fitting execution is
interrupted.
calc_gradient: bool, optional
Whether to add "sensitivity variables" to the equation that track
the sensitivity of the equation variables to the parameters. This
information will be used to pass the local gradient of the error
with respect to the parameters to the optimization function. This
can lead to much faster convergence than with an estimated gradient
but comes at the expense of additional computation. Defaults to
``False``.
optimize : bool, optional
Whether to remove sensitivity variables from the equations that do
not evolve if initialized to zero (e.g. ``dS_x_y/dt = -S_x_y/tau``
would be removed). This avoids unnecessary computation but will fail
in the rare case that such a sensitivity variable needs to be
initialized to a non-zero value. Only taken into account if
``calc_gradient`` is ``True``. Defaults to ``True``.
level : int, optional
How much farther to go down in the stack to find the namespace.
kwds
Additional arguments can overwrite the bounds for individual
parameters (if not given, the bounds previously specified in the
call to `~.TraceFitter.fit` will be used). All other arguments will
be passed on to `.lmfit.minimize` and can be used to e.g. change the
method, or to specify method-specific arguments.
Returns
-------
parameters : dict
The parameters at the end of the optimization process as a
dictionary.
result : `.lmfit.MinimizerResult`
The result of the optimization process.
Notes
-----
The default method used by `lmfit` is least-squares minimization using
a Levenberg-Marquardt method. Note that there is no support for
specifying a `Metric`, the given output trace(s) will be subtracted
from the simulated trace(s) and passed on to the minimization algorithm.
This method always uses the runtime mode, independent of the selection
of the current device.
"""
try:
import lmfit
except ImportError:
raise ImportError('Refinement needs the "lmfit" package.')
if params is None:
if self.best_params is None:
raise TypeError('You need to either specify parameters or run '
'the fit function first.')
params = self.best_params
if t_start is None:
t_start = getattr(self.metric, 't_start', 0*second)
if normalization is None:
normalization = getattr(self.metric, 'normalization', 1.)
else:
normalization = 1/normalization
callback_func = callback_setup(callback, None)
# Set up Parameter objects
parameters = lmfit.Parameters()
for param_name in self.parameter_names:
if param_name not in kwds:
if self.bounds is None:
raise TypeError('You need to either specify bounds for all '
'parameters or run the fit function first.')
min_bound, max_bound = self.bounds[param_name]
else:
min_bound, max_bound = kwds.pop(param_name)
parameters.add(param_name, value=array(params[param_name]),
min=array(min_bound), max=array(max_bound))
self.simulator = self.setup_simulator('refine', self.n_traces,
output_var=self.output_var,
param_init=self.param_init,
calc_gradient=calc_gradient,
optimize=optimize,
level=level+1)
t_start_steps = int(round(t_start / self.dt))
def _calc_error(params):
param_dic = get_param_dic([params[p] for p in self.parameter_names],
self.parameter_names, self.n_traces, 1)
self.simulator.run(self.duration, param_dic,
self.parameter_names, name='refine')
trace = getattr(self.simulator.monitor, self.output_var+'_')
residual = trace[:, t_start_steps:] - self.output_[:, t_start_steps:]
return residual.flatten() * normalization
def _calc_gradient(params):
residuals = []
for name in self.parameter_names:
trace = getattr(self.simulator.monitor,
f'S_{self.output_var}_{name}_')
residual = trace[:, t_start_steps:]
residuals.append(residual.flatten() * normalization)
gradient = array(residuals)
return gradient.T
tested_parameters = []
errors = []
def _callback_wrapper(params, iter, resid, *args, **kwds):
error = mean(resid**2)
errors.append(error)
if self.use_units:
error_dim = self.output_dim**2 * get_dimensions(normalization)**2
all_errors = Quantity(errors, dim=error_dim)
params = {p: Quantity(val, dim=self.model[p].dim)
for p, val in params.items()}
else:
all_errors = array(errors)
params = {p: float(val) for p, val in params.items()}
tested_parameters.append(params)
best_idx = argmin(errors)
best_error = all_errors[best_idx]
best_params = tested_parameters[best_idx]
return callback_func(params, all_errors,
best_params, best_error, iter)
assert 'Dfun' not in kwds
if calc_gradient:
kwds.update({'Dfun': _calc_gradient})
if 'iter_cb' in kwds:
# Use the given callback but raise a warning if callback is not
# set to None
if callback is not None:
logger.warn('The iter_cb keyword has been specified together '
f'with callback={callback!r}. Only the iter_cb '
'callback will be used. Use the standard '
'callback mechanism or set callback=None to '
'remove this warning.',
name_suffix='iter_cb_callback')
iter_cb = kwds.pop('iter_cb')
else:
iter_cb = _callback_wrapper
result = lmfit.minimize(_calc_error, parameters,
iter_cb=iter_cb,
**kwds)
if self.use_units:
param_dict = {p: Quantity(float(val), dim=self.model[p].dim)
for p, val in result.params.items()}
else:
param_dict = {p: float(val)
for p, val in result.params.items()}
return param_dict, result
class SpikeFitter(Fitter):
def __init__(self, model, input, output, dt, reset, threshold,
input_var='I', refractory=False, n_samples=30,
method=None, param_init=None, use_units=True):
"""Initialize the fitter."""
if method is None:
method = 'exponential_euler'
super().__init__(dt, model, input, output, input_var, 'spikes',
n_samples, threshold, reset, refractory, method,
param_init, use_units=use_units)
self.output = [Quantity(o) for o in output]
self.output_ = [array(o) for o in output]
if param_init:
for param, val in param_init.items():
if not (param in self.model.identifiers or param in self.model.names):
raise ValueError("%s is not a model variable or an "