Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 143 additions & 1 deletion brainpy/_src/math/compat_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax.numpy as jnp
import numpy as np

from .ndarray import Array, _as_jax_array_
from .ndarray import Array, _as_jax_array_, _return, _check_out
from .compat_numpy import (
concatenate, shape
)
Expand Down Expand Up @@ -86,3 +86,145 @@ def unsqueeze(input: Union[jax.Array, Array], dim: int) -> Array:
"""
input = _as_jax_array_(input)
return Array(jnp.expand_dims(input, dim))


# Math operations
def abs(input: Union[jax.Array, Array],
*, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
input = _as_jax_array_(input)
r = jnp.abs(input)
if out is None:
return _return(r)
else:
_check_out(out)
out.value = r

absolute = abs

def acos(input: Union[jax.Array, Array],
*, out: Optional[Union[Array,jax.Array, np.ndarray]] = None) -> Optional[Array]:
input = _as_jax_array_(input)
r = jnp.arccos(input)
if out is None:
return _return(r)
else:
_check_out(out)
out.value = r

arccos = acos

def acosh(input: Union[jax.Array, Array],
*, out: Optional[Union[Array,jax.Array, np.ndarray]] = None) -> Optional[Array]:
input = _as_jax_array_(input)
r = jnp.arccosh(input)
if out is None:
return _return(r)
else:
_check_out(out)
out.value = r

arccosh = acosh

def add(input: Union[jax.Array, Array, jnp.number],
other: Union[jax.Array, Array, jnp.number],
*, alpha: Optional[jnp.number] = 1,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is alpha used?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Pytorch, the add func is not only "add". It also means "Adds other, scaled by alpha, to input.", and alpha is used in line 134.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
input = _as_jax_array_(input)
other = _as_jax_array_(other)
other = jnp.multiply(alpha, other)
r = jnp.add(input, other)
if out is None:
return _return(r)
else:
_check_out(out)
out.value = r

def addcdiv(input: Union[jax.Array, Array, jnp.number],
tensor1: Union[jax.Array, Array, jnp.number],
tensor2: Union[jax.Array, Array, jnp.number],
*, value: jnp.number = 1,
out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
tensor1 = _as_jax_array_(tensor1)
tensor2 = _as_jax_array_(tensor2)
other = jnp.divide(tensor1, tensor2)
return add(input, other, alpha=value, out=out)

def addcmul(input: Union[jax.Array, Array, jnp.number],
tensor1: Union[jax.Array, Array, jnp.number],
tensor2: Union[jax.Array, Array, jnp.number],
*, value: jnp.number = 1,
out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
tensor1 = _as_jax_array_(tensor1)
tensor2 = _as_jax_array_(tensor2)
other = jnp.multiply(tensor1, tensor2)
return add(input, other, alpha=value, out=out)

def angle(input: Union[jax.Array, Array, jnp.number],
*, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
input = _as_jax_array_(input)
r = jnp.angle(input)
if out is None:
return _return(r)
else:
_check_out(out)
out.value = r

def asin(input: Union[jax.Array, Array],
*, out: Optional[Union[Array,jax.Array, np.ndarray]] = None) -> Optional[Array]:
input = _as_jax_array_(input)
r = jnp.arcsin(input)
if out is None:
return _return(r)
else:
_check_out(out)
out.value = r

arcsin = asin

def asinh(input: Union[jax.Array, Array],
*, out: Optional[Union[Array,jax.Array, np.ndarray]] = None) -> Optional[Array]:
input = _as_jax_array_(input)
r = jnp.arcsinh(input)
if out is None:
return _return(r)
else:
_check_out(out)
out.value = r

arcsinh = asinh

def atan(input: Union[jax.Array, Array],
*, out: Optional[Union[Array,jax.Array, np.ndarray]] = None) -> Optional[Array]:
input = _as_jax_array_(input)
r = jnp.arctan(input)
if out is None:
return _return(r)
else:
_check_out(out)
out.value = r

arctan = atan

def atanh(input: Union[jax.Array, Array],
*, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
input = _as_jax_array_(input)
r = jnp.arctanh(input)
if out is None:
return _return(r)
else:
_check_out(out)
out.value = r

arctanh = atanh

def atan2(input: Union[jax.Array, Array],
*, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
input = _as_jax_array_(input)
r = jnp.arctan2(input)
if out is None:
return _return(r)
else:
_check_out(out)
out.value = r

arctan2 = atan2
31 changes: 31 additions & 0 deletions brainpy/_src/math/tests/test_compat_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import unittest
import brainpy.math as bm
from brainpy._src.math import compat_pytorch
import brainpy._src.math.compat_pytorch as torch

from absl .testing import parameterized

Expand Down Expand Up @@ -45,3 +46,33 @@ def test1(self):
a = a.expand(1, 6, 4, -1)
self.assertTrue(a.shape == (1, 6, 4, 5))

class TestMathOperators(unittest.TestCase):
def test_abs(self):
arr = compat_pytorch.Tensor([-1, -2, 3])
a = compat_pytorch.abs(arr)
res = compat_pytorch.Tensor([1, 2, 3])
b = compat_pytorch.absolute(arr)
self.assertTrue(bm.array_equal(a, res))
self.assertTrue(bm.array_equal(b, res))

def test_add(self):
a = compat_pytorch.Tensor([0.0202, 1.0985, 1.3506, -0.6056])
a = compat_pytorch.add(a, 20)
res = compat_pytorch.Tensor([20.0202, 21.0985, 21.3506, 19.3944])
self.assertTrue(bm.array_equal(a, res))
b = compat_pytorch.Tensor([-0.9732, -0.3497, 0.6245, 0.4022])
c = compat_pytorch.Tensor([[0.3743], [-1.7724], [-0.5811], [-0.8017]])
b = compat_pytorch.add(b, c, alpha=10)
self.assertTrue(b.shape == (4, 4))
print("b:", b)

def test_addcdiv(self):
rng = bm.random.default_rng(999)
t = rng.rand(1, 3)
t1 = rng.randn(3, 1)
rng = bm.random.default_rng(199)
t2 = rng.randn(1, 3)
res = torch.addcdiv(t, t1, t2, value=0.1)
print("t + t1/t2 * value:", res)
res = torch.addcmul(t, t1, t2, value=0.1)
print("t + t1*t2 * value:", res)
20 changes: 19 additions & 1 deletion brainpy/math/compat_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,23 @@

flatten as flatten,
cat as cat,

unsqueeze as unsqueeze,
abs as abs,
absolute as absolute,
acos as acos,
arccos as arccos,
acosh as acosh,
arccosh as arccosh,
add as add,
addcdiv as addcdiv,
addcmul as addcmul,
angle as angle,
asin as asin,
arcsin as arcsin,
asinh as asinh,
arcsin as arcsin,
atan as atan,
arctan as arctan,
atan2 as atan2,
atanh as atanh,
)