# Diffusion Model for Predicting Molecular coordinates


## Import the necessary library

In [44]:
import dgl
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from dgl.data.utils import split_dataset
from dgl.data import QM9EdgeDataset
import dgl.nn.pytorch as dglnn
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

from e3nn import o3
from e3nn.o3 import FullyConnectedTensorProduct
from e3nn.nn import Gate


## Load Data
Load the data from DGL. Using QM9EdgeDataset as we are only using the 3D coordinates and the sequence attributes.

- Todo: split the data into train, val, test

In [35]:
qm9_data = QM9EdgeDataset()
print(qm9_data[0][0].ndata)
print(len(qm9_data))


Done loading data from cached files.
{'pos': tensor([[-1.2700e-02,  1.0858e+00,  8.0000e-03],
        [ 2.2000e-03, -6.0000e-03,  2.0000e-03],
        [ 1.0117e+00,  1.4638e+00,  3.0000e-04],
        [-5.4080e-01,  1.4475e+00, -8.7660e-01],
        [-5.2380e-01,  1.4379e+00,  9.0640e-01]]), 'attr': tensor([[0., 1., 0., 0., 0., 6., 0., 0., 0., 0., 4.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])}
130831


In [37]:
# Preprocess dataset to get graphs, node features, edge features, and 3D coordinates
def process_qm9_data(data):
    graphs, node_features, coordinates = [], [], []
    for i in range(len(data)):
        g = data[i][0]  # DGLGraph for the molecule
        coords = g.ndata['pos']  # 3D coordinates of atoms
        
        # Append graph and features
        graphs.append(g)
        node_features.append(g.ndata['attr'])
        coordinates.append(coords)
    
    return graphs, node_features, coordinates

graphs, node_features, coordinates = process_qm9_data(qm9_data)


In [38]:
print(node_features[0])

tensor([[0., 1., 0., 0., 0., 6., 0., 0., 0., 0., 4.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])


In [39]:

# Custom collate function to handle DGLGraphs in the DataLoader
def collate_fn(batch):
    graphs, node_features, coordinates = map(list, zip(*batch))
    
    # Batch graphs
    batched_graph = dgl.batch(graphs)
    
    # Concatenate node and edge features along the batch dimension
    batched_node_features = torch.cat(node_features, dim=0)
    batched_coordinates = torch.cat(coordinates, dim=0)
    
    return batched_graph, batched_node_features, batched_coordinates


In [40]:
class QM9Dataset(Dataset):
    def __init__(self, graphs, node_features, coordinates):
        self.graphs = graphs
        self.node_features = node_features
        self.coordinates = coordinates
    
    def __len__(self):
        return len(self.graphs)
    
    def __getitem__(self, idx):
        return (self.graphs[idx], self.node_features[idx], self.coordinates[idx])

# Initialize QM9 dataset and DataLoader
qm9_dataset = QM9Dataset(graphs, node_features, coordinates)


In [45]:
class SE3EquivariantLayer(nn.Module):
    def __init__(self, input_irreps, output_irreps, hidden_dim=128):
        """
        Initialize an SE(3)-equivariant layer.
        
        Args:
            input_irreps (o3.Irreps): Irreducible representations of input features.
            output_irreps (o3.Irreps): Irreducible representations of output features.
            hidden_dim (int): Hidden dimension of fully connected layers.
        """
        super().__init__()

        # Define equivariant fully connected layer
        self.fc_tp = FullyConnectedTensorProduct(input_irreps, o3.Irreps(f"{hidden_dim}x0e"), output_irreps)
        
        # Nonlinear activation - gate structure to control outputs
        scalar_irreps = o3.Irreps("0e")  # Scalar irreps
        self.gate = Gate(scalar_irreps, [nn.SiLU()], output_irreps)

    def forward(self, features, edge_vectors):
        """
        Forward pass for SE(3)-equivariant layer.
        
        Args:
            features (torch.Tensor): Input features (node embeddings).
            edge_vectors (torch.Tensor): Edge vectors in 3D space, for relative distances.

        Returns:
            torch.Tensor: SE(3)-equivariant output features.
        """
        # Step 1: Fully connected tensor product
        x = self.fc_tp(features, edge_vectors)
        
        # Step 2: Nonlinearity (gating mechanism)
        x = self.gate(x)
        
        return x


In [48]:
class EquivariantDiffusionProcess(nn.Module):
    def __init__(self, timesteps):
        super().__init__()
        self.timesteps = timesteps

    def forward_diffusion(self, coords, t):
        """
        Forward diffusion process to add noise to 3D coordinates.
        
        Args:
            coords (torch.Tensor): Original coordinates.
            t (int): Current timestep.

        Returns:
            torch.Tensor: Noisy coordinates.
        """
        noise = torch.randn_like(coords)
        alpha_t = torch.exp(-0.5 * t / self.timesteps)  # Scaling factor
        noisy_coords = alpha_t * coords + (1 - alpha_t**2).sqrt() * noise
        return noisy_coords

    def reverse_denoising(self, noisy_coords, predicted_noise, t):
        """
        Reverse diffusion to progressively denoise coordinates.
        
        Args:
            noisy_coords (torch.Tensor): Noisy coordinates at timestep t.
            predicted_noise (torch.Tensor): Noise predicted by the model.
            t (int): Current timestep.

        Returns:
            torch.Tensor: Partially denoised coordinates.
        """
        alpha_t = torch.exp(-0.5 * t / self.timesteps)
        return (noisy_coords - (1 - alpha_t**2).sqrt() * predicted_noise) / alpha_t


class EquivariantDiffusionModel(nn.Module):
    def __init__(self, node_feat_dim, hidden_dim, num_layers=4):
        super().__init__()
        self.num_layers = num_layers

        # Define irreps for input and output features
        self.input_irreps = o3.Irreps(f"{node_feat_dim}x0e")  # Scalars (atomic numbers, etc.)
        print(self.input_irreps)
        self.output_irreps = o3.Irreps("1x1o")  # Output irreps as vectors for 3D coordinates
        self.hidden_irreps = o3.Irreps(f"{hidden_dim}x0e + {hidden_dim}x1o")  # Scalars + vectors

        # Create equivariant layers
        self.layers = nn.ModuleList([
            SE3EquivariantLayer(self.input_irreps, self.hidden_irreps)
            for _ in range(self.num_layers)
        ])

        # Final layer to predict noise in coordinates
        self.final_layer = SE3EquivariantLayer(self.hidden_irreps, self.output_irreps)

    def forward(self, g, node_features, noisy_coords):
        """
        Forward pass for the diffusion model.

        Args:
            g (DGLGraph): Graph structure.
            node_features (torch.Tensor): Node features for each atom.
            noisy_coords (torch.Tensor): Noisy 3D coordinates.

        Returns:
            torch.Tensor: Predicted noise for denoising.
        """
        x = node_features
        edge_vectors = noisy_coords[g.edges()[1]] - noisy_coords[g.edges()[0]]  # Relative positions
        
        # Pass through equivariant layers
        for layer in self.layers:
            x = layer(x, edge_vectors)

        # Final prediction for noise
        predicted_noise = self.final_layer(x, edge_vectors)
        return predicted_noise


In [42]:
# Hyperparameters
epochs = 100        # Number of epochs
batch_size = 16     # Batch size
learning_rate = 1e-4  # Learning rate

dataset_size = len(qm9_dataset)
train_size = int(0.8 * dataset_size)  # 80% for training
val_size = int(0.1 * dataset_size)    # 10% for validation
test_size = dataset_size - train_size - val_size  # Remaining 10% for test

train_dataset, val_dataset, test_dataset = random_split(qm9_dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

In [49]:
# Instantiate the model and diffusion process
node_feat_dim = 11  # Adjust based on dataset
hidden_dim = 128
timesteps = 1000
model = EquivariantDiffusionModel(node_feat_dim=node_feat_dim, hidden_dim=hidden_dim)
diffusion_process = EquivariantDiffusionProcess(timesteps=timesteps)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

11x0e


TypeError: __init__() missing 2 required positional arguments: 'act_gates' and 'irreps_gated'

In [43]:
# Training loop
for epoch in range(epochs):
    print(f"Train epoch {epoch}:")
    model.train()
    epoch_loss = 0
    
    for param in model.parameters():
        param.requires_grad = True
    
    for batch in train_loader:
        g, node_features, true_coords = batch  # Unpack batch (graph, features, coordinates)
        
        # Ensure node features and true_coords have requires_grad=True
        node_features = node_features.requires_grad_(True)
        true_coords = true_coords.requires_grad_(True)
        # Sample a random timestep t
        t = torch.randint(0, diffusion_process.timesteps, (1,)).item()

        # Add noise to the true coordinates (forward diffusion process)
        noisy_coords = diffusion_process.forward_diffusion(true_coords, t)
        noisy_coords = noisy_coords.requires_grad_(True)
        
        # Predict the denoised coordinates from the noisy ones
        denoised_features, denoised_coords = model(g, node_features, noisy_coords)
        
        # Calculate the loss as the mean squared error between predicted and actual noise
        noise = noisy_coords - true_coords  # Calculate actual noise added
        loss = F.mse_loss(denoised_coords, noise)  # Denoising loss
        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    # Print epoch summary
    print(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss / len(train_loader)}")
    
    # validation step
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            g, node_features, true_coords = batch
            
            # Sample a random timestep t for validation
            t = torch.randint(0, diffusion_process.timesteps, (1,)).item()

            noisy_coords = diffusion_process.forward_diffusion(true_coords, t)
            denoised_features, denoised_coords = model(g, node_features, noisy_coords)

            noise = noisy_coords - true_coords
            val_loss += F.mse_loss(denoised_coords, noise).item()

    print(f"Epoch {epoch + 1}/{epochs}, Validation Loss: {val_loss / len(val_loader)}")
print("Training complete.")


Train epoch 0:
Epoch 1/100, Loss: 2.970868894974415
Epoch 1/100, Validation Loss: 2.9605988322376913
Train epoch 1:
Epoch 2/100, Loss: 2.9705692175439795
Epoch 2/100, Validation Loss: 2.9605988322376913
Train epoch 2:
Epoch 3/100, Loss: 2.9706233970909737
Epoch 3/100, Validation Loss: 2.9605988322376913
Train epoch 3:
Epoch 4/100, Loss: 2.970562763477968
Epoch 4/100, Validation Loss: 2.9605988322376913
Train epoch 4:
Epoch 5/100, Loss: 2.9706947670111195
Epoch 5/100, Validation Loss: 2.9605988322376913
Train epoch 5:
Epoch 6/100, Loss: 2.9707352902611546
Epoch 6/100, Validation Loss: 2.9605988322376913
Train epoch 6:
Epoch 7/100, Loss: 2.9707057800033487
Epoch 7/100, Validation Loss: 2.9605988322376913
Train epoch 7:
Epoch 8/100, Loss: 2.970980544332261
Epoch 8/100, Validation Loss: 2.9605988322376913
Train epoch 8:
Epoch 9/100, Loss: 2.970757808535152
Epoch 9/100, Validation Loss: 2.9605988322376913
Train epoch 9:
Epoch 10/100, Loss: 2.97076145066868
Epoch 10/100, Validation Loss: 2.9

KeyboardInterrupt: 

In [41]:
# # Define the equivariant layer (you might need a custom SE(3)-equivariant layer or use an existing library for SE(3) equivariance)
# class EquivariantLayer(nn.Module):
#     def __init__(self, in_dim, out_dim):
#         super(EquivariantLayer, self).__init__()
#         self.layer = dglnn.GraphConv(in_dim, out_dim)  # Placeholder, should be SE(3) equivariant
        
#     def forward(self, g, features, coords):
#         h = self.layer(g, features)  # apply graph convolution
#         # Apply transformations on coordinates for equivariance if needed
#         return h, coords

# # Define the Denoising Model
# class EquivariantDiffusionModel(nn.Module):
#     def __init__(self, node_feat_dim, hidden_dim):
#         super(EquivariantDiffusionModel, self).__init__()
#         self.node_encoder = nn.Linear(node_feat_dim, hidden_dim)
#         self.equiv_layers = nn.ModuleList([EquivariantLayer(hidden_dim, hidden_dim) for _ in range(3)])
#         self.node_decoder = nn.Linear(hidden_dim, node_feat_dim)
    
#     def forward(self, g, node_features, coords):
#         # Encode node and edge features
#         h = F.relu(self.node_encoder(node_features))
        
#         for equiv_layer in self.equiv_layers:
#             h, coords = equiv_layer(g, h, coords)
        
#         # Decode to original feature dimension
#         denoised_features = self.node_decoder(h)
#         return denoised_features, coords

# # Sample Diffusion Process for 3D Denoising
# class DiffusionProcess:
#     def __init__(self, timesteps, beta_start=0.0001, beta_end=0.02):
#         self.timesteps = timesteps
#         self.beta = torch.linspace(beta_start, beta_end, timesteps)
    
#     def forward_diffusion(self, coords, t):
#         noise = torch.randn_like(coords)
#         return coords * (1 - self.beta[t]).sqrt() + noise * self.beta[t].sqrt()
    
#     def reverse_denoising(self, model, g, features, coords, t):
#         denoised_features, denoised_coords = model(g, features, coords)
#         # Implement denoising step based on learned prediction and noise schedule
#         return denoised_coords

