Skip to content

Commit

Permalink
jax.numpy.ndarray -> jax.Array
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Jan 22, 2024
1 parent a031682 commit ef33ae9
Show file tree
Hide file tree
Showing 11 changed files with 213 additions and 213 deletions.
4 changes: 2 additions & 2 deletions e3nn_jax/_src/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def from_chunks(
Args:
irreps (Irreps): irreps
chunks (list of optional `jax.numpy.ndarray`): list of arrays
chunks (list of optional `jax.Array`): list of arrays
leading_shape (tuple of int): leading shape of the arrays (without the irreps)
Returns:
Expand Down Expand Up @@ -82,7 +82,7 @@ def as_irreps_array(array: Union[jax.Array, e3nn.IrrepsArray], *, backend=None):
"""Convert an array to an IrrepsArray.
Args:
array (jax.numpy.ndarray or IrrepsArray): array to convert
array (jax.Array or IrrepsArray): array to convert
Returns:
IrrepsArray
Expand Down
50 changes: 25 additions & 25 deletions e3nn_jax/_src/irreps.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,12 @@ def D_from_log_coordinates(self, log_coordinates, k=0):
(matrix) Representation of :math:`O(3)`. :math:`D` is the representation of :math:`SO(3)`.
Args:
log_coordinates (`jax.numpy.ndarray`): of shape :math:`(..., 3)`
k (optional `jax.numpy.ndarray`): of shape :math:`(...)`
log_coordinates (`jax.Array`): of shape :math:`(..., 3)`
k (optional `jax.Array`): of shape :math:`(...)`
How many times the parity is applied.
Returns:
`jax.numpy.ndarray`: of shape :math:`(..., 2l+1, 2l+1)`
`jax.Array`: of shape :math:`(..., 2l+1, 2l+1)`
See Also:
Irreps.D_from_log_coordinates
Expand All @@ -144,17 +144,17 @@ def D_from_angles(self, alpha, beta, gamma, k=0):
(matrix) Representation of :math:`O(3)`. :math:`D` is the representation of :math:`SO(3)`.
Args:
alpha (`jax.numpy.ndarray`): of shape :math:`(...)`
alpha (`jax.Array`): of shape :math:`(...)`
Rotation :math:`\alpha` around Y axis, applied third.
beta (`jax.numpy.ndarray`): of shape :math:`(...)`
beta (`jax.Array`): of shape :math:`(...)`
Rotation :math:`\beta` around X axis, applied second.
gamma (`jax.numpy.ndarray`): of shape :math:`(...)`
gamma (`jax.Array`): of shape :math:`(...)`
Rotation :math:`\gamma` around Y axis, applied first.
k (optional `jax.numpy.ndarray`): of shape :math:`(...)`
k (optional `jax.Array`): of shape :math:`(...)`
How many times the parity is applied.
Returns:
`jax.numpy.ndarray`: of shape :math:`(..., 2l+1, 2l+1)`
`jax.Array`: of shape :math:`(..., 2l+1, 2l+1)`
See Also:
Irreps.D_from_angles
Expand Down Expand Up @@ -196,23 +196,23 @@ def D_from_quaternion(self, q, k=0):
r"""Matrix of the representation, see `Irrep.D_from_angles`.
Args:
q (`jax.numpy.ndarray`): shape :math:`(..., 4)`
k (optional `jax.numpy.ndarray`): shape :math:`(...)`
q (`jax.Array`): shape :math:`(..., 4)`
k (optional `jax.Array`): shape :math:`(...)`
Returns:
`jax.numpy.ndarray`: shape :math:`(..., 2l+1, 2l+1)`
`jax.Array`: shape :math:`(..., 2l+1, 2l+1)`
"""
return self.D_from_angles(*quaternion_to_angles(q), k)

def D_from_matrix(self, R):
r"""Matrix of the representation.
Args:
R (`jax.numpy.ndarray`): array of shape :math:`(..., 3, 3)`
k (`jax.numpy.ndarray`, optional): array of shape :math:`(...)`
R (`jax.Array`): array of shape :math:`(..., 3, 3)`
k (`jax.Array`, optional): array of shape :math:`(...)`
Returns:
`jax.numpy.ndarray`: array of shape :math:`(..., 2l+1, 2l+1)`
`jax.Array`: array of shape :math:`(..., 2l+1, 2l+1)`
Examples:
>>> m = Irrep(1, -1).D_from_matrix(-jnp.eye(3))
Expand All @@ -238,7 +238,7 @@ def generators(self):
r"""Generators of the representation of :math:`SO(3)`.
Returns:
`jax.numpy.ndarray`: array of shape :math:`(3, 2l+1, 2l+1)`
`jax.Array`: array of shape :math:`(3, 2l+1, 2l+1)`
See Also:
`generators`
Expand Down Expand Up @@ -868,11 +868,11 @@ def D_from_log_coordinates(self, log_coordinates, k=0):
r"""Matrix of the representation.
Args:
log_coordinates (`jax.numpy.ndarray`): array of shape :math:`(..., 3)`
k (`jax.numpy.ndarray`, optional): array of shape :math:`(...)`
log_coordinates (`jax.Array`): array of shape :math:`(..., 3)`
k (`jax.Array`, optional): array of shape :math:`(...)`
Returns:
`jax.numpy.ndarray`: array of shape :math:`(..., \mathrm{dim}, \mathrm{dim})`
`jax.Array`: array of shape :math:`(..., \mathrm{dim}, \mathrm{dim})`
"""
return jax.scipy.linalg.block_diag(
*[
Expand All @@ -892,7 +892,7 @@ def D_from_angles(self, alpha, beta, gamma, k=0):
k (int): parity operation
Returns:
`jax.numpy.ndarray`: array of shape :math:`(..., \mathrm{dim}, \mathrm{dim})`
`jax.Array`: array of shape :math:`(..., \mathrm{dim}, \mathrm{dim})`
"""
return jax.scipy.linalg.block_diag(
*[
Expand All @@ -906,22 +906,22 @@ def D_from_quaternion(self, q, k=0):
r"""Matrix of the representation.
Args:
q (`jax.numpy.ndarray`): array of shape :math:`(..., 4)`
k (`jax.numpy.ndarray`, optional): array of shape :math:`(...)`
q (`jax.Array`): array of shape :math:`(..., 4)`
k (`jax.Array`, optional): array of shape :math:`(...)`
Returns:
`jax.numpy.ndarray`: array of shape :math:`(..., \mathrm{dim}, \mathrm{dim})`
`jax.Array`: array of shape :math:`(..., \mathrm{dim}, \mathrm{dim})`
"""
return self.D_from_angles(*quaternion_to_angles(q), k)

def D_from_matrix(self, R):
r"""Matrix of the representation.
Args:
R (`jax.numpy.ndarray`): array of shape :math:`(..., 3, 3)`
R (`jax.Array`): array of shape :math:`(..., 3, 3)`
Returns:
`jax.numpy.ndarray`: array of shape :math:`(..., \mathrm{dim}, \mathrm{dim})`
`jax.Array`: array of shape :math:`(..., \mathrm{dim}, \mathrm{dim})`
"""
d = jnp.sign(jnp.linalg.det(R))
R = d[..., None, None] * R
Expand All @@ -937,7 +937,7 @@ def generators(self) -> jax.Array:
r"""Generators of the representation.
Returns:
`jax.numpy.ndarray`: array of shape :math:`(3, \mathrm{dim}, \mathrm{dim})`
`jax.Array`: array of shape :math:`(3, \mathrm{dim}, \mathrm{dim})`
"""
return jax.vmap(jax.scipy.linalg.block_diag)(
*[ir.generators() for mul, ir in self for _ in range(mul)]
Expand Down
10 changes: 5 additions & 5 deletions e3nn_jax/_src/irreps_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class IrrepsArray:
Args:
irreps (Irreps): representation of the data
array (`jax.numpy.ndarray`): the data, an array of shape ``(..., irreps.dim)``
array (`jax.Array`): the data, an array of shape ``(..., irreps.dim)``
zero_flags (tuple of bool, optional): whether each chunk of the data is zero
Examples:
Expand Down Expand Up @@ -962,7 +962,7 @@ def transform_by_log_coordinates(
r"""Rotate data by a rotation given by log coordinates.
Args:
log_coordinates (`jax.numpy.ndarray`): log coordinates
log_coordinates (`jax.Array`): log coordinates
k (int): parity operation
Returns:
Expand Down Expand Up @@ -1039,7 +1039,7 @@ def transform_by_quaternion(self, q: jax.Array, k: int = 0) -> "IrrepsArray":
r"""Rotate data by a rotation given by a quaternion.
Args:
q (`jax.numpy.ndarray`): quaternion
q (`jax.Array`): quaternion
k (int): parity operation
Returns:
Expand All @@ -1055,7 +1055,7 @@ def transform_by_axis_angle(
r"""Rotate data by a rotation given by an axis and an angle.
Args:
axis (`jax.numpy.ndarray`): axis
axis (`jax.Array`): axis
angle (float): angle (in radians)
k (int): parity operation
Expand All @@ -1070,7 +1070,7 @@ def transform_by_matrix(self, R: jax.Array) -> "IrrepsArray":
r"""Rotate data by a rotation given by a matrix.
Args:
R (`jax.numpy.ndarray`): rotation matrix
R (`jax.Array`): rotation matrix
Returns:
`IrrepsArray`: rotated data
Expand Down
2 changes: 1 addition & 1 deletion e3nn_jax/_src/mlp_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __call__(
) -> Union[jax.Array, e3nn.IrrepsArray]:
"""Evaluate the MLP
Input and output are either `jax.numpy.ndarray` or `IrrepsArray`.
Input and output are either `jax.Array` or `IrrepsArray`.
If the input is a `IrrepsArray`, it must contain only scalars.
Args:
Expand Down
2 changes: 1 addition & 1 deletion e3nn_jax/_src/mlp_haiku.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __call__(
) -> Union[jax.Array, e3nn.IrrepsArray]:
"""Evaluate the MLP
Input and output are either `jax.numpy.ndarray` or `IrrepsArray`.
Input and output are either `jax.Array` or `IrrepsArray`.
If the input is a `IrrepsArray`, it must contain only scalars.
Args:
Expand Down
8 changes: 4 additions & 4 deletions e3nn_jax/_src/radius_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@ def radius_graph(
r"""Try to use ``matscipy.neighbours.neighbour_list`` instead.
Args:
pos (`jax.numpy.ndarray`): array of shape ``(n, 3)``
pos (`jax.Array`): array of shape ``(n, 3)``
r_max (float):
batch (`jax.numpy.ndarray`): indices
batch (`jax.Array`): indices
size (int): size of the output
loop (bool): whether to include self-loops
Returns:
(tuple): tuple containing:
jax.numpy.ndarray: source indices
jax.numpy.ndarray: destination indices
jax.Array: source indices
jax.Array: destination indices
Examples:
>>> key = jax.random.PRNGKey(0)
Expand Down
Loading

0 comments on commit ef33ae9

Please sign in to comment.