In [17]:
import torch
import numpy as np
import os
import glob

from tensorboardX import SummaryWriter
from test_tube import Experiment
from rdkit import Chem
from rdkit.Chem import AllChem

ModuleNotFoundError: No module named 'rdkit'

In [2]:
# TODO : try to fix masking (true mask and virtual node)
# Naming of variables (ie original nodes embed vs nodes embed), keep the same terminology

In [11]:
# 3D Coordinates autoencoder model
class CoordAE(torch.nn.Module):

    def __init__(self, n_max, dim_node, dim_edge, hidden_node_dim, dim_f, \
                batch_size, \
                mpnn_steps=5, alignment_type='default', tol=1e-5, \
                use_X=True, use_R=True, virtual_node=False, seed=0, \
                refine_steps=0, refine_mom=0.99):
        
        super(CoordAE, self).__init__()
        
        # set random seed
        np.random.seed(seed)
        torch.manual_seed(0)
        
        self.mpnn_steps = mpnn_steps
        self.n_max = n_max
        self.dim_node = dim_node
        self.dim_edge = dim_edge
        self.hidden_node_dim = hidden_node_dim
        self.dim_f = dim_f
        self.batch_size = batch_size
        self.tol = tol
        self.virtual_node = virtual_node
        self.refine_steps = refine_steps
        self.refine_mom = refine_mom
        self.use_X = use_X
        self.use_R = use_R
            
        #self.G not included
        #placeholders not included
        #prior_T not included
        
        # find a way to define self.mask
        
        # ADD SELF IN ARGS FOR FOLLOWING LINES
        self.embed_nodes = EmbedNode(batch_size, n_max, node_dim, hidden_dim)
        
        # Prior Z
        self.edge_nn_prior_z = EdgeNN(batch_size, n_max, edge_dim + 1, hidden_dim)
        self.mpnn_prior_z = MPNN(batch_size, n_max, hidden_dim, message_size, mpnn_steps)
        self.latent_nn_prior_z = LatentNN(batch_size, n_max, hidden_dim, dim_f, 2*hidden_node_dim)
        
        # Post Z
        if use_R :
            self.edge_nn_post_z = EdgeNN(batch_size, n_max, edge_dim + 2, hidden_dim)
        else :
            self.edge_nn_post_z = EdgeNN(batch_size, n_max, edge_dim + 1, hidden_dim)
            
        if use_X:
            self.embed_nodes_pos = EmbedNode(batch_size, n_max, node_dim + 3, hidden_dim)
            
        self.mpnn_post_z = MPNN(batch_size, n_max, hidden_dim, message_size, mpnn_steps)
        self.latent_nn_post_z = LatentNN(batch_size, n_max, hidden_dim, dim_f, 2*hidden_node_dim)
        
        # Post X
        self.edge_nn_post_x = EdgeNN(batch_size, n_max, edge_dim + 1, hidden_dim)
        self.mpnn_post_x = MPNN(batch_size, n_max, hidden_dim, message_size, mpnn_steps)
        self.latent_nn_post_x = LatentNN(batch_size, n_max, hidden_dim, dim_f, 3)
        
        # Post X det
        self.edge_nn_post_x_det = EdgeNN(batch_size, n_max, edge_dim + 1, hidden_dim)
        self.mpnn_post_x_det = MPNN(batch_size, n_max, hidden_dim, message_size, mpnn_steps)
        self.latent_nn_post_x_det = LatentNN(batch_size, n_max, hidden_dim, dim_f, 3)
        
        # Pred X
        self.edge_nn_pred_x = EdgeNN(batch_size, n_max, edge_dim + 1, hidden_dim)
        self.mpnn_pred_x = MPNN(batch_size, n_max, hidden_dim, message_size, mpnn_steps)
        self.latent_nn_pred_x = LatentNN(batch_size, n_max, hidden_dim, dim_f, 3)
        
        
        
        
        def forward(self, nodes, edges, mask, pos, proximity) :
            
            """ Args :
                    nodes : Tensor(batch_size, n_max, node_feature_size)
                    edges : Tensor(batch_size, n_max, n_max, edge_feature_size)
                    mask : Tensor(batch_size, n_max, 1)
                    pos : Tensor(batch_size, n_max, 3)
                    proximity : Tensor(batch_size, n_max, n_max) is actually a distance matrix
                Returns :
                    
            """
            
            # TBD
            if self.virtual_node:
                mask = self.true_masks
            else:
                mask = self.mask
            
            nodes_embed = self.embed_nodes(node, mask) # (batch_size, n_max, hidden_node_dim)
            
            n_atom = mask.permute(0, 2, 1).sum(2) # (batch_size, 1)
            
            tiled_n_atom = n_atom.view(self.batch_size, 1, 1, 1).repeat(1, self.n_max, self.n_max, 1) # (batch_size, n_max, n_max, 1)
            
            # Isn't there a better way to add n_atom in edge features ?
            edge_2 = torch.cat([edges, tiled_n_atom], 3) # (batch_size, n_max, nmax, edge_feature_size + 1)
            
            
            # p(Z|G) -- prior of Z
            
            priorZ_edge_wgt = self.edge_nn_prior_z(edge_2) #[batch_size, n_max, n_max, hidden_node_dim, hidden_node_dim]
            priorZ_hidden = self.mpnn_prior_z(priorZ_edge_wgt, nodes_embed, mask) # (batch_size, n_max, hidden_node_dim), nodes_embed like
            priorZ_out = self.latent_nn_prior_z(priorZ_hidden, nodes_embed, mask) # (batch_size, n_max, 2*hidden_dim)
            
            priorZ_mu, priorZ_lsgms = priorZ_out.split([self.hidden_dim, self.hidden_dim], 2)
            priorZ_sample = self._draw_sample(priorZ_mu, priorZ_lsgms, mask)
            
            
            # q(Z|R(X),G) -- posterior of Z, used R instead of X as input for simplicity, should be updated
            
            if use_R:
                proximity_view = proximity.view(self.batch_size, self.n_max, self.n_max, 1)
                edge_cat = torch.cat([edge_2, proximity_view], 3) #[batch_size, n_max, n_max, edge_feature_size + 2]
                postZ_edge_wgt = self.edge_nn_post_z(edge_cat) #[batch_size, n_max, n_max, hidden_node_dim, hidden_node_dim]
            else:
                postZ_edge_wgt = self.edge_nn_post_z(self.edge_2) 

            if use_X:
                nodes_pos = torch.cat([nodes, pos], 2) # (batch_size, n_max, node_dim + 3)
                nodes_pos_embed = self.embed_nodes_pos(nodes_pos, mask)
                postZ_hidden = self.mpnn_post_z(postZ_edge_wgt, nodes_pos_embed, mask)
            else:
                postZ_hidden = self.mpnn_post_z(postZ_edge_wgt, nodes_embed, mask)
            
            postZ_out = self.latent_nn_prior_z(postZ_hidden, nodes_embed, mask)
            
            postZ_mu, postZ_lsgms = postZ_out.split([self.hidden_dim, self.hidden_dim], 2)
            postZ_sample = self._draw_sample(postZ_mu, postZ_lsgms, mask)
            
            
            # p(X|Z,G) -- posterior of X
            
            X_edge_wgt = self.edge_nn_post_x(edge_2) #[batch_size, n_max, n_max, dim_h, dim_h]
            X_hidden = self.mpnn_post_x(X_edge_wgt, postZ_sample + nodes_embed, mask)
            X_pred = self.latent_nn_post_x(X_hidden, nodes_embed, mask)
            
            
            # p(X|Z,G) -- posterior of X without sampling from latent space
            # used for iterative refinement of predictions ; det stands for deterministic
            
            X_edge_wgt_det = self.edge_nn_post_x_det(self.edge_2) #[batch_size, n_max, n_max, dim_h, dim_h]
            X_hidden_det = self.mpnn_post_x_det(X_edge_wgt_det, postZ_mu + nodes_embed, mask)
            X_pred_det = self.latent_nn_post_x_det(X_hidden_det, nodes_embed, mask)
            
            
            # Prediction of X with p(Z|G) in the test phase
            
            PX_edge_wgt = self.edge_nn_pred_x(edge_2) #[batch_size, n_max, n_max, dim_h, dim_h]
            PX_hidden = self.mpnn_pred_x(PX_edge_wgt, priorZ_sample + nodes_embed, mask)
            PX_pred = self.latent_nn_pred_x(PX_hidden, nodes_embed, mask)
            
            return postZ_mu, postZ_lsgms, priorZ_mu, priorZ_lsgms, X_pred, PX_pred
        
        def _draw_sample(self, mu, lsgms, mask):

            epsilon = torch.randn_like(lsgms)
            
            sample = torch.mul(torch.exp(0.5 * lsgms), epsilon)
            sample = torch.add(mu, sample)
            sample = torch.mul(sample, mask)

            return sample

In [4]:
# There may be a way to merge EmbedNode and EdgeNN

In [5]:
class EmbedNode(torch.nn.Module):
    
    def __init__(self, batch_size, n_max, node_dim, hidden_dim):
        
        super(EmbedNode, self).__init__()
        
        self.batch_size = batch_size
        self.n_max = n_max
        self.node_dim = node_dim
        self.hidden_dim = hidden_dim
        
        self.FC_hidden = nn.Linear(node_dim, hidden_dim)
        self.FC_output = nn.Linear(hidden_dim, hidden_dim)
        
    def forward(self, nodes, mask):

        """
            Args :
                nodes :  Tensor(batch_size, n_max, node_dim)
            Returns :
                nodes_embed : Tensor(batch_size, n_max, hidden_node_dim)
        """
        
        nodes_view = nodes.view(self.batch_size * self.n_max, nodes.shape[2])

        emb1 = torch.sigmoid(self.FC_hidden(nodes_view))
        emb2 = torch.tanh(self.FC_output(emb1))

        nodes_embed = emb2.view(emb2, [self.batch_size, self.n_max, self.hidden_dim])
        nodes_embed = torch.mul(nodes_embed, mask)

        return nodes_embed

In [6]:
class EdgeNN(torch.nn.Module):
    
    def __init__(self, batch_size, n_max, edge_dim, hidden_dim):
        
        super(EdgeNN, self).__init__()
        
        self.batch_size = batch_size
        self.n_max = n_max
        self.edge_dim = edge_dim
        self.hidden_dim = hidden_dim
        
        self.FC_hidden = nn.Linear(edge_dim, 2 * hidden_dim)
        self.FC_output = nn.Linear(2 * hidden_dim, hidden_dim * hidden_dim)
        
    def forward(self, edges):

        """
            Args :
                edges :  Tensor(batch_size, n_max, node_dim)
            Returns :
                edges_embed : Tensor(batch_size, n_max, n_max, hidden_dim, hidden_dim)
        """
        
        edges_view = edges.view(self.batch_size * self.n_max * self.n_max, nodes.shape[3])

        emb1 = torch.sigmoid(self.FC_hidden(edges_view))
        emb2 = torch.tanh(self.FC_output(emb1))

        edges_embed = emb2.view(emb2, [self.batch_size, self.n_max, self.n_max, self.hidden_dim, self.hidden_dim])

        return edges_embed

In [7]:
class MPNN(torch.nn.Module):
    
    def __init__(self, batch_size, n_max, hidden_dim, message_size, mpnn_steps):
        
        super(MPNN, self).__init__()
        
        self.batch_size = batch_size
        self.n_max = n_max
        self.hidden_dim = hidden_dim
        self.message_size = message_size
        self.mpnn_steps = mpnn_steps
        
        self.gru = torch.nn.GRUCell(input_size=message_size, hidden_size=hidden_dim)
        
    def forward(self, edge_wgt, nodes_embed, mask): 

        """
            Args :
                edge_wgt : Tensor(batch_size, n_max, n_max, hidden_dim, hidden_dim)
                nodes_embed : Tensor(batch_size, n_max, hidden_node_dim)
                mask : Tensor(batch_size, n_max, 1)
            Returns :
                nodes_embed : Tensor(batch_size, n_max, hidden_node_dim)
        """
        
        for i in range(self.mpnn_steps):
        
            messages = self.msg_nn(edge_wgt, nodes_embed) # (batch_size, n_max, hidden_node_dim)

            if true_mask and i == self.mpnn_steps - 1:
                nodes_embed = self.update_GRU(messages, nodes_embed, mask)
            else:
                nodes_embed = self.update_GRU(messages, nodes_embed)

        return nodes_embed
    
    def update_GRU(self, messages, nodes, mask):

        """
            Args :
                messages : Tensor(batch_size, n_max, hidden_dim)
                nodes : Tensor(batch_size, n_max, hidden_dim)
                mask : Tensor(batch_size, n_max, 1)
            Returns :
                nodes_next : Tensor(batch_size, n_max, hidden_dim)
        """

        messages = messages.view(self.batch_size * self.n_max, 1, self.hidden_dim) 
        nodes = nodes.view(self.batch_size * self.n_max, self.hidden_dim)

        nodes_next = self.gru(messages, nodes)

        nodes_next = nodes_next.view(self.batch_size, self.n_max, self.hidden_dim)
        nodes_next = torch.mul(nodes_next, mask) #TBD

        return nodes_next
    
    def compute_messages(self, edge_wgt, nodes) :
        
        """
            Args :
                edge_wgt : Tensor(batch_size, n_max, hidden_dim)
                nodes : Tensor(batch_size, n_max, hidden_node_dim)
            Returns :
                messages : Tensor(batch_size, n_max, hidden_node_dim)
        """
        
        weights = edge_wgt.view(self.batch_size * self.n_max, self.n_max * self.hidden_dim, self.hidden_dim)
        nodes = nodes.view(self.batch_size * self.n_max, self.hidden_dim, 1)

        messages = torch.matmul(weights, nodes)
        messages = messages.view(self.batch_size, self.n_max, self.n_max, self.hidden_dim)
        messages = messages.permute(0, 2, 3, 1)
        messages = messages.mean(3) / self.n_max

        return messages

In [8]:
class LatentNN(torch.nn.Module):
    
    def __init__(self, batch_size, n_max, hidden_dim, dim_f, outdim):
        
        super(LatentNN, self).__init__()
        
        self.batch_size = batch_size
        self.n_max = n_max
        self.hidden_dim = hidden_dim
        self.dim_f = dim_f
        self.outdim = outdim
        
        self.dropout1 = torch.nn.Dropout(0.2)
        self.FC_hidden = nn.Linear(hidden_dim, dim_f)
        self.dropout2 = torch.nn.Dropout(0.2)
        #self.FC_hidden2 = nn.Linear(dim_f, dim_f)
        self.FC_output = nn.Linear(dim_f, outdim)
        
    def forward(self, nodes_embed, original_nodes_embed, mask):

        """
            Args :
                nodes_embed :  Tensor(batch_size, n_max, hidden_dim)
                original_nodes_embed :  Tensor(batch_size, n_max, hidden_dim)
            Returns :
                nodes_embed : Tensor(batch_size, n_max, outdim)
        """
        
        nodes_cat = torch.cat([nodes_embed, original_nodes_embed], 2)
        nodes_cat = nodes_cat.view(self.batch_size * self.n_max, nodes_concat.shape[2])
        
        nodes_cat = self.dropout1(nodes_cat)
        nodes_cat = torch.sigmoid(self.FC_hidden(nodes_cat))
        nodes_cat = self.dropout2(nodes_cat)
        # nodes_cat = torch.sigmoid(self.FC_hidden2(nodes_cat))
        nodes_cat = self.FC_output(nodes_cat)
        
        nodes_cat = nodes_cat.view(self.batch_size, self.n_max, outdim)
        nodes_cat = torch.mul(nodes_cat, mask)

        return nodes_cat

In [160]:
# Add underscore to internal functions

class MSDScorer(object) :
    def __init__(self, alignment_type='linear', tol=1e-5):
        self.alignment_type = alignment_type
        self.tol = tol
        
        if alignment_type == 'linear':
            self.msd_func = self.linear_transform_msd
        elif alignment_type == 'kabsch':
            self.msd_func = self.kabsch_msd
        elif alignment_type == 'default':
            self.msd_func = self.mol_msd
        
    def score(self, X_pred, pos, mask=None) :
        return self.msd_func(X_pred, pos, mask)
        
    def kabsch_msd(self, frames, targets, masks):
        losses = []
        for i in range(batch_size):
            frame = frames[i]
            target = targets[i]
            mask = masks[i]
            target_cent = target - self.torch_centroid_masked(target, mask)
            frame_cent = frame - self.torch_centroid_masked(frame, mask)
            losses.append(self.torch_kabsch_rmsd_masked(target_cent.detach(), frame_cent, mask))

        loss = torch.stack(losses)
        return loss

    def optimal_rotational_quaternion(self, r):
        """Just need the largest eigenvalue of this to minimize RMSD over rotations

        References
        ----------
        [1] http://dx.doi.org/10.1002/jcc.20110
        """
        return [
            [r[0][0] + r[1][1] + r[2][2], r[1][2] - r[2][1], r[2][0] - r[0][2], r[0][1] - r[1][0]],
            [r[1][2] - r[2][1], r[0][0] - r[1][1] - r[2][2], r[0][1] + r[1][0], r[0][2] + r[2][0]],
            [r[2][0] - r[0][2], r[0][1] + r[1][0], -r[0][0] + r[1][1] - r[2][2], r[1][2] + r[2][1]],
            [r[0][1] - r[1][0], r[0][2] + r[2][0], r[1][2] + r[2][1], -r[0][0] - r[1][1] + r[2][2]],
        ]
    
    def squared_deviation(self, frame, target):
        """Calculate squared deviation (n_atoms * RMSD^2) from `frame` to `target`
        First we compute `R` which is the ordinary cross-correlation of xyz coordinates.
        Turns out you can do a bunch of quaternion math to find an eigen-expression for finding optimal
        rotations. There aren't quaternions in tensorflow, so we use the handy formula for turning
        quaternions back into 4-matrices. This is the `F` matrix. We find its leading eigenvalue
        to get the MSD after optimal rotation. Note: *finding* the optimal rotation requires the values
        and vectors, but we don't care.

        Parameters
        ----------
        frame, target : Tensor, shape=(n_atoms, 3)
            Calculate the MSD between these two frames

        Returns
        -------
        sd : Tensor, shape=(0,)
            Divide by number of atoms and take the square root for RMSD
        """
        R = torch.matmul(frame.T, target)
        R_parts = [torch.unbind(t) for t in torch.unbind(R)]
        F_parts = self.optimal_rotational_quaternion(R_parts)
        F = torch.Tensor(F_parts)
        vals, vecs = torch.symeig(F, eigenvectors=True)
        # This isn't differentiable for some godforsaken reason.
        # vals = tf.self_adjoint_eigvals(F, name='vals')
        lmax = torch.unbind(vals)[-1]
        sd = torch.sum(frame ** 2 + target ** 2) - 2 * lmax
        return sd
    
    # https://towardsdatascience.com/tensorflow-rmsd-using-tensorflow-for-things-it-was-not-designed-to-do-ada4c9aa0ea2
    # https://github.com/mdtraj/tftraj
    def mol_msd(self, frames, targets, masks):
        frames -= frames.mean(1, keepdim=True)
        targets -= targets.mean(1, keepdim=True)

        loss = torch.stack([self.squared_deviation( self.do_mask(frames[i], masks[i]), self.do_mask(targets[i], masks[i]) ) for i in range(batch_size)], 0)
        return loss / masks.sum((1,2))

    def linear_transform_msd(self, frames, targets, masks):
        def linearly_transform_frames(padded_frames, padded_targets):
            u, s, v = torch.svd(padded_frames)
            tol = 1e-7
            atol = s.max() * tol
            s = torch.masked_select(s, s > atol)
            s_inv = torch.diag(1. / s)
            pseudo_inverse = torch.matmul(v, torch.matmul(s_inv, u.T))

            weight_matrix = torch.matmul(padded_targets, pseudo_inverse)
            transformed_frames = torch.matmul(weight_matrix, padded_frames)
            return transformed_frames

        padded_frames = torch.nn.functional.pad(frames, (0, 1), 'constant', 1)
        padded_targets = torch.nn.functional.pad(targets, (0, 1), 'constant', 1)

        mask_matrices = []
        for i in range(batch_size):
            mask_matrix = torch.diag(masks[i].view(-1))
            mask_matrices.append(mask_matrix)
        #mask_matrix = tf.diag(tf.reshape(masks, [self.batch_size, -1]))
        mask_tensor = torch.stack(mask_matrices)
        masked_frames = torch.matmul(mask_tensor, padded_frames)
        masked_targets = torch.matmul(mask_tensor, padded_targets)
        transformed_frames = []
        for i in range(batch_size):
            transformed_frames.append(linearly_transform_frames(masked_frames[i], masked_targets[i]))
        transformed_frames = torch.stack(transformed_frames)
        #transformed_frames = linearly_transform_frames(masked_frames, masked_targets)
        mse_loss = torch.nn.MSELoss()
        loss = mse_loss(transformed_frames, masked_targets)

        return loss
        
    def torch_kabsch(self, P, Q):
        # calculate covariance matrix
        C = torch.matmul(P.T, Q)

        V, S, W = torch.svd(C, some=False)
        
        def adjoint(W) :
            W = torch.transpose(W, -2, -1)
            W = torch.conj(W)
            return W
        
        W = adjoint(W)

        # implement the following numpy ops in pytorch ; could be factorized
        # S[-1] = -S[-1]
        # V[:, -1] = -V[:, -1]
            
        m1 = torch.ones((3,), dtype=torch.float32)
        m1[-1] = -m1[-1]

        m2 = torch.ones((3,3), dtype=torch.float32)
        m2[:,-1] = -m2[:,-1]

        d = torch.det(V) * torch.det(W)
        S = torch.where(d < 0., S * m1, S)
        V = torch.where(d < 0., V * m2, V)
        # Rotation matrix U
        U = torch.matmul(V, W)
        return U

    # maybe I could implement a batch RMSE implementation
    # N could be handled by taking the number of rows different from [0 0 0] or [nan nan nan]
    # depending on how padding is handled
    
    def torch_rmsd_masked(self, V, W, N=None):
        """
        Compute the RMSD between the two coordinates matrices V and W.
        N (int) is the number of atoms having coordinates (selected via masking in the workflow)
        Args :
            V : Tensor(n_max, 3)
            W : Tensor(n_max, 3)
            N : int
        Returns :
            RMSD : float
        """
        if N is None :
            N = V.shape[0]
        SE = (V - W) ** 2 # SE = Squared Error
        MSE = SE.sum() / N.float() # MSE = Mean Squared Error
        return torch.sqrt(MSE) # RMSE = Mean Squared Error

    def torch_kabsch_rotate(self, P, Q):
        U = self.torch_kabsch(P, Q) # rotate matrix P
        return torch.matmul(P, U)

    def torch_kabsch_rmsd_masked(self, P, Q, mask=None):
        N = None
        if mask != None :
            N = mask.sum()
            mask_mat = torch.diag(mask.view((-1,)))
            P = torch.matmul(mask_mat, P) + self.tol
            Q = torch.matmul(mask_mat, Q) + self.tol
        P_transformed = self.torch_kabsch_rotate(P, Q)
        return self.torch_rmsd_masked(P_transformed, Q, N)

    def torch_centroid_masked(self, P, mask=None):
        N = P.shape[0]
        if mask != None : # mask P
            N = mask.sum()
            mask_mat = torch.diag(mask.view((-1,)))
            
            P = torch.matmul(mask_mat, P) + self.tol
        return P.sum(0, keepdim=True) / N.float()

    def do_mask(self, vec, mask):
        return vec[torch.gt(mask, 0.5).view((mask.shape[0],))]

In [None]:
def KLD(mu0, lsgm0, mu1, lsgm1, mask):
    """
    lsgm : log variance
    Args :
        mu0 : Tensor(batch_size, n_max, dim_h)
        lsgm0 : Tensor(batch_size, n_max, dim_h)
        mu1 : Tensor(batch_size, n_max, dim_h)
        lsgm1 : Tensor(batch_size, n_max, dim_h)
        mask : Tensor(batch_size, n_max, 1)
        
    Returns :
        kld : Tensor(batch_size, n_max, dim_h)
    """
    
    
    var0 = torch.exp(lsgm0)
    var1 = torch.exp(lsgm1)
    a = torch.div(var0 + 1e-5, var1 + 1e-5)
    b = torch.div(torch.square(torch.subtract(mu1, mu0)), var1 + 1e-5)
    c = torch.log(torch.div(var1 + 1e-5, var0 + 1e-5) + 1e-5)

    kld = 0.5 * torch.sum(a + b - 1 + c, 2, keepdim=True) * mask

    return kld


def KLD_zero(mu0, lsgm0, mask):
    """
    lsgm : log variance
    Args :
        mu0 : Tensor(batch_size, n_max, dim_h)
        lsgm0 : Tensor(batch_size, n_max, dim_h)
        mask : Tensor(batch_size, n_max, 1)
    
    Returns :
        kld : Tensor(batch_size, n_max, dim_h)
    """
    
    a = torch.exp(lsgm0) + torch.square(mu0)
    b = 1 + lsgm0

    kld = 0.5 * torch.sum(a - b, 2, keepdim=True) * mask

    return kld

In [None]:
train_dataset = MNIST(dataset_path, transform=mnist_transform, train=True, download=True)
test_dataset  = MNIST(dataset_path, transform=mnist_transform, train=False, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
test_loader  = DataLoader(dataset=test_dataset,  batch_size=batch_size, shuffle=True,  **kwargs)

In [None]:
# training step
# add model.train() optimizer zerograd ...

D1_t, D2_t, D3_t, D4_t, D5_t, MS_t, D1_v, D2_v, D3_v, D4_v, D5_v, MS_v # given by train script

batch_size = 20
save_path = None
train_event_path = None
valid_event_path = None
log_train_steps=100
tm_trn=None
tm_val=None
w_reg=1e-3
debug=False
exp=None # Experiment

model = CoordAE(n_max=50, dim_node=7, dim_edge=4, hidden_node_dim=15, dim_f=50,
                batch_size=24)

optimizer = Adam(model.parameters(), lr=3e-4)
msd_scorer = MSDScorer('default')

if exp is not None:
    data_path = exp.get_data_path(exp.name, exp.version)
    save_path = os.path.join(data_path, 'checkpoints/model.ckpt')
    event_path = os.path.join(data_path, 'event/')
    print(save_path, flush=True)
    print(event_path, flush=True)
    
if not debug:
    train_summary_writer = SummaryWriter(train_event_path)
    valid_summary_writer = SummaryWriter(valid_event_path)

# session
n_batch_val = int(len(D1_v)/batch_size)
np.set_printoptions(precision=5, suppress=True)

# training
print('::: start training')
num_epochs = 2500
valaggr_mean = np.zeros(num_epochs)
valaggr_std = np.zeros(num_epochs)

model.train()

train_loader = TODO

for epoch in range(num_epochs):

    [D1_t, D2_t, D3_t, D4_t, D5_t] = self._permutation([D1_t, D2_t, D3_t, D4_t, D5_t])

    trnscores = np.zeros((len(train_loader), 4))
    
    # CHANGE AND USE TRAIN LOADER
    for batch_idx, batch in enumerate(train_loader) :
        
        # batch to be created
        nodes, masks, edges, proximity, pos = batch
        
        postZ_mu, postZ_lsgms, priorZ_mu, priorZ_lsgms, X_pred, PX_pred = model(nodes, edges, masks, pos, proximity)
    
        cost_KLDZ = torch.mean(torch.sum(KLD(postZ_mu, postZ_lsgms, priorZ_mu, priorZ_lsgms, mask), (1, 2))) # posterior | prior
        cost_KLD0 = torch.mean(torch.sum(KLD_zero(priorZ_mu, priorZ_lsgms, mask), (1, 2))) # prior | N(0,1)

        #mask = self.true_masks if self.virtual_node else self.mask
        cost_X = torch.mean(msd_scorer.score(X_pred, pos, mask))

        cost_op = cost_X + cost_KLDZ + w_reg * cost_KLD0

        if debug:
            print (i, n_batch)
            print(trnresult, flush=True)

        # log results
        curr_iter = epoch * len(train_loader) + batch_idx # maybe it's len dataloader

        if not debug:
            if curr_iter % log_train_steps == 0:
                train_summary_writer.add_scalar("train/cost_op", trnresult[0], curr_iter)
                train_summary_writer.add_scalar("train/cost_X", trnresult[1], curr_iter)
                train_summary_writer.add_scalar("train/cost_KLDZ", trnresult[2], curr_iter)
                train_summary_writer.add_scalar("train/cost_KLD0", trnresult[3], curr_iter)

        assert np.sum(np.isnan(trnresult)) == 0
        trnscores[i,:] = trnresult
        
    print(np.mean(trnscores,0), flush=True)
    
    exp_dict = {}
    if exp is not None:
        exp_dict['training epoch id'] = epoch
        exp_dict['train_score'] = np.mean(trnscores,0)

    valscores_mean, valscores_std = self.test(D1_v, D2_v, D3_v, D4_v, D5_v, MS_v, \
                                    tm_v=tm_val, debug=debug)

    valaggr_mean[epoch] = valscores_mean
    valaggr_std[epoch] = valscores_std

    if not debug:
        valid_summary_writer.add_scalar("val/valscores_mean", valscores_mean, epoch)
        valid_summary_writer.add_scalar("val/min_valscores_mean", np.min(valaggr_mean[0:epoch+1]), epoch)
        valid_summary_writer.add_scalar("val/valscores_std", valscores_std, epoch)
        valid_summary_writer.add_scalar("val/min_valscores_std", np.min(valaggr_std[0:epoch+1]), epoch)

    print ('::: training epoch id {} :: --- val mean={} , std={} ; --- best val mean={} , std={} '.format(\
            epoch, valscores_mean, valscores_std, np.min(valaggr_mean[0:epoch+1]), np.min(valaggr_std[0:epoch+1])))
    
    if exp is not None:
        exp_dict['val mean'] = valscores_mean
        exp_dict['std'] = valscores_std
        exp_dict['best val mean'] = np.min(valaggr_mean[0:epoch+1])
        exp_dict['std of best val mean'] = np.min(valaggr_std[0:epoch+1])
        exp.log(exp_dict)
        exp.save()

    # keep track of the best model as well in the separate checkpoint
    # it is done by copying the checkpoint
    if valaggr_mean[epoch] == np.min(valaggr_mean[0:epoch+1]) and not debug:
        for ckpt_f in glob.glob(save_path + '*'):
            model_name_split = ckpt_f.split('/')
            model_path = '/'.join(model_name_split[:-1])
            model_name = model_name_split[-1]
            best_model_name = model_name.split('.')[0] + '_best.' + '.'.join(model_name.split('.')[1:])
            full_best_model_path = os.path.join(model_path, best_model_name)
            full_model_path = ckpt_f
            shutil.copyfile(full_model_path, full_best_model_path)


In [None]:
def pos_to_proximity(pos, mask):
    """ Args
            pos : Tensor(batch_size, n_max, 3)
            mask : Tensor(batch_size, n_max, 1)
            
        Returns :
            proximity : Tensor(batch_size, n_max, nmax)
    """

    pos_1 = pos.unsqueeze(2)
    pos_2 = pos.unsqueeze(1)

    pos_sub = torch.sub(pos_1, pos_2) #[batch_size, n_max, nmax, 3]
    proximity = torch.square(pos_sub)
    proximity = torch.sum(proximity, 3) #[batch_size, n_max, nmax]
    proximity = torch.sqrt(proximity + 1e-5)

    #proximity_view = torch.view(self.batch_size, self.n_max, self.n_max) I don't understand the rationale
    proximity = torch.mul(proximity, mask)
    proximity = torch.mul(proximity, mask.permute(0, 2, 1))

    # set diagonal of distance matrix to 0
    proximity[:, torch.arange(proximity.shape[1]), torch.arange(proximity.shape[2])] = 0

    return proximity

In [None]:
def getRMSD(reference_mol, positions, useFF=False):
    """
    Args :
        reference_mol : RDKit.Molecule
        positions : Tensor(n_atom, 3)
    """

    def optimizeWithFF(mol):

        mol = Chem.AddHs(mol, addCoords=True)
        AllChem.MMFFOptimizeMolecule(mol)
        mol = Chem.RemoveHs(mol)

        return mol

    n_atom = reference_mol.GetNumAtoms()

    test_cf = Chem.rdchem.Conformer(n_atom)
    for k in range(n_atom):
        test_cf.SetAtomPosition(k, positions[k].tolist())

    test_mol = copy.deepcopy(reference_mol)
    test_mol.RemoveConformer(0)
    test_mol.AddConformer(test_cf)

    if useFF:
        try:
            rmsd = AllChem.AlignMol(reference_mol, optimizeWithFF(test_mol))
        except:
            rmsd = AllChem.AlignMol(reference_mol, test_mol)
    else:
        rmsd = AllChem.AlignMol(reference_mol, test_mol)

    return rmsd

In [None]:
def test(test_loader, debug=False, savepred_path=None, savepermol=False, useFF=False)

    val_num_samples = 10 # number of conformers to draw from prior 

    # val batch size is different from train batch size since we use multiple samples
    val_batch_size = int(batch_size / val_num_samples) # number of molecules to draw conformers from ; ie 2
    n_batch_val = int(len(D1_v)/val_batch_size) # 1500 if D1_v = 3000
    assert ((batch_size % val_num_samples) == 0)
    assert (len(D1_v) % val_batch_size == 0)

    val_size = D1_v.shape[0]
    valscores_mean = np.zeros(val_size)
    valscores_std = np.zeros(val_size)

    if savepred_path != None:
        if not savepermol:
            pred_v = np.zeros(D1_v.shape[0], val_num_samples, n_max, 3)

    print ("testing model...")
    model.eval()

    with torch.no_grad():

        for batch_idx, batch in enumerate(test_loader):
            start_ = batch_idx * val_batch_size
            end_ = start_ + val_batch_size

            nodes, masks, edges, proximity, pos = batch #D1 D2 D3 D4 D5

            # repeat because we want val_batch_size (molecule) * val_num_samples (conformer per molecule)
            nodes = torch.repeat_interleave(nodes, val_num_samples, dim=0)
            masks = torch.repeat_interleave(masks, val_num_samples, dim=0)
            edges = torch.repeat_interleave(edges, val_num_samples, dim=0)
            proximity = torch.repeat_interleave(proximity, val_num_samples, dim=0)

            if debug:
                print (i, len(test_loader))

            _, _, _, _, _, PX_pred = model(nodes, edges, masks, pos, proximity)

            if savepred_path != None:
                if not savepermol:
                    pred_v[start_:end_] = PX_pred.view(val_batch_size, val_num_samples, self.n_max, 3)

            X_pred = PX_pred
            for r in range(self.refine_steps):
                if self.use_X:
                    pos = X_pred
                if self.use_R:
                    proximity = pos_to_proximity(X_pred, mask)
                _, _, _, _, last_X_pred, _ = model(nodes, edges, mask, pos, proximity)
                X_pred = self.refine_mom * X_pred + (1-self.refine_mom) * last_X_pred

            valrmsd=[]
            for j in range(X_pred.shape[0]):
                ms_v_index = int(j / self.val_num_samples) + start_
                rmsd = self.getRMSD(MS_v[ms_v_index], D5_batch_pred[j], useFF)
                valrmsd.append(rmsd)

            valrmsd = np.array(valrmsd)
            valrmsd = np.reshape(valrmsd, (val_batch_size, val_num_samples))
            valrmsd_mean = np.mean(valrmsd, axis=1)
            valrmsd_std = np.std(valrmsd, axis=1)

            valscores_mean[start_:end_] = valrmsd_mean
            valscores_std[start_:end_] = valrmsd_std

            # save results per molecule if request
            if savepermol:
                pred_curr = copy.deepcopy(X_pred).view(val_batch_size, val_num_samples, n_max, 3)
                for tt in range(0, val_batch_size):
                    save_dict_tt = {'rmsd': valrmsd[tt], 'pred': pred_curr[tt]}
                    pkl.dump(save_dict_tt, \
                        open(os.path.join(savepred_path, 'mol_{}_neuralnet.p'.format(tt+start_)), 'wb'))

        print ("val scores: mean is {} , std is {}".format(np.mean(valscores_mean), np.mean(valscores_std)))
        if savepred_path != None:
            if not savepermol:
                print ("saving neural net predictions into {}".format(savepred_path))
                pkl.dump(pred_v, open(savepred_path, 'wb'))

        return np.mean(valscores_mean), np.mean(valscores_std)

In [None]:
parser.add_argument('--loaddir', type=str, default=None)

In [14]:
# handled in args parse for a py script version

n_max = 50
dim_node = 35
dim_edge = 10
nval = 3000
ntst = 3000
hidden_node_dim = 50
dim_f = 100
batch_size = 20
val_num_samples = 10
model_name = 'dl4chem'
savepermol = True
savepreddir = 'savepreddir'
use_val = True
mpnn_steps = 5
alignment_type = 'kabsch'
tol = 1e-5
use_X=False
use_R=True
seed=1334
refine_steps=0
refine_mom=0.99
debug = False
useFF = False
w_reg = 1e-5
log_train_steps=100


data_dir = '/home/bb596/rds/hpc-work/dl4chem/'
dataset = 'COD'
COD_molset_50_path = data_dir + 'COD_molset_50.p'  
COD_molset_all_path = data_dir + 'COD_molset_all.p' 
COD_molvec_50_path = data_dir + 'COD_molvec_50.p'

# create directories to store results

ckptdir = './checkpoints/'
if not os.path.exists(args.ckptdir):
    os.makedirs(args.ckptdir)
    
eventdir = './events/'
train_eventdir = eventdir.split('/')
train_eventdir.insert(-1, 'train')
train_eventdir = '/'.join(train_eventdir)

valid_eventdir = eventdir.split('/')
valid_eventdir.insert(-1, 'valid')
valid_eventdir = '/'.join(valid_eventdir)

if not os.path.exists(args.train_eventdir):
    os.makedirs(args.train_eventdir)
if not os.path.exists(args.valid_eventdir):
    os.makedirs(args.valid_eventdir)

save_path = os.path.join(ckptdir, model_name + '_model.ckpt')

molvec_fname = data_dir + dataset + '_molvec_'+str(n_max)+'.p'
molset_fname = data_dir + dataset + '_molset_'+str(n_max)+'.p'

# load data

[D1, D2, D3, D4, D5] = pkl.load(open(molvec_fname,'rb'))
D1 = D1.todense()
D2 = D2.todense()
D3 = D3.todense()

ntrn = len(D5)-nval-ntst

[molsup, molsmi] = pkl.load(open(molset_fname,'rb'))

D1_trn = D1[:ntrn]
D2_trn = D2[:ntrn]
D3_trn = D3[:ntrn]
D4_trn = D4[:ntrn]
D5_trn = D5[:ntrn]
molsup_trn =molsup[:ntrn]
D1_val = D1[ntrn:ntrn+nval]
D2_val = D2[ntrn:ntrn+nval]
D3_val = D3[ntrn:ntrn+nval]
D4_val = D4[ntrn:ntrn+nval]
D5_val = D5[ntrn:ntrn+nval]
molsup_val =molsup[ntrn:ntrn+nval]
D1_tst = D1[ntrn+nval:ntrn+nval+ntst]
D2_tst = D2[ntrn+nval:ntrn+nval+ntst]
D3_tst = D3[ntrn+nval:ntrn+nval+ntst]
D4_tst = D4[ntrn+nval:ntrn+nval+ntst]
D5_tst = D5[ntrn+nval:ntrn+nval+ntst]
molsup_tst =molsup[ntrn+nval:ntrn+nval+ntst]
print ('::: num train samples is ')
print(D1_trn.shape, D3_trn.shape)

tm_trn, tm_val, tm_tst = None, None, None

del D1, D2, D3, D4, D5, molsup

if savepermol:
    savepreddir = os.path.join(savepreddir, dataset, "_val_" if use_val else "_test_")
    if not os.path.exists(args.savepreddir):
        os.makedirs(args.savepreddir)
        
model = CoordAE(n_max, dim_node, dim_edge, hidden_node_dim, dim_f, batch_size, val_num_samples, \
                    mpnn_steps=mpnn_steps, alignment_type=alignment_type, tol=tol,\
                    use_X=use_X, use_R=use_R, seed=seed, \
                    refine_steps=refine_steps, refine_mom=refine_mom)


NameError: name 'os' is not defined

In [None]:
model.train(D1_trn, D2_trn, D3_trn, D4_trn, D5_trn, molsup_trn, \
                D1_val, D2_val, D3_val, D4_val, D5_val, molsup_val, \
                load_path=args.loaddir, save_path=save_path, \
                train_event_path=args.train_eventdir, valid_event_path=args.valid_eventdir, \
                log_train_steps=args.log_train_steps, tm_trn=tm_trn, tm_val=tm_val, \
                w_reg=args.w_reg, \
                debug=args.debug, exp=exp)

In [None]:
model.test(D1_val, D2_val, D3_val, D4_val, D5_val, molsup_val, \
                    load_path=args.loaddir, tm_v=tm_val, debug=args.debug, \
                    savepred_path=args.savepreddir, savepermol=args.savepermol, useFF=args.useFF)

In [None]:
model.test(D1_tst, D2_tst, D3_tst, D4_tst, D5_tst, molsup_tst, \
                    load_path=args.loaddir, tm_v=tm_tst, debug=args.debug, \
                    savepred_path=args.savepreddir, savepermol=args.savepermol, useFF=args.useFF)