Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions brainpy/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,4 @@
from .lowdim.lowdim_bifurcation import *

from .constants import *
from . import constants as C
from . import stability
from . import utils
from . import constants as C, stability, plotstyle, utils
16 changes: 9 additions & 7 deletions brainpy/analysis/lowdim/lowdim_bifurcation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import brainpy.math as bm
from brainpy import errors
from brainpy.analysis import stability, utils, constants as C
from brainpy.analysis import stability, plotstyle, utils, constants as C
from brainpy.analysis.lowdim.lowdim_analyzer import *

pyplot = None
Expand Down Expand Up @@ -79,8 +79,8 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
pyplot.figure(self.x_var)
for fp_type, points in container.items():
if len(points['x']):
plot_style = stability.plot_scheme[fp_type]
pyplot.plot(points['p'], points['x'], '.', **plot_style, label=fp_type)
plot_style = plotstyle.plot_schema[fp_type]
pyplot.plot(points['p'], points['x'], **plot_style, label=fp_type)
pyplot.xlabel(self.target_par_names[0])
pyplot.ylabel(self.x_var)

Expand All @@ -107,10 +107,11 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
ax = fig.add_subplot(projection='3d')
for fp_type, points in container.items():
if len(points['x']):
plot_style = stability.plot_scheme[fp_type]
plot_style = plotstyle.plot_schema[fp_type]
xs = points['p0']
ys = points['p1']
zs = points['x']
plot_style.pop('linestyle')
ax.scatter(xs, ys, zs, **plot_style, label=fp_type)

ax.set_xlabel(self.target_par_names[0])
Expand Down Expand Up @@ -298,8 +299,8 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
pyplot.figure(var)
for fp_type, points in container.items():
if len(points['p']):
plot_style = stability.plot_scheme[fp_type]
pyplot.plot(points['p'], points[var], '.', **plot_style, label=fp_type)
plot_style = plotstyle.plot_schema[fp_type]
pyplot.plot(points['p'], points[var], **plot_style, label=fp_type)
pyplot.xlabel(self.target_par_names[0])
pyplot.ylabel(var)

Expand Down Expand Up @@ -330,10 +331,11 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
ax = fig.add_subplot(projection='3d')
for fp_type, points in container.items():
if len(points['p0']):
plot_style = stability.plot_scheme[fp_type]
plot_style = plotstyle.plot_schema[fp_type]
xs = points['p0']
ys = points['p1']
zs = points[var]
plot_style.pop('linestyle')
ax.scatter(xs, ys, zs, **plot_style, label=fp_type)

ax.set_xlabel(self.target_par_names[0])
Expand Down
22 changes: 11 additions & 11 deletions brainpy/analysis/lowdim/lowdim_phase_plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import brainpy.math as bm
from brainpy import errors, math
from brainpy.analysis import stability, constants as C, utils
from brainpy.analysis import stability, plotstyle, constants as C, utils
from brainpy.analysis.lowdim.lowdim_analyzer import *

pyplot = None
Expand Down Expand Up @@ -107,8 +107,8 @@ def plot_fixed_point(self, show=False, with_plot=True, with_return=False):
if with_plot:
for fp_type, points in container.items():
if len(points):
plot_style = stability.plot_scheme[fp_type]
pyplot.plot(points, [0] * len(points), '.', markersize=20, **plot_style, label=fp_type)
plot_style = plotstyle.plot_schema[fp_type]
pyplot.plot(points, [0] * len(points), **plot_style, label=fp_type)
pyplot.legend()
if show:
pyplot.show()
Expand Down Expand Up @@ -248,9 +248,9 @@ def plot_nullcline(self, with_plot=True, with_return=False,

if with_plot:
if x_style is None:
x_style = dict(color='cornflowerblue', alpha=.7, )
fmt = x_style.pop('fmt', '.')
pyplot.plot(x_values_in_fx, y_values_in_fx, fmt, **x_style, label=f"{self.x_var} nullcline")
x_style = dict(color='cornflowerblue', alpha=.7, fmt='.')
line_args = (x_style.pop('fmt'), ) if 'fmt' in x_style else tuple()
pyplot.plot(x_values_in_fx, y_values_in_fx, *line_args, **x_style, label=f"{self.x_var} nullcline")

# Nullcline of the y variable
utils.output('I am computing fy-nullcline ...')
Expand All @@ -260,9 +260,9 @@ def plot_nullcline(self, with_plot=True, with_return=False,

if with_plot:
if y_style is None:
y_style = dict(color='lightcoral', alpha=.7, )
fmt = y_style.pop('fmt', '.')
pyplot.plot(x_values_in_fy, y_values_in_fy, fmt, **y_style, label=f"{self.y_var} nullcline")
y_style = dict(color='lightcoral', alpha=.7, fmt='.')
line_args = (y_style.pop('fmt'), ) if 'fmt' in y_style else tuple()
pyplot.plot(x_values_in_fy, y_values_in_fy, *line_args, **y_style, label=f"{self.y_var} nullcline")

if with_plot:
pyplot.xlabel(self.x_var)
Expand Down Expand Up @@ -349,8 +349,8 @@ def plot_fixed_point(self, with_plot=True, with_return=False, show=False,
if with_plot:
for fp_type, points in container.items():
if len(points['x']):
plot_style = stability.plot_scheme[fp_type]
pyplot.plot(points['x'], points['y'], '.', markersize=20, **plot_style, label=fp_type)
plot_style = plotstyle.plot_schema[fp_type]
pyplot.plot(points['x'], points['y'], **plot_style, label=fp_type)
pyplot.legend()
if show:
pyplot.show()
Expand Down
72 changes: 72 additions & 0 deletions brainpy/analysis/plotstyle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-


__all__ = [
'plot_schema',
'set_plot_schema',
]

from .stability import (CENTER_MANIFOLD, SADDLE_NODE, STABLE_POINT_1D,
UNSTABLE_POINT_1D, CENTER_2D, STABLE_NODE_2D,
STABLE_FOCUS_2D, STABLE_STAR_2D, STABLE_DEGENERATE_2D,
UNSTABLE_NODE_2D, UNSTABLE_FOCUS_2D, UNSTABLE_STAR_2D,
UNSTABLE_DEGENERATE_2D, UNSTABLE_LINE_2D,
STABLE_POINT_3D, UNSTABLE_POINT_3D, STABLE_NODE_3D,
UNSTABLE_SADDLE_3D, UNSTABLE_NODE_3D, STABLE_FOCUS_3D,
UNSTABLE_FOCUS_3D, UNSTABLE_CENTER_3D, UNKNOWN_3D)


_markersize = 20

plot_schema = {}

plot_schema[CENTER_MANIFOLD] = {'color': 'orangered', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}
plot_schema[SADDLE_NODE] = {"color": 'tab:blue', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}

plot_schema[STABLE_POINT_1D] = {"color": 'tab:red', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}
plot_schema[UNSTABLE_POINT_1D] = {"color": 'tab:olive', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'}

plot_schema.update({
CENTER_2D: {'color': 'lime', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
STABLE_NODE_2D: {"color": 'tab:red', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
STABLE_FOCUS_2D: {"color": 'tab:purple', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
STABLE_STAR_2D: {'color': 'tab:olive', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
STABLE_DEGENERATE_2D: {'color': 'blueviolet', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_NODE_2D: {"color": 'tab:orange', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_FOCUS_2D: {"color": 'tab:cyan', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_STAR_2D: {'color': 'green', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_DEGENERATE_2D: {'color': 'springgreen', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_LINE_2D: {'color': 'dodgerblue', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
})


plot_schema.update({
STABLE_POINT_3D: {'color': 'tab:gray', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_POINT_3D: {'color': 'tab:purple', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
STABLE_NODE_3D: {'color': 'tab:green', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_SADDLE_3D: {'color': 'tab:red', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_FOCUS_3D: {'color': 'tab:pink', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
STABLE_FOCUS_3D: {'color': 'tab:purple', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_NODE_3D: {'color': 'tab:orange', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNSTABLE_CENTER_3D: {'color': 'tab:olive', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
UNKNOWN_3D: {'color': 'tab:cyan', 'markersize': _markersize, 'linestyle': 'None', 'marker': '.'},
})


def set_plot_schema(fixed_point: str, **schema):
if not isinstance(fixed_point, str):
raise TypeError(f'Must instance of string, but we got {type(fixed_point)}: {fixed_point}')
if fixed_point not in plot_schema:
raise KeyError(f'Fixed point type {fixed_point} does not found in the built-in types. ')
plot_schema[fixed_point].update(**schema)


def set_markersize(markersize):
if not isinstance(markersize, int):
raise TypeError(f"Must be an integer, but got {type(markersize)}: {markersize}")
global _markersize
__markersize = markersize
for key in tuple(plot_schema.keys()):
plot_schema[key]['markersize'] = markersize


32 changes: 3 additions & 29 deletions brainpy/analysis/stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
'get_1d_stability_types',
'get_2d_stability_types',
'get_3d_stability_types',
'plot_scheme',


'stability_analysis',

Expand All @@ -27,17 +27,13 @@
'UNSTABLE_LINE_2D',
]

plot_scheme = {}


SADDLE_NODE = 'saddle node'
CENTER_MANIFOLD = 'center manifold'
plot_scheme[CENTER_MANIFOLD] = {'color': 'orangered'}
plot_scheme[SADDLE_NODE] = {"color": 'tab:blue'}

STABLE_POINT_1D = 'stable point'
UNSTABLE_POINT_1D = 'unstable point'
plot_scheme[STABLE_POINT_1D] = {"color": 'tab:red'}
plot_scheme[UNSTABLE_POINT_1D] = {"color": 'tab:olive'}

CENTER_2D = 'center'
STABLE_NODE_2D = 'stable node'
Expand All @@ -49,18 +45,7 @@
UNSTABLE_STAR_2D = 'unstable star'
UNSTABLE_DEGENERATE_2D = 'unstable degenerate'
UNSTABLE_LINE_2D = 'unstable line'
plot_scheme.update({
CENTER_2D: {'color': 'lime'},
STABLE_NODE_2D: {"color": 'tab:red'},
STABLE_FOCUS_2D: {"color": 'tab:purple'},
STABLE_STAR_2D: {'color': 'tab:olive'},
STABLE_DEGENERATE_2D: {'color': 'blueviolet'},
UNSTABLE_NODE_2D: {"color": 'tab:orange'},
UNSTABLE_FOCUS_2D: {"color": 'tab:cyan'},
UNSTABLE_STAR_2D: {'color': 'green'},
UNSTABLE_DEGENERATE_2D: {'color': 'springgreen'},
UNSTABLE_LINE_2D: {'color': 'dodgerblue'},
})


STABLE_POINT_3D = 'unclassified stable point'
UNSTABLE_POINT_3D = 'unclassified unstable point'
Expand All @@ -71,17 +56,6 @@
UNSTABLE_FOCUS_3D = 'unstable focus'
UNSTABLE_CENTER_3D = 'unstable center'
UNKNOWN_3D = 'unknown 3d'
plot_scheme.update({
STABLE_POINT_3D: {'color': 'tab:gray'},
UNSTABLE_POINT_3D: {'color': 'tab:purple'},
STABLE_NODE_3D: {'color': 'tab:green'},
UNSTABLE_SADDLE_3D: {'color': 'tab:red'},
UNSTABLE_FOCUS_3D: {'color': 'tab:pink'},
STABLE_FOCUS_3D: {'color': 'tab:purple'},
UNSTABLE_NODE_3D: {'color': 'tab:orange'},
UNSTABLE_CENTER_3D: {'color': 'tab:olive'},
UNKNOWN_3D: {'color': 'tab:cyan'},
})


def get_1d_stability_types():
Expand Down
24 changes: 12 additions & 12 deletions brainpy/dyn/rates/populations.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,8 @@ def update(self, tdi, x=None):
self.y.value = y

def clear_input(self):
self.input[:] = 0.
self.input_y[:] = 0.
self.input.value = bm.zeros_like(self.input)
self.input_y.value = bm.zeros_like(self.input_y)


class FeedbackFHN(RateModel):
Expand Down Expand Up @@ -375,8 +375,8 @@ def update(self, tdi, x=None):
self.y.value = y

def clear_input(self):
self.input[:] = 0.
self.input_y[:] = 0.
self.input.value = bm.zeros_like(self.input)
self.input_y.value = bm.zeros_like(self.input_y)


class QIF(RateModel):
Expand Down Expand Up @@ -558,8 +558,8 @@ def update(self, tdi, x=None):
self.y.value = y

def clear_input(self):
self.input[:] = 0.
self.input_y[:] = 0.
self.input.value = bm.zeros_like(self.input)
self.input_y.value = bm.zeros_like(self.input_y)


class StuartLandauOscillator(RateModel):
Expand Down Expand Up @@ -700,8 +700,8 @@ def update(self, tdi, x=None):
self.y.value = y

def clear_input(self):
self.input[:] = 0.
self.input_y[:] = 0.
self.input.value = bm.zeros_like(self.input)
self.input_y.value = bm.zeros_like(self.input_y)


class WilsonCowanModel(RateModel):
Expand Down Expand Up @@ -857,8 +857,8 @@ def update(self, tdi, x=None):
self.y.value = y

def clear_input(self):
self.input[:] = 0.
self.input_y[:] = 0.
self.input.value = bm.zeros_like(self.input)
self.input_y.value = bm.zeros_like(self.input_y)


class JansenRitModel(RateModel):
Expand Down Expand Up @@ -976,5 +976,5 @@ def update(self, tdi, x=None):
self.i.value = bm.maximum(self.i + di * dt, 0.)

def clear_input(self):
self.Ie[:] = 0.
self.Ii[:] = 0.
self.Ie.value = bm.zeros_like(self.Ie)
self.Ii.value = bm.zeros_like(self.Ii)
2 changes: 1 addition & 1 deletion brainpy/math/delayvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def update(self, value: Union[float, int, bool, JaxArray, jnp.DeviceArray]):
self.idx.value = stop_gradient((self.idx + 1) % self.num_delay_step)

elif self.update_method == CONCAT_UPDATING:
self.data.value = bm.concatenate([self.data[1:], bm.broadcast_to(value, self.delay_target_shape)], axis=0)
self.data.value = bm.vstack([self.data[1:], bm.broadcast_to(value,self.data.shape[1:])])

else:
raise ValueError(f'Unknown updating method "{self.update_method}"')
Expand Down
Loading