Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions brainpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@
synapses, # synaptic dynamics
synouts, # synaptic output
synplast, # synaptic plasticity
experimental, # experimental model
syn,
)
from brainpy._src.dyn.base import not_pass_shargs
from brainpy._src.dyn.base import (DynamicalSystem as DynamicalSystem,
Module as Module,
from brainpy._src.dyn.base import not_pass_sha
from brainpy._src.dyn.base import (DynamicalSystem,
DynamicalSystemNS,
Container as Container,
Sequential as Sequential,
Network as Network,
Expand All @@ -77,6 +77,8 @@
from brainpy._src.dyn.transform import (NoSharedArg as NoSharedArg, # transformations
LoopOverTime as LoopOverTime,)
from brainpy._src.dyn.runners import (DSRunner as DSRunner) # runner
from brainpy._src.dyn.context import share
from brainpy._src.dyn.delay import Delay


# Part 4: Training #
Expand Down Expand Up @@ -240,3 +242,7 @@
dyn.__dict__['NMDA'] = compat.NMDA
del compat


from brainpy._src import checking
tools.__dict__['checking'] = checking
del checking
2 changes: 1 addition & 1 deletion brainpy/_src/analysis/highdim/slow_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ def f_cell(h: Dict):

# call update functions
args = (shared,) + self.args
target.update(*args)
target(*args)

# get new states
new_h = {k: (v.value if (v.batch_axis is None) else jnp.squeeze(v.value, axis=v.batch_axis))
Expand Down
6 changes: 3 additions & 3 deletions brainpy/_src/analysis/lowdim/lowdim_bifurcation.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,10 @@ def plot_bifurcation(self, with_plot=True, show=False, with_return=False,

if self._can_convert_to_one_eq():
if self.convert_type() == C.x_by_y:
X = self.resolutions[self.y_var].value
X = bm.as_jax(self.resolutions[self.y_var])
else:
X = self.resolutions[self.x_var].value
pars = tuple(self.resolutions[p].value for p in self.target_par_names)
X = bm.as_jax(self.resolutions[self.x_var])
pars = tuple(bm.as_jax(self.resolutions[p]) for p in self.target_par_names)
mesh_values = jnp.meshgrid(*((X,) + pars))
mesh_values = tuple(jnp.moveaxis(v, 0, 1).flatten() for v in mesh_values)
candidates = mesh_values[0]
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/analysis/lowdim/lowdim_phase_plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,9 @@ def plot_fixed_point(self, with_plot=True, with_return=False, show=False,

if self._can_convert_to_one_eq():
if self.convert_type() == C.x_by_y:
candidates = self.resolutions[self.y_var].value
candidates = bm.as_jax(self.resolutions[self.y_var])
else:
candidates = self.resolutions[self.x_var].value
candidates = bm.as_jax(self.resolutions[self.x_var])
else:
if select_candidates == 'fx-nullcline':
candidates = [self.analyzed_results[key][0] for key in self.analyzed_results.keys()
Expand Down
File renamed without changes.
29 changes: 22 additions & 7 deletions brainpy/_src/checkpoints/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,13 @@
get_tensorstore_spec = None

from brainpy._src.math.ndarray import Array
from brainpy._src.math.object_transform.base import Collector
from brainpy.errors import (AlreadyExistsError,
MPACheckpointingRequiredError,
MPARestoreTargetRequiredError,
MPARestoreDataCorruptedError,
MPARestoreTypeNotMatchError,
InvalidCheckpointPath,
InvalidCheckpointError)
from brainpy.tools import DotDict
from brainpy.types import PyTree

__all__ = [
Expand Down Expand Up @@ -154,17 +152,27 @@ def from_state_dict(target, state: Dict[str, Any], name: str = '.'):
A copy of the object with the restored state.
"""
ty = _NamedTuple if _is_namedtuple(target) else type(target)
if ty not in _STATE_DICT_REGISTRY:
for t in _STATE_DICT_REGISTRY.keys():
if issubclass(ty, t):
ty = t
break
else:
return state
ty_from_state_dict = _STATE_DICT_REGISTRY[ty][1]
with _record_path(name):
return ty_from_state_dict(target, state)



def to_state_dict(target) -> Dict[str, Any]:
"""Returns a dictionary with the state of the given target."""
ty = _NamedTuple if _is_namedtuple(target) else type(target)
if ty not in _STATE_DICT_REGISTRY:

for t in _STATE_DICT_REGISTRY.keys():
if issubclass(ty, t):
ty = t
break
else:
return target

ty_to_state_dict = _STATE_DICT_REGISTRY[ty][0]
Expand Down Expand Up @@ -269,8 +277,9 @@ def _restore_namedtuple(xs, state_dict: Dict[str, Any]):

register_serialization_state(Array, _array_dict_state, _restore_array)
register_serialization_state(dict, _dict_state_dict, _restore_dict)
register_serialization_state(DotDict, _dict_state_dict, _restore_dict)
register_serialization_state(Collector, _dict_state_dict, _restore_dict)
# register_serialization_state(DotDict, _dict_state_dict, _restore_dict)
# register_serialization_state(Collector, _dict_state_dict, _restore_dict)
# register_serialization_state(ArrayCollector, _dict_state_dict, _restore_dict)
register_serialization_state(list, _list_state_dict, _restore_list)
register_serialization_state(tuple,
_list_state_dict,
Expand Down Expand Up @@ -1221,8 +1230,9 @@ def _save_main_ckpt_file2(target: bytes,
def save_pytree(
filename: str,
target: PyTree,
overwrite: bool = False,
overwrite: bool = True,
async_manager: Optional[AsyncManager] = None,
verbose: bool = True,
) -> None:
"""Save a checkpoint of the model. Suitable for single-host.

Expand Down Expand Up @@ -1250,12 +1260,16 @@ def save_pytree(
if defined, the save will run without blocking the main
thread. Only works for single host. Note that an ongoing save will still
block subsequent saves, to make sure overwrite/keep logic works correctly.
verbose: bool
Whether output the print information.

Returns
-------
out: str
Filename of saved checkpoint.
"""
if verbose:
print(f'Saving checkpoint into {filename}')
start_time = time.time()
# Make sure all saves are finished before the logic of checking and removing
# outdated checkpoints happens.
Expand Down Expand Up @@ -1284,6 +1298,7 @@ def save_main_ckpt_task():
end_time - start_time)



def multiprocess_save(
ckpt_dir: Union[str, os.PathLike],
target: PyTree,
Expand Down
1 change: 0 additions & 1 deletion brainpy/_src/dyn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
channels, neurons, rates, # neuron related
synapses, synouts, synplast, # synapse related
networks,
layers, # ANN related
runners,
transform,
)
Expand Down
Loading