diff --git a/.github/workflows/Linux_CI.yml b/.github/workflows/Linux_CI.yml index 0c8984036..dfe658f99 100644 --- a/.github/workflows/Linux_CI.yml +++ b/.github/workflows/Linux_CI.yml @@ -28,7 +28,6 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install flake8 pytest - # python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.14.tar.gz if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi python setup.py install - name: Lint with flake8 @@ -36,7 +35,7 @@ jobs: # stop the build if there are Python syntax errors or undefined names flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - # flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Test with pytest run: | pytest brainpy/ diff --git a/.github/workflows/MacOS_CI.yml b/.github/workflows/MacOS_CI.yml index 70db5de77..debd1a539 100644 --- a/.github/workflows/MacOS_CI.yml +++ b/.github/workflows/MacOS_CI.yml @@ -28,8 +28,6 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install flake8 pytest - python -m pip install jax==0.3.14 - python -m pip install jaxlib==0.3.14 if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi python setup.py install - name: Lint with flake8 @@ -37,7 +35,7 @@ jobs: # stop the build if there are Python syntax errors or undefined names flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - # flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Test with pytest run: | pytest brainpy/ diff --git a/.github/workflows/Windows_CI.yml b/.github/workflows/Windows_CI.yml index 9043c2ff0..29e3f7ae5 100644 --- a/.github/workflows/Windows_CI.yml +++ b/.github/workflows/Windows_CI.yml @@ -39,7 +39,7 @@ jobs: # stop the build if there are Python syntax errors or undefined names flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - # flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Test with pytest run: | pytest brainpy/ diff --git a/brainpy/math/delayvars.py b/brainpy/math/delayvars.py index c709c4804..339e483d8 100644 --- a/brainpy/math/delayvars.py +++ b/brainpy/math/delayvars.py @@ -280,12 +280,24 @@ class LengthDelay(AbstractDelay): It can also be arrays. Or a callable function or instance of ``Connector``. Note that ``initial_delay_data`` should be arranged as the following way:: - delay = delay_len [ data - delay = delay_len-1 data + delay = 1 [ data + delay = 2 data ... .... ... .... - delay = 2 data - delay = 1 data ] + delay = delay_len-1 data + delay = delay_len data ] + + .. versionchanged:: 2.2.3.2 + + The data in the previous version of ``LengthDelay`` is:: + + delay = delay_len [ data + delay = delay_len-1 data + ... .... + ... .... + delay = 2 data + delay = 1 data ] + name: str The delay object name. @@ -368,13 +380,13 @@ def reset( dtype=delay_target.dtype) # update delay data - self.data[-1] = delay_target + self.data[0] = delay_target if initial_delay_data is None: pass elif isinstance(initial_delay_data, (ndarray, jnp.ndarray, float, int, bool)): - self.data[:-1] = initial_delay_data + self.data[1:] = initial_delay_data elif callable(initial_delay_data): - self.data[:-1] = initial_delay_data((delay_len,) + delay_target.shape, + self.data[1:] = initial_delay_data((delay_len,) + delay_target.shape, dtype=delay_target.dtype) else: raise ValueError(f'"delay_data" does not support {type(initial_delay_data)}') @@ -406,20 +418,22 @@ def retrieve(self, delay_len, *indices): check_error_in_jit(bm.any(delay_len >= self.num_delay_step), self._check_delay, delay_len) if self.update_method == ROTATION_UPDATING: - # the delay length - delay_idx = (self.idx[0] - delay_len - 1) % self.num_delay_step + delay_idx = (self.idx[0] + delay_len) % self.num_delay_step delay_idx = stop_gradient(delay_idx) - if not jnp.issubdtype(delay_idx.dtype, jnp.integer): - raise ValueError(f'"delay_len" must be integer, but we got {delay_len}') elif self.update_method == CONCAT_UPDATING: - delay_idx = self.num_delay_step - 1 - delay_len + delay_idx = delay_len else: raise ValueError(f'Unknown updating method "{self.update_method}"') - # the delay data + # the delay index + if isinstance(delay_idx, int): + pass + elif hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer): + raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}') indices = (delay_idx,) + tuple(indices) + # the delay data return self.data[indices] def update(self, value: Union[float, int, bool, JaxArray, jnp.DeviceArray]): @@ -435,7 +449,10 @@ def update(self, value: Union[float, int, bool, JaxArray, jnp.DeviceArray]): self.idx.value = stop_gradient((self.idx + 1) % self.num_delay_step) elif self.update_method == CONCAT_UPDATING: - self.data.value = bm.vstack([self.data[1:], bm.broadcast_to(value,self.data.shape[1:])]) + if self.num_delay_step >= 2: + self.data.value = bm.vstack([bm.broadcast_to(value, self.data.shape[1:]), self.data[1:]]) + else: + self.data[:] = value else: raise ValueError(f'Unknown updating method "{self.update_method}"') diff --git a/brainpy/math/tests/test_delay_vars.py b/brainpy/math/tests/test_delay_vars.py index cc2757ffc..d1573a595 100644 --- a/brainpy/math/tests/test_delay_vars.py +++ b/brainpy/math/tests/test_delay_vars.py @@ -94,7 +94,8 @@ def test2(self): dim = 3 for update_method in [ROTATION_UPDATING, CONCAT_UPDATING]: delay = bm.LengthDelay(jnp.zeros(dim), 10, - initial_delay_data=jnp.arange(1, 11).reshape((10, 1)), + # initial_delay_data=jnp.arange(1, 11).reshape((10, 1)), + initial_delay_data=jnp.arange(10, 0, -1).reshape((10, 1)), update_method=update_method) print(delay(0)) self.assertTrue(jnp.array_equal(delay(0), jnp.zeros(dim))) @@ -111,7 +112,8 @@ def test3(self): dim = 3 for update_method in [ROTATION_UPDATING, CONCAT_UPDATING]: delay = bm.LengthDelay(jnp.zeros(dim), 10, - initial_delay_data=jnp.arange(1, 11).reshape((10, 1)), + # initial_delay_data=jnp.arange(1, 11).reshape((10, 1)), + initial_delay_data=jnp.arange(10, 0, -1).reshape((10, 1)), update_method=update_method) print(delay(jnp.asarray([1, 2, 3]), jnp.arange(3)))