Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/Linux_CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,14 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
# python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.14.tar.gz
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
python setup.py install
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest brainpy/
4 changes: 1 addition & 3 deletions .github/workflows/MacOS_CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,14 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
python -m pip install jax==0.3.14
python -m pip install jaxlib==0.3.14
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
python setup.py install
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest brainpy/
2 changes: 1 addition & 1 deletion .github/workflows/Windows_CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest brainpy/
45 changes: 31 additions & 14 deletions brainpy/math/delayvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,24 @@ class LengthDelay(AbstractDelay):
It can also be arrays. Or a callable function or instance of ``Connector``.
Note that ``initial_delay_data`` should be arranged as the following way::

delay = delay_len [ data
delay = delay_len-1 data
delay = 1 [ data
delay = 2 data
... ....
... ....
delay = 2 data
delay = 1 data ]
delay = delay_len-1 data
delay = delay_len data ]

.. versionchanged:: 2.2.3.2

The data in the previous version of ``LengthDelay`` is::

delay = delay_len [ data
delay = delay_len-1 data
... ....
... ....
delay = 2 data
delay = 1 data ]


name: str
The delay object name.
Expand Down Expand Up @@ -368,13 +380,13 @@ def reset(
dtype=delay_target.dtype)

# update delay data
self.data[-1] = delay_target
self.data[0] = delay_target
if initial_delay_data is None:
pass
elif isinstance(initial_delay_data, (ndarray, jnp.ndarray, float, int, bool)):
self.data[:-1] = initial_delay_data
self.data[1:] = initial_delay_data
elif callable(initial_delay_data):
self.data[:-1] = initial_delay_data((delay_len,) + delay_target.shape,
self.data[1:] = initial_delay_data((delay_len,) + delay_target.shape,
dtype=delay_target.dtype)
else:
raise ValueError(f'"delay_data" does not support {type(initial_delay_data)}')
Expand Down Expand Up @@ -406,20 +418,22 @@ def retrieve(self, delay_len, *indices):
check_error_in_jit(bm.any(delay_len >= self.num_delay_step), self._check_delay, delay_len)

if self.update_method == ROTATION_UPDATING:
# the delay length
delay_idx = (self.idx[0] - delay_len - 1) % self.num_delay_step
delay_idx = (self.idx[0] + delay_len) % self.num_delay_step
delay_idx = stop_gradient(delay_idx)
if not jnp.issubdtype(delay_idx.dtype, jnp.integer):
raise ValueError(f'"delay_len" must be integer, but we got {delay_len}')

elif self.update_method == CONCAT_UPDATING:
delay_idx = self.num_delay_step - 1 - delay_len
delay_idx = delay_len

else:
raise ValueError(f'Unknown updating method "{self.update_method}"')

# the delay data
# the delay index
if isinstance(delay_idx, int):
pass
elif hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer):
raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}')
indices = (delay_idx,) + tuple(indices)
# the delay data
return self.data[indices]

def update(self, value: Union[float, int, bool, JaxArray, jnp.DeviceArray]):
Expand All @@ -435,7 +449,10 @@ def update(self, value: Union[float, int, bool, JaxArray, jnp.DeviceArray]):
self.idx.value = stop_gradient((self.idx + 1) % self.num_delay_step)

elif self.update_method == CONCAT_UPDATING:
self.data.value = bm.vstack([self.data[1:], bm.broadcast_to(value,self.data.shape[1:])])
if self.num_delay_step >= 2:
self.data.value = bm.vstack([bm.broadcast_to(value, self.data.shape[1:]), self.data[1:]])
else:
self.data[:] = value

else:
raise ValueError(f'Unknown updating method "{self.update_method}"')
Expand Down
6 changes: 4 additions & 2 deletions brainpy/math/tests/test_delay_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def test2(self):
dim = 3
for update_method in [ROTATION_UPDATING, CONCAT_UPDATING]:
delay = bm.LengthDelay(jnp.zeros(dim), 10,
initial_delay_data=jnp.arange(1, 11).reshape((10, 1)),
# initial_delay_data=jnp.arange(1, 11).reshape((10, 1)),
initial_delay_data=jnp.arange(10, 0, -1).reshape((10, 1)),
update_method=update_method)
print(delay(0))
self.assertTrue(jnp.array_equal(delay(0), jnp.zeros(dim)))
Expand All @@ -111,7 +112,8 @@ def test3(self):
dim = 3
for update_method in [ROTATION_UPDATING, CONCAT_UPDATING]:
delay = bm.LengthDelay(jnp.zeros(dim), 10,
initial_delay_data=jnp.arange(1, 11).reshape((10, 1)),
# initial_delay_data=jnp.arange(1, 11).reshape((10, 1)),
initial_delay_data=jnp.arange(10, 0, -1).reshape((10, 1)),
update_method=update_method)
print(delay(jnp.asarray([1, 2, 3]),
jnp.arange(3)))
Expand Down