Skip to content

Commit

Permalink
make sure outermean can be masked for padding in msa rows, and also f…
Browse files Browse the repository at this point in the history
…ix a bug with mask
  • Loading branch information
lucidrains committed Aug 23, 2021
1 parent db44ed6 commit dda22b5
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
25 changes: 18 additions & 7 deletions alphafold2_pytorch/alphafold2.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,22 +316,32 @@ class OuterMean(nn.Module):
def __init__(
self,
dim,
hidden_dim = None
hidden_dim = None,
eps = 1e-5
):
super().__init__()
self.eps = eps
self.norm = nn.LayerNorm(dim)
hidden_dim = default(hidden_dim, dim)

self.left_proj = nn.Linear(dim, hidden_dim)
self.right_proj = nn.Linear(dim, hidden_dim)
self.proj_out = nn.Linear(hidden_dim, dim)

def forward(self, x):
def forward(self, x, mask = None):
x = self.norm(x)
left = self.left_proj(x)
right = self.right_proj(x)
outer = rearrange(left, 'b m i d -> b m i () d') * rearrange(right, 'b m j d -> b m () j d')
outer = outer.mean(dim = 1)

if exists(mask):
# masked mean, if there are padding in the rows of the MSA
mask = rearrange(mask, 'b m i -> b m i () ()') * rearrange(mask, 'b m j -> b m () j ()')
outer = outer.masked_fill(mask, 0.)
outer = outer.mean(dim = 1) / (mask.sum(dim = 1) + self.eps)
else:
outer = outer.mean(dim = 1)

return self.proj_out(outer)

class PairwiseAttentionBlock(nn.Module):
Expand All @@ -356,10 +366,11 @@ def forward(
self,
x,
mask = None,
msa_repr = None
msa_repr = None,
msa_mask = None
):
if exists(msa_repr):
x = x + self.outer_mean(msa_repr)
x = x + self.outer_mean(msa_repr, mask = msa_mask)

x = self.triangle_multiply_outgoing(x, mask = mask) + x
x = self.triangle_multiply_ingoing(x, mask = mask) + x
Expand Down Expand Up @@ -423,7 +434,7 @@ def forward(self, inputs):

# pairwise attention and transition

x = attn(x, mask = mask)
x = attn(x, mask = mask, msa_repr = m, msa_mask = msa_mask)
x = ff(x) + x

return x, m, mask, msa_mask
Expand Down Expand Up @@ -695,7 +706,7 @@ def forward(

x_left, x_right = self.to_pairwise_repr(x).chunk(2, dim = -1)
x = rearrange(x_left, 'b i d -> b i () d') + rearrange(x_right, 'b j d-> b () j d') # create pair-wise residue embeds
x_mask = rearrange(mask, 'b i -> b i ()') + rearrange(mask, 'b j -> b () j') if exists(mask) else None
x_mask = rearrange(mask, 'b i -> b i ()') * rearrange(mask, 'b j -> b () j') if exists(mask) else None

# add relative positional embedding

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'alphafold2-pytorch',
packages = find_packages(),
version = '0.4.25',
version = '0.4.26',
license='MIT',
description = 'AlphaFold2 - Pytorch',
author = 'Phil Wang, Eric Alcaide',
Expand Down

0 comments on commit dda22b5

Please sign in to comment.