Skip to content

Commit

Permalink
[BUFG][GeneticDiffusionModule]
Browse files Browse the repository at this point in the history
  • Loading branch information
kyegomez committed May 8, 2024
1 parent 8d1d8cb commit 6f90817
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions alphafold3/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
self.channels = channels
self.num_diffusion_steps = num_diffusion_steps
self.training = training
self.noise_scale = nn.Parametr(
self.noise_scale = nn.Parameter(
torch.linspace(1.0, 0.01, num_diffusion_steps)
)
self.prediction_network = nn.Sequential(
Expand All @@ -56,19 +56,14 @@ def forward(self, x: Tensor = None, ground_truth: Tensor = None):
torch.Tensor: Output tensor of shape (batch_size, num_atoms, channels) with
denoised atom coordinates.
"""
batch_size, num_atoms, channels = x.shape
batch_size, num_nodes, num_nodes_two, num_features = x.shape
noisy_x = x.clone()

for step in range(self.num_diffusion_steps):
# Generate noise scaled by the noise level for the current step
noise_level = self.noise_scale[step]
noise = (
torch.randn(
batch_size, num_atoms, channels, device=x.device
)
* noise_level
)

noise = torch.randn_like(x) * noise_level

# Add noise to the input
noisy_x = x + noise

Expand All @@ -85,10 +80,14 @@ def forward(self, x: Tensor = None, ground_truth: Tensor = None):
# Example usage
if __name__ == "__main__":
model = GeneticDiffusionModule(
channels=3
channels=3, training=True
) # Assuming 3D coordinates
input_coords = torch.randn(
10, 100, 3
10, 100, 100, 3
) # Example with batch size 10 and 100 atoms
ground_truth = torch.randn(
10, 100, 100, 3
) # Example with batch size 10 and 100 atoms
output_coords = model(input_coords)
print(output_coords.shape) # Should be (10, 100, 3)
output_coords, loss = model(input_coords, ground_truth)
print(output_coords) # Should be (10, 100, 3)
print(loss) # Should be a scalar (MSE loss value

0 comments on commit 6f90817

Please sign in to comment.