Skip to content

Commit

Permalink
Merge pull request #42
Browse files Browse the repository at this point in the history
Fix: Fitter exception message #41
  • Loading branch information
mstimberg authored Apr 19, 2021
2 parents a6a41e8 + 0a0be47 commit 8a9d06e
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 9 deletions.
19 changes: 14 additions & 5 deletions brian2modelfitting/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@

from brian2 import (NeuronGroup, defaultclock, get_device, Network,
StateMonitor, SpikeMonitor, second, get_local_namespace,
Quantity, get_logger)
Quantity, get_logger, ms)
from brian2.input import TimedArray
from brian2.equations.equations import Equations, SUBEXPRESSION
from brian2.devices import set_device, reset_device, device
from brian2.devices.cpp_standalone.device import CPPStandaloneDevice
from brian2.core.functions import Function
from .simulator import RuntimeSimulator, CPPStandaloneSimulator
from .metric import Metric, SpikeMetric, TraceMetric, MSEMetric, normalize_weights
from .optimizer import Optimizer
from .metric import Metric, SpikeMetric, TraceMetric, MSEMetric, GammaFactor, normalize_weights
from .optimizer import Optimizer, NevergradOptimizer
from .utils import callback_setup, make_dic


Expand Down Expand Up @@ -560,8 +560,8 @@ def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
raise TypeError("metric has to be a child of class Metric or None "
"for OnlineTraceFitter")

if not (isinstance(optimizer, Optimizer)) or optimizer is None:
raise TypeError("metric has to be a child of class Optimizer")
if not (isinstance(optimizer, Optimizer) or optimizer is None):
raise TypeError("optimizer has to be a child of class Optimizer or None")

if self.metric is not None and restart is False:
if metric is not self.metric:
Expand All @@ -577,6 +577,9 @@ def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
if penalty is None:
penalty = self.penalty

if optimizer is None:
optimizer = NevergradOptimizer()

if self.optimizer is None or restart:
if start_iteration is None:
self.iteration = 0
Expand Down Expand Up @@ -862,6 +865,9 @@ def calc_errors(self, metric):
def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
restart=False, start_iteration=None, penalty=None,
level=0, **params):
if metric is None:
metric = MSEMetric()

if not isinstance(metric, TraceMetric):
raise TypeError("You can only use TraceMetric child metric with "
"TraceFitter")
Expand Down Expand Up @@ -1149,6 +1155,9 @@ def calc_errors(self, metric):
def fit(self, optimizer, metric=None, n_rounds=1, callback='text',
restart=False, start_iteration=None, penalty=None,
level=0, **params):
if metric is None:
metric = GammaFactor(delta=2*ms, time=self.duration)

if not isinstance(metric, SpikeMetric):
raise TypeError("You can only use SpikeMetric child metric with "
"SpikeFitter")
Expand Down
45 changes: 42 additions & 3 deletions brian2modelfitting/tests/test_modelfitting_spikefitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,19 +144,58 @@ def test_spikefitter_fit(setup):

def test_spikefitter_fit_errors(setup):
dt, sf = setup
class NaiveOptimizer:
def __init__(self):
self.best = []
naive_opt = NaiveOptimizer()
with pytest.raises(TypeError):
results, errors = sf.fit(n_rounds=2,
optimizer=n_opt,
metric=MSEMetric(),
metric=MSEMetric(), #testing a wrong metric
gL=[20*nS, 40*nS],
C=[0.5*nF, 1.5*nF])
with pytest.raises(TypeError):
results, errors = sf.fit(n_rounds=2,
optimizer=None,
metric=MSEMetric(),
optimizer=naive_opt, #testing a Non-Optimizer child
metric=metric,
gL=[20*nS, 40*nS],
C=[0.5*nF, 1.5*nF])

def test_fitter_fit_default_optimizer(setup):
dt, sf = setup
results, errors = sf.fit(n_rounds=2,
optimizer=None,
metric=metric,
gL=[20*nS, 40*nS],
C=[0.5*nF, 1.5*nF])
assert sf.simulator.neurons.iteration == 1
attr_fit = ['optimizer', 'metric', 'best_params']
for attr in attr_fit:
assert hasattr(sf, attr)

assert isinstance(sf.optimizer, NevergradOptimizer) #default optimizer
assert isinstance(sf.simulator, Simulator)

assert_equal(results, sf.best_params)
assert_equal(errors, sf.best_error)


def test_spikefitter_fit_default_metric(setup):
dt, sf = setup
results, errors = sf.fit(n_rounds=2,
optimizer=n_opt,
metric=None,
gL=[20*nS, 40*nS],
C=[0.5*nF, 1.5*nF])
assert sf.simulator.neurons.iteration == 1
attr_fit = ['optimizer', 'metric', 'best_params']
for attr in attr_fit:
assert hasattr(sf, attr)
assert isinstance(sf.metric, GammaFactor) #default spike metric
assert isinstance(sf.simulator, Simulator)

assert_equal(results, sf.best_params)
assert_equal(errors, sf.best_error)


def test_spikefitter_param_init(setup):
Expand Down
44 changes: 43 additions & 1 deletion brian2modelfitting/tests/test_modelfitting_tracefitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,25 @@ def test_tracefitter_init_errors(setup):
output=np.array(output_traces), # no units
n_samples=2)


def test_tracefitter_fit_default_metric(setup):
dt, tf = setup
results, errors = tf.fit(n_rounds=2,
optimizer=n_opt,
metric=None,
g=[1*nS, 30*nS],
callback=None)
assert tf.simulator.neurons.iteration == 1
attr_fit = ['optimizer', 'metric', 'best_params']
for attr in attr_fit:
assert hasattr(tf, attr)
assert isinstance(tf.metric, MSEMetric) #default trace metric
assert isinstance(tf.simulator, Simulator)

assert_equal(results, tf.best_params)
assert_equal(errors, tf.best_error)


from nevergrad.optimization import registry as nevergrad_registry
@pytest.mark.parametrize('method', sorted(nevergrad_registry.keys()))
def test_fitter_fit_methods(method):
Expand Down Expand Up @@ -341,6 +360,25 @@ def test_fitter_fit_no_units(setup_no_units):
assert_equal(errors, tf.best_error)


def test_fitter_fit_default_optimizer(setup):
dt, tf = setup
results, errors = tf.fit(n_rounds=2,
optimizer=None,
metric=metric,
g=[1*nS, 30*nS],
callback=None)
assert tf.simulator.neurons.iteration == 1
attr_fit = ['optimizer', 'metric', 'best_params']
for attr in attr_fit:
assert hasattr(tf, attr)

assert isinstance(tf.optimizer, NevergradOptimizer) #default optimizer
assert isinstance(tf.simulator, Simulator)

assert_equal(results, tf.best_params)
assert_equal(errors, tf.best_error)


def test_fitter_fit_callback(setup):
dt, tf = setup

Expand Down Expand Up @@ -381,9 +419,13 @@ def our_callback(params, errors, best_params, best_error, index):

def test_fitter_fit_errors(setup):
dt, tf = setup
class NaiiveOptimizer:
def __init__(self):
self.best = []
opt = NaiiveOptimizer()
with pytest.raises(TypeError):
tf.fit(n_rounds=2,
optimizer=None,
optimizer=opt, #testing a Non-Optimizer child
metric=metric,
g=[1*nS, 30*nS])

Expand Down

0 comments on commit 8a9d06e

Please sign in to comment.