From dcb429b93d985354f570517a5cbb1453ebdce25a Mon Sep 17 00:00:00 2001 From: Breeze-P Date: Mon, 20 Feb 2023 14:35:44 +0800 Subject: [PATCH 1/2] feat: pytorch math-operations-a --- brainpy/_src/math/compat_pytorch.py | 144 +++++++++++++++++- .../_src/math/tests/test_compat_pytorch.py | 31 ++++ 2 files changed, 174 insertions(+), 1 deletion(-) diff --git a/brainpy/_src/math/compat_pytorch.py b/brainpy/_src/math/compat_pytorch.py index 65bcee0aa..798593efc 100644 --- a/brainpy/_src/math/compat_pytorch.py +++ b/brainpy/_src/math/compat_pytorch.py @@ -4,7 +4,7 @@ import jax.numpy as jnp import numpy as np -from .ndarray import Array, _as_jax_array_ +from .ndarray import Array, _as_jax_array_, _return, _check_out from .compat_numpy import ( concatenate, shape ) @@ -86,3 +86,145 @@ def unsqueeze(input: Union[jax.Array, Array], dim: int) -> Array: """ input = _as_jax_array_(input) return Array(jnp.expand_dims(input, dim)) + + +# Math operations +def abs(input: Union[jax.Array, Array], + *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]: + input = _as_jax_array_(input) + r = jnp.abs(input) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + +absolute = abs + +def acos(input: Union[jax.Array, Array], + *, out: Optional[Union[Array,jax.Array, np.ndarray]] = None) -> Optional[Array]: + input = _as_jax_array_(input) + r = jnp.arccos(input) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + +arccos = acos + +def acosh(input: Union[jax.Array, Array], + *, out: Optional[Union[Array,jax.Array, np.ndarray]] = None) -> Optional[Array]: + input = _as_jax_array_(input) + r = jnp.arccosh(input) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + +arccosh = acosh + +def add(input: Union[jax.Array, Array, jnp.number], + other: Union[jax.Array, Array, jnp.number], + *, alpha: Optional[jnp.number] = 1, + out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Array: + input = _as_jax_array_(input) + other = _as_jax_array_(other) + other = jnp.multiply(alpha, other) + r = jnp.add(input, other) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + +def addcdiv(input: Union[jax.Array, Array, jnp.number], + tensor1: Union[jax.Array, Array, jnp.number], + tensor2: Union[jax.Array, Array, jnp.number], + *, value: jnp.number = 1, + out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Array: + tensor1 = _as_jax_array_(tensor1) + tensor2 = _as_jax_array_(tensor2) + other = jnp.divide(tensor1, tensor2) + return add(input, other, alpha=value, out=out) + +def addcmul(input: Union[jax.Array, Array, jnp.number], + tensor1: Union[jax.Array, Array, jnp.number], + tensor2: Union[jax.Array, Array, jnp.number], + *, value: jnp.number = 1, + out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Array: + tensor1 = _as_jax_array_(tensor1) + tensor2 = _as_jax_array_(tensor2) + other = jnp.multiply(tensor1, tensor2) + return add(input, other, alpha=value, out=out) + +def angle(input: Union[jax.Array, Array, jnp.number], + *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Array: + input = _as_jax_array_(input) + r = jnp.angle(input) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + +def asin(input: Union[jax.Array, Array], + *, out: Optional[Union[Array,jax.Array, np.ndarray]] = None) -> Optional[Array]: + input = _as_jax_array_(input) + r = jnp.arcsin(input) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + +arcsin = asin + +def asinh(input: Union[jax.Array, Array], + *, out: Optional[Union[Array,jax.Array, np.ndarray]] = None) -> Optional[Array]: + input = _as_jax_array_(input) + r = jnp.arcsinh(input) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + +arcsinh = asinh + +def atan(input: Union[jax.Array, Array], + *, out: Optional[Union[Array,jax.Array, np.ndarray]] = None) -> Optional[Array]: + input = _as_jax_array_(input) + r = jnp.arctan(input) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + +arctan = atan + +def atanh(input: Union[jax.Array, Array], + *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]: + input = _as_jax_array_(input) + r = jnp.arctanh(input) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + +arctanh = atanh + +def atan2(input: Union[jax.Array, Array], + *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]: + input = _as_jax_array_(input) + r = jnp.arctan2(input) + if out is None: + return _return(r) + else: + _check_out(out) + out.value = r + +arctan2 = atan2 \ No newline at end of file diff --git a/brainpy/_src/math/tests/test_compat_pytorch.py b/brainpy/_src/math/tests/test_compat_pytorch.py index 5b44cb009..ce3dd161b 100644 --- a/brainpy/_src/math/tests/test_compat_pytorch.py +++ b/brainpy/_src/math/tests/test_compat_pytorch.py @@ -7,6 +7,7 @@ import unittest import brainpy.math as bm from brainpy._src.math import compat_pytorch +import brainpy._src.math.compat_pytorch as torch from absl .testing import parameterized @@ -45,3 +46,33 @@ def test1(self): a = a.expand(1, 6, 4, -1) self.assertTrue(a.shape == (1, 6, 4, 5)) +class TestMathOperators(unittest.TestCase): + def test_abs(self): + arr = compat_pytorch.Tensor([-1, -2, 3]) + a = compat_pytorch.abs(arr) + res = compat_pytorch.Tensor([1, 2, 3]) + b = compat_pytorch.absolute(arr) + self.assertTrue(bm.array_equal(a, res)) + self.assertTrue(bm.array_equal(b, res)) + + def test_add(self): + a = compat_pytorch.Tensor([0.0202, 1.0985, 1.3506, -0.6056]) + a = compat_pytorch.add(a, 20) + res = compat_pytorch.Tensor([20.0202, 21.0985, 21.3506, 19.3944]) + self.assertTrue(bm.array_equal(a, res)) + b = compat_pytorch.Tensor([-0.9732, -0.3497, 0.6245, 0.4022]) + c = compat_pytorch.Tensor([[0.3743], [-1.7724], [-0.5811], [-0.8017]]) + b = compat_pytorch.add(b, c, alpha=10) + self.assertTrue(b.shape == (4, 4)) + print("b:", b) + + def test_addcdiv(self): + rng = bm.random.default_rng(999) + t = rng.rand(1, 3) + t1 = rng.randn(3, 1) + rng = bm.random.default_rng(199) + t2 = rng.randn(1, 3) + res = torch.addcdiv(t, t1, t2, value=0.1) + print("t + t1/t2 * value:", res) + res = torch.addcmul(t, t1, t2, value=0.1) + print("t + t1*t2 * value:", res) From d7e504a50dcfe1ccd16b7f82a72045974c151384 Mon Sep 17 00:00:00 2001 From: Breeze-P Date: Wed, 22 Feb 2023 17:29:30 +0800 Subject: [PATCH 2/2] fix: pytorch math-operations-a problems resolved. --- brainpy/_src/math/compat_pytorch.py | 8 ++++---- brainpy/math/compat_pytorch.py | 20 +++++++++++++++++++- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/brainpy/_src/math/compat_pytorch.py b/brainpy/_src/math/compat_pytorch.py index 798593efc..82f4b99c4 100644 --- a/brainpy/_src/math/compat_pytorch.py +++ b/brainpy/_src/math/compat_pytorch.py @@ -128,7 +128,7 @@ def acosh(input: Union[jax.Array, Array], def add(input: Union[jax.Array, Array, jnp.number], other: Union[jax.Array, Array, jnp.number], *, alpha: Optional[jnp.number] = 1, - out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Array: + out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]: input = _as_jax_array_(input) other = _as_jax_array_(other) other = jnp.multiply(alpha, other) @@ -143,7 +143,7 @@ def addcdiv(input: Union[jax.Array, Array, jnp.number], tensor1: Union[jax.Array, Array, jnp.number], tensor2: Union[jax.Array, Array, jnp.number], *, value: jnp.number = 1, - out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Array: + out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]: tensor1 = _as_jax_array_(tensor1) tensor2 = _as_jax_array_(tensor2) other = jnp.divide(tensor1, tensor2) @@ -153,14 +153,14 @@ def addcmul(input: Union[jax.Array, Array, jnp.number], tensor1: Union[jax.Array, Array, jnp.number], tensor2: Union[jax.Array, Array, jnp.number], *, value: jnp.number = 1, - out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Array: + out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]: tensor1 = _as_jax_array_(tensor1) tensor2 = _as_jax_array_(tensor2) other = jnp.multiply(tensor1, tensor2) return add(input, other, alpha=value, out=out) def angle(input: Union[jax.Array, Array, jnp.number], - *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Array: + *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]: input = _as_jax_array_(input) r = jnp.angle(input) if out is None: diff --git a/brainpy/math/compat_pytorch.py b/brainpy/math/compat_pytorch.py index c3a7021f4..919134aac 100644 --- a/brainpy/math/compat_pytorch.py +++ b/brainpy/math/compat_pytorch.py @@ -4,5 +4,23 @@ flatten as flatten, cat as cat, - + unsqueeze as unsqueeze, + abs as abs, + absolute as absolute, + acos as acos, + arccos as arccos, + acosh as acosh, + arccosh as arccosh, + add as add, + addcdiv as addcdiv, + addcmul as addcmul, + angle as angle, + asin as asin, + arcsin as arcsin, + asinh as asinh, + arcsin as arcsin, + atan as atan, + arctan as arctan, + atan2 as atan2, + atanh as atanh, )