From ce57d51f5e90e194e695ff7bf12926324e7b0e13 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 15 May 2022 13:48:00 +0800 Subject: [PATCH 1/3] support `dtype` setting in array interchange functions --- brainpy/math/jaxarray.py | 15 +++--- brainpy/math/numpy_ops.py | 100 ++++++++++++++++++++++++++++++++++---- 2 files changed, 99 insertions(+), 16 deletions(-) diff --git a/brainpy/math/jaxarray.py b/brainpy/math/jaxarray.py index 453793455..20c0fb6a1 100644 --- a/brainpy/math/jaxarray.py +++ b/brainpy/math/jaxarray.py @@ -872,17 +872,20 @@ def view(self, dtype=None, *args, **kwargs): # NumPy support # ------------------ - def numpy(self): + def numpy(self, dtype=None): """Convert to numpy.ndarray.""" - return np.asarray(self.value) + return np.asarray(self.value, dtype=dtype) - def to_numpy(self): + def to_numpy(self, dtype=None): """Convert to numpy.ndarray.""" - return np.asarray(self.value) + return np.asarray(self.value, dtype=dtype) - def to_jax(self): + def to_jax(self, dtype=None): """Convert to jax.numpy.ndarray.""" - return self.value + if dtype is None: + return self.value + else: + return jnp.asarray(self.value, dtype=dtype) def __array__(self, dtype=None): """Support ``numpy.array()`` and ``numpy.asarray()`` functions.""" diff --git a/brainpy/math/numpy_ops.py b/brainpy/math/numpy_ops.py index 3c6ced954..6f945b4c1 100644 --- a/brainpy/math/numpy_ops.py +++ b/brainpy/math/numpy_ops.py @@ -109,6 +109,18 @@ def remove_diag(arr): + """Remove the diagonal of the matrix. + + Parameters + ---------- + arr: JaxArray, jnp.ndarray + The matrix with the shape of `(M, N)`. + + Returns + ------- + arr: JaxArray + The matrix without diagonal which has the shape of `(M, N-1)`. + """ if arr.ndim != 2: raise ValueError(f'Only support 2D matrix, while we got a {arr.ndim}D array.') eyes = ones(arr.shape, dtype=bool) @@ -116,26 +128,77 @@ def remove_diag(arr): return reshape(arr[eyes.value], (arr.shape[0], arr.shape[1] - 1)) -def as_device_array(tensor): +def as_device_array(tensor, dtype=None): + """Convert the input to a ``jax.numpy.DeviceArray``. + + Parameters + ---------- + tensor: array_like + Input data, in any form that can be converted to an array. This + includes lists, lists of tuples, tuples, tuples of tuples, tuples + of lists, numpy.ndarray, JaxArray, jax.numpy.ndarray. + dtype: data-type, optional + By default, the data-type is inferred from the input data. + + Returns + ------- + out : ndarray + Array interpretation of `tensor`. No copy is performed if the input + is already an ndarray with matching dtype. + """ if isinstance(tensor, JaxArray): - return tensor.value + return tensor.to_jax(dtype) elif isinstance(tensor, jnp.ndarray): - return tensor + return tensor if (dtype is None) else jnp.asarray(tensor, dtype=dtype) elif isinstance(tensor, np.ndarray): - return jnp.asarray(tensor) + return jnp.asarray(tensor, dtype=dtype) else: - return jnp.asarray(tensor) + return jnp.asarray(tensor, dtype=dtype) -def as_numpy(tensor): +def as_numpy(tensor, dtype=None): + """Convert the input to a ``numpy.ndarray``. + + Parameters + ---------- + tensor: array_like + Input data, in any form that can be converted to an array. This + includes lists, lists of tuples, tuples, tuples of tuples, tuples + of lists, numpy.ndarray, JaxArray, jax.numpy.ndarray. + dtype: data-type, optional + By default, the data-type is inferred from the input data. + + Returns + ------- + out : ndarray + Array interpretation of `tensor`. No copy is performed if the input + is already an ndarray with matching dtype. + """ if isinstance(tensor, JaxArray): - return tensor.numpy() + return tensor.numpy(dtype=dtype) else: - return np.asarray(tensor) + return np.asarray(tensor, dtype=dtype) -def as_variable(tensor): - return Variable(asarray(tensor)) +def as_variable(tensor, dtype=None): + """Convert the input to a ``brainpy.math.Variable``. + + Parameters + ---------- + tensor: array_like + Input data, in any form that can be converted to an array. This + includes lists, lists of tuples, tuples, tuples of tuples, tuples + of lists, numpy.ndarray, JaxArray, jax.numpy.ndarray. + dtype: data-type, optional + By default, the data-type is inferred from the input data. + + Returns + ------- + out : ndarray + Array interpretation of `tensor`. No copy is performed if the input + is already an ndarray with matching dtype. + """ + return Variable(asarray(tensor, dtype=dtype)) def _remove_jaxarray(obj): @@ -1704,6 +1767,23 @@ def array(a, dtype=None, copy=True, order="K", ndmin=0): @wraps(jnp.asarray) def asarray(a, dtype=None, order=None): + """Convert the input to a ``brainpy.math.JaxArray``. + + Parameters + ---------- + a: array_like + Input data, in any form that can be converted to an array. This + includes lists, lists of tuples, tuples, tuples of tuples, tuples + of lists, numpy.ndarray, JaxArray, jax.numpy.ndarray. + dtype: data-type, optional + By default, the data-type is inferred from the input data. + + Returns + ------- + out : ndarray + Array interpretation of `a`. No copy is performed if the input + is already an ndarray with matching dtype. + """ a = _remove_jaxarray(a) try: res = jnp.asarray(a=a, dtype=dtype, order=order) From 5ed17cd356942182051d6709ef9f29a80856b71f Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 15 May 2022 13:49:23 +0800 Subject: [PATCH 2/3] remove changelog action --- .github/workflows/auto-changelog.yml | 19 ------------------- .github/workflows/generate_changelog.yml | 19 ------------------- 2 files changed, 38 deletions(-) delete mode 100644 .github/workflows/auto-changelog.yml delete mode 100644 .github/workflows/generate_changelog.yml diff --git a/.github/workflows/auto-changelog.yml b/.github/workflows/auto-changelog.yml deleted file mode 100644 index 6d3ccf302..000000000 --- a/.github/workflows/auto-changelog.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Generate changelog -on: - push: - branches: [master] - -jobs: - generate-changelog: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - uses: BobAnkh/auto-generate-changelog@master - with: - REPO_NAME: 'ztqakita/BrainPy' - ACCESS_TOKEN: ${{secrets.GITHUB_TOKEN}} - PATH: 'changelog.rst' - COMMIT_MESSAGE: 'docs(changelog): update release notes' - TYPE: 'feat:Feature,fix:Bug Fixes,docs:Documentation,refactor:Refactor,perf:Performance Improvements' diff --git a/.github/workflows/generate_changelog.yml b/.github/workflows/generate_changelog.yml deleted file mode 100644 index 61fe5b76e..000000000 --- a/.github/workflows/generate_changelog.yml +++ /dev/null @@ -1,19 +0,0 @@ -name: Generate changelog -on: - release: - types: [created, edited] - -jobs: - generate-changelog: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - uses: BobAnkh/auto-generate-changelog@master - with: - REPO_NAME: 'PKU-NIP-Lab/BrainPy' - ACCESS_TOKEN: ${{secrets.GITHUB_TOKEN}} - PATH: 'CHANGELOG.rst' - COMMIT_MESSAGE: 'docs(CHANGELOG): update release notes' - TYPE: 'feat:Feature,fix:Bug Fixes,docs:Documentation,refactor:Refactor,perf:Performance Improvements' \ No newline at end of file From 472a183c121a387514cd471fd07898b227e52f92 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 15 May 2022 13:50:23 +0800 Subject: [PATCH 3/3] remove contributor github action --- .github/workflows/contributors.yml | 22 ---------------------- 1 file changed, 22 deletions(-) delete mode 100644 .github/workflows/contributors.yml diff --git a/.github/workflows/contributors.yml b/.github/workflows/contributors.yml deleted file mode 100644 index 5cbd61459..000000000 --- a/.github/workflows/contributors.yml +++ /dev/null @@ -1,22 +0,0 @@ -name: Add contributors -on: - schedule: - - cron: '20 20 * * *' - push: - branches: [ master ] - -jobs: - add-contributors: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - uses: BobAnkh/add-contributors@V2.1.0 - with: - CONTRIBUTOR: '# Contributors' - COLUMN_PER_ROW: '6' - ACCESS_TOKEN: ${{secrets.GITHUB_TOKEN}} - IMG_WIDTH: '100' - FONT_SIZE: '14' - PATH: '/README.md' - COMMIT_MESSAGE: 'docs(README): update contributors' - AVATAR_SHAPE: 'round'