Skip to content

Commit

Permalink
[CLEANUP]
Browse files Browse the repository at this point in the history
  • Loading branch information
kyegomez committed May 9, 2024
1 parent c4ed607 commit b7c4f34
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 19 deletions.
2 changes: 2 additions & 0 deletions alphafold3/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from alphafold3.diffusion import GeneticDiffusionModuleBlock
from alphafold3.model import AlphaFold3

__all__ = [
"GeneticDiffusionModuleBlock",
"AlphaFold3",
]
19 changes: 1 addition & 18 deletions model.py → alphafold3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,21 +629,4 @@ def forward(
if return_embeddings is True:
return x


x = torch.randn(1, 5, 5, 64)
y = torch.randn(1, 5, 64)

model = AlphaFold3(
dim=64,
seq_len=5,
heads=8,
dim_head=64,
attn_dropout=0.0,
ff_dropout=0.0,
global_column_attn=False,
pair_former_depth=48,
num_diffusion_steps=1000,
diffusion_depth=30,
)
output = model(x, y)
print(output.shape)

2 changes: 1 addition & 1 deletion alphafold3/pairformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.utils.checkpoint import checkpoint_sequential
from typing import Tuple, Optional
import torch
from model import (
from alphafold3.model import (
FeedForward,
AxialAttention,
TriangleMultiplicativeModule,
Expand Down
20 changes: 20 additions & 0 deletions model_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
from alphafold3 import AlphaFold3

x = torch.randn(1, 5, 5, 64)
y = torch.randn(1, 5, 64)

model = AlphaFold3(
dim=64,
seq_len=5,
heads=8,
dim_head=64,
attn_dropout=0.0,
ff_dropout=0.0,
global_column_attn=False,
pair_former_depth=48,
num_diffusion_steps=1000,
diffusion_depth=30,
)
output = model(x, y)
print(output.shape)

0 comments on commit b7c4f34

Please sign in to comment.