In [1]:
%load_ext autoreload
%autoreload 2

In [2]:

import sys
import os
base_path = '../../'
sys.path.append(base_path)
from src.dgl_flow_field_dataset import DGLSurfaceFlowFieldDataset
from src.pyvista_flow_field_dataset import PyvistaFlowFieldDataset
from modulus.models.meshgraphnet import MeshGraphNet

In [3]:
ds_pv = PyvistaFlowFieldDataset(os.path.join(base_path,'datasets/Example cgns volume + surface'))
ds_dgl = DGLSurfaceFlowFieldDataset(os.path.join(base_path,'datasets/dgl_surface'),ds_pv)

In [4]:
ds_dgl[0]

Graph(num_nodes=21293, num_edges=83778,
      ndata_schemes={'BodyID': Scheme(shape=(), dtype=torch.int32), 'CellArea': Scheme(shape=(), dtype=torch.float32), 'Normal': Scheme(shape=(3,), dtype=torch.float32), 'ShearStress': Scheme(shape=(3,), dtype=torch.float32), 'Position': Scheme(shape=(3,), dtype=torch.float32), 'Temperature': Scheme(shape=(), dtype=torch.float32), 'Pressure': Scheme(shape=(), dtype=torch.float32)}
      edata_schemes={'dx': Scheme(shape=(3,), dtype=torch.float32)})

In [5]:
from dgl.dataloading import GraphDataLoader
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataloader = GraphDataLoader(ds_dgl, batch_size=1, shuffle=True)
for batch in dataloader:
    print('New batch')
    print(batch)
    

New batch
Graph(num_nodes=22337, num_edges=87616,
      ndata_schemes={'BodyID': Scheme(shape=(), dtype=torch.int32), 'CellArea': Scheme(shape=(), dtype=torch.float32), 'Normal': Scheme(shape=(3,), dtype=torch.float32), 'ShearStress': Scheme(shape=(3,), dtype=torch.float32), 'Position': Scheme(shape=(3,), dtype=torch.float32), 'Temperature': Scheme(shape=(), dtype=torch.float32), 'Pressure': Scheme(shape=(), dtype=torch.float32)}
      edata_schemes={'dx': Scheme(shape=(3,), dtype=torch.float32)})
New batch
Graph(num_nodes=22292, num_edges=87534,
      ndata_schemes={'BodyID': Scheme(shape=(), dtype=torch.int32), 'CellArea': Scheme(shape=(), dtype=torch.float32), 'Normal': Scheme(shape=(3,), dtype=torch.float32), 'ShearStress': Scheme(shape=(3,), dtype=torch.float32), 'Position': Scheme(shape=(3,), dtype=torch.float32), 'Temperature': Scheme(shape=(), dtype=torch.float32), 'Pressure': Scheme(shape=(), dtype=torch.float32)}
      edata_schemes={'dx': Scheme(shape=(3,), dtype=torch.float

In [21]:
import dgl


def get_node_edge_X(graph: dgl.DGLGraph):
    node_X = torch.cat([graph.ndata["Position"], graph.ndata["Normal"]],dim=1)
    edge_X = torch.cat([graph.edata["dx"]],dim=1)
    return node_X, edge_X

def get_node_Y(graph: dgl.DGLGraph):
    return torch.cat([graph.ndata["Pressure"].unsqueeze(1),graph.ndata["Temperature"].unsqueeze(1),graph.ndata['ShearStress']],dim=1)

def set_graph_features(graph: dgl.DGLGraph, node_X, edge_X, node_Y):
    graph.ndata["Position"] = node_X[:,:3]
    graph.ndata["Normal"] = node_X[:,3:]
    graph.edata["dx"] = edge_X
    graph.ndata["Pressure"] = node_Y[:,0]
    graph.ndata["Temperature"] = node_Y[:,1]
    graph.ndata["ShearStress"] = node_Y[:,2:]
g=ds_dgl[0]
g_cp=g.clone()
ndx, edx = get_node_edge_X(g)
ndy = get_node_Y(g)
set_graph_features(g_cp, ndx, edx, ndy)
assert torch.allclose(g_cp.ndata["Position"], g.ndata["Position"])
assert torch.allclose(g_cp.ndata["Normal"], g.ndata["Normal"])
assert torch.allclose(g_cp.edata["dx"], g.edata["dx"])
assert torch.allclose(g_cp.ndata["Pressure"], g.ndata["Pressure"])
assert torch.allclose(g_cp.ndata["Temperature"], g.ndata["Temperature"])
assert torch.allclose(g_cp.ndata["ShearStress"], g.ndata["ShearStress"])
num_node_features = ndx.shape[1]
num_edge_features = edx.shape[1]
num_node_labels = ndy.shape[1]
print("Node X: ",ndx, ndx.shape)
print("Edge X: ",edx, edx.shape)
print("Node Y: ",ndy, ndy.shape)

Node X:  tensor([[ 1.2582e+00, -1.7498e+00, -1.0662e+00, -2.0974e-03, -3.8745e-03,
         -1.0238e+00],
        [ 1.2585e+00, -1.6781e+00, -1.0662e+00, -2.0974e-03, -3.8745e-03,
         -1.0238e+00],
        [ 1.2408e+00, -1.7498e+00, -1.0662e+00, -2.0974e-03, -3.8745e-03,
         -1.0238e+00],
        ...,
        [ 5.1143e-01, -3.1371e-02, -2.3643e-01, -2.0974e-03,  4.0907e+00,
          1.5715e-02],
        [ 4.9444e-01, -1.3117e-02, -6.9371e-03, -2.0974e-03, -3.8745e-03,
         -1.0238e+00],
        [ 4.9444e-01, -3.1371e-02, -6.8725e-02, -2.0974e-03,  4.0907e+00,
          1.5715e-02]]) torch.Size([21293, 6])
Edge X:  tensor([[ 1.9371e-02,  1.4985e+00, -5.3246e-10],
        [-1.4207e+00,  6.2623e-11, -5.3246e-10],
        [ 1.4207e+00,  6.2623e-11, -5.3246e-10],
        ...,
        [ 1.3838e+00,  6.2623e-11, -5.3246e-10],
        [ 6.4081e-10,  6.2623e-11, -4.8453e+00],
        [-1.3838e+00,  6.2623e-11, -5.3246e-10]]) torch.Size([83778, 3])
Node Y:  tensor([[-0.5311,  0.42

In [7]:
model = MeshGraphNet(
    input_dim_nodes=num_node_features,
    input_dim_edges=num_edge_features,
    output_dim=num_node_labels,
    aggregation='sum',
    hidden_dim_edge_encoder=64,
    hidden_dim_node_encoder=64,
    hidden_dim_processor=64,
    hidden_dim_node_decoder=64
)
model
model
model(ndx,edx,ds_dgl[0])

tensor([[ 0.1367,  0.2204, -0.3147, -0.1320,  0.0310],
        [ 0.0825,  0.2339, -0.3496, -0.0940,  0.0605],
        [ 0.1362,  0.2204, -0.3151, -0.1312,  0.0303],
        ...,
        [-0.2571, -0.0057, -0.0835, -0.0238,  0.0883],
        [-0.0122,  0.1083, -0.1325, -0.0696, -0.1092],
        [-0.2686, -0.0280, -0.1199, -0.0435,  0.0595]],
       grad_fn=<AddmmBackward0>)

In [8]:
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.amp import GradScaler
optimizer = Adam(model.parameters(), lr=1e-4)
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 0.99985 ** epoch)
scaler = GradScaler()

In [None]:

from modulus.launch.utils import save_checkpoint, load_checkpoint
checkpoint_path = 'checkpoints'
os.makedirs(checkpoint_path,exist_ok=True)
epoch_init = load_checkpoint(checkpoint_path,model,optimizer,scheduler,scaler,device=device)
num_epochs = 200
model.to(device)
for epoch in range(epoch_init,num_epochs):
    total_loss = 0
    for batch in dataloader:
        batch = batch.to(device)
        node_X, edge_X = get_node_edge_X(batch)
        node_Y = get_node_Y(batch)
        node_Y_pred = model(node_X,edge_X,batch)
        loss = torch.nn.functional.mse_loss(node_Y_pred,node_Y)
        total_loss += loss.item()
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    print(f'Epoch {epoch} loss: {total_loss}')
    scheduler.step()
    if epoch % 10 == 0:
        save_checkpoint(checkpoint_path,model,optimizer,scheduler,scaler,epoch)
save_checkpoint(checkpoint_path,model,optimizer,scheduler,scaler,epoch)

Epoch 99 loss: 1.3939292132854462
Epoch 100 loss: 1.3848245590925217
Epoch 101 loss: 1.365472212433815
Epoch 102 loss: 1.3364130407571793
Epoch 103 loss: 1.299078643321991
Epoch 104 loss: 1.2961855828762054
Epoch 105 loss: 1.2715221047401428
Epoch 106 loss: 1.2447074800729752
Epoch 107 loss: 1.2232197970151901
Epoch 108 loss: 1.231067880988121
Epoch 109 loss: 1.2336003929376602
Epoch 110 loss: 1.2381902188062668
Epoch 111 loss: 1.2243427485227585
Epoch 112 loss: 1.220887079834938
Epoch 113 loss: 1.2255081236362457
Epoch 114 loss: 1.1968663036823273
Epoch 115 loss: 1.1802183240652084
Epoch 116 loss: 1.1788102835416794
Epoch 117 loss: 1.1712919622659683
Epoch 118 loss: 1.1729488223791122
Epoch 119 loss: 1.163592666387558
Epoch 120 loss: 1.1399722695350647
Epoch 121 loss: 1.14089697599411
Epoch 122 loss: 1.1419382244348526
Epoch 123 loss: 1.2568971514701843
Epoch 124 loss: 1.3026102036237717
Epoch 125 loss: 1.2419874966144562
Epoch 126 loss: 1.2218234837055206
Epoch 127 loss: 1.2293851077

In [30]:
device = "cpu"
model.to(device)
g=ds_dgl[0].to(device)
g_pred = g.clone().to(device)
ndx, edx = get_node_edge_X(g)
y_pred = model(ndx,edx,g)
set_graph_features(g_pred, ndx, edx, y_pred)
ds_dgl.plot_surface(g_pred,"Pressure")

Widget(value='<iframe src="http://localhost:40415/index.html?ui=P_0x7fd8c02b7290_0&reconnect=auto" class="pyvi…

In [32]:
ds_dgl.plot_surface(g,"Pressure")

Widget(value='<iframe src="http://localhost:40415/index.html?ui=P_0x7fd8a42de4d0_2&reconnect=auto" class="pyvi…

In [28]:
g.ndata.detach()

AttributeError: 'HeteroNodeDataView' object has no attribute 'detach'