In [2]:
# Load Dataset

import torch
dataset = torch.load("../torchfem_dataset/simple_beam/combined.pt",weights_only=False)
#dataset_2 = torch.load("../torchfem_dataset/tube_2/tube_combined.pt",weights_only=False)

In [3]:
## Process dataset for training
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import HGTConv, Linear
from torch_geometric.loader import DataLoader
from torch_geometric.data import HeteroData

def build_laststep_features(data, dtype=torch.float32):
    ## Collapse to peak load timestep
    #print(data.keys)
    len = data['nodes'].f_ext.shape[0]
    t = int(len/2-1)
    # Nodes: x = [pos, bc, f_ext[-1], f_int[-1]]  (no leakage of u_ts into x)
    #pos   = data['nodes'].pos.to(dtype) #no need to include
    bc    = data['nodes'].bc.to(dtype)
    f_ext = torch.Tensor(data['nodes'].f_ext[t]).to(dtype)
    data['nodes'].fext = torch.Tensor(data['nodes'].f_ext[t]).to(dtype)
    data['nodes'].x = torch.cat([bc, f_ext], dim=-1)

    # Elements: x = [material (float), s_ts[-1].flatten(9)]
    mat = data['elements'].material
    if not torch.is_floating_point(mat):
        mat = mat.float()
    data['elements'].x = mat#torch.cat([mat], dim=-1)

    # Target: nodes â†’ u_ts[-1] (3D)
    data['nodes'].y_u = data['nodes'].u_ts[t].to(dtype)
    data['nodes'].y_fint = data['nodes'].f_ts[t].to(dtype)
    
    # Target: elements
    data['elements'].y_s = data['elements'].s_ts[t].to(dtype).reshape(-1, 9)
    #data['elements'].y_d = data['elements'].d_ts[t].to(dtype)

    # (Optional) free large tensors you won't use further to save RAM/VRAM
    del data['nodes'].pos, data['nodes'].bc, data['nodes'].u_ts, data['nodes'].f_ext, data['nodes'].f_int
    del data['elements'].s_ts, data['elements'].d_ts, data['elements'].material, data['nodes'].f_ts

    #print(data.keys)

    return data

class HeteroStandardScaler:
    def __init__(self):
        self.node_stats = {}
        self.edge_stats = {}

    def fit(self, dataset):
        node_x = torch.cat([d['nodes'].x[:,3:].float() for d in dataset],dim=0)
        node_f = torch.cat([d['nodes'].y_fint.float() for d in dataset],dim=0)
        node_u = torch.cat([d['nodes'].y_u.float() for d in dataset],dim=0)
        elem_s = torch.cat([d['elements'].y_s.float() for d in dataset],dim=0)
        edge_acc = {}
        # compute mean/std
        for data in dataset:

            for etype in data.edge_types:
                if "edge_attr" in data[etype]:
                    e = data[etype].edge_attr.float()
                    self.edge_stats[etype] = {
                        "mean": e.mean(dim=0, keepdim=True),
                        "std":  e.std(dim=0, keepdim=True) + 1e-8}
        
        self.node_stats['nodes_x'] = {
            "mean": node_x.mean(dim=0, keepdim=True),
            "std":  node_x.std(dim=0, keepdim=True) + 1e-8}
        self.node_stats['nodes_f'] = {
            "mean": node_f.mean(dim=0, keepdim=True),
            "std":  node_f.std(dim=0, keepdim=True) + 1e-8}
        self.node_stats['nodes_u'] = {
            "mean": node_u.mean(dim=0, keepdim=True),
            "std":  node_u.std(dim=0, keepdim=True) + 1e-8}
        self.node_stats['elem_s'] = {
            "mean": elem_s.mean(dim=0, keepdim=True),
            "std":  elem_s.std(dim=0, keepdim=True) + 1e-8}
        for etype, mats in edge_acc.items():
            E = torch.cat(mats, dim=0)
            self.edge_stats[etype] = {
                "mean": E.mean(dim=0, keepdim=True),
                "std":  E.std(dim=0, keepdim=True) + 1e-8}


    def transform(self, data: HeteroData):
        # apply normalization
        x = data['nodes'].x[:,3:].float()
        m_x = self.node_stats['nodes_x']["mean"]
        s_x = self.node_stats['nodes_x']["std"]
        data['nodes'].x[:,3:] = (x - m_x) / s_x
        y_f = data['nodes'].y_fint.float()
        m_f = self.node_stats['nodes_f']["mean"]
        s_f = self.node_stats['nodes_f']["std"]
        data['nodes'].y_fint = (y_f - m_f) / s_f
        y_u = data['nodes'].y_u.float()
        m_u = self.node_stats['nodes_u']["mean"]
        s_u = self.node_stats['nodes_u']["std"]
        data['nodes'].y_u = (y_u - m_u) / s_u
        y_s = data['elements'].y_s.float()
        m_s = self.node_stats['elem_s']["mean"]
        s_s = self.node_stats['elem_s']["std"]
        data['elements'].y_s = (y_s - m_s) / s_s

        # edges
        for etype in data.edge_types:
            if etype in self.edge_stats and "edge_attr" in data[etype]:
                e = data[etype].edge_attr.float()
                m = self.edge_stats[etype]["mean"]
                s = self.edge_stats[etype]["std"]
                data[etype].edge_attr = (e - m) / s

        return data
    
    def inverse_transform(self, data: HeteroData, fields=("nodes_x","nodes_f","nodes_u","elem_s")):
        # Nodes
        if "nodes_x" in fields:
            m, s = self.node_stats["nodes_x"]["mean"], self.node_stats["nodes_x"]["std"]
            data["nodes"].x[:, 3:] = data["nodes"].x[:, 3:] * s + m
        if "nodes_f" in fields:
            m, s = self.node_stats["nodes_f"]["mean"], self.node_stats["nodes_f"]["std"]
            data["nodes"].y_fint = data["nodes"].y_fint * s + m
        if "nodes_u" in fields:
            m, s = self.node_stats["nodes_u"]["mean"], self.node_stats["nodes_u"]["std"]
            data["nodes"].y_u = data["nodes"].y_u * s + m
        if "elem_s" in fields:
            m, s = self.node_stats["elem_s"]["mean"], self.node_stats["elem_s"]["std"]
            data["elements"].y_s = data["elements"].y_s * s + m

        # Edges
        for etype in data.edge_types:
            if etype in self.edge_stats and "edge_attr" in data[etype]:
                m, s = self.edge_stats[etype]["mean"], self.edge_stats[etype]["std"]
                data[etype].edge_attr = data[etype].edge_attr * s + m
        return data

scaler = HeteroStandardScaler()

## DataLoader with train/val split
dataset_p = [build_laststep_features(d) for d in dataset]
scaler.fit(dataset_p)
dataset_t = [scaler.transform(d) for d in dataset_p]

def split_dataset(dataset, val_ratio=0.1, shuffle=True):
    n = len(dataset)
    idx = torch.randperm(n) if shuffle else torch.arange(n)
    n_val = max(1, int(n * val_ratio))
    val_idx = idx[:n_val].tolist()
    train_idx = idx[n_val:].tolist()
    train_set = [dataset[i] for i in train_idx]
    val_set   = [dataset[i] for i in val_idx]
    return train_set, val_set

# preprocess first
train_set, val_set = split_dataset(dataset_t, val_ratio=0.1)


In [4]:
train_set, val_set = split_dataset(dataset_t, val_ratio=0.1)
print(train_set[10])

HeteroData(
  nodes={
    fext=[37500, 3],
    x=[37500, 6],
    y_u=[37500, 3],
    y_fint=[37500, 3],
  },
  elements={
    num_nodes=24552,
    x=[24552, 1],
    y_s=[24552, 9],
  },
  (elements, contributes, nodes)={
    edge_index=[2, 196416],
    edge_attr=[196416, 4],
  },
  (nodes, belongs_to, elements)={
    edge_index=[2, 196416],
    edge_attr=[196416, 4],
  },
  (nodes, adjacent, nodes)={
    edge_index=[2, 99325],
    edge_attr=[99325, 4],
  },
  (nodes, adjacent_rev, nodes)={
    edge_index=[2, 99325],
    edge_attr=[99325, 4],
  }
)


In [3]:
import os.path as osp
from typing import Dict, List, Union

import torch
import torch.nn.functional as F
from torch import nn

import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv, to_hetero

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1,-1),hidden_channels)
        self.conv2 = SAGEConv((-1,-1),out_channels)

    def forward(self,x,edge_index):
        x = torch.relu(self.conv1(x,edge_index))
        x = self.conv2(x,edge_index)
        return x
    
model = GNN(hidden_channels=64,out_channels=3)
model = to_hetero(model,dataset_t[0].metadata(),aggr='sum')

In [None]:
opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.5, patience=10)

## Losses
ALPHA_FINT = 0.3   # node f_int aux
BETA_S     = 0.0   # element stress aux
#GAMMA_D    = 0.0   # element damage aux (used only if available)
LAMBDA_EQ  = 0.0   # equilibrium regularizer on nodes

def compute_losses(batch, pred):
    # Nodes
    y_u    = batch['nodes'].y_u
    #y_fint = batch['nodes'].y_fint
    #fext   = batch['nodes'].x[:,3:6]
    #bc = batch['nodes'].x[:,:3]
    pu,  = pred['u'] 
    pf = pred['fint']

    L_u    = F.mse_loss(pu, y_u)
    #L_fint = F.mse_loss(pf, y_fint)
    #L_eq   = F.mse_loss(((torch.ones_like(bc)-bc)*(pf - fext)).sum(),torch.zeros((),device=bc.device))


    # Elements
    ps = pred['s']
    ys = batch['elements'].y_s
    L_s = F.mse_loss(ps, ys)

    # Optional damage
    #L_d = torch.tensor(0.0, device=pu.device)
    #if batch['elements'].y_d is not None:
    #    yd = batch['elements'].y_d
    #    pd = pred['d']
    #    L_d = F.mse_loss(pd, yd)

    loss = L_u + ALPHA_FINT * L_fint + LAMBDA_EQ * L_eq + BETA_S * L_s #+ GAMMA_D * L_d
    return loss, {'L_u': L_u.item(), 'L_fint': L_fint.item(), 'L_eq': L_eq.item(),
                  'L_s': L_s.item()}#, 'L_d': L_d.item()}


## Training Loop

def run_epoch(loader, train=True):
    model.train(train)
    total = 0.0; n = 0
    for batch in loader:
        batch = batch.to(device)
        #pooled, cluster, node_type = graclus_pool_hetero(batch)
        opt.zero_grad(set_to_none=True)
        #pred_pooled = model(pooled.x_dict, pooled.edge_index_dict)
        #pred_unpooled = unpool_preds(pred_pooled, cluster, node_type, batch.node_types)
        pred = model(batch.x_dict, batch.edge_index_dict)
        loss, loss_dict = compute_losses(batch, pred)
        #print(loss_dict)
        if train:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
        total += loss.item(); n += 1
    return total / max(1, n), loss_dict


EPOCHS = 100
best_val = float('inf')
best_state = None

losses = [('train_loss','val_loss')]

for epoch in range(1, EPOCHS+1):
    tr, ld = run_epoch(train_loader, train=True)
    model.eval()
    with torch.no_grad():
        va, ld = run_epoch(val_loader, train=False)
    scheduler.step(va)
    losses.append((ld))
    if va < best_val:
        best_val = va
        best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
    if epoch % 5 == 0 or epoch == 1:
        print(f"Epoch {epoch:03d} | train {tr:.6f} | val {va:.6f}")

if best_state is not None:
    model.load_state_dict(best_state)
torch.save(model.state_dict(), "1123_sequential.pt")
print("Best val:", best_val)

loss_df = pd.DataFrame(losses)
loss_df.to_csv('loss_1123_sequential.csv')


In [9]:
model

GraphModule(
  (conv1): ModuleDict(
    (elements__contributes__nodes): SAGEConv((-1, -1), 64, aggr=mean)
    (nodes__belongs_to__elements): SAGEConv((-1, -1), 64, aggr=mean)
    (nodes__adjacent__nodes): SAGEConv((-1, -1), 64, aggr=mean)
    (nodes__adjacent_rev__nodes): SAGEConv((-1, -1), 64, aggr=mean)
  )
  (conv2): ModuleDict(
    (elements__contributes__nodes): SAGEConv((-1, -1), 3, aggr=mean)
    (nodes__belongs_to__elements): SAGEConv((-1, -1), 3, aggr=mean)
    (nodes__adjacent__nodes): SAGEConv((-1, -1), 3, aggr=mean)
    (nodes__adjacent_rev__nodes): SAGEConv((-1, -1), 3, aggr=mean)
  )
)

In [None]:
for t in train_set:
    