Skip to content

Commit

Permalink
make sure to mask out padding tokens in mlp attention, at the attenti…
Browse files Browse the repository at this point in the history
…on processing stage
  • Loading branch information
lucidrains committed Aug 27, 2023
1 parent 5dfc465 commit dd057f3
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
20 changes: 17 additions & 3 deletions equiformer_pytorch/equiformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,9 @@ def forward(
if exists(neighbor_mask):
neighbor_mask = rearrange(neighbor_mask, 'b i j -> b 1 i j')

if self.attend_self:
neighbor_mask = F.pad(neighbor_mask, (1, 0), value = True)

features = self.prenorm(features)

queries = self.to_q(features)
Expand Down Expand Up @@ -651,9 +654,6 @@ def forward(
if not is_degree_zero:
sim = sim.sum(dim = -1)

if exists(neighbor_mask):
left_pad_needed = int(self.attend_self)
padded_neighbor_mask = F.pad(neighbor_mask, (left_pad_needed, 0), value = True)
sim = sim.masked_fill(~padded_neighbor_mask, -torch.finfo(sim.dtype).max)

attn = sim.softmax(dim = -1)
Expand Down Expand Up @@ -698,6 +698,8 @@ def __init__(
self.single_headed_kv = single_headed_kv
value_hidden_fiber = hidden_fiber if not single_headed_kv else dim_head

self.attend_self = attend_self

self.scale = tuple(dim ** -0.5 for dim in dim_head)
self.heads = heads

Expand Down Expand Up @@ -766,6 +768,14 @@ def forward(
):
one_headed_kv = self.single_headed_kv

_, neighbor_mask, _ = edge_info

if exists(neighbor_mask):
if self.attend_self:
neighbor_mask = F.pad(neighbor_mask, (1, 0), value = True)

neighbor_mask = rearrange(neighbor_mask, '... -> ... 1')

features = self.prenorm(features)

intermediate = self.to_attn_and_v(
Expand All @@ -788,6 +798,10 @@ def forward(
attn_intermediate = rearrange(attn_intermediate, '... 1 -> ...')
attn_logits = fn(attn_intermediate)
attn_logits = attn_logits * scale

if exists(neighbor_mask):
attn_logits = attn_logits.masked_fill(~neighbor_mask, -torch.finfo(attn_logits.dtype).max)

attn = attn_logits.softmax(dim = -2) # (batch, source, target, heads)
attentions.append(attn)

Expand Down
2 changes: 1 addition & 1 deletion equiformer_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.3.7'
__version__ = '0.3.8'

0 comments on commit dd057f3

Please sign in to comment.