Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions .github/workflows/auto-changelog.yml

This file was deleted.

22 changes: 0 additions & 22 deletions .github/workflows/contributors.yml

This file was deleted.

19 changes: 0 additions & 19 deletions .github/workflows/generate_changelog.yml

This file was deleted.

15 changes: 9 additions & 6 deletions brainpy/math/jaxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,17 +872,20 @@ def view(self, dtype=None, *args, **kwargs):
# NumPy support
# ------------------

def numpy(self):
def numpy(self, dtype=None):
"""Convert to numpy.ndarray."""
return np.asarray(self.value)
return np.asarray(self.value, dtype=dtype)

def to_numpy(self):
def to_numpy(self, dtype=None):
"""Convert to numpy.ndarray."""
return np.asarray(self.value)
return np.asarray(self.value, dtype=dtype)

def to_jax(self):
def to_jax(self, dtype=None):
"""Convert to jax.numpy.ndarray."""
return self.value
if dtype is None:
return self.value
else:
return jnp.asarray(self.value, dtype=dtype)

def __array__(self, dtype=None):
"""Support ``numpy.array()`` and ``numpy.asarray()`` functions."""
Expand Down
100 changes: 90 additions & 10 deletions brainpy/math/numpy_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,33 +109,96 @@


def remove_diag(arr):
"""Remove the diagonal of the matrix.

Parameters
----------
arr: JaxArray, jnp.ndarray
The matrix with the shape of `(M, N)`.

Returns
-------
arr: JaxArray
The matrix without diagonal which has the shape of `(M, N-1)`.
"""
if arr.ndim != 2:
raise ValueError(f'Only support 2D matrix, while we got a {arr.ndim}D array.')
eyes = ones(arr.shape, dtype=bool)
fill_diagonal(eyes, False)
return reshape(arr[eyes.value], (arr.shape[0], arr.shape[1] - 1))


def as_device_array(tensor):
def as_device_array(tensor, dtype=None):
"""Convert the input to a ``jax.numpy.DeviceArray``.

Parameters
----------
tensor: array_like
Input data, in any form that can be converted to an array. This
includes lists, lists of tuples, tuples, tuples of tuples, tuples
of lists, numpy.ndarray, JaxArray, jax.numpy.ndarray.
dtype: data-type, optional
By default, the data-type is inferred from the input data.

Returns
-------
out : ndarray
Array interpretation of `tensor`. No copy is performed if the input
is already an ndarray with matching dtype.
"""
if isinstance(tensor, JaxArray):
return tensor.value
return tensor.to_jax(dtype)
elif isinstance(tensor, jnp.ndarray):
return tensor
return tensor if (dtype is None) else jnp.asarray(tensor, dtype=dtype)
elif isinstance(tensor, np.ndarray):
return jnp.asarray(tensor)
return jnp.asarray(tensor, dtype=dtype)
else:
return jnp.asarray(tensor)
return jnp.asarray(tensor, dtype=dtype)


def as_numpy(tensor):
def as_numpy(tensor, dtype=None):
"""Convert the input to a ``numpy.ndarray``.

Parameters
----------
tensor: array_like
Input data, in any form that can be converted to an array. This
includes lists, lists of tuples, tuples, tuples of tuples, tuples
of lists, numpy.ndarray, JaxArray, jax.numpy.ndarray.
dtype: data-type, optional
By default, the data-type is inferred from the input data.

Returns
-------
out : ndarray
Array interpretation of `tensor`. No copy is performed if the input
is already an ndarray with matching dtype.
"""
if isinstance(tensor, JaxArray):
return tensor.numpy()
return tensor.numpy(dtype=dtype)
else:
return np.asarray(tensor)
return np.asarray(tensor, dtype=dtype)


def as_variable(tensor):
return Variable(asarray(tensor))
def as_variable(tensor, dtype=None):
"""Convert the input to a ``brainpy.math.Variable``.

Parameters
----------
tensor: array_like
Input data, in any form that can be converted to an array. This
includes lists, lists of tuples, tuples, tuples of tuples, tuples
of lists, numpy.ndarray, JaxArray, jax.numpy.ndarray.
dtype: data-type, optional
By default, the data-type is inferred from the input data.

Returns
-------
out : ndarray
Array interpretation of `tensor`. No copy is performed if the input
is already an ndarray with matching dtype.
"""
return Variable(asarray(tensor, dtype=dtype))


def _remove_jaxarray(obj):
Expand Down Expand Up @@ -1704,6 +1767,23 @@ def array(a, dtype=None, copy=True, order="K", ndmin=0):

@wraps(jnp.asarray)
def asarray(a, dtype=None, order=None):
"""Convert the input to a ``brainpy.math.JaxArray``.

Parameters
----------
a: array_like
Input data, in any form that can be converted to an array. This
includes lists, lists of tuples, tuples, tuples of tuples, tuples
of lists, numpy.ndarray, JaxArray, jax.numpy.ndarray.
dtype: data-type, optional
By default, the data-type is inferred from the input data.

Returns
-------
out : ndarray
Array interpretation of `a`. No copy is performed if the input
is already an ndarray with matching dtype.
"""
a = _remove_jaxarray(a)
try:
res = jnp.asarray(a=a, dtype=dtype, order=order)
Expand Down