diff --git a/brainpy/__init__.py b/brainpy/__init__.py index 183e1b19e..eb31a5f3a 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- -__version__ = "2.2.3.5" - +__version__ = "2.2.3.6" try: import jaxlib + del jaxlib except ModuleNotFoundError: raise ModuleNotFoundError( @@ -34,21 +34,17 @@ ''') from None - # fundamental modules from . import errors, tools, check, modes - # "base" module from . import base from .base.base import Base from .base.collector import Collector, TensorCollector - # math foundation from . import math - # toolboxes from . import ( connect, # synaptic connection @@ -61,7 +57,6 @@ algorithms, # online or offline training algorithms ) - # numerical integrators from . import integrators from .integrators import ode @@ -72,7 +67,6 @@ from .integrators.fde import fdeint from .integrators.joint_eq import JointEq - # dynamics simulation from . import dyn from .dyn import ( @@ -82,10 +76,10 @@ neurons, # neuron groups rates, # rate models synapses, # synaptic dynamics - synouts, # synaptic output + synouts, # synaptic output synplast, # synaptic plasticity ) -from brainpy.dyn.base import ( +from .dyn.base import ( DynamicalSystem, Container, Sequential, @@ -101,23 +95,33 @@ ) from .dyn.runners import * - # dynamics training from . import train - +from .train import ( + DSTrainer, + OnlineTrainer, ForceTrainer, + OfflineTrainer, RidgeTrainer, + BPFF, + BPTT, + OnlineBPTT, +) # automatic dynamics analysis from . import analysis - +from .analysis import ( + DSAnalyzer, + PhasePlane1D, PhasePlane2D, + Bifurcation1D, Bifurcation2D, + FastSlow1D, FastSlow2D, + SlowPointFinder, +) # running from . import running - # "visualization" module, will be removed soon from .visualization import visualize - # convenient access conn = connect init = initialize diff --git a/brainpy/analysis/lowdim/lowdim_analyzer.py b/brainpy/analysis/lowdim/lowdim_analyzer.py index 5ec76ca30..3adff83ef 100644 --- a/brainpy/analysis/lowdim/lowdim_analyzer.py +++ b/brainpy/analysis/lowdim/lowdim_analyzer.py @@ -137,12 +137,12 @@ def __init__( target_pars = dict() if not isinstance(target_pars, dict): raise errors.AnalyzerError('"target_pars" must be a dict with the format of {"par1": (val1, val2)}.') - for key in target_pars.keys(): + for key, value in target_pars.items(): if key not in self.model.parameters: raise errors.AnalyzerError(f'"{key}" is not a valid parameter in "{self.model}" model.') - value = self.target_vars[key] if value[0] > value[1]: - raise errors.AnalyzerError(f'The range of parameter {key} is reversed, which means {value[0]} should be smaller than {value[1]}.') + raise errors.AnalyzerError( + f'The range of parameter {key} is reversed, which means {value[0]} should be smaller than {value[1]}.') self.target_pars = Collector(target_pars) self.target_par_names = list(self.target_pars.keys()) # list of target_pars diff --git a/brainpy/analysis/lowdim/lowdim_bifurcation.py b/brainpy/analysis/lowdim/lowdim_bifurcation.py index e8431d1bd..abe78d917 100644 --- a/brainpy/analysis/lowdim/lowdim_bifurcation.py +++ b/brainpy/analysis/lowdim/lowdim_bifurcation.py @@ -5,6 +5,7 @@ import jax.numpy as jnp from jax import vmap import numpy as np +from copy import deepcopy import brainpy.math as bm from brainpy import errors @@ -79,7 +80,7 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False, pyplot.figure(self.x_var) for fp_type, points in container.items(): if len(points['x']): - plot_style = plotstyle.plot_schema[fp_type] + plot_style = deepcopy(plotstyle.plot_schema[fp_type]) pyplot.plot(points['p'], points['x'], **plot_style, label=fp_type) pyplot.xlabel(self.target_par_names[0]) pyplot.ylabel(self.x_var) @@ -107,11 +108,12 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False, ax = fig.add_subplot(projection='3d') for fp_type, points in container.items(): if len(points['x']): - plot_style = plotstyle.plot_schema[fp_type] + plot_style = deepcopy(plotstyle.plot_schema[fp_type]) xs = points['p0'] ys = points['p1'] zs = points['x'] plot_style.pop('linestyle') + plot_style['s'] = plot_style.pop('markersize', None) ax.scatter(xs, ys, zs, **plot_style, label=fp_type) ax.set_xlabel(self.target_par_names[0]) @@ -299,7 +301,7 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False, pyplot.figure(var) for fp_type, points in container.items(): if len(points['p']): - plot_style = plotstyle.plot_schema[fp_type] + plot_style = deepcopy(plotstyle.plot_schema[fp_type]) pyplot.plot(points['p'], points[var], **plot_style, label=fp_type) pyplot.xlabel(self.target_par_names[0]) pyplot.ylabel(var) @@ -331,11 +333,12 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False, ax = fig.add_subplot(projection='3d') for fp_type, points in container.items(): if len(points['p0']): - plot_style = plotstyle.plot_schema[fp_type] + plot_style = deepcopy(plotstyle.plot_schema[fp_type]) xs = points['p0'] ys = points['p1'] zs = points[var] plot_style.pop('linestyle') + plot_style['s'] = plot_style.pop('markersize', None) ax.scatter(xs, ys, zs, **plot_style, label=fp_type) ax.set_xlabel(self.target_par_names[0]) diff --git a/brainpy/analysis/lowdim/lowdim_phase_plane.py b/brainpy/analysis/lowdim/lowdim_phase_plane.py index 05259b4e6..7995301a0 100644 --- a/brainpy/analysis/lowdim/lowdim_phase_plane.py +++ b/brainpy/analysis/lowdim/lowdim_phase_plane.py @@ -4,6 +4,7 @@ import numpy as np from jax import vmap +from copy import deepcopy import brainpy.math as bm from brainpy import errors, math from brainpy.analysis import stability, plotstyle, constants as C, utils @@ -107,7 +108,7 @@ def plot_fixed_point(self, show=False, with_plot=True, with_return=False): if with_plot: for fp_type, points in container.items(): if len(points): - plot_style = plotstyle.plot_schema[fp_type] + plot_style = deepcopy(plotstyle.plot_schema[fp_type]) pyplot.plot(points, [0] * len(points), **plot_style, label=fp_type) pyplot.legend() if show: @@ -349,7 +350,7 @@ def plot_fixed_point(self, with_plot=True, with_return=False, show=False, if with_plot: for fp_type, points in container.items(): if len(points['x']): - plot_style = plotstyle.plot_schema[fp_type] + plot_style = deepcopy(plotstyle.plot_schema[fp_type]) pyplot.plot(points['x'], points['y'], **plot_style, label=fp_type) pyplot.legend() if show: diff --git a/brainpy/analysis/lowdim/tests/test_bifurcation.py b/brainpy/analysis/lowdim/tests/test_bifurcation.py new file mode 100644 index 000000000..061038e02 --- /dev/null +++ b/brainpy/analysis/lowdim/tests/test_bifurcation.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- + + +import pytest +pytest.skip('Test cannot pass in github action.', allow_module_level=True) +import unittest + +import brainpy as bp +import brainpy.math as bm +import matplotlib.pyplot as plt + +block = False + + +class FitzHughNagumoModel(bp.dyn.DynamicalSystem): + def __init__(self, method='exp_auto'): + super(FitzHughNagumoModel, self).__init__() + + # parameters + self.a = 0.7 + self.b = 0.8 + self.tau = 12.5 + + # variables + self.V = bm.Variable(bm.zeros(1)) + self.w = bm.Variable(bm.zeros(1)) + self.Iext = bm.Variable(bm.zeros(1)) + + # functions + def dV(V, t, w, Iext=0.): + dV = V - V * V * V / 3 - w + Iext + return dV + + def dw(w, t, V, a=0.7, b=0.8): + dw = (V + a - b * w) / self.tau + return dw + + self.int_V = bp.odeint(dV, method=method) + self.int_w = bp.odeint(dw, method=method) + + def update(self, tdi): + t, dt = tdi['t'], tdi['dt'] + self.V.value = self.int_V(self.V, t, self.w, self.Iext, dt) + self.w.value = self.int_w(self.w, t, self.V, self.a, self.b, dt) + self.Iext[:] = 0. + + +class TestBifurcation1D(unittest.TestCase): + def test_bifurcation_1d(self): + bp.math.enable_x64() + + @bp.odeint + def int_x(x, t, a=1., b=1.): + return bp.math.sin(a * x) + bp.math.cos(b * x) + + pp = bp.analysis.PhasePlane1D( + model=int_x, + target_vars={'x': [-bp.math.pi, bp.math.pi]}, + resolutions=0.1 + ) + pp.plot_vector_field() + pp.plot_fixed_point(show=True) + + bf = bp.analysis.Bifurcation1D( + model=int_x, + target_vars={'x': [-bp.math.pi, bp.math.pi]}, + target_pars={'a': [0.5, 1.5], 'b': [0.5, 1.5]}, + resolutions={'a': 0.1, 'b': 0.1} + ) + bf.plot_bifurcation(show=False) + plt.show(block=block) + plt.close() + bp.math.disable_x64() + + def test_bifurcation_2d(self): + bp.math.enable_x64() + + model = FitzHughNagumoModel() + bif = bp.analysis.Bifurcation2D( + model=model, + target_vars={'V': [-3., 3.], 'w': [-1, 3.]}, + target_pars={'Iext': [0., 1.]}, + resolutions={'Iext': 0.1} + ) + bif.plot_bifurcation() + bif.plot_limit_cycle_by_sim() + plt.show(block=block) + + # bp.math.disable_x64() diff --git a/brainpy/analysis/lowdim/tests/test_phase_plane.py b/brainpy/analysis/lowdim/tests/test_phase_plane.py index f93c0bc4d..8534085c8 100644 --- a/brainpy/analysis/lowdim/tests/test_phase_plane.py +++ b/brainpy/analysis/lowdim/tests/test_phase_plane.py @@ -3,13 +3,13 @@ import unittest import brainpy as bp +import matplotlib.pyplot as plt block = False class TestPhasePlane(unittest.TestCase): def test_1d(self): - import matplotlib.pyplot as plt bp.math.enable_x64() @bp.odeint @@ -30,8 +30,6 @@ def int_x(x, t, Iext): bp.math.disable_x64() def test_2d_decision_making_model(self): - import matplotlib.pyplot as plt - bp.math.enable_x64() gamma = 0.641 # Saturation factor for gating variable tau = 0.06 # Synaptic time constant [sec] diff --git a/brainpy/analysis/plotstyle.py b/brainpy/analysis/plotstyle.py index 50c568a15..3a81735c1 100644 --- a/brainpy/analysis/plotstyle.py +++ b/brainpy/analysis/plotstyle.py @@ -16,7 +16,7 @@ UNSTABLE_FOCUS_3D, UNSTABLE_CENTER_3D, UNKNOWN_3D) -_markersize = 20 +_markersize = 10 plot_schema = {} diff --git a/brainpy/dyn/base.py b/brainpy/dyn/base.py index ceeda4bd9..afc352173 100644 --- a/brainpy/dyn/base.py +++ b/brainpy/dyn/base.py @@ -182,7 +182,8 @@ def register_delay( elif delay.num_delay_step - 1 < max_delay_step: self.global_delay_data[identifier][0].reset(delay_target, max_delay_step, initial_delay_data) else: - self.global_delay_data[identifier] = (None, delay_target) + if identifier not in self.global_delay_data: + self.global_delay_data[identifier] = (None, delay_target) self.register_implicit_nodes(self.local_delay_vars) return delay_step diff --git a/brainpy/dyn/tests/test_base_classes.py b/brainpy/dyn/tests/test_base_classes.py new file mode 100644 index 000000000..9c095a30e --- /dev/null +++ b/brainpy/dyn/tests/test_base_classes.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- + +import unittest + +import brainpy as bp + + +class TestDynamicalSystem(unittest.TestCase): + def test_delay(self): + A = bp.neurons.LIF(1) + B = bp.neurons.LIF(1) + C = bp.neurons.LIF(1) + A2B = bp.synapses.Exponential(A, B, bp.conn.All2All(), delay_step=1) + A2C = bp.synapses.Exponential(A, C, bp.conn.All2All(), delay_step=None) + net = bp.Network(A, B, C, A2B, A2C) + + runner = bp.DSRunner(net,) + runner.run(10.) + +