In [15]:
import pandas as pd
from tqdm import tqdm
import numpy as np
import os
from scipy.spatial import distance_matrix

import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.utils import shuffle
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

FOLDER_DIR = "./preprocessed/"

### Step 1 : create dataset

The goal is to create 5 numpy arrays : nodes_train, in_edges_train, out_edges_train, nodes_test, in_edges_test

In [4]:
train_df = pd.read_csv(os.path.join(FOLDER_DIR,'train_df.csv'))
test_df = pd.read_csv(os.path.join(FOLDER_DIR,'test_df.csv'))

train_structures_df = pd.read_csv(os.path.join(FOLDER_DIR,'train_structures_df.csv'))
test_structures_df = pd.read_csv(os.path.join(FOLDER_DIR,'test_structures_df.csv'))

# train_bonds and test_bonds come from BondFeatures.ipynb
train_bonds = pd.read_csv(os.path.join(FOLDER_DIR,'train_bonds.csv'))
test_bonds = pd.read_csv(os.path.join(FOLDER_DIR,'test_bonds.csv'))

In [5]:
train_df["molecule_index"] = pd.factorize(train_df["molecule_name"])[0]
test_df["molecule_index"] = pd.factorize(test_df["molecule_name"])[0]
train_df

Unnamed: 0,id,molecule_name,atom_index_0,atom_index_1,scalar_coupling_constant,x0,y0,z0,x1,y1,...,dist_z,1JHC,1JHN,2JHC,2JHH,2JHN,3JHC,3JHH,3JHN,molecule_index
0,0,dsgdb9nsd_000001,1,0,84.80760,0.002150,-0.006031,0.001976,-0.012698,1.085804,...,0.006025,True,False,False,False,False,False,False,False,0
1,1,dsgdb9nsd_000001,1,2,-11.25700,0.002150,-0.006031,0.001976,1.011731,1.463751,...,0.001700,False,False,False,True,False,False,False,False,0
2,2,dsgdb9nsd_000001,1,3,-11.25480,0.002150,-0.006031,0.001976,-0.540815,1.447527,...,0.878620,False,False,False,True,False,False,False,False,0
3,3,dsgdb9nsd_000001,1,4,-11.25430,0.002150,-0.006031,0.001976,-0.523814,1.437933,...,0.904421,False,False,False,True,False,False,False,False,0
4,4,dsgdb9nsd_000001,2,0,84.80740,1.011731,1.463751,0.000277,-0.012698,1.085804,...,0.007724,True,False,False,False,False,False,False,False,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3724006,4658993,dsgdb9nsd_133883,16,4,6.18566,-0.254709,0.115179,-2.320941,1.396720,-0.764049,...,2.766766,False,False,False,False,False,True,False,False,68008
3724007,4658994,dsgdb9nsd_133883,16,5,5.27455,-0.254709,0.115179,-2.320941,0.709493,0.478235,...,3.282113,False,False,False,False,False,True,False,False,68008
3724008,4658995,dsgdb9nsd_133883,16,6,1.52689,-0.254709,0.115179,-2.320941,-0.302582,0.693610,...,2.139460,False,False,True,False,False,False,False,False,68008
3724009,4658996,dsgdb9nsd_133883,16,7,92.46210,-0.254709,0.115179,-2.320941,0.402259,0.359544,...,0.838362,True,False,False,False,False,False,False,False,68008


In [6]:
train_structures_df["molecule_index"] = pd.factorize(train_structures_df["molecule_name"])[0]
test_structures_df["molecule_index"] = pd.factorize(test_structures_df["molecule_name"])[0]
train_structures_df

Unnamed: 0,molecule_name,atom_index,x,y,z,C,F,H,N,O,molecule_index
0,dsgdb9nsd_000001,0,-0.012698,1.085804,0.008001,True,False,False,False,False,0
1,dsgdb9nsd_000001,1,0.002150,-0.006031,0.001976,False,False,True,False,False,0
2,dsgdb9nsd_000001,2,1.011731,1.463751,0.000277,False,False,True,False,False,0
3,dsgdb9nsd_000001,3,-0.540815,1.447527,-0.876644,False,False,True,False,False,0
4,dsgdb9nsd_000001,4,-0.523814,1.437933,0.906397,False,False,True,False,False,0
...,...,...,...,...,...,...,...,...,...,...,...
1226160,dsgdb9nsd_133883,12,0.167157,-2.642346,0.003546,False,False,True,False,False,68008
1226161,dsgdb9nsd_133883,13,2.336668,-1.165247,0.799579,False,False,True,False,False,68008
1226162,dsgdb9nsd_133883,14,1.287517,1.303344,1.376396,False,False,True,False,False,68008
1226163,dsgdb9nsd_133883,15,1.160599,1.078773,-1.801647,False,False,True,False,False,68008


In [7]:
train_bonds["molecule_index"] = pd.factorize(train_bonds["molecule_name"])[0]
test_bonds["molecule_index"] = pd.factorize(test_bonds["molecule_name"])[0]
train_bonds[['nbond_1', 'nbond_1.5', 'nbond_2', 'nbond_3']] = pd.get_dummies(train_bonds['nbond'])
test_bonds[['nbond_1', 'nbond_1.5', 'nbond_2', 'nbond_3']] = pd.get_dummies(test_bonds['nbond'])
train_bonds

Unnamed: 0,molecule_name,atom_index_0,atom_index_1,nbond,L2dist,error,bond_type,molecule_index,nbond_1,nbond_1.5,nbond_2,nbond_3
0,dsgdb9nsd_000001,0,1,1.0,1.091953,0,1.0CH,0,True,False,False,False
1,dsgdb9nsd_000001,0,2,1.0,1.091952,0,1.0CH,0,True,False,False,False
2,dsgdb9nsd_000001,0,3,1.0,1.091946,0,1.0CH,0,True,False,False,False
3,dsgdb9nsd_000001,0,4,1.0,1.091948,0,1.0CH,0,True,False,False,False
4,dsgdb9nsd_000002,0,1,1.0,1.017190,0,1.0HN,1,True,False,False,False
...,...,...,...,...,...,...,...,...,...,...,...,...
1268463,dsgdb9nsd_133883,2,6,1.0,1.541542,0,1.0CC,68008,True,False,False,False
1268464,dsgdb9nsd_133883,3,4,1.0,1.482432,0,1.0CC,68008,True,False,False,False
1268465,dsgdb9nsd_133883,4,5,1.0,1.510342,0,1.0CC,68008,True,False,False,False
1268466,dsgdb9nsd_133883,5,6,1.0,1.541538,0,1.0CC,68008,True,False,False,False


First thing is to find the size of the ajdacency matrix which is the biggest number of atoms in a molecule in the dataset (size of the biggest molecule).

In [8]:
max_size_train = max(train_df.groupby('molecule_name')['atom_index_0'].max())
max_size_test = max(test_df.groupby('molecule_name')['atom_index_0'].max())

max_size = max(max_size_train, max_size_test) + 1 # We are given indexes so that goes from 0 to max_size_train or max_size_test
print(max_size)

29


This means that :

nodes_train.size = [nb_molecule_train, max_size, nb_features_nodes] = [68009, 29, 8]

nodes_test.size = [nb_molecule_test, max_size, nb_features_nodes] = [17003, 29, 8]

in_edges_train.size = [nb_molecule_train, max_size, max_size, nb_features_edges] = [68009, 29, 29, 16]

in_edges_test.size = [nb_molecule_test, max_size, max_size, nb_features_edges] = [17003, 29, 29, 16]

out_edges_train.size = [nb_molecule_train, max_size, max_size, 1] = [68009, 29, 29, 1]

Because the features for the nodes are : the atome, its position (x,y,z).
And the features for the edges are : the distance, dist_x, dist_y, dist_z, the type of the coupling. 

In [9]:
n_train = train_df['molecule_name'].nunique()
n_test = test_df['molecule_name'].nunique()
n_train, n_test

(68009, 17003)

In [46]:
def make_nodes(train_structures_df, test_structures_df):
    nodes_train = np.zeros((n_train, max_size, 8), dtype=np.float32)
    nodes_test = np.zeros((n_test, max_size, 8), dtype=np.float32)

    for df, nodes in zip([train_structures_df, test_structures_df], [nodes_train, nodes_test]):
        molecule_indices = df["molecule_index"].values
        atom_indices = df["atom_index"].values
        features = df[["x", "y", "z", "C", "F", "H", "N", "O"]].values

        nodes[molecule_indices, atom_indices] = features
 
    return nodes_train, nodes_test
    
def make_in_edges(train_df, test_df, train_structures_df, test_structures_df, train_bonds, test_bonds):
    in_edges_train = np.zeros((n_train, max_size, max_size, 16), dtype=np.float32)
    in_edges_test = np.zeros((n_test, max_size, max_size, 16), dtype=np.float32)

    # First, iterate through train_df and test_df
    for df, in_edges in zip([train_df, test_df], [in_edges_train, in_edges_test]):
        molecule_indices = df["molecule_index"].values
        atom_indices_0 = df["atom_index_0"].values
        atom_indices_1 = df["atom_index_1"].values
        features = df[["dist", "dist_x", "dist_y", "dist_z", '1JHC', '1JHN', '2JHC', '2JHH', '2JHN', '3JHC', '3JHH', '3JHN']].values

        in_edges[molecule_indices, atom_indices_0, atom_indices_1,:12] = features
        in_edges[molecule_indices, atom_indices_1, atom_indices_0,:12] = features

    # Then, iterate through train_structures_df and test_structures_df to complete the adjency matrix
    for df, in_edges in zip([train_structures_df, test_structures_df], [in_edges_train, in_edges_test]):
        for molecule_index, molecule_df in df.groupby("molecule_index"):
            features = np.zeros((max_size, max_size, 4))

            for i, coords_df in enumerate([molecule_df[["x", "y", "z"]], molecule_df["x"], molecule_df["y"], molecule_df["z"]]):
                coords = coords_df.values

                if i != 0:
                    coords = coords.reshape((len(coords), 1)) # Converts the 1D array into a 2D matrix

                dist = distance_matrix(coords, coords)
                features[:dist.shape[0], :dist.shape[1], i] = dist

            in_edges[molecule_index, :, :, :4] = features

    # Finally, add the bond features 
    for df, in_edges in zip([train_bonds, test_bonds], [in_edges_train, in_edges_test]):
        molecule_indices = df["molecule_index"].values
        atom_indices_0 = df["atom_index_0"].values
        atom_indices_1 = df["atom_index_1"].values
        features = df[['nbond_1', 'nbond_1.5', 'nbond_2', 'nbond_3']].values

        in_edges[molecule_indices, atom_indices_0, atom_indices_1,12:] = features
        in_edges[molecule_indices, atom_indices_1, atom_indices_0,12:] = features

    return in_edges_train, in_edges_test
    
def make_out_edges(train_df):

    out_edges_train = np.zeros((n_train, max_size, max_size), dtype=np.float32)

    molecule_indices = train_df["molecule_index"].values
    atom_indices_0 = train_df["atom_index_0"].values
    atom_indices_1 = train_df["atom_index_1"].values
    scc_values = train_df["scalar_coupling_constant"].values

    out_edges_train[molecule_indices, atom_indices_0, atom_indices_1] = scc_values
    out_edges_train[molecule_indices, atom_indices_1, atom_indices_0] = scc_values

    return out_edges_train


In [47]:
nodes_train, nodes_test = make_nodes(train_structures_df, test_structures_df)
nodes_train[3,1]

array([ 0.00231072, -0.01915859,  0.00192873,  0.        ,  0.        ,
        0.        ,  1.        ,  0.        ], dtype=float32)

In [48]:
in_edges_train, in_edges_test = make_in_edges(train_df, test_df, train_structures_df, test_structures_df, train_bonds, test_bonds)
in_edges_train[0,0,1]

array([1.091953  , 0.01484855, 1.0918355 , 0.00602488, 1.        ,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.        , 1.        , 0.        , 0.        ,
       0.        ], dtype=float32)

In [49]:
in_edges_train[0,2,1], in_edges_train[0,0,5], 

(array([1.7831198e+00, 1.0095804e+00, 1.4697825e+00, 1.6995457e-03,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 1.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
        0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
       dtype=float32),
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       dtype=float32))

In [50]:
out_edges_train = make_out_edges(train_df)
out_edges_train[-1,0,9], out_edges_train[0,1,0], out_edges_train[0,0,2], out_edges_train[0,3,0], out_edges_train[0,1,4], out_edges_train[0,3,4]

(105.769, 84.8076, 84.8074, 84.8093, -11.2543, -11.2543)

### Step 2 : train MPNN

In [51]:
print(out_edges_train.shape)
print(in_edges_train.shape)
print(nodes_train.shape)
print(nodes_test.shape)
print(in_edges_test.shape)

(68009, 29, 29)
(68009, 29, 29, 16)
(68009, 29, 8)
(17003, 29, 8)
(17003, 29, 29, 16)


In [52]:
out_edges_train = out_edges_train.reshape(-1,out_edges_train.shape[1]*out_edges_train.shape[2],1)
in_edges_train = in_edges_train.reshape(-1,in_edges_train.shape[1]*in_edges_train.shape[2],in_edges_train.shape[3])
in_edges_test  = in_edges_test.reshape(-1,in_edges_test.shape[1]*in_edges_test.shape[2],in_edges_test.shape[3])

nodes_train, in_edges_train, out_labels = shuffle(nodes_train, in_edges_train, out_edges_train)

In [53]:
print(nodes_train.shape)
print(in_edges_train.shape)
print(out_labels.shape)
print(nodes_test.shape)
print(in_edges_test.shape)

(68009, 29, 8)
(68009, 841, 16)
(68009, 841, 1)
(17003, 29, 8)
(17003, 841, 16)


In [54]:
class Message_Passer_NNM(nn.Module):
    def __init__(self, node_dim, nb_features_edge):
        super(Message_Passer_NNM, self).__init__()
        self.node_dim = node_dim
        self.linear = nn.Sequential(nn.Linear(nb_features_edge, self.node_dim * self.node_dim), nn.ReLU())
      
    def forward(self, node_j, edge_ij):

        #print("    ===== MESSAGE PASSER=========")
        #print('    node_j', node_j.shape)
        #print('    edge_ij', edge_ij.shape)
        A = self.linear(edge_ij)

        #print('    edge_ij embedded', A.shape)

        A = A.view(-1, self.node_dim, self.node_dim)

        #print('    reshape A so that matrix can be done', A.shape)
        
        node_j = node_j.view(-1, self.node_dim, 1)

        #print('    reshape node_j so that matrix can be done', node_j.shape)

        messages = torch.matmul(A, node_j)

        #print('    messages', messages.shape)
        messages = messages.view(-1, edge_ij.size(1), self.node_dim)

        #print('    messages reshaped', messages.shape)

        return messages
    

class Message_Agg(nn.Module):
    def __init__(self):
        super(Message_Agg, self).__init__()

    def forward(self, messages):
        #print('    ==========MESSAGE AGREG==========')
        s = torch.sum(messages, 2)
        #print('    sum after agregation of messages', s.shape)
        return s

class Update_Func_GRU(nn.Module):
    def __init__(self, state_dim):
        super(Update_Func_GRU, self).__init__()
        self.GRU = nn.GRU(state_dim, state_dim, batch_first=True)
        
    def forward(self, old_state, agg_messages):

        #print('    ==========UPDATE GRU==========')

        #print('    old state', old_state.shape)
        #print('    agg_messages', agg_messages.shape)
        
        n_nodes, node_dim = old_state.size(1), old_state.size(2)

        #print('    n_nodes', n_nodes)
        #print('    node_dim', node_dim)
        
        old_state = old_state.view(-1, 1, old_state.size(-1))

        #print('    old_state after reshape', old_state.shape)
        agg_messages = agg_messages.view(-1, 1, agg_messages.size(-1))
        #print('    agg_messages after reshape', agg_messages.shape)
        concat = torch.cat((old_state, agg_messages), dim=1)
        #print('    concat', concat.shape)
        
        activation, _ = self.GRU(concat)
        activation = activation[:,-1,:]
        #print('    activation', activation.shape)
        activation = activation.view(-1, n_nodes, node_dim)
        #print('    activation after reshape', activation.shape)
        
        return activation
    
class Edge_Regressor(nn.Module):
    def __init__(self, state_dim, nb_features_edge, intermediate_dim):
        super(Edge_Regressor, self).__init__()
        self.hidden_layer_1 = nn.Sequential(nn.Linear(2*state_dim + nb_features_edge, intermediate_dim), nn.ReLU())
        self.hidden_layer_2 = nn.Sequential(nn.Linear(intermediate_dim, intermediate_dim), nn.ReLU())
        self.output_layer = nn.Linear(intermediate_dim, 1)

        #print(intermediate_dim)
        
    def forward(self, nodes, edges):
        #print('  ============EDGE REGRESSOR===========')
        
        # Remember node dims
        n_nodes, node_dim = nodes.size(1), nodes.size(2)
        
        # Tile and reshape to match edges
        state_i = nodes.repeat(1, 1, n_nodes).view(-1, n_nodes * n_nodes, node_dim)
        state_j = nodes.repeat(1, n_nodes, 1)

        #print('state_i', state_i.shape)
        #print('state_j', state_j.shape)
        
        # Concatenate edges and nodes and apply MLP
        concat = torch.cat((state_i, edges, state_j), dim=2)

        #print('concat', concat.shape)
        
        activation_1 = self.hidden_layer_1(concat)  

        #print('activation_1', activation_1.shape)
        
        activation_2 = self.hidden_layer_2(activation_1)

        #print('activation_2', activation_2.shape)

        return self.output_layer(activation_2)

class MP_Layer(nn.Module):
    def __init__(self, state_dim, nb_features_edge):
        super(MP_Layer, self).__init__()
        self.message_passers = Message_Passer_NNM(node_dim=state_dim, nb_features_edge=nb_features_edge)
        self.message_aggs = Message_Agg()
        self.update_functions = Update_Func_GRU(state_dim=state_dim)
        self.state_dim = state_dim

    def forward(self, nodes, edges, mask):

        #print('  ============MP LAYER===========')
        
        n_nodes, node_dim = nodes.size(1), nodes.size(2)
        state_j = nodes.repeat(1, n_nodes, 1)

        #print('  n_nodes', n_nodes)
        #print('  node_dim', node_dim)
        #print('  state_j', state_j.shape)

        messages = self.message_passers(state_j, edges)

        # Multiply messages by the mask to ignore messages from non-existent nodes
        masked = messages * mask

        #print('  masked', masked.shape)

        masked = masked.view(messages.size(0), n_nodes, n_nodes, node_dim)

        #print('  masked after reshape', masked.shape)

        agg_m = self.message_aggs(masked)

        #print('  agg messages', agg_m.shape)

        updated_nodes = self.update_functions(nodes, agg_m)

        #print('  updated_nodes are nodes out', updated_nodes.shape)

        nodes_out = updated_nodes

        return nodes_out

class MPNN(nn.Module):
    def __init__(self, nb_features_node, nb_features_edge, out_int_dim, state_dim, T):
        super(MPNN, self).__init__()
        self.T = T
        self.embed = nn.Sequential(nn.Linear(nb_features_node, state_dim), nn.ReLU())
        self.MP = MP_Layer(state_dim, nb_features_edge)
        self.edge_regressor = Edge_Regressor(state_dim, nb_features_edge, out_int_dim)

    def forward(self, adj_input, nod_input):

        #print("=========MPNN=========")
        nodes = nod_input
        edges = adj_input

        #print('nodes input : ', nodes.shape)
        #print('edges input : ', edges.shape)

        # Get distances, and create a mask wherever 0 (i.e., non-existent nodes)
        # This also masks node self-interactions...
        # This assumes distance is last
        len_edges = edges.shape[-1]

        _, x = torch.split(edges, [len_edges - 1, 1], dim=2)

        #print('x', x.shape)
        
        mask = torch.where(x == 0, x, torch.ones_like(x))

        #print('mask', mask.shape)

        # Embed nodes to the chosen node dimension
        print(type(nodes))
        nodes = self.embed(nodes)

        #print('nodes before the MP layers', nodes.shape)
        #print('edges before the MP layers', edges.shape)

        # Run the T message-passing steps
        for mp in range(self.T):
            nodes = self.MP(nodes, edges, mask)

        #print('nodes after the MP layers', nodes.shape)

        # Regress the output values
        con_edges = self.edge_regressor(nodes, edges)

        return con_edges

In [55]:
mpnn = MPNN(nb_features_node = 8, nb_features_edge = 16, out_int_dim = 512, state_dim = 128, T = 4)

In [56]:
def log_mae(orig, preds):
    # Mask values for which no scalar coupling exists
    mask = orig != 0
    nums = orig[mask]
    preds = preds[mask]

    reconstruction_error = torch.log(torch.mean(torch.abs(nums - preds)))

    return reconstruction_error

In [57]:
class Set(Dataset):
    def __init__(self, in_nodes, in_edges, out_edges):
        self.nodes = in_nodes
        self.in_edges = in_edges
        self.out_edges = out_edges
    def __len__(self):
        return len(self.nodes)
    def __getitem__(self, idx):
        s1 = self.nodes[idx]
        s2 = self.in_edges[idx]
        s3 = self.out_edges[idx]
        return s1, s2, s3

train_set = Set(in_edges_train[:50000], nodes_train[:50000], out_edges_train[:60000]) 
val_set = Set(in_edges_train[50000:], nodes_train[50000:], out_edges_train[60000:]) 

train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
val_loader = DataLoader(train_set, batch_size=16, shuffle=True)

In [58]:
n_epochs = 10
learning_rate = 0.001
optimizer = torch.optim.Adam(params = mpnn.parameters(), lr=learning_rate)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

mpnn = mpnn.to(device)

for i in range(n_epochs):
    mpnn.train()
    total_loss = 0
    for batch in tqdm(train_loader):
        nodes, in_edges, out_edges = batch
        nodes, in_edges, out_edges = nodes.to(device), in_edges.to(device), out_edges.to(device)
        out = mpnn(nodes, in_edges)
        loss = log_mae(out_edges, out)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    average_loss = total_loss / len(train_loader)
    print("average train loss over an epoch :", average_loss)

    mpnn.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in tqdm(val_loader):
            nodes, in_edges, out_edges = batch
            nodes, in_edges, out_edges = nodes.to(device), in_edges.to(device), out_edges.to(device)
            out = mpnn(nodes, in_edges)
            loss = log_mae(out_edges, out)
            total_loss += loss.item()
    average_loss = total_loss / len(train_loader)
    print("average val loss", val_loss)

  0%|          | 0/3125 [00:00<?, ?it/s]

<class 'torch.Tensor'>


  0%|          | 1/3125 [00:05<4:29:17,  5.17s/it]

<class 'torch.Tensor'>


  0%|          | 2/3125 [00:08<3:32:53,  4.09s/it]

<class 'torch.Tensor'>


  0%|          | 2/3125 [00:09<4:11:54,  4.84s/it]


KeyboardInterrupt: 