This project follows along with the process presented in *A. Sanchez-Gonzalez, J. Godwin, T. Pfaff, R. Ying, J. Leskovec, and P. W. Battaglia. Learning to simulate complex physics with graph networks. In Proceedings of the 37th International Conference on Machine Learning, ICML 2020, 13–18 July 2020, Virtual Event, volume 119 of Proceedings of Machine Learning Research, pages 8459–8468. PMLR, 2020. URL http://proceedings.mlr.press/v119/sanchez-gonzalez20a.html.*

In [5]:
import torch.nn as nn
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch_geometric as pyg

In [6]:
class InteractionNetwork(pyg.nn.MessagePassing):
    def __init__(self, hidden_size, layers=3):
        super().__init__()
        self.lin_edge = MLP(hidden_size*3, hidden_size, layers)
        self.lin_node = MLP((hidden_size*2, hidden_size, layers))

In [7]:
# construct messages for graph edges
def message(self, x_i, x_j, edge_feature):
    x = torch.cat((x_i, x_j, edge_feature), dim=1)
    x.self.lin_edge(x)
    return x

In [8]:
# sum edge messages
def aggregate(self, inputs, index):
    out = torch_scatter.scatter(inputs, index, dim=self.node_dim, reduce="sum")
    return (inputs, out)

In [9]:
# update node and edge features
def forward(self, x, edge_index, edge_feature):
    edge_out, aggr = self.propagate(edge_index, x=(x, x), edge_feature=edge_feature)
    node_out = self.lin_node(torch.cat((x, aggr), dim=-1))
    edge_out = edge_feature + edge_out
    node_out = x + node_out
    return node_out, edge_out

In [None]:
class LearnedSimulator(torch.nn.module):
    def __init__(self, 
                 hidden_size=128,
                 n_mp_layers=10, # GNN layers
                 node_feature_dim=30,
                 edge_feature_dim==3,
                 dim=2, # 2d or 3d space
                 ):
        super().__init__()
        self.node_in = MLP(node_feature_dim, hidden_size, 3)
        self.edge_in = MLP(edge_feature_dim, hidden_size, 3)
        self.node_out = MLP(hidden_size, dim, 3)