In [158]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.utils import shuffle
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

In [19]:
datadir = "champs-scalar-coupling/"

nodes_train     = np.load(datadir + "champs-basic-graph/nodes_train.npz" )['arr_0']
in_edges_train  = np.load(datadir + "champs-basic-graph/in_edges_train.npz")['arr_0']
out_edges_train = np.load(datadir + "champs-basic-graph/out_edges_train.npz" )['arr_0']

nodes_test     = np.load(datadir + "champs-basic-graph/nodes_test.npz" )['arr_0']
in_edges_test  = np.load(datadir + "champs-basic-graph/in_edges_test.npz")['arr_0']

In [20]:
print(nodes_train.shape)
print(in_edges_train.shape)
print(out_edges_train.shape)
print(nodes_test.shape)
print(in_edges_test.shape)

(85003, 29, 5)
(85003, 29, 29, 16)
(85003, 29, 29, 1)
(45772, 29, 5)
(45772, 29, 29, 16)


In [21]:
out_labels = out_edges_train.reshape(-1,out_edges_train.shape[1]*out_edges_train.shape[2],1)
in_edges_train = in_edges_train.reshape(-1,in_edges_train.shape[1]*in_edges_train.shape[2],in_edges_train.shape[3])
in_edges_test  = in_edges_test.reshape(-1,in_edges_test.shape[1]*in_edges_test.shape[2],in_edges_test.shape[3])

nodes_train, in_edges_train, out_labels = shuffle(nodes_train, in_edges_train, out_labels)

In [23]:
print(nodes_train.shape)
print(in_edges_train.shape)
print(out_labels.shape)
print(nodes_test.shape)
print(in_edges_test.shape)

(85003, 29, 5)
(85003, 841, 16)
(85003, 841, 1)
(45772, 29, 5)
(45772, 841, 16)


In [164]:
class Message_Passer_NNM(nn.Module):
    def __init__(self, node_dim, nb_features_edge):
        super(Message_Passer_NNM, self).__init__()
        self.node_dim = node_dim
        self.linear = nn.Sequential(nn.Linear(nb_features_edge, self.node_dim * self.node_dim), nn.ReLU())
      
    def forward(self, node_j, edge_ij):

        #print("    ===== MESSAGE PASSER=========")
        #print('    node_j', node_j.shape)
        #print('    edge_ij', edge_ij.shape)
        A = self.linear(edge_ij)

        #print('    edge_ij embedded', A.shape)

        A = A.view(-1, self.node_dim, self.node_dim)

        #print('    reshape A so that matrix can be done', A.shape)
        
        node_j = node_j.view(-1, self.node_dim, 1)

        #print('    reshape node_j so that matrix can be done', node_j.shape)

        messages = torch.matmul(A, node_j)

        #print('    messages', messages.shape)
        messages = messages.view(-1, edge_ij.size(1), self.node_dim)

        #print('    messages reshaped', messages.shape)

        return messages

In [165]:
class Message_Agg(nn.Module):
    def __init__(self):
        super(Message_Agg, self).__init__()

    def forward(self, messages):
        #print('    ==========MESSAGE AGREG==========')
        s = torch.sum(messages, 2)
        #print('    sum after agregation of messages', s.shape)
        return s

In [166]:
class Update_Func_GRU(nn.Module):
    def __init__(self, state_dim):
        super(Update_Func_GRU, self).__init__()
        self.GRU = nn.GRU(state_dim, state_dim, batch_first=True)
        
    def forward(self, old_state, agg_messages):

        #print('    ==========UPDATE GRU==========')

        #print('    old state', old_state.shape)
        #print('    agg_messages', agg_messages.shape)
        
        n_nodes, node_dim = old_state.size(1), old_state.size(2)

        #print('    n_nodes', n_nodes)
        #print('    node_dim', node_dim)
        
        old_state = old_state.view(-1, 1, old_state.size(-1))

        #print('    old_state after reshape', old_state.shape)
        agg_messages = agg_messages.view(-1, 1, agg_messages.size(-1))
        #print('    agg_messages after reshape', agg_messages.shape)
        concat = torch.cat((old_state, agg_messages), dim=1)
        #print('    concat', concat.shape)
        
        activation, _ = self.GRU(concat)
        activation = activation[:,-1,:]
        #print('    activation', activation.shape)
        activation = activation.view(-1, n_nodes, node_dim)
        #print('    activation after reshape', activation.shape)
        
        return activation

In [167]:
class Edge_Regressor(nn.Module):
    def __init__(self, state_dim, nb_features_edge, intermediate_dim):
        super(Edge_Regressor, self).__init__()
        self.hidden_layer_1 = nn.Sequential(nn.Linear(2*state_dim + nb_features_edge, intermediate_dim), nn.ReLU())
        self.hidden_layer_2 = nn.Sequential(nn.Linear(intermediate_dim, intermediate_dim), nn.ReLU())
        self.output_layer = nn.Linear(intermediate_dim, 1)

        #print(intermediate_dim)
        
    def forward(self, nodes, edges):
        #print('  ============EDGE REGRESSOR===========')
        
        # Remember node dims
        n_nodes, node_dim = nodes.size(1), nodes.size(2)
        
        # Tile and reshape to match edges
        state_i = nodes.repeat(1, 1, n_nodes).view(-1, n_nodes * n_nodes, node_dim)
        state_j = nodes.repeat(1, n_nodes, 1)

        #print('state_i', state_i.shape)
        #print('state_j', state_j.shape)
        
        # Concatenate edges and nodes and apply MLP
        concat = torch.cat((state_i, edges, state_j), dim=2)

        #print('concat', concat.shape)
        
        activation_1 = self.hidden_layer_1(concat)  

        #print('activation_1', activation_1.shape)
        
        activation_2 = self.hidden_layer_2(activation_1)

        #print('activation_2', activation_2.shape)

        return self.output_layer(activation_2)

In [168]:
class MP_Layer(nn.Module):
    def __init__(self, state_dim, nb_features_edge):
        super(MP_Layer, self).__init__()
        self.message_passers = Message_Passer_NNM(node_dim=state_dim, nb_features_edge=nb_features_edge)
        self.message_aggs = Message_Agg()
        self.update_functions = Update_Func_GRU(state_dim=state_dim)
        self.state_dim = state_dim

    def forward(self, nodes, edges, mask):

        #print('  ============MP LAYER===========')
        
        n_nodes, node_dim = nodes.size(1), nodes.size(2)
        state_j = nodes.repeat(1, n_nodes, 1)

        #print('  n_nodes', n_nodes)
        #print('  node_dim', node_dim)
        #print('  state_j', state_j.shape)

        messages = self.message_passers(state_j, edges)

        # Multiply messages by the mask to ignore messages from non-existent nodes
        masked = messages * mask

        #print('  masked', masked.shape)

        masked = masked.view(messages.size(0), n_nodes, n_nodes, node_dim)

        #print('  masked after reshape', masked.shape)

        agg_m = self.message_aggs(masked)

        #print('  agg messages', agg_m.shape)

        updated_nodes = self.update_functions(nodes, agg_m)

        #print('  updated_nodes are nodes out', updated_nodes.shape)

        nodes_out = updated_nodes

        return nodes_out

In [183]:
class MPNN(nn.Module):
    def __init__(self, nb_features_node, nb_features_edge, out_int_dim, state_dim, T):
        super(MPNN, self).__init__()
        self.T = T
        self.embed = nn.Sequential(nn.Linear(nb_features_node, state_dim), nn.ReLU())
        self.MP = MP_Layer(state_dim, nb_features_edge)
        self.edge_regressor = Edge_Regressor(state_dim, nb_features_edge, out_int_dim)

    def forward(self, adj_input, nod_input):

        #print("=========MPNN=========")
        nodes = nod_input
        edges = adj_input

        #print('nodes input : ', nodes.shape)
        #print('edges input : ', edges.shape)

        # Get distances, and create a mask wherever 0 (i.e., non-existent nodes)
        # This also masks node self-interactions...
        # This assumes distance is last
        len_edges = edges.shape[-1]

        _, x = torch.split(edges, [len_edges - 1, 1], dim=2)

        #print('x', x.shape)
        
        mask = torch.where(x == 0, x, torch.ones_like(x))

        #print('mask', mask.shape)

        # Embed nodes to the chosen node dimension
        nodes = self.embed(nodes)

        #print('nodes before the MP layers', nodes.shape)
        #print('edges before the MP layers', edges.shape)

        # Run the T message-passing steps
        for mp in range(self.T):
            nodes = self.MP(nodes, edges, mask)

        #print('nodes after the MP layers', nodes.shape)

        # Regress the output values
        con_edges = self.edge_regressor(nodes, edges)

        return con_edges

In [184]:
input_edges = in_edges_train[:10]
input_nodes = nodes_train[:10]

In [185]:
mpnn = MPNN(nb_features_node = 5, nb_features_edge = 16, out_int_dim = 512, state_dim = 128, T = 4)

In [186]:
out = mpnn(torch.Tensor(input_edges), torch.Tensor(input_nodes))

In [187]:
out.shape

torch.Size([10, 841, 1])

In [188]:
def log_mae(orig, preds):
    # Mask values for which no scalar coupling exists
    mask = orig != 0
    nums = orig[mask]
    preds = preds[mask]

    reconstruction_error = torch.log(torch.mean(torch.abs(nums - preds)))

    return reconstruction_error

In [189]:
class Set(Dataset):
    def __init__(self, in_nodes, in_edges, out_edges):
        self.nodes = in_nodes
        self.in_edges = in_edges
        self.out_edges = out_edges
    def __len__(self):
        return len(self.nodes)
    def __getitem__(self, idx):
        s1 = self.nodes[idx]
        s2 = self.in_edges[idx]
        s3 = self.out_edges[idx]
        return s1, s2, s3

train_set = Set(in_edges_train[:60000], nodes_train[:60000], out_labels[:60000]) 
val_set = Set(in_edges_train[60000:], nodes_train[60000:], out_labels[60000:]) 

train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
val_loader = DataLoader(train_set, batch_size=16, shuffle=True)

In [190]:
n_epochs = 10
learning_rate = 0.001
optimizer = torch.optim.Adam(params = mpnn.parameters(), lr=learning_rate)

for i in range(n_epochs):
    mpnn.train()
    total_loss = 0
    for batch in tqdm(train_loader):
        nodes, in_edges, out_edges = batch
        out = mpnn(nodes, in_edges)
        loss = log_mae(out_edges, out)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    average_loss = total_loss / len(train_loader)
    print("average train loss over an epoch :", average_loss)

    mpnn.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in tqdm(val_loader):
            nodes, in_edges, out_edges = batch
            out = model(nodes, in_edges)
            loss = log_mae(out_edges, out)
            total_loss += loss.item()
    average_loss = total_loss / len(train_loader)
    print("average val loss", val_loss)

  0%|                                        | 4/3750 [00:13<3:29:49,  3.36s/it]


KeyboardInterrupt: 