Skip to content

Commit

Permalink
fix broadcast functions
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jun 27, 2024
1 parent 5ec009d commit a54759b
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions brainunit/math/_fun_keep_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import jax.numpy as jnp
import numpy as np

from ._fun_array_creation import asarray
from .._base import Quantity, fail_for_dimension_mismatch, DIMENSIONLESS
from .._misc import set_module_as

Expand Down Expand Up @@ -439,6 +440,7 @@ def vsplit(


def _broadcast_fun(func, *args, **kwargs):
args = [asarray(x) for x in args]
args, treedef = jax.tree.flatten(args)
r = func(*args, **kwargs)
r = treedef.unflatten(r)
Expand Down Expand Up @@ -471,9 +473,7 @@ def broadcast_arrays(
``writable`` flag True, writing to a single output value may end up
changing more than one location in the output array.
"""
leaves, tree = jax.tree.flatten(args)
leaves = jnp.broadcast_arrays(*leaves)
return jax.tree.unflatten(tree, leaves)
return _broadcast_fun(jnp.broadcast_arrays, *args)


@set_module_as('brainunit.math')
Expand Down

0 comments on commit a54759b

Please sign in to comment.