diff --git a/brainunit/math/_fun_keep_unit.py b/brainunit/math/_fun_keep_unit.py index 277486e..87247f1 100644 --- a/brainunit/math/_fun_keep_unit.py +++ b/brainunit/math/_fun_keep_unit.py @@ -443,7 +443,7 @@ def _broadcast_fun(func, *args, **kwargs): args = [asarray(x) for x in args] args, treedef = jax.tree.flatten(args) r = func(*args, **kwargs) - r = treedef.unflatten(r) + r = treedef.unflatten([r] if isinstance(r, jax.Array) else r) if len(r) == 1: return r[0] return r @@ -485,7 +485,7 @@ def atleast_1d( Parameters ---------- - *args : array_like, Quantity + *arys : array_like, Quantity One or more input arrays or quantities. Returns @@ -505,7 +505,7 @@ def atleast_2d( Parameters ---------- - *args : array_like, Quantity + *arys : array_like, Quantity One or more input arrays or quantities. Returns @@ -525,7 +525,7 @@ def atleast_3d( Parameters ---------- - *args : array_like, Quantity + *arys : array_like, Quantity One or more input arrays or quantities. Returns