Skip to content

Commit

Permalink
[cleanup]
Browse files Browse the repository at this point in the history
  • Loading branch information
kyegomez committed May 8, 2024
1 parent 6f90817 commit dd97234
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 15 deletions.
5 changes: 5 additions & 0 deletions alphafold3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from alphafold3.diffusion import GeneticDiffusionModule

__all__ = [
"GeneticDiffusionModule",
]
30 changes: 15 additions & 15 deletions alphafold3/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def forward(self, x: Tensor = None, ground_truth: Tensor = None):
# Generate noise scaled by the noise level for the current step
noise_level = self.noise_scale[step]
noise = torch.randn_like(x) * noise_level

# Add noise to the input
noisy_x = x + noise

Expand All @@ -77,17 +77,17 @@ def forward(self, x: Tensor = None, ground_truth: Tensor = None):
return noisy_x


# Example usage
if __name__ == "__main__":
model = GeneticDiffusionModule(
channels=3, training=True
) # Assuming 3D coordinates
input_coords = torch.randn(
10, 100, 100, 3
) # Example with batch size 10 and 100 atoms
ground_truth = torch.randn(
10, 100, 100, 3
) # Example with batch size 10 and 100 atoms
output_coords, loss = model(input_coords, ground_truth)
print(output_coords) # Should be (10, 100, 3)
print(loss) # Should be a scalar (MSE loss value
# # Example usage
# if __name__ == "__main__":
# model = GeneticDiffusionModule(
# channels=3, training=True
# ) # Assuming 3D coordinates
# input_coords = torch.randn(
# 10, 100, 100, 3
# ) # Example with batch size 10 and 100 atoms
# ground_truth = torch.randn(
# 10, 100, 100, 3
# ) # Example with batch size 10 and 100 atoms
# output_coords, loss = model(input_coords, ground_truth)
# print(output_coords) # Should be (10, 100, 3)
# print(loss) # Should be a scalar (MSE loss value

0 comments on commit dd97234

Please sign in to comment.