Skip to content

Commit

Permalink
[FEAT][Main class]
Browse files Browse the repository at this point in the history
  • Loading branch information
kyegomez committed May 9, 2024
1 parent 96ebc48 commit c4ed607
Show file tree
Hide file tree
Showing 8 changed files with 188 additions and 138 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ Need review but basically it operates on atomic coordinates.

```python
import torch
from alphafold3.diffusion import GeneticDiffusionModule
from alphafold3.diffusion import GeneticDiffusionModuleBlock

# Create an instance of the GeneticDiffusionModule
model = GeneticDiffusionModule(channels=3, training=True)
# Create an instance of the GeneticDiffusionModuleBlock
model = GeneticDiffusionModuleBlock(channels=3, training=True)

# Generate random input coordinates
input_coords = torch.randn(10, 100, 100, 3)
Expand Down
4 changes: 2 additions & 2 deletions alphafold3/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from alphafold3.diffusion import GeneticDiffusionModule
from alphafold3.diffusion import GeneticDiffusionModuleBlock

__all__ = [
"GeneticDiffusionModule",
"GeneticDiffusionModuleBlock",
]
67 changes: 64 additions & 3 deletions alphafold3/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.nn.functional as F


class GeneticDiffusionModule(nn.Module):
class GeneticDiffusionModuleBlock(nn.Module):
"""
Diffusion Module from AlphaFold 3.
Expand All @@ -22,6 +22,7 @@ def __init__(
channels: int,
num_diffusion_steps: int = 1000,
training: bool = False,
depth: int = 30,
):
"""
Initializes the DiffusionModule with the specified number of channels and diffusion steps.
Expand All @@ -30,10 +31,11 @@ def __init__(
channels (int): Number of feature channels for the input.
num_diffusion_steps (int): Number of diffusion steps (time steps in the diffusion process).
"""
super(GeneticDiffusionModule, self).__init__()
super(GeneticDiffusionModuleBlock, self).__init__()
self.channels = channels
self.num_diffusion_steps = num_diffusion_steps
self.training = training
self.depth = depth
self.noise_scale = nn.Parameter(
torch.linspace(1.0, 0.01, num_diffusion_steps)
)
Expand Down Expand Up @@ -77,9 +79,68 @@ def forward(self, x: Tensor = None, ground_truth: Tensor = None):
return noisy_x


class GeneticDiffusion(nn.Module):
"""
GeneticDiffusion module for performing genetic diffusion.
Args:
channels (int): Number of input channels.
num_diffusion_steps (int): Number of diffusion steps to perform.
training (bool): Whether the module is in training mode or not.
depth (int): Number of diffusion module blocks to stack.
Attributes:
channels (int): Number of input channels.
num_diffusion_steps (int): Number of diffusion steps to perform.
training (bool): Whether the module is in training mode or not.
depth (int): Number of diffusion module blocks to stack.
layers (nn.ModuleList): List of GeneticDiffusionModuleBlock instances.
"""

def __init__(
self,
channels: int,
num_diffusion_steps: int = 1000,
training: bool = False,
depth: int = 30,
):
super(GeneticDiffusion, self).__init__()
self.channels = channels
self.num_diffusion_steps = num_diffusion_steps
self.training = training
self.depth = depth

# Layers
self.layers = nn.ModuleList(
[
GeneticDiffusionModuleBlock(
channels, num_diffusion_steps, training, depth
)
for _ in range(depth)
]
)

def forward(self, x: Tensor = None, ground_truth: Tensor = None):
"""
Forward pass of the GeneticDiffusion module.
Args:
x (Tensor): Input tensor.
ground_truth (Tensor): Ground truth tensor.
Returns:
Tuple[Tensor, Tensor]: Output tensor and loss tensor.
"""
for layer in self.layers:
x, loss = layer(x, ground_truth)
return x, loss


# # Example usage
# if __name__ == "__main__":
# model = GeneticDiffusionModule(
# model = GeneticDiffusionModuleBlock(
# channels=3, training=True
# ) # Assuming 3D coordinates
# input_coords = torch.randn(
Expand Down
99 changes: 0 additions & 99 deletions alphafold3/pair_transition.py

This file was deleted.

2 changes: 1 addition & 1 deletion alphafold3/pairformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.utils.checkpoint import checkpoint_sequential
from typing import Tuple, Optional
import torch
from alphafold3.model import (
from model import (
FeedForward,
AxialAttention,
TriangleMultiplicativeModule,
Expand Down
6 changes: 3 additions & 3 deletions diffusion_example.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
from alphafold3.diffusion import GeneticDiffusionModule
from alphafold3.diffusion import GeneticDiffusionModuleBlock

# Create an instance of the GeneticDiffusionModule
model = GeneticDiffusionModule(channels=3, training=True)
# Create an instance of the GeneticDiffusionModuleBlock
model = GeneticDiffusionModuleBlock(channels=3, training=True)

# Generate random input coordinates
input_coords = torch.randn(10, 100, 100, 3)
Expand Down
Loading

0 comments on commit c4ed607

Please sign in to comment.