Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions brainpy/analysis/highdim/slow_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,7 @@ def train(idx):
return loss

def batch_train(start_i, n_batch):
f = bm.make_loop(train, dyn_vars=dyn_vars, has_return=True)
return f(bm.arange(start_i, start_i + n_batch))
return bm.for_loop(train, dyn_vars, bm.arange(start_i, start_i + n_batch))

# Run the optimization
if self.verbose:
Expand All @@ -369,7 +368,7 @@ def batch_train(start_i, n_batch):
break
batch_idx_start = oidx * num_batch
start_time = time.time()
(_, train_losses) = batch_train(start_i=batch_idx_start, n_batch=num_batch)
train_losses = batch_train(start_i=batch_idx_start, n_batch=num_batch)
batch_time = time.time() - start_time
opt_losses.append(train_losses)

Expand Down Expand Up @@ -722,8 +721,6 @@ def _generate_ds_cell_function(
shared = DotDict(t=t, dt=dt, i=0)

def f_cell(h: Dict):
target.clear_input()

# update target variables
for k, v in self.target_vars.items():
v.value = (bm.asarray(h[k], dtype=v.dtype)
Expand All @@ -735,6 +732,7 @@ def f_cell(h: Dict):
v.value = self.excluded_data[k]

# add inputs
target.clear_input()
if f_input is not None:
f_input(shared)

Expand All @@ -743,7 +741,7 @@ def f_cell(h: Dict):
target.update(*args)

# get new states
new_h = {k: (v.value if v.batch_axis is None else jnp.squeeze(v.value, axis=v.batch_axis))
new_h = {k: (v.value if (v.batch_axis is None) else jnp.squeeze(v.value, axis=v.batch_axis))
for k, v in self.target_vars.items()}
return new_h

Expand Down
51 changes: 41 additions & 10 deletions brainpy/analysis/lowdim/lowdim_bifurcation.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,17 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,
if with_return:
return final_fps, final_pars, jacobians

def plot_limit_cycle_by_sim(self, duration=100, with_plot=True, with_return=False,
plot_style=None, tol=0.001, show=False, dt=None, offset=1.):
def plot_limit_cycle_by_sim(
self,
duration=100,
with_plot: bool = True,
with_return: bool = False,
plot_style: dict = None,
tol: float = 0.001,
show: bool = False,
dt: float = None,
offset: float = 1.
):
global pyplot
if pyplot is None: from matplotlib import pyplot
utils.output('I am plotting the limit cycle ...')
Expand Down Expand Up @@ -400,10 +409,16 @@ def plot_limit_cycle_by_sim(self, duration=100, with_plot=True, with_return=Fals
if len(ps_limit_cycle[0]):
for i, var in enumerate(self.target_var_names):
pyplot.figure(var)
pyplot.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['max'],
**plot_style, label='limit cycle (max)')
pyplot.plot(ps_limit_cycle[0], ps_limit_cycle[1], vs_limit_cycle[i]['min'],
**plot_style, label='limit cycle (min)')
pyplot.plot(ps_limit_cycle[0],
ps_limit_cycle[1],
vs_limit_cycle[i]['max'],
**plot_style,
label='limit cycle (max)')
pyplot.plot(ps_limit_cycle[0],
ps_limit_cycle[1],
vs_limit_cycle[i]['min'],
**plot_style,
label='limit cycle (min)')
pyplot.legend()

elif len(self.target_par_names) == 1:
Expand All @@ -427,8 +442,16 @@ def plot_limit_cycle_by_sim(self, duration=100, with_plot=True, with_return=Fals


class FastSlow1D(Bifurcation1D):
def __init__(self, model, fast_vars, slow_vars, fixed_vars=None,
pars_update=None, resolutions=None, options=None):
def __init__(
self,
model,
fast_vars: dict,
slow_vars: dict,
fixed_vars: dict = None,
pars_update: dict = None,
resolutions=None,
options: dict = None
):
super(FastSlow1D, self).__init__(model=model,
target_pars=slow_vars,
target_vars=fast_vars,
Expand Down Expand Up @@ -510,8 +533,16 @@ def plot_trajectory(self, initials, duration, plot_durations=None,


class FastSlow2D(Bifurcation2D):
def __init__(self, model, fast_vars, slow_vars, fixed_vars=None,
pars_update=None, resolutions=0.1, options=None):
def __init__(
self,
model,
fast_vars: dict,
slow_vars: dict,
fixed_vars: dict = None,
pars_update: dict = None,
resolutions=0.1,
options: dict = None
):
super(FastSlow2D, self).__init__(model=model,
target_pars=slow_vars,
target_vars=fast_vars,
Expand Down
7 changes: 4 additions & 3 deletions brainpy/analysis/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,14 @@ def __init__(self, integrals: dict, initial_vars: dict, pars=None, dt=None):

# variables
assert isinstance(initial_vars, dict)
initial_vars = {k: bm.Variable(jnp.asarray(bm.as_device_array(v), dtype=jnp.float_))
initial_vars = {k: bm.Variable(jnp.asarray(bm.as_device_array(v), dtype=bm.dftype()))
for k, v in initial_vars.items()}
self.register_implicit_vars(initial_vars)

# parameters
pars = dict() if pars is None else pars
assert isinstance(pars, dict)
self.pars = [jnp.asarray(bm.as_device_array(v), dtype=jnp.float_)
self.pars = [jnp.asarray(bm.as_device_array(v), dtype=bm.dftype())
for k, v in pars.items()]

# integrals
Expand All @@ -128,7 +128,8 @@ def __init__(self, integrals: dict, initial_vars: dict, pars=None, dt=None):
# runner
self.runner = DSRunner(self,
monitors=list(initial_vars.keys()),
dyn_vars=self.vars().unique(), dt=dt,
dyn_vars=self.vars().unique(),
dt=dt,
progress_bar=False)

def update(self, sha):
Expand Down
7 changes: 5 additions & 2 deletions brainpy/dyn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,7 @@ def offline_fit(self,
raise NoImplementationError('Subclass must implement offline_fit() function when using OfflineTrainer.')

def clear_input(self):
for node in self.nodes(level=1, include_self=False).subset(NeuGroup).unique().values():
node.clear_input()
pass


class Container(DynamicalSystem):
Expand Down Expand Up @@ -430,6 +429,10 @@ def __getattr__(self, item):
else:
return super(Container, self).__getattribute__(item)

def clear_input(self):
for node in self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values():
node.clear_input()


class Sequential(Container):
def __init__(
Expand Down
25 changes: 11 additions & 14 deletions brainpy/dyn/neurons/biological_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,20 +244,17 @@ def __init__(

# variables
self.V = variable(self._V_initializer, mode, self.varshape)
if self._m_initializer is None:
self.m = bm.Variable(self.m_inf(self.V.value))
else:
self.m = variable(self._m_initializer, mode, self.varshape)
if self._h_initializer is None:
self.h = bm.Variable(self.h_inf(self.V.value))
else:
self.h = variable(self._h_initializer, mode, self.varshape)
if self._n_initializer is None:
self.n = bm.Variable(self.n_inf(self.V.value))
else:
self.n = variable(self._n_initializer, mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)
self.m = (bm.Variable(self.m_inf(self.V.value))
if m_initializer is None else
variable(self._m_initializer, mode, self.varshape))
self.h = (bm.Variable(self.h_inf(self.V.value))
if h_initializer is None else
variable(self._h_initializer, mode, self.varshape))
self.n = (bm.Variable(self.n_inf(self.V.value))
if n_initializer is None else
variable(self._n_initializer, mode, self.varshape))
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
self.input = variable(bm.zeros, mode, self.varshape)

# integral
if self.noise is None:
Expand Down Expand Up @@ -309,7 +306,7 @@ def dV(self, V, t, m, h, n, I_ext):

@property
def derivative(self):
return JointEq([self.dV, self.dm, self.dh, self.dn])
return JointEq(self.dV, self.dm, self.dh, self.dn)

def update(self, tdi, x=None):
t, dt = tdi['t'], tdi['dt']
Expand Down
8 changes: 3 additions & 5 deletions brainpy/dyn/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,8 +566,7 @@ def f_predict(self, shared_args: Dict = None):

monitor_func = self.build_monitors(self._mon_info[0], self._mon_info[1], shared_args)

def _step_func(inputs):
t, i, x = inputs
def _step_func(t, i, x):
self.target.clear_input()
# input step
shared = DotDict(t=t, i=i, dt=self.dt)
Expand All @@ -586,8 +585,7 @@ def _step_func(inputs):
if self.jit['predict']:
dyn_vars = self.target.vars()
dyn_vars.update(self.dyn_vars)
f = bm.make_loop(_step_func, dyn_vars=dyn_vars.unique(), has_return=True)
run_func = lambda all_inputs: f(all_inputs)[1]
run_func = lambda all_inputs: bm.for_loop(_step_func, dyn_vars.unique(), all_inputs)

else:
def run_func(xs):
Expand All @@ -601,7 +599,7 @@ def run_func(xs):
x = tree_map(lambda x: x[i], xs, is_leaf=lambda x: isinstance(x, bm.JaxArray))

# step at the i
output, mon = _step_func((times[i], indices[i], x))
output, mon = _step_func(times[i], indices[i], x)

# append output and monitor
outputs.append(output)
Expand Down
4 changes: 2 additions & 2 deletions brainpy/inputs/currents.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,9 @@ def ou_process(mean, sigma, tau, duration, dt=None, n=1, t_start=0., t_end=None,

def _f(t):
x.value = x + dt * ((mean - x) / tau) + sigma * dt_sqrt * rng.rand(n)
return x.value

f = bm.make_loop(_f, dyn_vars=[x, rng], out_vars=x)
noises = f(jnp.arange(t_start, t_end, dt))
noises = bm.for_loop(_f, [x, rng], jnp.arange(t_start, t_end, dt))

t_end = duration if t_end is None else t_end
i_start = int(t_start / dt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def f(t):

if show:
fig = plt.figure()
ax = fig.gca(projection='3d')
ax = fig.add_subplot(111, projection='3d')
plt.plot(mon_x, mon_y, mon_z)
ax.set_xlabel('x')
ax.set_xlabel('y')
Expand Down
31 changes: 14 additions & 17 deletions brainpy/integrators/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,16 +217,12 @@ def __init__(

# build the update step
if self.jit['predict']:
_loop_func = bm.make_loop(
self._step,
dyn_vars=self.dyn_vars,
out_vars={k: self.variables[k] for k in self.monitors.keys()},
has_return=True
)
def _loop_func(times):
return bm.for_loop(self._step, self.dyn_vars, times)
else:
def _loop_func(times):
out_vars = {k: [] for k in self.monitors.keys()}
returns = {k: [] for k in self.fun_monitors.keys()}
returns.update({k: [] for k in self.monitors.keys()})
for i in range(len(times)):
_t = times[i]
_dt = self.dt
Expand All @@ -237,9 +233,9 @@ def _loop_func(times):
self._step(_t)
# variable monitors
for k in self.monitors.keys():
out_vars[k].append(bm.as_device_array(self.variables[k]))
out_vars = {k: bm.asarray(out_vars[k]) for k in self.monitors.keys()}
return out_vars, returns
returns[k].append(bm.as_device_array(self.variables[k]))
returns = {k: bm.asarray(returns[k]) for k in returns.keys()}
return returns
self.step_func = _loop_func

def _step(self, t):
Expand All @@ -252,11 +248,6 @@ def _step(self, t):
kwargs.update({k: v[self.idx.value] for k, v in self._dyn_args.items()})
self.idx += 1

# return of function monitors
returns = dict()
for key, func in self.fun_monitors.items():
returns[key] = func(t, self.dt)

# call integrator function
update_values = self.target(**kwargs)
if len(self.target.variables) == 1:
Expand All @@ -268,6 +259,13 @@ def _step(self, t):
# progress bar
if self.progress_bar:
id_tap(lambda *args: self._pbar.update(), ())

# return of function monitors
returns = dict()
for key, func in self.fun_monitors.items():
returns[key] = func(t, self.dt)
for k in self.monitors.keys():
returns[k] = self.variables[k].value
return returns

def run(self, duration, start_t=None, eval_time=False):
Expand Down Expand Up @@ -302,14 +300,13 @@ def run(self, duration, start_t=None, eval_time=False):
refresh=True)
if eval_time:
t0 = time.time()
hists, returns = self.step_func(times)
hists = self.step_func(times)
if eval_time:
running_time = time.time() - t0
if self.progress_bar:
self._pbar.close()

# post-running
hists.update(returns)
times += self.dt
if self.numpy_mon_after_run:
times = np.asarray(times)
Expand Down
2 changes: 1 addition & 1 deletion brainpy/integrators/sde/tests/test_sde_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def lorenz_system(method, **kwargs):
mon3 = bp.math.array(mon3).to_numpy()

fig = plt.figure()
ax = fig.gca(projection='3d')
ax = fig.add_subplot(111, projection='3d')
plt.plot(mon1, mon2, mon3)
ax.set_xlabel('x')
ax.set_xlabel('y')
Expand Down
6 changes: 4 additions & 2 deletions brainpy/math/operators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ def _check_brainpylib(ops_name):
raise PackageMissingError(
f'"{ops_name}" operator need "brainpylib>={_BRAINPYLIB_MINIMAL_VERSION}". \n'
f'Please install it through:\n\n'
f'>>> pip install brainpylib>={_BRAINPYLIB_MINIMAL_VERSION} -U'
f'>>> pip install brainpylib=={_BRAINPYLIB_MINIMAL_VERSION}\n'
f'>>> # or \n'
f'>>> pip install brainpylib -U'
)
else:
raise PackageMissingError(
f'"brainpylib" must be installed when the user '
f'wants to use "{ops_name}" operator. \n'
f'Please install "brainpylib>={_BRAINPYLIB_MINIMAL_VERSION}" through:\n\n'
f'>>> pip install brainpylib>={_BRAINPYLIB_MINIMAL_VERSION}'
f'>>> pip install brainpylib'
)
8 changes: 4 additions & 4 deletions brainpy/measure/correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
]


@jit
# @jit
@partial(vmap, in_axes=(None, 0, 0))
def _cc(states, i, j):
sqrt_ij = jnp.sqrt(jnp.sum(states[i]) * jnp.sum(states[j]))
Expand Down Expand Up @@ -86,7 +86,7 @@ def _var(neu_signal, i):
return jnp.mean(neu_signal * neu_signal) - jnp.mean(neu_signal) ** 2


@jit
# @jit
def voltage_fluctuation(potentials):
r"""Calculate neuronal synchronization via voltage variance.

Expand Down Expand Up @@ -202,7 +202,7 @@ def functional_connectivity(activities):
return np.nan_to_num(fc)


@jit
# @jit
def functional_connectivity_dynamics(activities, window_size=30, step_size=5):
"""Computes functional connectivity dynamics (FCD) matrix.

Expand Down Expand Up @@ -233,7 +233,7 @@ def _weighted_cov(x, y, w):
return jnp.sum(w * (x - _weighted_mean(x, w)) * (y - _weighted_mean(y, w))) / jnp.sum(w)


@jit
# @jit
def weighted_correlation(x, y, w):
"""Weighted Pearson correlation of two data series.

Expand Down
Loading