Skip to content

Commit

Permalink
Support jax==0.4.28
Browse files Browse the repository at this point in the history
This reverts commit 59fb681.
  • Loading branch information
Routhleck committed May 13, 2024
1 parent 59fb681 commit 5253b59
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
12 changes: 5 additions & 7 deletions brainpy/_src/math/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,27 +660,25 @@ def searchsorted(self, v, side='left', sorter=None):
"""
return _return(self.value.searchsorted(v=_as_jax_array_(v), side=side, sorter=sorter))

def sort(self, axis=-1, kind='quicksort', order=None):
def sort(self, axis=-1, stable=True, order=None):
"""Sort an array in-place.
Parameters
----------
axis : int, optional
Axis along which to sort. Default is -1, which means sort along the
last axis.
kind : {'quicksort', 'mergesort', 'heapsort', 'stable'}
Sorting algorithm. The default is 'quicksort'. Note that both 'stable'
and 'mergesort' use timsort under the covers and, in general, the
actual implementation will vary with datatype. The 'mergesort' option
is retained for backwards compatibility.
stable : bool, optional
Whether to use a stable sorting algorithm. The default is True.
order : str or list of str, optional
When `a` is an array with fields defined, this argument specifies
which fields to compare first, second, etc. A single field can
be specified as a string, and not all fields need be specified,
but unspecified fields will still be used, in the order in which
they come up in the dtype, to break ties.
"""
self.value = self.value.sort(axis=axis, kind=kind, order=order)
self.value = self.value.sort(axis=axis, stable=stable, order=order)


def squeeze(self, axis=None):
"""Remove axes of length one from ``a``."""
Expand Down
6 changes: 3 additions & 3 deletions brainpy/_src/math/object_transform/tests/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __call__(self, *args, **kwargs):
def test_jit_with_static(self):
a = bm.Variable(bm.ones(2))

@bm.jit(static_argnums=1)
@bm.jit(static_argnums=0)
def f(b, c):
a.value *= b
a.value /= c
Expand Down Expand Up @@ -104,7 +104,7 @@ def __init__(self):
self.a = bm.zeros(2)
self.b = bm.Variable(bm.ones(2))

self.call1 = bm.jit(self.call, static_argnums=0)
self.call1 = bm.jit(self.call, static_argnums=1)
self.call2 = bm.jit(self.call, static_argnames=['fit'])

def call(self, fit=True):
Expand Down Expand Up @@ -157,7 +157,7 @@ class MyObj:
def __init__(self):
self.a = bm.Variable(bm.ones(2))

@bm.cls_jit(static_argnums=1)
@bm.cls_jit(static_argnums=0)
def f(self, b, c):
self.a.value *= b
self.a.value /= c
Expand Down

0 comments on commit 5253b59

Please sign in to comment.