# Graph Neural Network


**Slides:** [![View Slides](https://img.shields.io/badge/View-Presentation-yellow?style=flat-square&logo=googleslides&logoColor=white)](https://docs.google.com/presentation/d/1hDcMFqbOXPehlxjaCBAPCxdmiMJSzLj5FYeuXAo6Dg4/edit?usp=sharing)

## Graph Neural Networks (GNNs) and message passing

Although traditional physics simulators are powerful, there are some important drawbacks with them: (1) it is expensive and time-consuming to get high-quality results of large-scale simulations for traditional physical simulators; (2) to set up physical simulators, we need to have full knowledge of the physics parameters of the object and the environment, which are extremely hard to know in some cases.

Graph Neural Network (GNN) is a special subset of neural networks that take less structured data, such as a graph, as input, while other neural networks like Convolutional Neural Network (CNN) and Transformer, can only accept more structured data (e.g., grid and sequence). By “less structured”, it means that the input can have arbitrary shapes and sizes and can have complex topological relations.

In particle-based physics simulation, we have the unstructured position information of all the particles as the input, which inspires the idea of using a GNN.

*Permutation equivariance*

One key characteristic of GNN which distinguishes it from other neural networks is permutation equivalence. That is to say, the nodes in a graph do not have a canonical order, so how we “order” the nodes in a graph does not impact the results produced by GNNs.

Since particles of an object are “identical” in the particle-based simulation, they are permutation-equivariant when applying physics laws on them. Therefore, a permutation-equivariant model such as a GNN is suitable to simulate the interactions between particles.

### Graphs
Graphs are powerful means of representing interactions between physical systems.
A granular material media can be represented as a graph $ G=\left(V,E\right) $ consisting of a set of vertices ($\mathbf{v}_i \in\ V$) representing the soil grains and edges ($\mathbf{e}_{i,j} \in\ E$) connecting a pair of vertices ($\mathbf{v}_i$ and $\mathbf{v}_j$) representing the interaction relationship between grains.
We describe how graphs work by showing a simple example involving interaction between balls in a box (Figure 1a).
The state of the physical system (Figure 1a and 1d) can be encoded as a graph (Figure 1b and 1c). The vertices describe the balls, and the edges describe the directional interaction between them, shown as arrows in Figure 1b and 1c.
The state of the ball i is represented as a vertex feature vector $\mathbf{v}_i$ at $i$. The feature vector includes properties such as velocities, mass, and distance to the boundary.
The edge feature vector $\mathbf{e}_{i,j}$ includes the information about the interaction between balls $i$ and $j$ such as the relative distance between the balls.

Graphs offer a permutation invariant form of encoding data, where the interaction between vertices is independent of the order of vertices or their position in Euclidean space.
Rather,  graphs represent the interactions through the edge connection, not affected by the permutation of the vertices.
Therefore, graphs can efficiently represent the physical state of granular flow where numerous orderless particles interact by using vertices to represent particles and edges to their interaction.

### Graph neural networks (GNNs)
GNNs are a state-of-the-art deep learning architecture that can operate on a graph and learn the local interactions.
GNNs take a graph $G=\left(\mathbf{V},\mathbf{E}\right)$ at time t as an input, compute properties and propagate information through the network, termed as message passing, and output an updated graph $G^\prime=\left(\mathbf{V}^\prime,\mathbf{E}^\prime\right)$ with an identical structure, where $\mathbf{V}^\prime$ and $\mathbf{E}^\prime$ are the set of updated vertex and edge features ($\mathbf{v}_i^\prime$ and $\mathbf{e}_{i,\ j}^\prime$).
In the balls-in-a-box example, the GNN first takes the original graph $G=\left(\mathbf{V},\mathbf{E}\right)$  (Figure 1b) that describes the current state of the physical system ($\mathbf{X}^t$).
The GNN then updates the state of the physical system through message passing, which models the exchange of energy and momentum between the balls communicating through the edges, and returns an updated graph $G^\prime=\left(\mathbf{V}^\prime,\mathbf{E}^\prime\right)$ (Figure 1c).
After the GNN computation, we may decode G^\prime to extract useful information related to the future state of the physical system ($\mathbf{X}^{t+1}$) such as the next position or acceleration of the balls (Figure 1d).

![balls-in-a-box](https://github.com/chishiki-ai/sciml/blob/main/docs/04-gns/figs/balls-in-a-box.svg?raw=1)
*Figure. 1. An example of a graph and graph neural network (GNN) that process the graph (modified from Battaglia et al. (2018)):
(a) A state of the current physical system ($\mathbf{X}^t$) where the balls are bouncing in a box boundary;
(b) Graph representation of the physical system ($G$).
There are three vertices representing balls and six edges representing their directional interaction shown as arrows;
(c) The updated graph ($G^\prime$) that GNN outputs through message passing; (d) The predicted future state of the physical system ($\mathbf{X}^{t+1}$) (i.e., the positions of the balls at the next timestep) decoded from the updated graph.*

### Message passing
Message passing consists of three operations: message construction (Eq. 1), message aggregation (Eq. 2), and the vertex update function (Eq. 3).

$$
\begin{equation}
    \mathbf{e}_{i,j}^\prime=\phi_{\mathbf{\Theta}_\phi}\left(\mathbf{v}_i,\mathbf{v}_j,\mathbf{e}_{i,\ j}\right)
\end{equation}
$$

$$
\begin{equation}
    {\bar{\mathbf{v}}}_i=\Sigma_{j \in N\left(i\right)}\ \mathbf{e}_{i,j}^\prime
\end{equation}
$$

$$
\begin{equation}
    \mathbf{v}_i^\prime=\gamma_{\mathbf{\Theta}_\gamma}\left(\mathbf{v}_i,{\bar{\mathbf{v}}}_i\right)
\end{equation}
$$

The subscript $\mathbf{\Theta}_\phi$ and $\mathbf{\Theta}_\gamma$ represent a set of learnable parameters in each computation.
The message construction function $\phi_{\Theta_{\phi}}$ (Eq. 1) takes the feature vector of the receiver and sender vertices ($\mathbf{v}_i$ and $\mathbf{v}_j$) and the feature vector of the edge connecting them ($\mathbf{e}_{i,\ j}$) and returns an updated edge feature vector $\mathbf{e}_{i,j}^\prime$ as the output.
$\phi_{\Theta_{\phi}}$ is a matrix operation including the learnable parameter $\mathbf{\Theta}_\phi$.
The updated edge feature vector $\mathbf{e}_{i,j}^\prime$ is the message sent from vertex $j$ to $i$.
Figure 2a shows an example of constructing messages on edges directing to vertex 0 originating from vertices 1, 2, and 3 ($\mathbf{e}_{0,1}^\prime, \mathbf{e}_{0,2}^\prime, \mathbf{e}_{0,3}^\prime$).
Here, we define the message construction function $\phi_{\Theta_{\phi}}$ as $\left(\left(\mathbf{v}_i+\mathbf{v}_j\right)\times\mathbf{e}_{i,j}\right)\times\mathbf{\Theta}_\phi$.
The updated feature vector $\mathbf{e}_{0,\ 1}^\prime$ is computed as $\left(\left(\mathbf{v}_0+\mathbf{v}_1\right)\times\mathbf{e}_{0,1}\right)\times\mathbf{\Theta}_\phi$, where $\mathbf{v}_0$ and $\mathbf{v}_1$ are the receiver and sender vertex feature vectors, and $\mathbf{e}_{0,1}$ is their edge feature vector.
If we assume that all values of $\mathbf{\Theta}_\phi$ are 1.0 for simplicity, we obtain $\mathbf{e}_{0,\ 1}^\prime=(\left(\left[1,\ 0,\ 2\right]\right)+\left[1,\ 3,\ 2\right])\times\left[2,\ 1,\ 0\right]^T)\times1=[4,\ 3,\ 0]$.
Similarly, we compute the messages $\mathbf{e}_{0,\ 2}^\prime=\left[0,\ 3,\ 9\right]$ and $\mathbf{e}_{0,\ 3}^\prime=\left[3,\ 4,\ 9\right]$.

The next step in message passing is the message aggregation $\Sigma_{j \in N\left(i\right)}$ (Eq. 2), where $N\left(i\right)$ is the set of sender vertices j related to vertex $i$.
It collects all the messages directing to vertex $i$ and aggregates those into a single vector with the same dimension as the aggregated message (${\bar{\mathbf{v}}}_i$).
The aggregation rule can be element-wise vector summation or averaging, hence it is a permutation invariant computation.
In Figure 2a, the aggregated message $\bar{\mathbf{v}_0}=\left[7,10,18\right]$ is the element-wise summation of the messages directing to vertex 0 as $\bar{\mathbf{v}_o}=\mathbf{e}_{0,\ 1}^\prime+\ \mathbf{e}_{0,\ 2}^\prime+\ \mathbf{e}_{0,\ 3}^\prime$.

The final step of the message passing is updating vertex features using Eq. 3.
It takes the aggregated message (${\bar{\mathbf{v}}}_i$) and the current vertex feature vector $\mathbf{v}_i$, and returns an updated vertex feature vector $\mathbf{v}_i^\prime$, using predefined vector operations including the learnable parameter $\mathbf{\Theta}_\gamma$. Figure 2b shows an example of the update at vertex 0.
Here, we define the update function $\gamma_{\Theta_{\gamma}}$ as $\mathbf{\Theta}_\gamma\left(\mathbf{v}_i+{\bar{\mathbf{v}}}_i\right)$.
The updated feature vector $\mathbf{v}_0^\prime$ is computed as $\mathbf{\Theta}_\gamma\left(\mathbf{v}_0+{\bar{\mathbf{v}}}_\mathbf{0}\right)$.
Assuming all parameters in $\mathbf{\Theta}_\gamma$ are 1.0 for simpliticy, we obtain $\mathbf{v}_0^\prime=\left[1,\ 0,\ 2\right]+\left[7,\ 10,\ 18\right]=\left[8,10,20\right]$. Similarly, we update the other vertex features $(\mathbf{v}_1^\prime, \mathbf{v}_2^\prime, \mathbf{v}_3^\prime)$.

At the end of the message passing, the graph vertex and edge features ($\mathbf{v}_i$ and $\mathbf{e}_{i,\ j}$) are updated to $\mathbf{v}_i^\prime$ and $\mathbf{e}_{i,\ j}^\prime$.
The GNN may include multiple message passing steps to propagate the information further through the network.

![message_construction](https://github.com/chishiki-ai/sciml/blob/main/docs/04-gns/figs/message_construction.svg?raw=1)
(a)
![update](https://github.com/chishiki-ai/sciml/blob/main/docs/04-gns/figs/message_aggregate.svg?raw=1)
(b)

*Figure 2. An example of message passing on a graph:
(a) message construction directing to receiver vertex 0 $(\mathbf{e}_{0,\ 1}^\prime, \mathbf{e}_{0,\ 2}^\prime, \mathbf{e}_{0,\ 3}^\prime)$ and the resultant aggregated message $({\bar{\mathbf{v}}}_0)$;
(b) feature update at vertex 0 using ${\bar{\mathbf{v}}}_0$. Note that we assume $\mathbf{\Theta}_\phi$ and $\mathbf{\Theta}_r$ are 1.0 for the convenience of calculation.*

Unlike the example shown above, where we assume a constant value of 1.0 for the learnable parameters, in a supervised learning environment, the optimization algorithm will find a set of the best learnable parameters ($\mathbf{\Theta}_\phi, \mathbf{\Theta}_\gamma$) in the message passing operation.

## Graph Neural Network-based Simulator (GNS)

In this study, we use GNN as a surrogate simulator to model granular flow behavior.
Figure 3 shows an overview of the general concepts and structure of the GNN-based simulator (GNS).
Consider a granular flow domain represented as particles (Figure 3a).
In GNS, we represent the physical state of the granular domain at time t with a set of $\mathbf{x}_i^t$ describing the state and properties of each particle.
The GNS takes the current state of the granular flow $\mathbf{x}_t^i \in \mathbf{X}_t$ and predicts its next state ${\mathbf{x}_{i+1}^i \in\ bm{X}}_{t+1}$ (Figure 3a).
The GNS consists of two components: a parameterized function approximator $\ d_\mathbf{\Theta}$ and an updater function (Figure 3b).
The approximator $d_\theta$ take takes $\mathbf{X}_t$ as an input and outputs dynamics information ${\mathbf{y}_i^t \in \mathbf{Y}}_t$.
The updater then computes $\mathbf{X}_{t+1}$ using $\mathbf{Y}_t$ and $\mathbf{X}_t$.
Figure 3c shows the details of $d_\theta$ which consists of an encoder, a processor, and a decoder.
The encoder (Figure 3c-1) takes the state of the system $\mathbf{X}^t$ and embed it into a latent graph $G_0=\left(\mathbf{V}_0,\ \mathbf{E}_0\right)$ to represent the relationship between particles, where the vertices $\mathbf{v}_i^t \in \mathbf{V}_0$ contain latent information of the current particle state, and the edges $\mathbf{e}_{i,j}^t \in \mathbf{E}_0$ contain latent information of the pair-wise relationship between particles.
Next, the processer (Figure 3c-2) converts $G_0$ to $G_M$ with $M$ stacks of message passing GNN ($G_0\rightarrow\ G_1\rightarrow\cdots\rightarrow\ G_M$)  to compute the interaction between particles.
Finally, the decoder (Figure 3c-3) extracts dynamics of the particles ($\mathbf{Y}^t$) from $G_M$, such as the acceleration of the physical system.
The entire simulation (Figure 3a) involves running GNS surrogate model through $K$ timesteps predicting from the initial state $\mathbf{X}_0$ to $\mathbf{X}_K$ $(\mathbf{X}_0,\ \ \mathbf{X}_1,\ \ \ldots,\ \ \mathbf{X}_K$), updating at each step ($\mathbf{X}_t\rightarrow\mathbf{X}_{t+1}$)

![GNS](https://github.com/chishiki-ai/sciml/blob/main/docs/04-gns/figs/gns_structure.svg?raw=1)
*Figure 3. The structure of the graph neural network (GNN)-based physics simulator (GNS) for granular flow (modified from Sanchez-Gonzalez et al. (2020)):
(a) The entire simulation procedure using the GNS,
(b) The computation procedure of GNS and its composition, (c) The computation procedure of the parameterized function approximator $d_\theta$ and its composition.*

### Input
The input to the GNS, $\mathbf{x}_i^t \in \mathbf{X}^t$, is a vector consisting of the current particle position $\mathbf{p}_i^t$, the particle velocity context ${\dot{\mathbf{p}}}_i^{\le t}$, information on boundaries $\mathbf{b}_i^t$, and particle type embedding ${\mathbf{f}}$ (Eq. 4).
$\mathbf{x}_i^t$ will be used to construct vertex feature ($\mathbf{v}_i^t$) (Eq. 6).

$$
\begin{equation}
    \mathbf{x}_i^t=\left[\mathbf{p}_i^t,{\dot{\mathbf{p}}}_i^{\le t},\mathbf{b}_i^t,\mathbf{f}\right]
\end{equation}
$$

The velocity context ${\dot{\mathbf{p}}}_i^{\le t}$ includes the current and previous particle velocities for n timesteps $\left[{\dot{\mathbf{p}}}_i^{t-n},\cdots,\ {\dot{\mathbf{p}}}_i^t\right]$.
We use $n$=4 to include sufficient velocity context in the vertex feature $\mathbf{x}_i^t$.
Sanchez-Gonzalez et al. (2020) show that having $n$>1 significantly improves the model performance.
The velocities are computed using the finite difference of the position sequence (i.e.,  ${\dot{\mathbf{p}}}_i^t=\left(\mathbf{p}_i^t-\mathbf{p}_i^{t-1}\right)/\Delta t$).
For a 2D problem, $\mathbf{b}_i^t$ has four components each of which indicates the distance between particles and the four walls.
We normalize $\mathbf{b}_i^t$ by the connectivity radius, which is explained in the next section, and restrict it between 1.0 to 1.0. $\mathbf{b}_i^t$ is used to evaluate boundary interaction for a particle.
${\mathbf{f}}$ is a vector embedding describing a particle type.

In addition to $\mathbf{x}_i^t$, we define the interaction relationship between particles $i$ and $j$ as $\mathbf{r}_{i,\ j}^t$ using the distance and displacement of the particles in the current timestep (see Eq. 5).
The former reflects the level of interaction, and the latter reflects its spatial direction.
$\mathbf{r}_{i,\ j}^t$ will be used to construct edge features ($\mathbf{e}_{i,j}^t$).

$$
\begin{equation}
    \mathbf{r}_{i,j}^t=\left[(\mathbf{p}_i^t-\mathbf{p}_j^t),||\mathbf{p}_i^t-\mathbf{p}_j^t||\right]
\end{equation}
$$

### Encoder
The vertex and edge encoders ($\varepsilon_\Theta^v$ and $\varepsilon_\Theta^e$) convert $\mathbf{x}_i^t$ and $\mathbf{r}_{i,\ j}^t$ into the vertex and edge feature vectors ($\mathbf{v}_i^t$ and $\mathbf{e}_{i,j}^t$) (Eq. 6) and embed them into a latent graph $G_0=\left(\mathbf{V}_0, \mathbf{E}_0\right)$,  $\mathbf{v}_i^t \in \mathbf{V}_0$, $\mathbf{e}_{i,j}^t \in \mathbf{E}_0$.


$$
\begin{equation}
    \mathbf{v}_i^t=\varepsilon_\Theta^v\left(\mathbf{x}_i^t\right),\ \ \mathbf{e}_{r,s}^t=\varepsilon_\Theta^e\left(\mathbf{r}_{r,s}^t\right)
\end{equation}
$$

We use a two-layered 128-dimensional multi-layer perceptron (MLP) for the $\varepsilon_\Theta^v$ and $\varepsilon_\Theta^e$.
The MLP and optimization algorithm search for the best candidate for the parameter set $\Theta$ that estimates a proper way of representing the physical state of the particles and their relationship which will be embedded into $G_0$.

The edge encoder $\varepsilon_\Theta^v$ uses $\mathbf{x}_i^t$ (Eq. 4) without the current position of the particle ($\mathbf{p}_i^t$), but still with its velocities (${\dot{\mathbf{p}}}_i^{\le t}$), since velocity governs the momentum, and the interaction dynamics is independent of the absolute position of the particles.
Rubanova et al. (2022) confirmed that including position causes poorer model performance.
We only use $\mathbf{p}_i^t$ to predict the next position $\mathbf{p}_i^{t+1}$ based on the predicted velocity ${\dot{\mathbf{p}}}_i^{t+1}$ (Eq. 9).

We consider the interaction between two particles by constructing the edges between them only if vertices are located within a certain distance called connectivity radius $R$ (see the shaded circular area in Figure 3b).
The connectivity radius is a critical hyperparameter that governs how effectively the model learns the local interaction.
$R$ should be sufficiently large to include the local interaction as edges between particles but also to capture the global dynamics of the simulation domain.

### Processor
The processor performs message passing (based on Eq. 1-3) on the initial latent graph ($G_0$) from the encoder for M times ($G_0\rightarrow\ G_1\rightarrow\cdots\rightarrow\ G_M$) and returns a final updated graph $G_M$.
We use two-layered 128-dimensional MLPs for both message construction function $\phi_{\mathbf{\Theta}_\phi}$ and vertex update function $\gamma_{\mathbf{\Theta}_r}$, and element-wise summation for the message aggregation function $\mathbf{\Sigma}_{j \in N\left(i\right)}$ in Eq. 1-3.
We set $M$=10 to ensure sufficient message propagation through the network.
These stacks of message passing models the propagation of information through the network of particles.

### Decoder
The decoder $\delta_\Theta^v$ extracts the dynamics $\mathbf{y}_i^t \in \mathbf{Y}^t$ of the particles from the vertices $\mathbf{v}_i^t$ (Eq. 7) using the final graph $G_M$.
We use a two-layered 128-dimensional MLP for $\delta_\Theta^v$ which learns to extract the relevant particle dynamics from $G_M$.


$$
\begin{equation}
\mathbf{y}_i^t=\delta_\Theta^v\left(\mathbf{v}_i^t\right)
\end{equation}
$$


### Updater
We use the dynamics $\mathbf{y}_i^t$ to predict the velocity and position of the particles at the next timestep (${\dot{\mathbf{p}}}_i^{t+1}$ and  $\mathbf{p}_i^{t+1}$) based on Euler integration (Eq. 8 and Eq. 9), which makes $\mathbf{y}_i^t$ analogous to acceleration  ${\ddot{\mathbf{p}}}_i^t$.

$$
\begin{equation}
{\dot{\mathbf{p}}}_i^{t+1}={\dot{\mathbf{p}}}_i^t+\mathbf{y}_i^t\Delta t
\end{equation}
$$

$$
\begin{equation}
\mathbf{p}_i^{t+1}=\mathbf{p}_i^t+{\dot{\mathbf{p}}}_i^{t+1}\Delta t
\end{equation}
$$

Based on the new particle position and velocity, we update $\mathbf{x}_i^t \in \mathbf{X}^t$ (Eq. 5) to $\mathbf{x}_i^{t+1} \in \mathbf{X}^{t+1}$. The updated physical state $\mathbf{X}^{t+1}$ is then used to predict the position and velocity for the next timestep.

The updater imposes inductive biases to GNS to improve learning efficiency.
GNS does not directly predict the next position from the current position and velocity (i.e., $\mathbf{p}_i^{t+1}=GNS\left(\mathbf{p}_i^t,\ {\dot{\mathbf{p}}}_i^t\right)$) which has to learn the static motion and inertial motion.
Instead, it uses (1) the inertial prior (Eq. 8) where the prediction of next velocity ${\dot{\mathbf{p}}}_i^{t+1}$ should be based on the current velocity ${\dot{\mathbf{p}}}_i^t$  and (2) the static prior (Eq. 9) where the prediction of the next position $\mathbf{p}_i^{t+1}$ should be based on the current position $\mathbf{p}_i^t$.
These make GNS to be trivial to learn static and inertial motions that is already certain and focus on learning dynamics which is uncertain.
In addition, since the dynamics of particles are not controlled by their absolute position, GNS prediction can be generalizable to other geometric conditions.

In [None]:
import os
import json
import torch
import numpy as np
import torch_geometric as pyg

GNN layers such as a IN layer can be easily implemented in PyTorch Geometric (PyG). In PyG, a GNN layer is generally implemented as a subclass of the MessagePassing class. We follow this convention and define the InteractionNetwork Class as follows

In [None]:
class InteractionNetwork(pyg.nn.MessagePassing):
   def __init__(self, hidden_size, layers=3):
       super().__init__()
       self.lin_edge = MLP(hidden_size * 3, hidden_size, layers)
       self.lin_node = MLP(hidden_size * 2, hidden_size, layers)

(1) Construct a message for each edge of the graph. The message is generated by concatenating the features of the edge’s two nodes and the feature of the edge itself, and transforming the concatenated vector with an MLP.

In [None]:
def message(self, x_i, x_j, edge_feature):
    x = torch.cat((x_i, x_j, edge_feature), dim=-1)
    x = self.lin_edge(x)
    return x

(2) Aggregate (sum up) the messages of all the incoming edges for each node.

In [None]:
def aggregate(self, inputs, index):
    out = torch_scatter.scatter(inputs, index, dim=self.node_dim, reduce="sum")
    return (inputs, out)

(3) Update node features and edge features. Each edge’s new feature is the sum of its old feature and the message on the edge. Each node’s new feature is determined by its old feature and the aggregation of messages.

In [None]:
def forward(self, x, edge_index, edge_feature):
    edge_out, aggr = self.propagate(edge_index, x=(x, x), edge_feature=edge_feature)
    node_out = self.lin_node(torch.cat((x, aggr), dim=-1))
    edge_out = edge_feature + edge_out
    node_out = x + node_out
    return node_out, edge_out

Let’s include the encoder, the processor and the decoder together! Before GNN layers, input features are transformed by MLP so that the expressiveness of GNN is improved without increasing GNN layers. After GNN layers, final outputs (accelerations of particles in our case) are extracted from features generated by GNN layers to meet the requirement of the task.

In [None]:
class LearnedSimulator(torch.nn.Module):
   """Graph Network-based Simulators(GNS)"""
   def __init__(
       self,
       hidden_size=128,
       n_mp_layers=10, # number of GNN layers
       node_feature_dim=30,
       edge_feature_dim=3,
       dim=2, # dimension of the world, typically 2D or 3D
   ):
       super().__init__()
       self.node_in = MLP(node_feature_dim, hidden_size, 3)
       self.edge_in = MLP(edge_feature_dim, hidden_size, 3)
       self.node_out = MLP(hidden_size, dim, 3)
       self.layers = torch.nn.ModuleList([InteractionNetwork(hidden_size, 3) for _ in range(n_mp_layers)])

   def forward(self, edge_index, node_feature, edge_feature):
       # encoder
       node_feature = self.node_in(node_feature)
       edge_feature = self.edge_in(edge_feature)
       # processor
       for layer in self.layers:
           node_feature, edge_feature = layer(node_feature, edge_index, edge_feature=edge_feature)
       # decoder
       out = self.node_out(node_feature)
       return out

## Overview

**Before we get started:**

- This notebook includes a concise PyG implementation of the paper ***Learning to Simulate Complex Physics with Graph Networks*. We adapted our code from the open-source tensorflow implementation by DeepMind.
    - Link to the pdf of this paper: https://arxiv.org/abs/2002.09405
    - Link to Deepmind's implementation: https://github.com/deepmind/deepmind-research/tree/master/learning_to_simulate
    - Link to the video site by DeepMind: https://sites.google.com/view/learning-to-simulate
- Make sure to **sequentially run all the cells in each section**, so that the intermediate variables / packages will carry over to the next cell.
- Feel free to make a copy to your own drive to play around with it! Have fun with this tutorial!

## Dataset

The dataset WaterDropSmall includes simulations of dropping water to the ground rendered in a particle-based physics simulator. We will download this dataset to the folder `temp/datasets` in the file system. You can inspect the downloaded files on the **Files** menu on the left of this notebook.

The `metadata.json` file in the dataset includes the following information:
1. The sequence length of each video data point
2. The dimensionality, 2d or 3d
3. The box bounds, which specify the bounding box for the scene
4. The default connectivity radius, which defines the size of each particle's neighborhood
5. The statistics for normalization, such as the mean and standard deviation of the velocity and acceleration of particles


Each data point in the dataset includes the following information:
1. The type of the particles, such as water
2. The particle positions at each frame in the video

In [None]:
DATASET_NAME = "WaterDropSample"
OUTPUT_DIR = "./WaterDropSample"

## Data Preprocessing

Since we cannot apply the raw data in the dataset to train the GNN model directly, we need to go through the following steps to convert the raw data into graphs with descriptive node features and edge features:
1. Apply noise to the trajectory to have more diverse training examples
1. Construct the graph based on the distance between particles
1. Extract node-level features: particle velocities and their distance to the boundary
1. Extract edge-level features: displacement and distance between particles

If you are not interested in the data pipeline, your can skip to the end of this section. There is a detailed explanation and visualization of one data point.

In [None]:
import json
import numpy as np
import torch_geometric as pyg

def generate_noise(position_seq, noise_std):
    """Generate noise for a trajectory"""
    velocity_seq = position_seq[:, 1:] - position_seq[:, :-1]
    time_steps = velocity_seq.size(1)
    velocity_noise = torch.randn_like(velocity_seq) * (noise_std / time_steps ** 0.5)
    velocity_noise = velocity_noise.cumsum(dim=1)
    position_noise = velocity_noise.cumsum(dim=1)
    position_noise = torch.cat((torch.zeros_like(position_noise)[:, 0:1], position_noise), dim=1)
    return position_noise


def preprocess(particle_type, position_seq, target_position, metadata, noise_std):
    """Preprocess a trajectory and construct the graph"""
    # apply noise to the trajectory
    position_noise = generate_noise(position_seq, noise_std)
    position_seq = position_seq + position_noise

    # calculate the velocities of particles
    recent_position = position_seq[:, -1]
    velocity_seq = position_seq[:, 1:] - position_seq[:, :-1]

    # construct the graph based on the distances between particles
    n_particle = recent_position.size(0)
    edge_index = pyg.nn.radius_graph(recent_position, metadata["default_connectivity_radius"], loop=True, max_num_neighbors=n_particle)

    # node-level features: velocity, distance to the boundary
    normal_velocity_seq = (velocity_seq - torch.tensor(metadata["vel_mean"])) / torch.sqrt(torch.tensor(metadata["vel_std"]) ** 2 + noise_std ** 2)
    boundary = torch.tensor(metadata["bounds"])
    distance_to_lower_boundary = recent_position - boundary[:, 0]
    distance_to_upper_boundary = boundary[:, 1] - recent_position
    distance_to_boundary = torch.cat((distance_to_lower_boundary, distance_to_upper_boundary), dim=-1)
    distance_to_boundary = torch.clip(distance_to_boundary / metadata["default_connectivity_radius"], -1.0, 1.0)

    # edge-level features: displacement, distance
    dim = recent_position.size(-1)
    edge_displacement = (torch.gather(recent_position, dim=0, index=edge_index[0].unsqueeze(-1).expand(-1, dim)) -
                   torch.gather(recent_position, dim=0, index=edge_index[1].unsqueeze(-1).expand(-1, dim)))
    edge_displacement /= metadata["default_connectivity_radius"]
    edge_distance = torch.norm(edge_displacement, dim=-1, keepdim=True)

    # ground truth for training
    if target_position is not None:
        last_velocity = velocity_seq[:, -1]
        next_velocity = target_position + position_noise[:, -1] - recent_position
        acceleration = next_velocity - last_velocity
        acceleration = (acceleration - torch.tensor(metadata["acc_mean"])) / torch.sqrt(torch.tensor(metadata["acc_std"]) ** 2 + noise_std ** 2)
    else:
        acceleration = None

    # return the graph with features
    graph = pyg.data.Data(
        x=particle_type,
        edge_index=edge_index,
        edge_attr=torch.cat((edge_displacement, edge_distance), dim=-1),
        y=acceleration,
        pos=torch.cat((velocity_seq.reshape(velocity_seq.size(0), -1), distance_to_boundary), dim=-1)
    )
    return graph

## Operation Modes of GNS

The GNS works in two modes: one-step mode and rollout mode. In one-step mode, the GNS always makes predictions with ground-truth inputs. In rollout mode, the GNS predicts positions of particles in the next step based on its own predictions in the previous step. As a result, errors accumulate over time for rollout mode.

![gns-modes](https://github.com/chishiki-ai/sciml/blob/main/docs/04-gns/figs/gns-modes.webp?raw=1)

### One Step Dataset

Each datapoint in this dataset contains trajectories sliced to short time windows. We will use this dataset in the training phase because the history of particles' states are necessary for the model to make predictions. But in the meantime, since long-horizon prediction is usually inaccurate and time-consuming, we sliced the trajectories to short time windows to improve the perfomance of the model.

In [None]:
class OneStepDataset(pyg.data.Dataset):
    def __init__(self, data_path, split, window_length=7, noise_std=0.0, return_pos=False):
        super().__init__()

        # load dataset from the disk
        with open(os.path.join(data_path, "metadata.json")) as f:
            self.metadata = json.load(f)
        with open(os.path.join(data_path, f"{split}_offset.json")) as f:
            self.offset = json.load(f)
        self.offset = {int(k): v for k, v in self.offset.items()}
        self.window_length = window_length
        self.noise_std = noise_std
        self.return_pos = return_pos

        self.particle_type = np.memmap(os.path.join(data_path, f"{split}_particle_type.dat"), dtype=np.int64, mode="r")
        self.position = np.memmap(os.path.join(data_path, f"{split}_position.dat"), dtype=np.float32, mode="r")

        for traj in self.offset.values():
            self.dim = traj["position"]["shape"][2]
            break

        # cut particle trajectories according to time slices
        self.windows = []
        for traj in self.offset.values():
            size = traj["position"]["shape"][1]
            length = traj["position"]["shape"][0] - window_length + 1
            for i in range(length):
                desc = {
                    "size": size,
                    "type": traj["particle_type"]["offset"],
                    "pos": traj["position"]["offset"] + i * size * self.dim,
                }
                self.windows.append(desc)

    def len(self):
        return len(self.windows)

    def get(self, idx):
        # load corresponding data for this time slice
        window = self.windows[idx]
        size = window["size"]
        particle_type = self.particle_type[window["type"]: window["type"] + size].copy()
        particle_type = torch.from_numpy(particle_type)
        position_seq = self.position[window["pos"]: window["pos"] + self.window_length * size * self.dim].copy()
        position_seq.resize(self.window_length, size, self.dim)
        position_seq = position_seq.transpose(1, 0, 2)
        target_position = position_seq[:, -1]
        position_seq = position_seq[:, :-1]
        target_position = torch.from_numpy(target_position)
        position_seq = torch.from_numpy(position_seq)

        # construct the graph
        with torch.no_grad():
            graph = preprocess(particle_type, position_seq, target_position, self.metadata, self.noise_std)
        if self.return_pos:
          return graph, position_seq[:, -1]
        return graph

### Rollout Dataset

Each datapoint in this dataset contains trajectories of particles over 1000 time frames. This dataset is used in the evaluation phase to measure the model's ability to makie long-horizon predictions.

In [None]:
class RolloutDataset(pyg.data.Dataset):
    def __init__(self, data_path, split, window_length=7):
        super().__init__()

        # load data from the disk
        with open(os.path.join(data_path, "metadata.json")) as f:
            self.metadata = json.load(f)
        with open(os.path.join(data_path, f"{split}_offset.json")) as f:
            self.offset = json.load(f)
        self.offset = {int(k): v for k, v in self.offset.items()}
        self.window_length = window_length

        self.particle_type = np.memmap(os.path.join(data_path, f"{split}_particle_type.dat"), dtype=np.int64, mode="r")
        self.position = np.memmap(os.path.join(data_path, f"{split}_position.dat"), dtype=np.float32, mode="r")

        for traj in self.offset.values():
            self.dim = traj["position"]["shape"][2]
            break

    def len(self):
        return len(self.offset)

    def get(self, idx):
        traj = self.offset[idx]
        size = traj["position"]["shape"][1]
        time_step = traj["position"]["shape"][0]
        particle_type = self.particle_type[traj["particle_type"]["offset"]: traj["particle_type"]["offset"] + size].copy()
        particle_type = torch.from_numpy(particle_type)
        position = self.position[traj["position"]["offset"]: traj["position"]["offset"] + time_step * size * self.dim].copy()
        position.resize(traj["position"]["shape"])
        position = torch.from_numpy(position)
        data = {"particle_type": particle_type, "position": position}
        return data

### Visualize a graph in the dataset

Each data point in the dataset is a `pyg.data.Data` object which describes a graph. We explain the contents of the first data point, and visualize the graph.

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import networkx as nx

dataset_sample = OneStepDataset(OUTPUT_DIR, "valid", return_pos=True)
graph, position = dataset_sample[0]

print(f"The first item in the valid set is a graph: {graph}")
print(f"This graph has {graph.num_nodes} nodes and {graph.num_edges} edges.")
print(f"Each node is a particle and each edge is the interaction between two particles.")
print(f"Each node has {graph.num_node_features} categorial feature (Data.x), which represents the type of the node.")
print(f"Each node has a {graph.pos.size(1)}-dim feature vector (Data.pos), which represents the positions and velocities of the particle (node) in several frames.")
print(f"Each edge has a {graph.num_edge_features}-dim feature vector (Data.edge_attr), which represents the relative distance and displacement between particles.")
print(f"The model is expected to predict a {graph.y.size(1)}-dim vector for each node (Data.y), which represents the acceleration of the particle.")

# remove directions of edges, because it is a symmetric directed graph.
nx_graph = pyg.utils.to_networkx(graph).to_undirected()
# remove self loops, because every node has a self loop.
nx_graph.remove_edges_from(nx.selfloop_edges(nx_graph))
plt.figure(figsize=(7, 7))
nx.draw(nx_graph, pos={i: tuple(v) for i, v in enumerate(position)}, node_size=50)
plt.show()

## GNN Model

We will walk through the implementation of the GNN model in this section!

### Helper class

We first define a class for Multi-Layer Perceptron (MLP). This class generates an MLP given the width and the depth of it. Because MLPs are used in several places of the GNN, this helper class will make the code cleaner.

In [None]:
import math
import torch_scatter

class MLP(torch.nn.Module):
    """Multi-Layer perceptron"""
    def __init__(self, input_size, hidden_size, output_size, layers, layernorm=True):
        super().__init__()
        self.layers = torch.nn.ModuleList()
        for i in range(layers):
            self.layers.append(torch.nn.Linear(
                input_size if i == 0 else hidden_size,
                output_size if i == layers - 1 else hidden_size,
            ))
            if i != layers - 1:
                self.layers.append(torch.nn.ReLU())
        if layernorm:
            self.layers.append(torch.nn.LayerNorm(output_size))
        self.reset_parameters()

    def reset_parameters(self):
        for layer in self.layers:
            if isinstance(layer, torch.nn.Linear):
                layer.weight.data.normal_(0, 1 / math.sqrt(layer.in_features))
                layer.bias.data.fill_(0)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

### GNN layers

In the following code block, we implement one type of GNN layer named `InteractionNetwork` (IN), which is proposed by the paper *Interaction Networks for Learning about Objects,
Relations and Physics*.

For a graph $G$, let the feature of node $i$ be $v_i$, and the feature of edge $(i, j)$ be $e_{i, j}$. There are three stages for IN to generate new features of nodes and edges.

1. **Message generation.** If there is an edge pointing from node $i$ to node $j$, node $i$ sends a message to node $j$. The message carries the information of the edge and its two nodes, so it is generated by the following equation $\mathrm{Msg}_{i,j} = \mathrm{MLP}(v_i, v_j, e_{i,j})$.

1. **Message aggregation.** In this stage, each node of the graph aggregates all the messages that it received to a fixed-sized representation. In the IN, aggregation means summing all the messages up, i.e., $\mathrm{Agg}_i=\sum_{(j,i)\in G}\mathrm{Msg}_{i,j}$.

1. **Update.** Finally, we update features of nodes and edges with the results of previous stages. For each edge, its new feature is simply the sum of its old feature and the correspond message, i.e., $e'_{i,j}=e_{i,j}+\mathrm{Msg}_{i,j}$. For each node, the new feature is determined by its old feature and the aggregated message, i.e., $v'_i=v_i+\mathrm{MLP}(v_i, \mathrm{Agg}_i)$.

In PyG, GNN layers are implemented as subclass of `MessagePassing`. We need to override three critical functions to implement our `InteractionNetwork` GNN layer. Each function corresponds to one stage of the GNN layer.

1. `message()` -> message generation

  This function controls how a message is generated on each edge of the graph. It takes three arguments: (1) `x_i`, features of the source nodes; (2) `x_j`, features of the target nodes; and (3) `edge_feature`, features of the edges themselves. In the IN, we simply concatenate all these features and generate the messages with an MLP.

1. `aggregate()` -> message aggregation

  This function aggregates messages for nodes. It depends on two arguments: (1) `inputs`, messages; and (2) `index`, the graph structure. We handle over the task of message aggregation to the function `torch_scatter.scatter` and specifies in the argument `reduce` that we want to sum messages up. Because we want to retain messages themselves to update edge features, we return both messages and aggregated messages.

1. `forward()` -> update

  This function puts everything together. `x` is the node features, `edge_index` is the graph structure and `edge_feature` is edge features. The function`MessagePassing.propagate` invokes functions `message` and `aggregate` for us. Then, we update node features and edge features and return them.

In [None]:
class InteractionNetwork(pyg.nn.MessagePassing):
    """Interaction Network as proposed in this paper:
    https://proceedings.neurips.cc/paper/2016/hash/3147da8ab4a0437c15ef51a5cc7f2dc4-Abstract.html"""
    def __init__(self, hidden_size, layers):
        super().__init__()
        self.lin_edge = MLP(hidden_size * 3, hidden_size, hidden_size, layers)
        self.lin_node = MLP(hidden_size * 2, hidden_size, hidden_size, layers)

    def forward(self, x, edge_index, edge_feature):
        edge_out, aggr = self.propagate(edge_index, x=(x, x), edge_feature=edge_feature)
        node_out = self.lin_node(torch.cat((x, aggr), dim=-1))
        edge_out = edge_feature + edge_out
        node_out = x + node_out
        return node_out, edge_out

    def message(self, x_i, x_j, edge_feature):
        x = torch.cat((x_i, x_j, edge_feature), dim=-1)
        x = self.lin_edge(x)
        return x

    def aggregate(self, inputs, index, dim_size=None):
        out = torch_scatter.scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce="sum")
        return (inputs, out)

### The GNN

Now its time to stack GNN layers to a GNN. Besides GNN layers, there are pre-processing and post-processing blocks in the GNN. Before GNN layers, input features are transformed by MLP so that the expressiveness of GNN is improved without increasing GNN layers. After GNN layers, final outputs (accelerations of particles in our case) are extracted from features generated by GNN layers to meet the requirement of the task.

In [None]:
class LearnedSimulator(torch.nn.Module):
    """Graph Network-based Simulators(GNS)"""
    def __init__(
        self,
        hidden_size=128,
        n_mp_layers=10, # number of GNN layers
        num_particle_types=9,
        particle_type_dim=16, # embedding dimension of particle types
        dim=2, # dimension of the world, typical 2D or 3D
        window_size=5, # the model looks into W frames before the frame to be predicted
    ):
        super().__init__()
        self.window_size = window_size
        self.embed_type = torch.nn.Embedding(num_particle_types, particle_type_dim)
        self.node_in = MLP(particle_type_dim + dim * (window_size + 2), hidden_size, hidden_size, 3)
        self.edge_in = MLP(dim + 1, hidden_size, hidden_size, 3)
        self.node_out = MLP(hidden_size, hidden_size, dim, 3, layernorm=False)
        self.n_mp_layers = n_mp_layers
        self.layers = torch.nn.ModuleList([InteractionNetwork(
            hidden_size, 3
        ) for _ in range(n_mp_layers)])

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.embed_type.weight)

    def forward(self, data):
        # pre-processing
        # node feature: combine categorial feature data.x and contiguous feature data.pos.
        node_feature = torch.cat((self.embed_type(data.x), data.pos), dim=-1)
        node_feature = self.node_in(node_feature)
        edge_feature = self.edge_in(data.edge_attr)
        # stack of GNN layers
        for i in range(self.n_mp_layers):
            node_feature, edge_feature = self.layers[i](node_feature, data.edge_index, edge_feature=edge_feature)
        # post-processing
        out = self.node_out(node_feature)
        return out

## Training

Before we start training the model, let's configure the hyperparameters! Since the accessible computaion power is limited in a notebook, we will only run 1 epoch of training, which takes about 1.5 hour. Consequently, we won't be able to produce as accurate results as shown in the original paper in this notebook. Alternatively, we provide a checkpoint of training the model on the entire WaterDrop dataset for 5 epochs, which takes about 14 hours with a GeForce RTX 3080 Ti.

In [None]:
data_path = OUTPUT_DIR
model_path = os.path.join("temp", "models", DATASET_NAME)
rollout_path = os.path.join("temp", "rollouts", DATASET_NAME)

!mkdir -p "$model_path"
!mkdir -p "$rollout_path"

params = {
    "epoch": 1,
    "batch_size": 4,
    "lr": 1e-4,
    "noise": 3e-4,
    "save_interval": 1000,
    "eval_interval": 1000,
    "rollout_interval": 200000,
}

Below are some helper functions for evaluation.

In [None]:
def rollout(model, data, metadata, noise_std):
    device = next(model.parameters()).device
    model.eval()
    window_size = model.window_size + 1
    total_time = data["position"].size(0)
    traj = data["position"][:window_size]
    traj = traj.permute(1, 0, 2)
    particle_type = data["particle_type"]

    for time in range(total_time - window_size):
        with torch.no_grad():
            graph = preprocess(particle_type, traj[:, -window_size:], None, metadata, 0.0)
            graph = graph.to(device)
            acceleration = model(graph).cpu()
            acceleration = acceleration * torch.sqrt(torch.tensor(metadata["acc_std"]) ** 2 + noise_std ** 2) + torch.tensor(metadata["acc_mean"])

            recent_position = traj[:, -1]
            recent_velocity = recent_position - traj[:, -2]
            new_velocity = recent_velocity + acceleration
            new_position = recent_position + new_velocity
            traj = torch.cat((traj, new_position.unsqueeze(1)), dim=1)

    return traj


def oneStepMSE(simulator, dataloader, metadata, noise):
    """Returns two values, loss and MSE"""
    total_loss = 0.0
    total_mse = 0.0
    batch_count = 0
    simulator.eval()
    with torch.no_grad():
        scale = torch.sqrt(torch.tensor(metadata["acc_std"]) ** 2 + noise ** 2).cuda()
        for data in valid_loader:
            data = data.cuda()
            pred = simulator(data)
            mse = ((pred - data.y) * scale) ** 2
            mse = mse.sum(dim=-1).mean()
            loss = ((pred - data.y) ** 2).mean()
            total_mse += mse.item()
            total_loss += loss.item()
            batch_count += 1
    return total_loss / batch_count, total_mse / batch_count


def rolloutMSE(simulator, dataset, noise):
    total_loss = 0.0
    batch_count = 0
    simulator.eval()
    with torch.no_grad():
        for rollout_data in dataset:
            rollout_out = rollout(simulator, rollout_data, dataset.metadata, noise)
            rollout_out = rollout_out.permute(1, 0, 2)
            loss = (rollout_out - rollout_data["position"]) ** 2
            loss = loss.sum(dim=-1).mean()
            total_loss += loss.item()
            batch_count += 1
    return total_loss / batch_count

Here is the main training loop!

In [None]:
from tqdm import tqdm

def train(params, simulator, train_loader, valid_loader, valid_rollout_dataset):
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(simulator.parameters(), lr=params["lr"])
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1 ** (1 / 5e6))

    # recording loss curve
    train_loss_list = []
    eval_loss_list = []
    onestep_mse_list = []
    rollout_mse_list = []
    total_step = 0

    for i in range(params["epoch"]):
        simulator.train()
        progress_bar = tqdm(train_loader, desc=f"Epoch {i}")
        total_loss = 0
        batch_count = 0
        for data in progress_bar:
            optimizer.zero_grad()
            data = data.cuda()
            pred = simulator(data)
            loss = loss_fn(pred, data.y)
            loss.backward()
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()
            batch_count += 1
            progress_bar.set_postfix({"loss": loss.item(), "avg_loss": total_loss / batch_count, "lr": optimizer.param_groups[0]["lr"]})
            total_step += 1
            train_loss_list.append((total_step, loss.item()))

            # evaluation
            if total_step % params["eval_interval"] == 0:
                simulator.eval()
                eval_loss, onestep_mse = oneStepMSE(simulator, valid_loader, valid_dataset.metadata, params["noise"])
                eval_loss_list.append((total_step, eval_loss))
                onestep_mse_list.append((total_step, onestep_mse))
                tqdm.write(f"\nEval: Loss: {eval_loss}, One Step MSE: {onestep_mse}")
                simulator.train()

            # do rollout on valid set
            if total_step % params["rollout_interval"] == 0:
                simulator.eval()
                rollout_mse = rolloutMSE(simulator, valid_rollout_dataset, params["noise"])
                rollout_mse_list.append((total_step, rollout_mse))
                tqdm.write(f"\nEval: Rollout MSE: {rollout_mse}")
                simulator.train()

            # save model
            if total_step % params["save_interval"] == 0:
                torch.save(
                    {
                        "model": simulator.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "scheduler": scheduler.state_dict(),
                    },
                    os.path.join(model_path, f"checkpoint_{total_step}.pt")
                )
    return train_loss_list, eval_loss_list, onestep_mse_list, rollout_mse_list

Finally, let's load the dataset and train the model! It takes roughly 1.5 hour to run this block of the notebook with the default parameters. **If you are impatient, we highly recommend you to skip the next 2 blocks and load the checkpoint we provided to save some time; otherwise, make a cup of tea/coffee and come back later to see the results of training!**

In [None]:
# Training the model is time-consuming. We highly recommend you to skip this block and load the checkpoint in the next block.

# load dataset
train_dataset = OneStepDataset(data_path, "train", noise_std=params["noise"])
valid_dataset = OneStepDataset(data_path, "valid", noise_std=params["noise"])
train_loader = pyg.loader.DataLoader(train_dataset, batch_size=params["batch_size"], shuffle=True, pin_memory=True, num_workers=2)
valid_loader = pyg.loader.DataLoader(valid_dataset, batch_size=params["batch_size"], shuffle=False, pin_memory=True, num_workers=2)
valid_rollout_dataset = RolloutDataset(data_path, "valid")

# build model
simulator = LearnedSimulator()
simulator = simulator.cuda()

# train the model
train_loss_list, eval_loss_list, onestep_mse_list, rollout_mse_list = train(params, simulator, train_loader, valid_loader, valid_rollout_dataset)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

# visualize the loss curve
plt.figure()
plt.plot(*zip(*train_loss_list), label="train")
plt.plot(*zip(*eval_loss_list), label="valid")
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.title('Loss')
plt.legend()
plt.show()

In [None]:
!mkdir -p temp/models/WaterDropSample

Load the checkpoint trained by us. Do **not** run this block if you have trained your model in the previous block.

In [None]:
simulator = LearnedSimulator()
simulator = simulator.cuda()

checkpoint = torch.load("WaterDropSample/checkpoint_100000.pt")
simulator.load_state_dict(checkpoint["model"])

## Visualization

Since the video is 1000 frames long, it might take a few minutes to rollout.

In [None]:
rollout_dataset = RolloutDataset(data_path, "valid")
simulator.eval()
rollout_data = rollout_dataset[0]
rollout_out = rollout(simulator, rollout_data, rollout_dataset.metadata, params["noise"])
rollout_out = rollout_out.permute(1, 0, 2)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML

TYPE_TO_COLOR = {
    3: "black",
    0: "green",
    7: "magenta",
    6: "gold",
    5: "blue",
}


def visualize_prepare(ax, particle_type, position, metadata):
    bounds = metadata["bounds"]
    ax.set_xlim(bounds[0][0], bounds[0][1])
    ax.set_ylim(bounds[1][0], bounds[1][1])
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect(1.0)
    points = {type_: ax.plot([], [], "o", ms=2, color=color)[0] for type_, color in TYPE_TO_COLOR.items()}
    return ax, position, points


def visualize_pair(particle_type, position_pred, position_gt, metadata):
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    plot_info = [
        visualize_prepare(axes[0], particle_type, position_gt, metadata),
        visualize_prepare(axes[1], particle_type, position_pred, metadata),
    ]
    axes[0].set_title("Ground truth")
    axes[1].set_title("Prediction")

    plt.close()

    def update(step_i):
        outputs = []
        for _, position, points in plot_info:
            for type_, line in points.items():
                mask = particle_type == type_
                line.set_data(position[step_i, mask, 0], position[step_i, mask, 1])
            outputs.append(line)
        return outputs

    return animation.FuncAnimation(fig, update, frames=np.arange(0, position_gt.size(0)), interval=10, blit=True)

anim = visualize_pair(rollout_data["particle_type"], rollout_out, rollout_data["position"], rollout_dataset.metadata)
HTML(anim.to_jshtml())