Skip to content

Commit

Permalink
small cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 22, 2023
1 parent cab5dcf commit 3c913c7
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions equiformer_pytorch/equiformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
pad_for_centering_y_to_x
)

from einops import rearrange, repeat, einsum, pack, unpack
from einops import rearrange, repeat, reduce, einsum, pack, unpack
from einops.layers.torch import Rearrange

# constants
Expand Down Expand Up @@ -1039,7 +1039,7 @@ def forward(
degree = ind + 2

next_degree_adj_mat = (adj_mat.float() @ adj_mat.float()) > 0
next_degree_mask = (next_degree_adj_mat.float() - adj_mat.float()).bool()
next_degree_mask = next_degree_adj_mat & ~adj_mat
adj_indices = adj_indices.masked_fill(next_degree_mask, degree)
adj_mat = next_degree_adj_mat.clone()

Expand All @@ -1056,7 +1056,7 @@ def forward(
adj_mat = remove_self(adj_mat)

adj_mat_values = adj_mat.float()
adj_mat_max_neighbors = adj_mat_values.sum(dim = -1).max().item()
adj_mat_max_neighbors = reduce(adj_mat_values, '... i j -> ... i', 'sum').amax().item()

if max_sparse_neighbors < adj_mat_max_neighbors:
eps = 1e-2
Expand Down
2 changes: 1 addition & 1 deletion equiformer_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.3.3'
__version__ = '0.3.4'

0 comments on commit 3c913c7

Please sign in to comment.