Skip to content

Commit

Permalink
add an assert to make sure total sequence length of MSA or AA does no…
Browse files Browse the repository at this point in the history
…t exceed the set length of the sparse attention kernel
  • Loading branch information
lucidrains committed Feb 11, 2021
1 parent b0d3ab6 commit 85c40d5
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 3 additions & 1 deletion alphafold2_pytorch/alphafold2.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,15 @@ def __init__(

def forward(self, x, mask = None):
device, h = x.device, self.heads
n = x.shape[1]
b, n = x.shape[:2]
assert n <= self.seq_len, f'either the AA sequence length {n} or the total MSA length {n} exceeds the allowable sequence length {self.seq_len} for sparse attention, set by `max_seq_len`'

remainder = x.shape[1] % self.block_size

if remainder > 0:
padding = self.block_size - remainder
x = F.pad(x, (0, 0, 0, padding), value = 0)
mask = torch.ones(b, n, device = device).bool()

This comment has been minimized.

Copy link
@nilbot

nilbot Feb 11, 2021

Contributor

Is this overwriting mask if it's not None?

mask = F.pad(mask, (0, padding), value = False)

q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'alphafold2-pytorch',
packages = find_packages(),
version = '0.0.17',
version = '0.0.18',
license='MIT',
description = 'AlphaFold2 - Pytorch',
author = 'Phil Wang, Eric Alcaide',
Expand Down

0 comments on commit 85c40d5

Please sign in to comment.