In [1]:
import numpy as np
from typing import List, Dict

from rdkit import Chem
from scipy import sparse
import torch
from torch.utils.data import Dataset
from tap import Tap


from conformation.data_pytorch import Data
from conformation.distance_matrix import distmat_to_vec

# noinspection PyUnresolvedReferences
from torch.utils.data import DataLoader

import json
from torch.distributions.multivariate_normal import MultivariateNormal
import torch.nn as nn
from argparse import Namespace
from conformation.flows import NormalizingFlowModel

import torch.nn.functional as F

from conformation.utils import to_undirected

In [43]:
def nets2(input_dim: int, condition_dim: int, hidden_size: int) -> nn.Sequential:
    """
    RealNVP "s" neural network definition.
    :param condition_dim: Dimension of embeddings.
    :param input_dim: Data input dimension.
    :param hidden_size: Neural network hidden size.
    :return: nn.Sequential neural network.
    """
    return nn.Sequential(nn.Linear((condition_dim + 1), hidden_size), nn.LeakyReLU(), nn.Linear(hidden_size,
                                                                                                    hidden_size),
                         nn.LeakyReLU(), nn.Linear(hidden_size, input_dim), nn.Tanh())

def nett2(input_dim: int, condition_dim: int, hidden_size: int) -> nn.Sequential:
    """
    RealNVP "t" neural network definition.
    :param condition_dim: Dimension of embeddings.
    :param input_dim: Data input dimension.
    :param hidden_size: Neural network hidden size.
    :return: nn.Sequential neural network.
    """
    return nn.Sequential(nn.Linear((condition_dim + 1), hidden_size), nn.LeakyReLU(), nn.Linear(hidden_size,
                                                                                                    hidden_size),
                         nn.LeakyReLU(), nn.Linear(hidden_size, input_dim))

In [44]:
class Args(Tap):
    """
    System arguments.
    """
    data_path: str  # Path to metadata file
    num_epochs: int  # Number of training epochs
    batch_size: int = 10  # Batch size
    lr: float = 1e-4  # Learning rate
    hidden_size: int = 256  # Hidden size
    num_layers: int = 10  # Number of layers
    num_edge_features: int = 6  # Number of edge features
    final_linear_size: int = 1024  # Size of last linear layer
    num_vertex_features: int = 118  # Number of vertex features
    cuda: bool = False  # Cuda availability
    checkpoint_path: str = None  # Directory of checkpoint to load saved model
    save_dir: str  # Save directory
    log_frequency: int = 10  # Log frequency

In [45]:
class RelationalNetwork(torch.nn.Module):
    """ Relational network definition """

    def __init__(self, hidden_size=256, num_layers=32, num_edge_features=None, num_vertex_features=None,
                 final_linear_size=1024, cnf=True):
        super(RelationalNetwork, self).__init__()
        self.hidden_size = hidden_size  # Internal feature size
        self.num_layers = num_layers  # Number of relational layers
        self.num_edge_features = num_edge_features  # Number of input edge features
        self.num_vertex_features = num_vertex_features  # Number of input vertex features
        self.final_linear_size = final_linear_size  # Number of nodes in final linear layer
        self.edge_featurize = torch.nn.Linear(self.num_edge_features,
                                              self.hidden_size)  # Initial linear layer for featurization of edge feat.
        self.vertex_featurize = torch.nn.Linear(self.num_vertex_features,
                                                self.hidden_size)  # Initial layer for featurization of vertex features
        self.L_e = torch.nn.ModuleList([torch.nn.Linear(self.hidden_size, self.hidden_size) for _ in
                                        range(self.num_layers)])  # Linear layers for edges
        self.L_v = torch.nn.ModuleList([torch.nn.Linear(self.hidden_size, self.hidden_size) for _ in
                                        range(self.num_layers)])  # Linear layers for vertices
        self.edge_batch_norm = torch.nn.ModuleList(
            [torch.nn.BatchNorm1d(self.hidden_size) for _ in range(self.num_layers)])  # Batch norms for edges (\phi_e)
        self.vertex_batch_norm = torch.nn.ModuleList([torch.nn.BatchNorm1d(self.hidden_size) for _ in
                                                      range(self.num_layers)])  # Batch norms for vertices (\phi_v)
        self.gru = torch.nn.ModuleList(
            [torch.nn.GRU(self.hidden_size, self.hidden_size) for _ in range(self.num_layers)])  # GRU cells
        self.final_linear_layer = torch.nn.Linear(self.hidden_size, self.final_linear_size)  # Final linear layer
        self.output_layer = torch.nn.Linear(self.final_linear_size, 1)  # Output layer
        self.cnf = cnf

    def forward(self, batch):
        """
        Forward pass.
        :param batch: Data batch.
        :return:
        """
        e_ij_in = self.edge_featurize(batch.edge_attr)  # Featurization
        v_i_in = self.vertex_featurize(batch.x)

        for k in range(self.num_layers):
            e_ij = self.L_e[k](e_ij_in)  # Linear layer for edges
            v_i_prime = self.L_v[k](v_i_in)  # Linear layer for vertices
            e_ij_prime = F.relu(self.edge_batch_norm[k](torch.stack(
                [e_ij[edge_num] + v_i_prime[batch.edge_index[0][edge_num]] + v_i_prime[batch.edge_index[1][edge_num]]
                 for edge_num in range(
                    e_ij.size(0))])))  # Add pairwise vertex features to edge features followed by batch norm and ReLU
            undirected_edge_index = to_undirected(batch.edge_index,
                                                  batch.num_nodes)  # Full set of undirected edges for bookkeeping
            # noinspection PyTypeChecker
            v_i_e = torch.stack([torch.max(e_ij_prime[np.array([np.intersect1d(
                np.where(batch.edge_index[0] == min(vertex_num, i)),
                np.where(batch.edge_index[1] == max(vertex_num, i))) for i in np.array(
                undirected_edge_index[1][np.where(undirected_edge_index[0] == vertex_num)])]).flatten()], 0)[0] for
                                 vertex_num in
                                 range(batch.num_nodes)])  # Aggregate edge features
            gru_input = v_i_e.view(1, batch.num_nodes, self.hidden_size)  # Resize GRU input
            gru_hidden = v_i_in.view(1, batch.num_nodes, self.hidden_size)  # Resize GRU hidden
            gru_output, _ = self.gru[k](gru_input, gru_hidden)  # Compute GRU output
            v_i_c = F.relu(self.vertex_batch_norm[k](
                gru_output.view(batch.num_nodes, self.hidden_size)))  # Apply batch norm and ReLU to GRU output
            v_i_in = v_i_c + v_i_in  # Add residual connection to vertex input
            e_ij_in = e_ij_prime + e_ij_in  # Add residual connection to edge input

        e_ij_final = self.final_linear_layer(e_ij_in)  # Compute final linear layer
        preds = self.output_layer(e_ij_final)  # Output layer

        if self.cnf:
            return e_ij_in
        else:
            return preds


In [46]:
state = torch.load("model-99.pt", map_location=lambda storage, loc: storage)
loaded_args = Args().from_dict(state['args'])
loaded_state_dict = state['state_dict']

model = RelationalNetwork(loaded_args.hidden_size, loaded_args.num_layers, loaded_args.num_edge_features,
                          loaded_args.num_vertex_features, loaded_args.final_linear_size, cnf=True)
model.load_state_dict(loaded_state_dict)

<All keys matched successfully>

In [85]:
class CNF(nn.Module):
    """
    Performs a single layer of the RealNVP flow.
    """

    def __init__(self, nets: nn.Sequential, nett: nn.Sequential, mask: torch.Tensor, prior: MultivariateNormal, padding_dim: int) -> None:
        """
        :param nets: "s" neural network definition.
        :param nett: "t" neural network definition.
        :param mask: Mask identifying which components of the vector will be processed together in any given layer.
        :param prior: Base distribution.
        :return: None.
        """
        super(CNF, self).__init__()
        self.prior = prior
        self.mask = nn.Parameter(mask, requires_grad=False)
        self.t = nett
        self.s = nets
        self.padding_dim = padding_dim

    def forward(self, z: torch.Tensor, c: torch.Tensor, num: torch.Tensor) -> torch.Tensor:
        """
        Transform a sample from the base distribution or previous layer.
        :param c: Condition tensor.
        :param z: Sample from the base distribution or previous layer.
        :return: Processed sample (in the direction towards the target distribution).
        """
        if self.mask[0] == 1.0:
            mask = [torch.from_numpy(np.array([j < int(num[i]/2) for j in range(num[i].item())]).astype(np.float32)) for i in range(len(num))]
        else:
            mask = [torch.from_numpy(np.array([j >= int(num[i]/2) for j in range(num[i].item())]).astype(np.float32)) for i in range(len(num))]
        
        for i in range(len(mask)):
                padding = np.zeros(self.padding_dim)
                padding[:mask[i].shape[0]] = mask[i]
                mask[i] = padding
        mask = nn.Parameter(torch.tensor(mask, dtype=torch.float32), requires_grad=False)
        x = z
        x_ = x * mask
        c_ = c*mask.unsqueeze(2).repeat(1, 1, c.shape[2])
        combine = torch.cat((c_, x_.unsqueeze(2)), axis=2)
        combine_ = combine        
#         combine_ = combine.view(combine.shape[0], -1)
        s = self.s(combine_).sum(dim=2) * (1 - mask)
        t = self.t(combine_).sum(dim=2) * (1 - mask)
        x = x_ + (1 - mask) * (x * torch.exp(s) + t)
        return x

    def inverse(self, x: torch.Tensor, c: torch.Tensor, num: torch.Tensor) -> torch.Tensor:
        """
        Compute the inverse of a target sample or a sample from the next layer.
        :param c: Condition tensor.
        :param x: Sample from the target distribution or the next layer.
        :return: Inverse sample (in the direction towards the base distribution).
        """
        if self.mask[0] == 1.0:
            mask = [torch.from_numpy(np.array([j < int(num[i]/2) for j in range(num[i].item())]).astype(np.float32)) for i in range(len(num))]
        else:
            mask = [torch.from_numpy(np.array([j >= int(num[i]/2) for j in range(num[i].item())]).astype(np.float32)) for i in range(len(num))]

        for i in range(len(mask)):
            padding = np.zeros(self.padding_dim)
            padding[:mask[i].shape[0]] = mask[i]
            mask[i] = padding
        
        mask = nn.Parameter(torch.tensor(mask, dtype=torch.float32), requires_grad=False)
        log_det_j, z = x.new_zeros(x.shape[0]), x
        z_ = mask * z
        c_ = c*mask.unsqueeze(2).repeat(1, 1, c.shape[2])
        combine = torch.cat((c_, z_.unsqueeze(2)), axis=2)
        combine_ = combine
#         combine_ = combine.view(combine.shape[0], -1)
        s = self.s(combine_).sum(dim=2) * (1 - mask)
        t = self.t(combine_).sum(dim=2) * (1 - mask)
        z = (1 - mask) * (z - t) * torch.exp(-s) + z_
        return z

    def log_abs_det_jacobian(self, x: torch.Tensor, c: torch.Tensor, num: torch.Tensor) -> torch.Tensor:
        """
        Compute the logarithm of the absolute value of the determinant of the Jacobian for a sample in the forward
        direction.
        :param c: Condition tensor.
        :param x: Sample.
        :return: log abs det jacobian.
        """
        if self.mask[0] == 1.0:
            mask = [torch.from_numpy(np.array([j < int(num[i]/2) for j in range(num[i])]).astype(np.float32)) for i in range(len(num))]
        else:
            mask = [torch.from_numpy(np.array([j >= int(num[i]/2) for j in range(num[i])]).astype(np.float32)) for i in range(len(num))]

        for i in range(len(mask)):
            padding = np.zeros(self.padding_dim)
            padding[:mask[i].shape[0]] = mask[i]
            mask[i] = padding
        
        mask = nn.Parameter(torch.tensor(mask, dtype=torch.float32), requires_grad=False)
        log_det_j, z = x.new_zeros(x.shape[0]), x
        z_ = mask * z
        c_ = c*mask.unsqueeze(2).repeat(1, 1, c.shape[2])
        combine = torch.cat((c_, z_.unsqueeze(2)), axis=2)
        combine_ = combine
#         combine_ = combine.view(combine.shape[0], -1)
        s = self.s(combine_).sum(dim=2) * (1 - mask)
        log_det_j += s.sum(dim=1)
        return log_det_j


In [67]:
class CNFDataset(Dataset):
    """
    Dataset class for loading atomic pairwise distance information for molecules.
    """

    def __init__(self, metadata: List[Dict[str, str]], graph_model: RelationalNetwork, padding_dim: int, atom_types: List[int] = None, bond_types: List[float] = None):
        super(Dataset, self).__init__()
        self.metadata = metadata
        self.graph_model = graph_model
        if bond_types is None:
            self.bond_types = [0., 1., 1.5, 2., 3.]
        if atom_types is None:
            self.atom_types = [1, 6, 7, 8, 9]
        self.padding_dim = padding_dim

    def __len__(self) -> int:
        return len(self.metadata)

    def __getitem__(self, idx: int) -> torch.Tensor:
        _, data = distmat_to_vec(self.metadata[idx]['path'])
        smiles = self.metadata[idx]['smiles']
        dist_vec = torch.from_numpy(data)
        dist_vec = dist_vec.type(torch.float32)        
        
        data = Data()  # Create data object

        # Molecule from SMILES string
        smiles = self.metadata[idx]['smiles']  # Read smiles string
        mol = Chem.MolFromSmiles(smiles)
        mol = Chem.AddHs(mol)
        num_atoms = mol.GetNumAtoms()

        # Compute edge connectivity in COO format corresponding to a complete graph on num_nodes
        complete_graph = np.ones([num_atoms, num_atoms])  # Create an auxiliary complete graph
        complete_graph = np.triu(complete_graph, k=1)  # Compute an upper triangular matrix of the complete graph
        complete_graph = sparse.csc_matrix(complete_graph)  # Compute a csc style sparse matrix from this graph
        row, col = complete_graph.nonzero()  # Extract the row and column indices corresponding to non-zero entries
        row = torch.tensor(row, dtype=torch.long)
        col = torch.tensor(col, dtype=torch.long)
        data.edge_index = torch.stack([row, col])  # Edge connectivity in COO format (all possible edges)

        # Edge features
        # Create one-hot encoding
        one_hot_bond_features = np.zeros((len(self.bond_types), len(self.bond_types)))
        np.fill_diagonal(one_hot_bond_features, 1.)
        bond_to_one_hot = dict()
        for i in range(len(self.bond_types)):
            bond_to_one_hot[self.bond_types[i]] = one_hot_bond_features[i]

        # Extract atom indices participating in bonds and bond types
        bonds = []
        bond_types = []
        for bond in mol.GetBonds():
            bonds.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])
            bond_types.append([bond_to_one_hot[bond.GetBondTypeAsDouble()]])

        # Compute edge attributes: 1 indicates presence of bond, 0 no bond. This is concatenated with one-hot bond feat.
        full_edges = [list(data.edge_index[:, i].numpy()) for i in range(data.edge_index.shape[1])]
        no_bond = np.concatenate([np.array([0]), bond_to_one_hot[0]])
        a = np.array([1])
        edge_attr = [np.concatenate([a, bond_types[bonds.index(full_edges[i])][0]]) if full_edges[i] in bonds else
                     no_bond for i in range(len(full_edges))]
        data.edge_attr = torch.tensor(edge_attr, dtype=torch.float)

        # Vertex features: one-hot representation of atomic number
        # Create one-hot encoding
        one_hot_vertex_features = np.zeros((len(self.atom_types), len(self.atom_types)))
        np.fill_diagonal(one_hot_vertex_features, 1.)
        atom_to_one_hot = dict()
        for i in range(len(self.atom_types)):
            atom_to_one_hot[self.atom_types[i]] = one_hot_vertex_features[i]

        # one_hot_vertex_features = np.zeros((self.max_atomic_num, self.max_atomic_num))
        # np.fill_diagonal(one_hot_vertex_features, 1.)
        one_hot_features = np.array([atom_to_one_hot[atom.GetAtomicNum()] for atom in mol.GetAtoms()])
        data.x = torch.tensor(one_hot_features, dtype=torch.float) 
        
        condition = self.graph_model(data).squeeze(1)
        padding = torch.zeros([self.padding_dim, 256])
        padding[0:condition.shape[0], :] = condition
        condition = padding

        num_dist = torch.tensor(dist_vec.shape[0])
        
        padding = torch.zeros(self.padding_dim)
        padding[:dist_vec.shape[0]] = dist_vec
        dist_vec = padding

        return dist_vec, condition, num_dist 

    def __repr__(self) -> str:
        return '{}({})'.format(self.__class__.__name__, len(self))

In [57]:
state = torch.load("model-99.pt", map_location=lambda storage, loc: storage)
loaded_args = Args().from_dict(state['args'])
loaded_state_dict = state['state_dict']

model = RelationalNetwork(loaded_args.hidden_size, loaded_args.num_layers, loaded_args.num_edge_features,
                          loaded_args.num_vertex_features, loaded_args.final_linear_size)
model.load_state_dict(loaded_state_dict)

<All keys matched successfully>

In [58]:
# Load datasets
metadata = json.load(open("metadata/metadata.json"))
train_data = CNFDataset(metadata, model, 528)
train_data = DataLoader(train_data, 10)

In [59]:
train_data = iter(train_data)

In [60]:
data = next(train_data)
x, c, num = data[0], data[1], data[2]

In [86]:
# Define the base distribution
base_dist = MultivariateNormal(torch.zeros(528), torch.eye(528))
flow = CNF(nets2(528, 256, 256), nett2(528, 256, 256), torch.from_numpy(np.array([j >= int(28/2) for j in
                                                             range(28)]).astype(np.float32)), base_dist, 528)

In [87]:
z = flow(x, c, num)
inv = flow.inverse(z, c, num)
log_abs_det = flow.log_abs_det_jacobian(x, c, num)

In [15]:
x, inv, log_abs_det

(tensor([[1.4769, 1.0819, 1.0874,  ..., 0.0000, 0.0000, 0.0000],
         [1.5086, 2.4952, 2.4569,  ..., 1.7383, 4.4792, 4.6842],
         [1.4329, 1.0972, 1.1054,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [1.5001, 2.5477, 2.4571,  ..., 1.7705, 3.7356, 4.5172],
         [1.4757, 1.0996, 1.1901,  ..., 0.0000, 0.0000, 0.0000],
         [1.5200, 2.4217, 2.5123,  ..., 1.7269, 4.4235, 4.3279]]),
 tensor([[1.4769, 1.0819, 1.0874,  ..., 0.0000, 0.0000, 0.0000],
         [1.5086, 2.4952, 2.4569,  ..., 1.7383, 4.4792, 4.6842],
         [1.4329, 1.0972, 1.1054,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [1.5001, 2.5477, 2.4571,  ..., 1.7705, 3.7356, 4.5172],
         [1.4757, 1.0996, 1.1901,  ..., 0.0000, 0.0000, 0.0000],
         [1.5200, 2.4217, 2.5123,  ..., 1.7269, 4.4235, 4.3279]],
        grad_fn=<AddBackward0>),
 tensor([-1.5182,  7.1061, -1.5184,  7.0972, -1.5195,  7.1051, -1.5184,  7.0776,
         -1.5242,  7.1318], grad_fn=<AddBackward0>))

In [19]:
s = nets2(528, 256, 256)
t = nett2(528, 256, 256)

mask = [torch.from_numpy(np.array([j < int(num[i]/2) for j in range(num[i].item())]).astype(np.float32)) for i in range(len(num))]

for i in range(len(mask)):
        padding = np.zeros(528)
        padding[:mask[i].shape[0]] = mask[i]
        mask[i] = padding
mask = nn.Parameter(torch.tensor(mask, dtype=torch.float32), requires_grad=False)
x = z
x_ = x * mask
c_ = c*mask.unsqueeze(2).repeat(1, 1, c.shape[2])
combine = torch.cat((c_, x_.unsqueeze(2)), axis=2)
combine_ = combine.view(combine.shape[0], -1)
s_ = s(combine_) * (1 - mask)
t_ = t(combine_) * (1 - mask)
x = x_ + (1 - mask) * (x * torch.exp(s_) + t_)
x

tensor([[ 1.7815,  1.1466,  1.2161,  ..., -0.0106,  0.0567, -0.0506],
        [ 1.8438,  1.9414,  2.6195,  ...,  2.0657,  2.5292,  4.2434],
        [ 1.7313,  1.1643,  1.2346,  ..., -0.0105,  0.0567, -0.0507],
        ...,
        [ 1.8328,  1.9849,  2.6165,  ...,  2.0976,  2.0569,  4.0927],
        [ 1.7802,  1.1669,  1.3224,  ..., -0.0108,  0.0569, -0.0506],
        [ 1.8476,  1.8752,  2.6724,  ...,  2.0518,  2.4952,  3.9216]],
       grad_fn=<AddBackward0>)

In [20]:
combine.shape

torch.Size([10, 528, 257])

In [21]:
x.shape

torch.Size([10, 528])

In [26]:
test = torch.nn.Linear(257, 256)
test(combine).sum(dim=2).shape

torch.Size([10, 528])