Skip to content

Commit

Permalink
fix template embedder
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 16, 2021
1 parent 54840c1 commit 63d410b
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 4 deletions.
4 changes: 2 additions & 2 deletions alphafold2_pytorch/alphafold2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,14 +1171,14 @@ def forward(
# todo (make efficient)

if exists(templates_sidechains):
if self.use_se3_transformer:
if self.template_embedder_type == 'se3':
t_seq = self.template_sidechain_emb(
t_seq,
templates_sidechains,
templates_coors,
mask = templates_mask
)
else:
elif self.template_embedder_type == 'en':
shape = t_seq.shape
t_seq = rearrange(t_seq, 'b t n d -> (b t) n d')
templates_coors = rearrange(templates_coors, 'b t n c -> (b t) n c')
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'alphafold2-pytorch',
packages = find_packages(),
version = '0.0.103',
version = '0.0.104',
license='MIT',
description = 'AlphaFold2 - Pytorch',
author = 'Phil Wang, Eric Alcaide',
Expand Down
35 changes: 34 additions & 1 deletion tests/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,13 @@ def test_kron_cross_attn():
)
assert True

def test_templates():
def test_templates_se3():
model = Alphafold2(
dim = 32,
depth = 2,
heads = 2,
dim_head = 32,
template_embedder_type = 'se3',
attn_types = ('full', 'intra_attn', 'seq_only')
)

Expand All @@ -138,6 +139,38 @@ def test_templates():
templates_coors = templates_coors,
templates_mask = templates_mask
)
assert True

def test_templates_en():
model = Alphafold2(
dim = 32,
depth = 2,
heads = 2,
dim_head = 32,
template_embedder_type = 'en',
attn_types = ('full', 'intra_attn', 'seq_only')
)

seq = torch.randint(0, 21, (2, 16))
mask = torch.ones_like(seq).bool()

msa = torch.randint(0, 21, (2, 5, 32))
msa_mask = torch.ones_like(msa).bool()

templates_seq = torch.randint(0, 21, (2, 2, 16))
templates_coors = torch.randn(2, 2, 16, 3)
templates_mask = torch.ones_like(templates_seq).bool()

distogram = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask,
templates_seq = templates_seq,
templates_coors = templates_coors,
templates_mask = templates_mask
)
assert True

def test_embeddings():
model = Alphafold2(
Expand Down

0 comments on commit 63d410b

Please sign in to comment.