Skip to content

Commit

Permalink
add ability to use EGNN for refinement
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 16, 2021
1 parent 63d410b commit 07c852a
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 5 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ distogram, theta, phi, omega = model(

Fabian's <a href="https://arxiv.org/abs/2102.13419">recent paper</a> suggests iteratively feeding the coordinates back into SE3 Transformer, weight shared, may work. I have decided to execute based on this idea, even though it is still up in the air how it actually works.

You can also use <a href="https://github.com/lucidrains/En-transformer">E(n)-Transformer</a> for the refinement, if you are in the experimental mood (paper just came out a week ago).
You can also use <a href="https://github.com/lucidrains/En-transformer">E(n)-Transformer</a> or <a href="https://github.com/lucidrains/egnn-pytorch">EGNN</a> for structural refinement.

```python
import torch
Expand Down Expand Up @@ -422,12 +422,14 @@ distogram = model(
There are two equivariant self attention libraries that I have prepared for the purposes of replication. One is the implementation by Fabian Fuchs as detailed in a <a href="https://fabianfuchsml.github.io/alphafold2/">speculatory blogpost</a>. The other is from a recent paper from Deepmind, claiming their approach is better than using irreducible representations.

- <a href="https://github.com/lucidrains/se3-transformer-pytorch">SE3 Transformer</a>
- <a href="https://github.com/lucidrains/lie-transformer-pytorch">Lie Transformer</a>
- <a href="https://github.com/lucidrains/egnn-pytorch">Lie Transformer</a>

A new paper from Welling uses invariant features for E(n) equivariance, reaching SOTA and outperforming SE3 Transformer at a number of benchmarks, while being much faster. You can use this by simply setting `structure_module_type = "en"` on Alphafold2 initialization.
A <a href="https://arxiv.org/abs/2102.09844">new paper</a> from Welling uses invariant features for E(n) equivariance, reaching SOTA and outperforming SE3 Transformer at a number of benchmarks, while being much faster. You can use this by simply setting `structure_module_type = "egnn"` or `structure_module_type = "en"` on Alphafold2 initialization.

- <a href="https://github.com/lucidrains/En-transformer">E(n)-Transformer</a>

- <a href="https://github.com/lucidrains/En-transformer">EGNN</a>

## Testing

```bash
Expand Down
13 changes: 12 additions & 1 deletion alphafold2_pytorch/alphafold2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from se3_transformer_pytorch import SE3Transformer
from se3_transformer_pytorch.utils import torch_default_dtype, fourier_encode
from en_transformer import EnTransformer
from egnn_pytorch import EGNN_Network

from performer_pytorch import FastAttention, ProjectionUpdater

Expand Down Expand Up @@ -820,6 +821,7 @@ def __init__(
structure_module_refinement_iters = 2,
structure_module_knn = 2,
structure_module_adj_neighbors = 2,
structure_module_adj_dim = 4,
cross_attn_linear = False,
cross_attn_linear_projection_update_every = 1000,
cross_attn_kron_primary = False,
Expand Down Expand Up @@ -1048,7 +1050,7 @@ def __init__(
attend_sparse_neighbors = True,
edge_dim = edge_dim,
num_adj_degrees = structure_module_adj_neighbors,
adj_dim = 4,
adj_dim = structure_module_adj_dim,
global_feats_dim = global_feats_dim,
tie_key_values = True,
one_headed_key_values = True,
Expand All @@ -1066,6 +1068,15 @@ def __init__(
num_adj_degrees = structure_module_adj_neighbors,
adj_dim = 4
)
elif structure_module_type == 'egnn':
self.structure_module = EGNN_Network(
dim = structure_module_dim,
depth = structure_module_depth,
num_positions = max_seq_len * 14, # hard code as 14 since residual to atom is not flexible atm
edge_dim = edge_dim,
num_adj_degrees = structure_module_adj_neighbors,
adj_dim = structure_module_adj_dim,
)
else:
raise ValueError('structure module must be either "se3", "en", or "egnn" for SE3 Transformers, E(n)-Transformers, or EGNN respectively')

Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'alphafold2-pytorch',
packages = find_packages(),
version = '0.0.104',
version = '0.1.0',
license='MIT',
description = 'AlphaFold2 - Pytorch',
author = 'Phil Wang, Eric Alcaide',
Expand All @@ -16,6 +16,7 @@
],
install_requires=[
'einops>=0.3',
'egnn-pytorch>=0.1.10',
'En-transformer>=0.2.3',
'mdtraj>=1.8',
'numpy',
Expand Down
27 changes: 27 additions & 0 deletions tests/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,33 @@ def test_coords_En_backwards():
coords.sum().backward()
assert True, 'must be able to go backwards through MDS and center distogram'

def test_coords_egnn_backwards():
model = Alphafold2(
dim = 256,
depth = 2,
heads = 2,
dim_head = 32,
structure_module_type = "egnn",
predict_coords = True,
num_backbone_atoms = 3
)

seq = torch.randint(0, 21, (2, 16))
mask = torch.ones_like(seq).bool()

msa = torch.randint(0, 21, (2, 5, 32))
msa_mask = torch.ones_like(msa).bool()

coords = model(
seq,
msa,
mask = mask,
msa_mask = msa_mask
)

coords.sum().backward()
assert True, 'must be able to go backwards through MDS and center distogram'


def test_confidence_En():
model = Alphafold2(
Expand Down

0 comments on commit 07c852a

Please sign in to comment.