In [1]:
%load_ext autoreload
%autoreload 2

from pathlib import Path

import numpy as np

from tqdm.notebook import tqdm
import logging
import pickle

from itertools import count
from graphorge.gnn_base_model.data.graph_data import GraphData
from graphorge.gnn_base_model.data.graph_dataset import GNNGraphDataset
from graphorge.gnn_base_model.train.training import train_model

import torch





from utils import *

logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Step 1: Raw data

The dataset is available as specified in [Deepmind's repository](https://github.com/google-deepmind/deepmind-research/tree/master/meshgraphnets) and was reported in [^1]. The dataset describes the turbulent flow of water around a cylinder obstacle. Each sample is simulated using COMSOL with irregular triangular 2D meshes over 600 time steps with a time step size of 0.01 seconds.

[^1]: Learning Mesh-Based Simulation with Graph Networks, https://doi.org/10.48550/arXiv.2010.03409

In [None]:
base_url = "https://storage.googleapis.com/dm-meshgraphnets/cylinder_flow"

download_file(file="meta.json", base_url=base_url, dest_path="data")
download_file(file="valid.tfrecord", base_url=base_url, dest_path="data")

In [3]:
parse_tensorflow_dataset(dataset_name="valid", dataset_directory="data")

2025-04-07 19:41:13.495567: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-07 19:41:13.549141: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-04-07 19:41:15.557545: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


FileNotFoundError: Dataset 'valid' not found in /home/guillaume/Documents/code/graphorge_fork/benchmarks/cfd/data.

# Step 2: Graph dataset

In [None]:
dataset_directories = dict(
    train="1_training_dataset",
    valid="2_validation_dataset",
    test="5_testing_id_dataset",
)

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Iterate over the datasets
for dataset_name, dataset_directory in dataset_directories.items():

    # Prepare a directory to store the dataset
    dataset_directory = Path(dataset_directories[dataset_name])
    dataset_directory.mkdir(parents=True, exist_ok=True)

    # Initialize a list to store the graph file paths
    graph_file_paths = []

    # Locate the raw data downloaded from Google's server and parsed
    raw_data_directory = Path("data") / dataset_name

    # Search the raw data directory for sample files
    sample_paths = list(
        raw_data_directory.glob(f"{dataset_name}_sample_*.pkl")
    )

    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    # Iterate over the sample files
    for sample_path in tqdm(sample_paths, desc="Generating graphs: "):

        # Extract the sample id from the file name
        sample_id = int(sample_path.stem.split("_")[-1])

        # Load the sample
        with open(sample_path, "rb") as file:
            sample = pickle.load(file)

        # Initialize a list to store edge indexes
        edge_indexes = []

        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # Iterate over the cells defined un the sample
        for cell in sample["cells"]:

            # Generate pairs of vertex indexes which define the edges
            for i in range(len(cell) - 1):
                edge_indexes.append((cell[i], cell[i + 1]))

            # Add the pair of vertices closing the cell
            edge_indexes.append(sorted((cell[-1], cell[0])))

        # Cast as a numpy array as required by the GraphData class
        edge_indexes = np.asarray(edge_indexes, dtype=int)

        # Remove duplicates and ensure edges are undirected
        edge_indexes = GraphData.get_undirected_unique_edges(
            edges_indexes=edge_indexes
        )

        # Since the mesh is fixed, we can extract geometry features
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # Get the node type values. We follow the original meshgraphnet
        # implementation and encode the node types as one-hot vectors.
        # Node types take 0, 4, 5 and 6 values. They should be postprocessed to
        # take 0, 1, 2 and 3 values for compatibility with the one_hot encoding
        node_types = np.max(
            (sample["node_type"] - 3, np.zeros_like(sample["node_type"])),
            axis=0,
        )

        # Encode the node types as one-hot vectors
        node_type_one_hot = torch.nn.functional.one_hot(
            torch.Tensor(node_types.reshape(-1)).long(), num_classes=4
        )

        # Calculate the edge length vectors
        distance_vector = (
            sample["mesh_pos"][edge_indexes[:, 0]]
            - sample["mesh_pos"][edge_indexes[:, 1]]
        )

        # Calculate the edge length (euclidian) norms
        distance_norm = np.linalg.norm(distance_vector, axis=1, keepdims=True)

        # Prepare the edge features
        edge_features = np.hstack((distance_vector, distance_norm))

        # Save the edge feature names and shapes for the metadata
        edge_feature_names = ("ditance", "distance_norm")
        edge_feature_shapes = ((2,), (1,))

        # Initialize a time step counter
        time_step_counter = count(1)

        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # Iterate over the time steps. The time steps are not explicitly stored
        # in the sample, instead the first data dimension is the time.
        # The meshgraphnet implementation predict a velocity update based on
        # the velocity at the previous time step. Hence, a prediction cannot be
        # made for the first time step. The total number of graphs is the
        # number of time steps minus one.
        for initial_velocity, updated_velocity, updated_pressure in zip(
            sample["velocity"][:-1],
            sample["velocity"][1:],
            sample["pressure"][1:],
        ):

            # Initialize the graph data
            graph_data = GraphData(n_dim=2, nodes_coords=sample["mesh_pos"])

            # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            # Prepare the node features
            node_features = np.hstack((initial_velocity, node_type_one_hot))

            # Save the node feature names and shapes for the metadata
            node_feature_names = ("velocity", "node_type_one_hot")
            node_feature_shapes = ((2,), (4,))

            # Set node features matrix
            graph_data.set_node_features_matrix(
                node_features_matrix=node_features
            )

            # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            # Set graph edges, uniqueness has already been checked
            graph_data.set_graph_edges_indexes(
                edges_indexes_mesh=edge_indexes, is_unique=False
            )

            # Set edge features
            graph_data.set_edge_features_matrix(
                edge_features_matrix=edge_features
            )

            # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            # Compute the velocity update
            velocity_update = updated_velocity - initial_velocity

            # Prepare the node targets
            node_targets = np.hstack((velocity_update, updated_pressure))

            # Save the node target names and shapes for the metadata
            node_target_names = ("velocity_update", "pressure")
            node_target_shapes = ((2,), (1,))

            # Set node targets matrix
            graph_data.set_node_targets_matrix(
                node_targets_matrix=node_targets
            )

            # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            # Get the time step id
            time_step_id = next(time_step_counter)

            # Prepare the graph metadata
            metadata_dict = dict(
                dataset_name=dataset_name,
                sample_id=int(sample_path.stem.split("_")[-1]),
                time_step_id=time_step_id,
                edge_features=edge_feature_names,
                edge_features_shapes=edge_feature_shapes,
                node_features=node_feature_names,
                node_features_shapes=node_feature_shapes,
                node_targets=node_target_names,
                node_targets_shapes=node_target_shapes,
            )

            # Set the graph metadata
            graph_data.set_metadata(metadata=metadata_dict)

            # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            # Extract the pyg graph data object
            pyg_graph = graph_data.get_torch_data_object()

            # Cast edge indexes to int16, save memory
            pyg_graph.edge_index = pyg_graph.edge_index.to(dtype=torch.int16)

            # Generate the graph file path
            graph_file_path = (
                dataset_directory
                / f"{dataset_name}_graph_{sample_id}_{time_step_id}.pt"
            )

            # Save the graph to file
            torch.save(pyg_graph, graph_file_path)

            # Append the graph file path to the list of graph file paths
            graph_file_paths.append(graph_file_path)

    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    # Generate the GNN-based data set, use `is_store_dataset=False` as the
    # graphs are already stored as individual files
    dataset = GNNGraphDataset(
        dataset_directory=dataset_directory,
        dataset_sample_files=graph_file_paths,
        dataset_basename=f"meshgraphnet_{dataset_name}",
        is_store_dataset=False,
    )

    # Save the dataset to file
    _ = dataset.save_dataset(is_append_n_sample=True)

In [None]:
graph = torch.load("2_validation_dataset/valid_graph_0_599.pt")
mesh = graph_to_pyvista_mesh(graph)
mesh.plot(
    jupyter_backend="html",
    show_edges=True,
    cpos="xy",
    preference="point",
    scalars="velocity",
)

EmbeddableWidget(value='<iframe srcdoc="<!DOCTYPE html>\n<html>\n  <head>\n    <meta http-equiv=&quot;Content-…

In [None]:
train_dataset_directory = Path("1_training_dataset")
train_dataset_file_path = list(
    Path("1_training_dataset").glob("meshgraphnet_train_*.pkl")
)[0]

validation_dataset_directory = Path("2_validation_dataset")
validation_dataset_file_path = list(
    Path("2_validation_dataset").glob("meshgraphnet_valid_*.pkl")
)[0]

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Load training datasets
train_dataset_file_path = ""
train_dataset = GNNGraphDataset.load_dataset(train_dataset_file_path)
validation_dataset = GNNGraphDataset.load_dataset(validation_dataset_file_path)


# Set the GNN model parameters
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
gnn_architecture_parameters = model_init_args = dict(
    # Set number of node input and output features
    n_node_in=5,
    n_node_out=0,
    n_time_node=0,
    # Set number of edge input and output features
    n_edge_in=0,
    n_edge_out=0,
    n_time_edge=0,
    # Set number of global input and output features
    n_global_in=0,
    n_global_out=1,
    n_time_global=0,
    # Set number of message-passing steps (number of processor layers)
    n_message_steps=2,
    # Set number of FNN/RNN hidden layers
    enc_n_hidden_layers=3,
    pro_n_hidden_layers=5,
    dec_n_hidden_layers=3,
    # Set hidden layer size
    hidden_layer_size=128,
    # Set model directory
    model_directory="3_model",
    model_name="meshgraphnet",
    # Set model input and output features normalization
    is_model_in_normalized=True,
    is_model_out_normalized=True,
    # Set aggregation schemes
    pro_edge_to_node_aggr="add",
    pro_node_to_global_aggr="mean",
    # Set activation functions
    enc_node_hidden_activ_type="tanh",
    enc_node_output_activ_type="identity",
    enc_edge_hidden_activ_type="tanh",
    enc_edge_output_activ_type="identity",
    pro_node_hidden_activ_type="tanh",
    pro_node_output_activ_type="identity",
    pro_edge_hidden_activ_type="tanh",
    pro_edge_output_activ_type="identity",
    dec_node_hidden_activ_type="tanh",
    dec_node_output_activ_type="identity",
    # Set device
    device_type="cuda" if torch.cuda.is_available() else "cpu",
)

training_parameters = dict(
    # Set number of epochs
    n_max_epochs=250,
    # Set batch size
    batch_size=16,
    # Set optimizer
    opt_algorithm="adam",
    # Set learning rate
    lr_init=1.0e-03,
    # Set learning rate scheduler
    lr_scheduler_type=None,
    lr_scheduler_kwargs=None,
    # Set loss function
    loss_nature="global_features_out",
    loss_type="mse",
    loss_kwargs=dict(),
    # Set data shuffling
    is_sampler_shuffle=False,
    # Set early stopping
    is_early_stopping=True,
    # Set early stopping parameters
    early_stopping_kwargs=dict(
        validation_dataset=validation_dataset,
        validation_frequency=1,
        trigger_tolerance=20,
        improvement_tolerance=1e-2,
    ),
    # Set seed
    seed=42,
)

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Compute exponential decay (learning rate scheduler)
lr_end = 1.0e-5
gamma = (lr_end / training_parameters["lr_init"]) ** (
    1 / training_parameters["n_max_epochs"]
)
# Set learning rate scheduler
training_parameters["lr_scheduler_type"] = "explr"
training_parameters["lr_scheduler_kwargs"] = dict(gamma=gamma)


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Set model state loading
load_model_state = None
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Training of GNN-based model
model, _, _ = train_model(dataset=train_dataset, **training_parameters)