Skip to content

Commit

Permalink
update black
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Jan 26, 2024
1 parent cc739f4 commit 8c502c1
Show file tree
Hide file tree
Showing 14 changed files with 119 additions and 71 deletions.
56 changes: 33 additions & 23 deletions e3nn_jax/_src/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,11 @@ def from_chunks(
if irreps.dim > 0:
array = jnp.concatenate(
[
jnp.zeros(leading_shape + (mul_ir.dim,), dtype)
if x is None
else x.reshape(leading_shape + (mul_ir.dim,))
(
jnp.zeros(leading_shape + (mul_ir.dim,), dtype)
if x is None
else x.reshape(leading_shape + (mul_ir.dim,))
)
for mul_ir, x in zip(irreps, chunks)
],
axis=-1,
Expand Down Expand Up @@ -306,16 +308,20 @@ def concatenate(arrays: List[e3nn.IrrepsArray], axis: int = -1) -> e3nn.IrrepsAr

zero_flags = [all(x) for x in zip(*[x.zero_flags for x in arrays])]
chunks = [
None
if z
else jnp.concatenate(
[
jnp.zeros(x.shape[:-1] + (mul, ir.dim), dtype=x.dtype)
if x.chunks[i] is None
else x.chunks[i]
for x in arrays
],
axis=axis,
(
None
if z
else jnp.concatenate(
[
(
jnp.zeros(x.shape[:-1] + (mul, ir.dim), dtype=x.dtype)
if x.chunks[i] is None
else x.chunks[i]
)
for x in arrays
],
axis=axis,
)
)
for i, ((mul, ir), z) in enumerate(zip(irreps, zero_flags))
]
Expand Down Expand Up @@ -374,16 +380,20 @@ def stack(arrays: List[e3nn.IrrepsArray], axis=0) -> e3nn.IrrepsArray:

zero_flags = [all(x) for x in zip(*[x.zero_flags for x in arrays])]
chunks = [
None
if z
else jnp.stack(
[
jnp.zeros(x.shape[:-1] + (mul, ir.dim), dtype=x.dtype)
if x.chunks[i] is None
else x.chunks[i]
for x in arrays
],
axis=axis,
(
None
if z
else jnp.stack(
[
(
jnp.zeros(x.shape[:-1] + (mul, ir.dim), dtype=x.dtype)
if x.chunks[i] is None
else x.chunks[i]
)
for x in arrays
],
axis=axis,
)
)
for i, ((mul, ir), z) in enumerate(zip(irreps, zero_flags))
]
Expand Down
2 changes: 2 additions & 0 deletions e3nn_jax/_src/irreps.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class Irrep:
>>> Irrep("1o") + Irrep("2o")
1x1o+1x2o
"""

l: int
p: int

Expand Down Expand Up @@ -308,6 +309,7 @@ def __eq__(self, other: object) -> bool:
@dataclasses.dataclass(init=False, frozen=True)
class MulIrrep:
r"""An Irrep with a multiplicity."""

mul: int
ir: Irrep

Expand Down
32 changes: 21 additions & 11 deletions e3nn_jax/_src/irreps_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,13 @@ def chunks(self) -> List[Optional[jax.Array]]:
chunks = [jnp.reshape(self.array, leading_shape + (mul, ir.dim))]
else:
chunks = [
None
if zero
else jnp.reshape(self.array[..., i], leading_shape + (mul, ir.dim))
(
None
if zero
else jnp.reshape(
self.array[..., i], leading_shape + (mul, ir.dim)
)
)
for zero, i, (mul, ir) in zip(
zeros, self.irreps.slices(), self.irreps
)
Expand Down Expand Up @@ -974,11 +978,14 @@ def transform_by_log_coordinates(
for ir in {ir for _, ir in self.irreps}
}
new_list = [
jnp.reshape(
jnp.einsum("ij,...uj->...ui", D[ir], x), self.shape[:-1] + (mul, ir.dim)
(
jnp.reshape(
jnp.einsum("ij,...uj->...ui", D[ir], x),
self.shape[:-1] + (mul, ir.dim),
)
if x is not None
else None
)
if x is not None
else None
for (mul, ir), x in zip(self.irreps, self.chunks)
]
return e3nn.from_chunks(self.irreps, new_list, self.shape[:-1], self.dtype)
Expand Down Expand Up @@ -1026,11 +1033,14 @@ def transform_by_angles(
if inverse:
D = {ir: jnp.swapaxes(D[ir], -2, -1) for ir in D}
new_chunks = [
jnp.reshape(
jnp.einsum("ij,...uj->...ui", D[ir], x), self.shape[:-1] + (mul, ir.dim)
(
jnp.reshape(
jnp.einsum("ij,...uj->...ui", D[ir], x),
self.shape[:-1] + (mul, ir.dim),
)
if x is not None
else None
)
if x is not None
else None
for (mul, ir), x in zip(self.irreps, self.chunks)
]
return e3nn.from_chunks(self.irreps, new_chunks, self.shape[:-1], self.dtype)
Expand Down
17 changes: 10 additions & 7 deletions e3nn_jax/_src/legacy/core_tensor_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class FunctionalTensorProduct:
gradient_normalization (str or float): Normalization of the gradients, ``element`` or ``path``.
0/1 corresponds to a normalization where each element/path has an equal contribution to the learning.
"""

irreps_in1: e3nn.Irreps
irreps_in2: e3nn.Irreps
irreps_out: e3nn.Irreps
Expand Down Expand Up @@ -215,14 +216,16 @@ def __init__(
if self.irreps_out.dim > 0:
self.output_mask = jnp.concatenate(
[
jnp.ones(mul_ir.dim, dtype=bool)
if any(
(ins.i_out == i_out)
and (ins.path_weight != 0)
and (0 not in ins.path_shape)
for ins in self.instructions
(
jnp.ones(mul_ir.dim, dtype=bool)
if any(
(ins.i_out == i_out)
and (ins.path_weight != 0)
and (0 not in ins.path_shape)
for ins in self.instructions
)
else jnp.zeros(mul_ir.dim, dtype=bool)
)
else jnp.zeros(mul_ir.dim, dtype=bool)
for i_out, mul_ir in enumerate(self.irreps_out)
]
)
Expand Down
60 changes: 36 additions & 24 deletions e3nn_jax/_src/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,14 @@ def alpha(this):
if irreps_out.dim > 0:
output_mask = jnp.concatenate(
[
jnp.ones(mul_ir.dim, bool)
if any(
(ins.i_out == i_out) and (0 not in ins.path_shape)
for ins in instructions
(
jnp.ones(mul_ir.dim, bool)
if any(
(ins.i_out == i_out) and (0 not in ins.path_shape)
for ins in instructions
)
else jnp.zeros(mul_ir.dim, bool)
)
else jnp.zeros(mul_ir.dim, bool)
for i_out, mul_ir in enumerate(irreps_out)
]
)
Expand Down Expand Up @@ -175,13 +177,15 @@ def __call__(
ws = self.split_weights(ws)

paths = [
ins.path_weight * w
if ins.i_in == -1
else (
None
if input.chunks[ins.i_in] is None
else ins.path_weight
* jnp.einsum("uw,ui->wi", w, input.chunks[ins.i_in])
(
ins.path_weight * w
if ins.i_in == -1
else (
None
if input.chunks[ins.i_in] is None
else ins.path_weight
* jnp.einsum("uw,ui->wi", w, input.chunks[ins.i_in])
)
)
for ins, w in zip(self.instructions, ws)
]
Expand Down Expand Up @@ -227,9 +231,11 @@ def linear_vanilla(
"""Vanilla linear layer."""
w = [
get_parameter(
f"b[{ins.i_out}] {linear.irreps_out[ins.i_out]}"
if ins.i_in == -1
else f"w[{ins.i_in},{ins.i_out}] {linear.irreps_in[ins.i_in]},{linear.irreps_out[ins.i_out]}",
(
f"b[{ins.i_out}] {linear.irreps_out[ins.i_out]}"
if ins.i_in == -1
else f"w[{ins.i_in},{ins.i_out}] {linear.irreps_in[ins.i_in]},{linear.irreps_out[ins.i_out]}"
),
ins.path_shape,
ins.weight_std,
input.dtype,
Expand Down Expand Up @@ -260,9 +266,11 @@ def linear_indexed(

w = [
get_parameter(
f"b[{ins.i_out}] {lin.irreps_out[ins.i_out]}"
if ins.i_in == -1
else f"w[{ins.i_in},{ins.i_out}] {lin.irreps_in[ins.i_in]},{lin.irreps_out[ins.i_out]}",
(
f"b[{ins.i_out}] {lin.irreps_out[ins.i_out]}"
if ins.i_in == -1
else f"w[{ins.i_in},{ins.i_out}] {lin.irreps_in[ins.i_in]},{lin.irreps_out[ins.i_out]}"
),
(num_indexed_weights,) + ins.path_shape,
ins.weight_std,
input.dtype,
Expand Down Expand Up @@ -299,9 +307,11 @@ def linear_mixed(

w = [
get_parameter(
f"b[{ins.i_out}] {lin.irreps_out[ins.i_out]}"
if ins.i_in == -1
else f"w[{ins.i_in},{ins.i_out}] {lin.irreps_in[ins.i_in]},{lin.irreps_out[ins.i_out]}",
(
f"b[{ins.i_out}] {lin.irreps_out[ins.i_out]}"
if ins.i_in == -1
else f"w[{ins.i_in},{ins.i_out}] {lin.irreps_in[ins.i_in]},{lin.irreps_out[ins.i_out]}"
),
(d,) + ins.path_shape,
stddev * ins.weight_std,
input.dtype,
Expand Down Expand Up @@ -341,9 +351,11 @@ def linear_mixed_per_channel(

w = [
get_parameter(
f"b[{ins.i_out}] {lin.irreps_out[ins.i_out]}"
if ins.i_in == -1
else f"w[{ins.i_in},{ins.i_out}] {lin.irreps_in[ins.i_in]},{lin.irreps_out[ins.i_out]}",
(
f"b[{ins.i_out}] {lin.irreps_out[ins.i_out]}"
if ins.i_in == -1
else f"w[{ins.i_in},{ins.i_out}] {lin.irreps_in[ins.i_in]},{lin.irreps_out[ins.i_out]}"
),
(d, nc) + ins.path_shape,
stddev * ins.weight_std,
input.dtype,
Expand Down
1 change: 1 addition & 0 deletions e3nn_jax/_src/linear_equinox.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class Linear(eqx.Module):
>>> linear(i, x).shape
(5,)
"""

irreps_out: e3nn.Irreps
irreps_in: e3nn.Irreps
channel_out: int
Expand Down
1 change: 1 addition & 0 deletions e3nn_jax/_src/linear_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class Linear(flax.linen.Module):
>>> linear.apply(w, i, x).shape
(5,)
"""

irreps_out: e3nn.Irreps
irreps_in: Optional[e3nn.Irreps] = None
channel_out: Optional[int] = None
Expand Down
1 change: 1 addition & 0 deletions e3nn_jax/_src/mlp_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class MultiLayerPerceptron(flax.linen.Module):
- "path" (default): normalization done explicitly in the forward pass,
gives the same importance to every layer independently of the number of neurons
"""

list_neurons: Tuple[int, ...]
act: Optional[Callable] = None
gradient_normalization: Union[str, float] = None
Expand Down
8 changes: 5 additions & 3 deletions e3nn_jax/_src/radial.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,11 @@ def u(p: int, x: jax.Array) -> jax.Array:

def _constraint(x: float, derivative: int, degree: int):
return [
0
if derivative > N
else factorial(N) // factorial(N - derivative) * x ** (N - derivative)
(
0
if derivative > N
else factorial(N) // factorial(N - derivative) * x ** (N - derivative)
)
for N in range(degree)
]

Expand Down
1 change: 1 addition & 0 deletions e3nn_jax/_src/reduced_tensor_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
- Ported in `e3nn-jax` by Mario Geiger
- Optimized the symmetric case by Ameya Daigavane and Mario Geiger
"""

import functools
import itertools
from math import prod
Expand Down
1 change: 1 addition & 0 deletions e3nn_jax/_src/s2grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def f(coords):
import plotly.graph_objects as go
go.Figure([go.Surface(signal.plotly_surface())])
"""

grid_values: jax.Array
quadrature: str
p_val: int
Expand Down
1 change: 1 addition & 0 deletions e3nn_jax/_src/symmetric_tensor_product_haiku.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Implementation from MACE: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields
Ilyes Batatia, Dávid Péter Kovács, Gregor N. C. Simm, Christoph Ortner and Gábor Csányi
"""

from typing import Any, Callable, Optional, Set, Tuple

import haiku as hk
Expand Down
8 changes: 5 additions & 3 deletions e3nn_jax/_src/utils/optimize_jaxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,11 @@ def param_key(p: Any):
eqns=[
eqn.replace(
params={
k: v.replace(jaxpr=remove_duplicate_equations(v.jaxpr))
if type(v) is ClosedJaxpr
else v
k: (
v.replace(jaxpr=remove_duplicate_equations(v.jaxpr))
if type(v) is ClosedJaxpr
else v
)
for k, v in eqn.params.items()
}
)
Expand Down
1 change: 1 addition & 0 deletions e3nn_jax/experimental/linear_shtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
- Added the support of inversion symmetry. (Mario Geiger)
"""

from typing import Sequence

import flax
Expand Down

0 comments on commit 8c502c1

Please sign in to comment.