Skip to content

Commit

Permalink
fix broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jun 27, 2024
1 parent 3823083 commit b42751f
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions brainunit/math/_fun_keep_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def _broadcast_fun(func, *args, **kwargs):
args = [asarray(x) for x in args]
args, treedef = jax.tree.flatten(args)
r = func(*args, **kwargs)
r = treedef.unflatten(r)
r = treedef.unflatten([r] if isinstance(r, jax.Array) else r)
if len(r) == 1:
return r[0]
return r
Expand Down Expand Up @@ -485,7 +485,7 @@ def atleast_1d(
Parameters
----------
*args : array_like, Quantity
*arys : array_like, Quantity
One or more input arrays or quantities.
Returns
Expand All @@ -505,7 +505,7 @@ def atleast_2d(
Parameters
----------
*args : array_like, Quantity
*arys : array_like, Quantity
One or more input arrays or quantities.
Returns
Expand All @@ -525,7 +525,7 @@ def atleast_3d(
Parameters
----------
*args : array_like, Quantity
*arys : array_like, Quantity
One or more input arrays or quantities.
Returns
Expand Down

0 comments on commit b42751f

Please sign in to comment.