Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ publishment.md
.vscode


brainpy/base/tests/io_test_tmp*

development

examples/simulation/data
Expand Down Expand Up @@ -53,7 +55,6 @@ develop/benchmark/CUBA/annarchy*
develop/benchmark/CUBA/brian2*



*~
\#*\#
*.pyc
Expand Down
2 changes: 1 addition & 1 deletion brainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "2.1.11"
__version__ = "2.1.12"


try:
Expand Down
31 changes: 16 additions & 15 deletions brainpy/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,49 +208,50 @@ def unique_name(self, name=None, type_=None):
naming.check_name_uniqueness(name=name, obj=self)
return name

def load_states(self, filename, verbose=False, check_missing=False):
def load_states(self, filename, verbose=False):
"""Load the model states.

Parameters
----------
filename : str
The filename which stores the model states.
verbose: bool
check_missing: bool
Whether report the load progress.
"""
if not os.path.exists(filename):
raise errors.BrainPyError(f'Cannot find the file path: {filename}')
elif filename.endswith('.hdf5') or filename.endswith('.h5'):
io.load_h5(filename, target=self, verbose=verbose, check=check_missing)
io.load_by_h5(filename, target=self, verbose=verbose)
elif filename.endswith('.pkl'):
io.load_pkl(filename, target=self, verbose=verbose, check=check_missing)
io.load_by_pkl(filename, target=self, verbose=verbose)
elif filename.endswith('.npz'):
io.load_npz(filename, target=self, verbose=verbose, check=check_missing)
io.load_by_npz(filename, target=self, verbose=verbose)
elif filename.endswith('.mat'):
io.load_mat(filename, target=self, verbose=verbose, check=check_missing)
io.load_by_mat(filename, target=self, verbose=verbose)
else:
raise errors.BrainPyError(f'Unknown file format: {filename}. We only supports {io.SUPPORTED_FORMATS}')

def save_states(self, filename, all_vars=None, **setting):
def save_states(self, filename, variables=None, **setting):
"""Save the model states.

Parameters
----------
filename : str
The file name which to store the model states.
all_vars: optional, dict, TensorCollector
variables: optional, dict, TensorCollector
The variables to save. If not provided, all variables retrieved by ``~.vars()`` will be used.
"""
if all_vars is None:
all_vars = self.vars(method='relative').unique()
if variables is None:
variables = self.vars(method='absolute', level=-1)

if filename.endswith('.hdf5') or filename.endswith('.h5'):
io.save_h5(filename, all_vars=all_vars)
elif filename.endswith('.pkl'):
io.save_pkl(filename, all_vars=all_vars)
io.save_as_h5(filename, variables=variables)
elif filename.endswith('.pkl') or filename.endswith('.pickle'):
io.save_as_pkl(filename, variables=variables)
elif filename.endswith('.npz'):
io.save_npz(filename, all_vars=all_vars, **setting)
io.save_as_npz(filename, variables=variables, **setting)
elif filename.endswith('.mat'):
io.save_mat(filename, all_vars=all_vars)
io.save_as_mat(filename, variables=variables)
else:
raise errors.BrainPyError(f'Unknown file format: {filename}. We only supports {io.SUPPORTED_FORMATS}')

Expand Down
24 changes: 24 additions & 0 deletions brainpy/base/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,35 @@ def update(self, other, **kwargs):
self[key] = value

def __add__(self, other):
"""Merging two dicts.

Parameters
----------
other: dict
The other dict instance.

Returns
-------
gather: Collector
The new collector.
"""
gather = type(self)(self)
gather.update(other)
return gather

def __sub__(self, other):
"""Remove other item in the collector.

Parameters
----------
other: dict
The items to remove.

Returns
-------
gather: Collector
The new collector.
"""
if not isinstance(other, dict):
raise ValueError(f'Only support dict, but we got {type(other)}.')
gather = type(self)()
Expand Down
Loading