Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jun 27, 2024
1 parent 419c104 commit 543d243
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 50 deletions.
6 changes: 2 additions & 4 deletions brainunit/math/_fun_keep_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import jax
import jax.numpy as jnp
import numpy as np

from .._base import Quantity, fail_for_dimension_mismatch, DIMENSIONLESS
from .._misc import set_module_as
Expand Down Expand Up @@ -411,7 +412,7 @@ def hsplit(
def vsplit(
a: Union[jax.Array, Quantity],
indices_or_sections: Union[int, Sequence[int]]
) -> Union[List[jax.Array], List[Quantity]]:
) -> Union[Sequence[jax.Array | Quantity]]:
"""
Split a quantity or an array into multiple sub-arrays vertically (row-wise).
Expand Down Expand Up @@ -1023,8 +1024,6 @@ def ravel(
return _fun_keep_unit_unary(jnp.ravel, a, order=order)




@set_module_as('brainunit.math')
def flatten(
x: jax.typing.ArrayLike | Quantity,
Expand Down Expand Up @@ -1130,7 +1129,6 @@ def remove_diag(x: jax.typing.ArrayLike | Quantity) -> jax.Array | Quantity:
return x



# ---------- selection


Expand Down
91 changes: 45 additions & 46 deletions brainunit/math/_fun_remove_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
# ==============================================================================
from __future__ import annotations

from typing import (Union, Optional)
from typing import (Union, Optional, Sequence)

import jax
import jax.numpy as jnp
from jax import Array

from .._base import Quantity, fail_for_dimension_mismatch, DIMENSIONLESS
from .._misc import set_module_as
Expand Down Expand Up @@ -53,9 +52,9 @@ def _fun_remove_unit_unary(func, x, *args, **kwargs):

@set_module_as('brainunit.math')
def heaviside(
x1: Union[Quantity, jax.Array],
x1: Union[Quantity, jax.jax.Array],
x2: jax.typing.ArrayLike
) -> Union[Quantity, jax.Array]:
) -> Union[Quantity, jax.jax.Array]:
"""
Compute the Heaviside step function.
Expand All @@ -78,7 +77,7 @@ def heaviside(


@set_module_as('brainunit.math')
def signbit(x: Union[Array, Quantity]) -> Array:
def signbit(x: Union[jax.Array, Quantity]) -> jax.Array:
"""
Returns element-wise True where signbit is set (less than zero).
Expand All @@ -97,7 +96,7 @@ def signbit(x: Union[Array, Quantity]) -> Array:


@set_module_as('brainunit.math')
def sign(x: Union[Array, Quantity]) -> Array:
def sign(x: Union[jax.Array, Quantity]) -> jax.Array:
"""
Returns the sign of each element in the input array.
Expand All @@ -117,12 +116,12 @@ def sign(x: Union[Array, Quantity]) -> Array:

@set_module_as('brainunit.math')
def bincount(
x: Union[Array, Quantity],
x: Union[jax.Array, Quantity],
weights: Optional[jax.typing.ArrayLike] = None,
minlength: int = 0,
*,
length: Optional[int] = None
) -> Array:
) -> jax.Array:
"""
Count number of occurrences of each value in array of non-negative ints.
Expand Down Expand Up @@ -155,10 +154,10 @@ def bincount(

@set_module_as('brainunit.math')
def digitize(
x: Union[Array, Quantity],
bins: Union[Array, Quantity],
x: Union[jax.Array, Quantity],
bins: Union[jax.Array, Quantity],
right: bool = False
) -> Array:
) -> jax.Array:
"""
Return the indices of the bins to which each value in input array belongs.
Expand Down Expand Up @@ -217,8 +216,8 @@ def all(
x: Union[Quantity, jax.typing.ArrayLike],
axis: Optional[int] = None,
keepdims: bool = False,
where: Optional[Array] = None
) -> Union[bool, Array]:
where: Optional[jax.Array] = None
) -> Union[bool, jax.Array]:
"""
Test whether all array elements along a given axis evaluate to True.
Expand Down Expand Up @@ -261,8 +260,8 @@ def any(
x: Union[Quantity, jax.typing.ArrayLike],
axis: Optional[int] = None,
keepdims: bool = False,
where: Optional[Array] = None
) -> Union[bool, Array]:
where: Optional[jax.Array] = None
) -> Union[bool, jax.Array]:
"""
Test whether any array element along a given axis evaluates to True.
Expand Down Expand Up @@ -303,7 +302,7 @@ def any(
@set_module_as('brainunit.math')
def logical_not(
x: Union[Quantity, jax.typing.ArrayLike],
) -> Union[bool, Array]:
) -> Union[bool, jax.Array]:
"""
Compute the truth value of NOT x element-wise.
Expand Down Expand Up @@ -349,7 +348,7 @@ def equal(
y: Union[Quantity, jax.typing.ArrayLike],
*args,
**kwargs
) -> Union[bool, Array]:
) -> Union[bool, jax.Array]:
"""
equal(x, y, /, out=None, *, where=True, casting='same_kind', order='K', dtype=None, subok=True[, signature, extobj])
Expand Down Expand Up @@ -393,7 +392,7 @@ def not_equal(
y: Union[Quantity, jax.typing.ArrayLike],
*args,
**kwargs
) -> Union[bool, Array]:
) -> Union[bool, jax.Array]:
"""
not_equal(x, y, /, out=None, *, where=True, casting='same_kind',
order='K', dtype=None, subok=True[, signature, extobj])
Expand Down Expand Up @@ -438,7 +437,7 @@ def greater(
y: Union[Quantity, jax.typing.ArrayLike],
*args,
**kwargs
) -> Union[bool, Array]:
) -> Union[bool, jax.Array]:
"""
greater(x, y, /, out=None, *, where=True, casting='same_kind',
order='K', dtype=None, subok=True[, signature, extobj])
Expand Down Expand Up @@ -484,7 +483,7 @@ def greater_equal(
*args,
**kwargs
) -> Union[
bool, Array]:
bool, jax.Array]:
"""
greater_equal(x, y, /, out=None, *, where=True, casting='same_kind', order='K', dtype=None, subok=True[, signature, extobj])
Expand Down Expand Up @@ -528,7 +527,7 @@ def less(
y: Union[Quantity, jax.typing.ArrayLike],
*args,
**kwargs
) -> Union[bool, Array]:
) -> Union[bool, jax.Array]:
"""
less(x, y, /, out=None, *, where=True, casting='same_kind',
order='K', dtype=None, subok=True[, signature, extobj])
Expand Down Expand Up @@ -574,7 +573,7 @@ def less_equal(
*args,
**kwargs
) -> Union[
bool, Array]:
bool, jax.Array]:
"""
less_equal(x, y, /, out=None, *, where=True, casting='same_kind', order='K', dtype=None, subok=True[, signature, extobj])
Expand Down Expand Up @@ -619,7 +618,7 @@ def array_equal(
*args,
**kwargs
) -> Union[
bool, Array]:
bool, jax.Array]:
"""
True if two arrays have the same shape and elements, False otherwise.
Expand Down Expand Up @@ -647,7 +646,7 @@ def isclose(
rtol: float | Quantity = 1e-05,
atol: float | Quantity = 1e-08,
equal_nan: bool = False
) -> Union[bool, Array]:
) -> Union[bool, jax.Array]:
"""
Returns a boolean array where two arrays are element-wise equal within a
tolerance.
Expand Down Expand Up @@ -698,7 +697,7 @@ def allclose(
rtol: float | Quantity = 1e-05,
atol: float | Quantity = 1e-08,
equal_nan: bool = False
) -> Union[bool, Array]:
) -> Union[bool, jax.Array]:
"""
Returns True if two arrays are element-wise equal within a tolerance.
Expand Down Expand Up @@ -750,7 +749,7 @@ def logical_and(
y: Union[Quantity, jax.typing.ArrayLike],
*args,
**kwargs
) -> Union[bool, Array]:
) -> Union[bool, jax.Array]:
"""
logical_and(x, y, /, out=None, *, where=True, casting='same_kind', order='K', dtype=None, subok=True[, signature, extobj])
Expand Down Expand Up @@ -794,7 +793,7 @@ def logical_or(
y: Union[Quantity, jax.typing.ArrayLike],
*args,
**kwargs
) -> Union[bool, Array]:
) -> Union[bool, jax.Array]:
"""
logical_or(x, y, /, out=None, *, where=True, casting='same_kind', order='K', dtype=None, subok=True[, signature, extobj])
Expand Down Expand Up @@ -838,7 +837,7 @@ def logical_xor(
y: Union[Quantity, jax.typing.ArrayLike],
*args,
**kwargs
) -> Union[bool, Array]:
) -> Union[bool, jax.Array]:
"""
logical_xor(x, y, /, out=None, *, where=True, casting='same_kind', order='K', dtype=None, subok=True[, signature, extobj])
Expand Down Expand Up @@ -883,14 +882,14 @@ def logical_xor(

@set_module_as('brainunit.math')
def argsort(
a: Union[Array, Quantity],
a: Union[jax.Array, Quantity],
axis: Optional[int] = -1,
*,
kind: None = None,
order: None = None,
stable: bool = True,
descending: bool = False,
) -> Array:
) -> jax.Array:
"""
Returns the indices that would sort an array or a quantity.
Expand Down Expand Up @@ -927,10 +926,10 @@ def argsort(

@set_module_as('brainunit.math')
def argmax(
a: Union[Array, Quantity],
a: Union[jax.Array, Quantity],
axis: Optional[int] = None,
keepdims: Optional[bool] = None
) -> Array:
) -> jax.Array:
"""
Returns indices of the max value along an axis.
Expand All @@ -954,10 +953,10 @@ def argmax(

@set_module_as('brainunit.math')
def argmin(
a: Union[Array, Quantity],
a: Union[jax.Array, Quantity],
axis: Optional[int] = None,
keepdims: Optional[bool] = None
) -> Array:
) -> jax.Array:
"""
Returns indices of the min value along an axis.
Expand Down Expand Up @@ -1050,11 +1049,11 @@ def nanargmin(

@set_module_as('brainunit.math')
def argwhere(
a: Union[Array, Quantity],
a: Union[jax.Array, Quantity],
*,
size: Optional[int] = None,
fill_value: Optional[jax.typing.ArrayLike] = None,
) -> Array:
) -> jax.Array:
"""
Find the indices of array elements that are non-zero, grouped by element.
Expand All @@ -1078,11 +1077,11 @@ def argwhere(

@set_module_as('brainunit.math')
def nonzero(
a: Union[Array, Quantity],
a: Union[jax.Array, Quantity],
*,
size: Optional[int] = None,
fill_value: Optional[jax.typing.ArrayLike] = None,
) -> Tuple[Array, ...]:
) -> Sequence[jax.Array]:
"""
Return the indices of the elements that are non-zero.
Expand All @@ -1107,11 +1106,11 @@ def nonzero(

@set_module_as('brainunit.math')
def flatnonzero(
a: Union[Array, Quantity],
a: Union[jax.Array, Quantity],
*,
size: Optional[int] = None,
fill_value: Optional[jax.typing.ArrayLike] = None,
) -> Array:
) -> jax.Array:
"""
Return indices that are non-zero in the flattened version of the input quantity or array.
Expand All @@ -1135,10 +1134,10 @@ def flatnonzero(

@set_module_as('brainunit.math')
def count_nonzero(
a: Union[Array, Quantity],
a: Union[jax.Array, Quantity],
axis: Optional[int] = None,
keepdims: Optional[bool] = None
) -> Array:
) -> jax.Array:
"""
Count the number of non-zero values in the quantity or array `a`.
Expand All @@ -1162,13 +1161,13 @@ def count_nonzero(

@set_module_as('brainunit.math')
def searchsorted(
a: Union[Array, Quantity],
v: Union[Array, Quantity],
a: Union[jax.Array, Quantity],
v: Union[jax.Array, Quantity],
side: str = 'left',
sorter: Optional[Array] = None,
sorter: Optional[jax.Array] = None,
*,
method: Optional[str] = 'scan'
) -> Array:
) -> jax.Array:
"""
Find indices where elements should be inserted to maintain order.
Expand Down

0 comments on commit 543d243

Please sign in to comment.