In [2]:
# Create a GNN model for 1D data
#pip install torch_geometric --user --quiet
#pip install torch --user --quiet

In [3]:
# Importing libraries
import torch
import torch.nn.functional as F
import torch_geometric as pyg
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [4]:
# Create GNN model

class InteractionNetwork(pyg.nn.MessagePassing):
   def __init__(self, hidden_size, layers=3):
       super().__init__()
       self.lin_edge = MLP(hidden_size * 3, hidden_size, layers)
       self.lin_node = MLP(hidden_size * 2, hidden_size, layers)

In [5]:
def message(self, x_i, x_j, edge_feature):
    x = torch.cat((x_i, x_j, edge_feature), dim=-1)
    x = self.lin_edge(x)
    return x

In [6]:
def aggregate(self, inputs, index):
    out = torch_scatter.scatter(inputs, index, dim=self.node_dim, reduce="sum")
    return (inputs, out)

In [7]:
def forward(self, x, edge_index, edge_feature):
    edge_out, aggr = self.propagate(edge_index, x=(x, x), edge_feature=edge_feature)
    node_out = self.lin_node(torch.cat((x, aggr), dim=-1))
    edge_out = edge_feature + edge_out
    node_out = x + node_out
    return node_out, edge_out

In [8]:
class LearnedSimulator(torch.nn.Module):
   """Graph Network-based Simulators(GNS)"""
   def __init__(
       self,
       hidden_size=128,
       n_mp_layers=2, # number of GNN layers
       node_feature_dim=30,
       edge_feature_dim=3,
       dim=1, # dimension of the world, typically 2D or 3D
   ):
       super().__init__()
       self.node_in = MLP(node_feature_dim, hidden_size, 3)
       self.edge_in = MLP(edge_feature_dim, hidden_size, 3)
       self.node_out = MLP(hidden_size, dim, 3)
       self.layers = torch.nn.ModuleList([InteractionNetwork(hidden_size, 3) for _ in range(n_mp_layers)])
 
   def forward(self, edge_index, node_feature, edge_feature):
       # encoder
       node_feature = self.node_in(node_feature)
       edge_feature = self.edge_in(edge_feature)
       # processor
       for layer in self.layers:
           node_feature, edge_feature = layer(node_feature, edge_index, edge_feature=edge_feature)
       # decoder
       out = self.node_out(node_feature)
       return out

In [10]:
class MLP(torch.nn.Module):
    """Multi-Layer perceptron"""
    def __init__(self, input_size, hidden_size, output_size, layers, layernorm=True):
        super().__init__()
        self.layers = torch.nn.ModuleList()
        for i in range(layers):
            self.layers.append(torch.nn.Linear(
                input_size if i == 0 else hidden_size,
                output_size if i == layers - 1 else hidden_size,
            ))
            if i != layers - 1:
                self.layers.append(torch.nn.ReLU())
        if layernorm:
            self.layers.append(torch.nn.LayerNorm(output_size))
        self.reset_parameters()

    def reset_parameters(self):
        for layer in self.layers:
            if isinstance(layer, torch.nn.Linear):
                layer.weight.data.normal_(0, 1 / math.sqrt(layer.in_features))
                layer.bias.data.fill_(0)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

In [12]:
class InteractionNetwork(pyg.nn.MessagePassing):
    """Interaction Network as proposed in this paper:
    https://proceedings.neurips.cc/paper/2016/hash/3147da8ab4a0437c15ef51a5cc7f2dc4-Abstract.html"""
    def __init__(self, hidden_size, layers):
        super().__init__()
        self.lin_edge = MLP(hidden_size * 3, hidden_size, hidden_size, layers)
        self.lin_node = MLP(hidden_size * 2, hidden_size, hidden_size, layers)

    def forward(self, x, edge_index, edge_feature):
        edge_out, aggr = self.propagate(edge_index, x=(x, x), edge_feature=edge_feature)
        node_out = self.lin_node(torch.cat((x, aggr), dim=-1))
        edge_out = edge_feature + edge_out
        node_out = x + node_out
        return node_out, edge_out

    def message(self, x_i, x_j, edge_feature):
        x = torch.cat((x_i, x_j, edge_feature), dim=-1)
        x = self.lin_edge(x)
        return x

    def aggregate(self, inputs, index, dim_size=None):
        out = torch_scatter.scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce="sum")
        return (inputs, out)