In [6]:
import numpy as np # type: ignore
import torch # type: ignore
import torch.nn as nn # type: ignore 
import pandas as pd # type: ignore
import random

In [24]:
def unsorted_segment_sum(data, segment_ids, num_segments):
    aggregation = data.new_full((num_segments, data.size(1)), 0)
    segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
    aggregation.scatter_add_(dim=0, index=segment_ids, src=data)
    return aggregation

In [30]:
class EGCL(nn.Module):
    '''
    Equivariant Graph Convolutional Layer
    '''

    def __init__(self, input_node_features, hidden_node_features, output_node_features, activation_function=nn.SiLU(), include_residual=True, use_normalization=True):
        super(EGCL, self).__init__()
        self.radial_num_features = 1
        self.input_node_features = input_node_features
        self.hidden_node_features = hidden_node_features
        self.output_node_features = output_node_features
        self.include_residual = include_residual
        self.use_normalization = use_normalization
        self.activation_function = activation_function

    def coord_to_radial(self, r, edge_index):
        ''' 
        INPUT: r (M x 3): coordinate matrix containing 3D coordinates for each of the M atoms in the point cloud
               edge_index (2 x num_edges): 2 lists containing corresponding source and target atoms
        
        OUTPUT: radials (num_edges x 1): tensor containing distance radials for each pair of connected atoms
                coordinate_differences (num_edges x 3): tensor containing coordinate differences for each pair of connected atoms
        '''
        source_node_indices, target_node_indices = edge_index[0], edge_index[1]
        coordinate_differences = r[source_node_indices] - r[target_node_indices]
        # sum the squared differences of the each pair of corresponding distance components
        radials = torch.sum(coordinate_differences ** 2, dim=1).unsqueeze(1) # use unsqueeze to keep the dimension that torch.sum gets rid of
        return coordinate_differences, radials, torch.sqrt(radials)


    def forward(self, r, h, u_mask, edge_index):
        '''
        r (M x 3): coordinate matrix containing coordinates for each of the M atoms of the point cloud
        h (M x node_features): feature matrix containing feature embeddings for each fo the M atoms of the point cloud
        '''

        coordinate_differences, radials, radials_sqrt = self.coord_to_radial(r, edge_index)


        #----------EDGE MODEL----------
        e_mlp = None
        if self.use_normalization == False:
            e_mlp = nn.Sequential(
                nn.Linear(self.input_node_features * 2 + self.radial_num_features, self.hidden_node_features),
                self.activation_function,
                nn.Linear(self.hidden_node_features, self.hidden_node_features),
            )
        else:
            e_mlp = nn.Sequential(
                nn.BatchNorm1d(self.input_node_features * 2 + self.radial_num_features),
                nn.Linear(self.input_node_features * 2 + self.radial_num_features, self.hidden_node_features),
                self.activation_function,
                nn.BatchNorm1d(self.hidden_node_features),
                nn.Linear(self.hidden_node_features, self.hidden_node_features),
            )
        

        source_node_indices, target_node_indices = edge_index[0], edge_index[1]
        source_node_h_embeddings = h[source_node_indices]
        target_node_h_embeddings = h[target_node_indices]

        e_mlp_input = torch.cat([source_node_h_embeddings, target_node_h_embeddings, radials], dim=1)
        m = e_mlp(e_mlp_input)


        #----------NODE FEATURE MODEL----------
        h_mlp = None
        if self.use_normalization == False:
            h_mlp = nn.Sequential(
                nn.Linear(self.input_node_features + self.hidden_node_features),
                self.activation_function,
                nn.Linear(self.hidden_node_features, self.output_node_features)
            )
        else:
            h_mlp = nn.Sequential(
                nn.BatchNorm1d(self.input_node_features + self.hidden_node_features),
                nn.Linear(self.input_node_features + self.hidden_node_features, self.hidden_node_features),
                self.activation_function,
                nn.BatchNorm1d(self.hidden_node_features),
                nn.Linear(self.hidden_node_features, self.output_node_features)
            )

        # implement aggregation to perform message passing between nodes
        agg = unsorted_segment_sum(data=m, segment_ids=source_node_indices, num_segments=h.size(0))
        h_mlp_input = torch.cat([h, agg])
        h_updated = h_mlp(h_mlp_input)
        
        if self.include_residual == True:
            h_updated = h + h_updated
        h = h_updated

        #----------COORDINATE MODEL----------
        r_mlp = None
        if self.use_normalization == False:
            r_mlp = nn.Sequential(
                nn.Linear(self.input_node_features * 2 + self.radial_num_features, self.hidden_node_features),
                self.activation_function,
                nn.Linear(self.hidden_node_features, 1)
            )
        else:
            r_mlp = nn.Sequential(
                nn.BatchNorm1d(self.input_node_features * 2 + self.radial_num_features),
                nn.Linear(self.input_node_features * 2 + self.radial_num_features, self.hidden_node_features),
                self.activation_function,
                nn.BatchNorm1d(self.hidden_node_features),
                nn.Linear(self.hidden_node_features, 1)
            )
        
        r_mlp_input = e_mlp_input
        r_mlp_output = r_mlp(r_mlp_input)

        vel_agg_components = (coordinate_differences / (radials_sqrt + 1)) * r_mlp_output
        vel_agg = unsorted_segment_sum(data=vel_agg_components, segment_ids=source_node_indices, num_segments=r.size(0))
        r = r + vel_agg * u_mask


        return r, h

        

In [31]:
class EGNN(nn.Module):
    def __init__(self, input_node_features, hidden_node_features, output_node_features, num_layers=4, use_residual=True, use_normalization=True):
        super(EGNN, self).__init__()

        self.input_embedding = nn.Linear(input_node_features, hidden_node_features)
        self.output_embedding = nn.Linear(hidden_node_features, output_node_features)
        self.num_layers = num_layers

        for l in range(num_layers):
            self.add_module(f"EGCL{l}", EGCL(
                input_node_features=hidden_node_features, 
                hidden_node_features=hidden_node_features, 
                output_node_features=hidden_node_features, 
                include_residual=use_residual, 
                use_normalization=use_normalization
            ))

    def forward(self, r, h, u_mask, edge_index):
        h = self.input_embedding(h)
        for l in range(self.num_layers):
            r, h = self._modules[f"EGCL{l}"](r, h, u_mask, edge_index)
        h = self.output_embedding(h)
        
        return r, h


In [None]:
class LinkerSizeGCL(nn.Module):
    def __init__(self, input_node_features, hidden_node_features, output_node_features, use_residual=True, use_normalization=True, activation_function=nn.ReLU()):
        super(LinkerSizeGCL, self).__init__()
        self.num_edge_attribute_features = 1
        self.use_residual = use_residual
        
        self.e_mlp = None
        if use_normalization == False:
            self.e_mlp = nn.Sequential(
                nn.Linear(2 * input_node_features + self.num_edge_attribute_features, hidden_node_features),
                activation_function,
                nn.Linear(hidden_node_features, hidden_node_features)
            )
        else:
            self.e_mlp = nn.Sequential(
                nn.BatchNorm1d(2 * input_node_features + self.num_edge_attribute_features),
                nn.Linear(2 * input_node_features + self.num_edge_attribute_features, hidden_node_features),
                activation_function,
                nn.BatchNorm1d(hidden_node_features),
                nn.Linear(hidden_node_features, hidden_node_features)
            )

        self.h_mlp = None
        if use_normalization == False:
            self.h_mlp = nn.Sequential(
                nn.Linear(input_node_features + hidden_node_features, hidden_node_features),
                activation_function,
                nn.Linear(hidden_node_features, output_node_features)
            )
        else:
            self.h_mlp = nn.Sequential(
                nn.BatchNorm1d(input_node_features + hidden_node_features),
                nn.Linear(input_node_features + hidden_node_features, hidden_node_features),
                activation_function,
                nn.BatchNorm1d(hidden_node_features),
                nn.Linear(hidden_node_features, output_node_features)
            )

    def forward(self, a, h, edge_index):
        # a is the edge feature matrix. it only has one column for each edge containing the distance for that edge
        source_node_indices, target_node_indices = edge_index[0], edge_index[1]
        source_node_embeddings, target_node_embeddings = h[source_node_indices], h[target_node_indices]
        
        e_mlp_input = torch.cat([source_node_embeddings, target_node_embeddings, a], dim=1)
        m = self.h_mlp(e_mlp_input)

        agg = unsorted_segment_sum(data=m, segment_ids=source_node_indices, num_segments=h.size(0))
        h_mlp_input = torch.cat([h, agg], dim=1)
        h_updated = self.h_mlp(h_mlp_input)
        if self.use_residual == True:
            h_updated = h + h_updated
        h = h_updated

        return h


In [34]:
class LinkerSizeGNN(nn.Module):
    def __init__(self, input_node_features, hidden_node_features, output_node_features, use_normalization=True, use_residual=True, num_layers=4, activation_function=nn.ReLU()):
        super(LinkerSizeGNN, self).__init__()
        self.num_layers = num_layers 
        self.input_embedding = nn.Linear(input_node_features, hidden_node_features)
        self.output_embedding = nn.Linear(hidden_node_features, output_node_features)

        for l in range(num_layers):
            self.add_module(f"LinkerSizeGCL{l}", LinkerSizeGCL(
                input_node_features=hidden_node_features,
                hidden_node_features=hidden_node_features,
                output_node_features=hidden_node_features,
                use_residual=use_residual,
                use_normalization=use_normalization,
                activation_function=activation_function
            ))

    def forward(self, a, h, edge_index):
        h = self.input_embedding(h)
        for l in range(self.num_layers):
            h = self._modules[f"LinkerSizeGCL{l}"](a, h, edge_index)
        h = self.output_embedding(h)
        return h
        

In [40]:
class DiffusionNoiseModel(nn.Module):
    def __init__(self, input_node_features, hidden_node_features, output_node_features, num_layers=8, use_residual=True, use_normalization=True):
        super(DiffusionNoiseModel, self).__init__()
        # rh passed first (concatenation of r and h along dim=1, meaning rh.shape[1] = r.shape[1] + h.shape[1] = 3 + h.shape[1])
        self.egnn = EGNN(
            input_node_features=input_node_features,
            hidden_node_features=hidden_node_features,
            output_node_features=output_node_features,
            num_layers=num_layers,
            use_residual=use_residual,
            use_normalization=use_normalization
        )

    def forward(self, rh, u_mask, t, edge_index):
        # rh has shape (batch_size, num_nodes, num_features)

        # prep the data to be fed into the EGNN
        batch_size, num_nodes, input_features = rh.shape[0], rh.shape[1], rh.shape[2]
        rh = rh.reshape(batch_size * num_nodes, input_features)
        u_mask = u_mask.reshape(batch_size * num_nodes, input_features)
        r, h = rh[:, :3], rh[:, 3:]
        h = torch.cat([h, t], dim=1)
        
        # run the EGNN
        r, h = self.egnn(r, h, u_mask, edge_index)
        # discard the time dimension and context nodes
        rh = torch.cat([r, h[:, :-1]], dim=1)[u_mask].reshape(batch_size, num_nodes, -1)
        
        return rh    

In [None]:
def train(X, u_mask, diffusion_model, edge_index, num_epochs=300, learning_rate=2e-5, batch_size=128, s=1e-5, T=500):
    # Split the training data into mini-batches
    shuffled_indices = torch.randperm(X.shape[0])
    X = X[shuffled_indices, :, :]
    u_mask = u_mask[shuffled_indices, :, :]
    num_training_examples = X.shape[0]
    num_batches = np.ceil(num_training_examples / batch_size)
    X_batches, u_mask_batches = [], []
    for curr_batch_number in range(num_batches):
        start_index = curr_batch_number * batch_size
        end_index = min((curr_batch_number + 1) * batch_size, num_training_examples)
        X_batches.append(X[start_index:end_index, :, :])
        u_mask_batches.append(u_mask[start_index:end_index, :, :])
    
    # Train the diffusion model
    curr_epoch_number, batch_index = 0, 0
    while(curr_epoch_number < num_epochs):
        curr_X_batch = X_batches[batch_index]
        curr_u_mask_batch = u_mask_batches[batch_index]
        batch_index = (batch_index + 1) % num_training_examples
        if(batch_index == 0): 
            curr_epoch_number += 1
        
        t = torch.randint(low=0, high=T, size=(batch_size, 1)) # (batch_size, 1)
        e_t = torch.randn_like(curr_X_batch)
        alpha_t = (1 - 2 * s) * (1 - (t / T) ** 2) # (batch_size, 1)
        sigma_t = 1 - alpha_t # (batch_size, 1)
        z_t = alpha_t * curr_X_batch + sigma_t * e_t
        predicted_e_t = diffusion_model(curr_X_batch, curr_u_mask_batch, t, edge_index)
        

        


    