In [1]:
import matplotlib.pyplot as plt
import trimesh
import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from nets import EncodeProcessDecode
from utils import visualize


# Train

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
graphs = torch.load("data/cantilever_10.pt", weights_only=False)
train_x = torch.cat([g["node"].x for g in graphs], dim=0)
train_y = torch.cat([g["node"].y for g in graphs], dim=0)

stats = {
    "x_mean": train_x.mean(dim=0),
    "x_std": train_x.std(dim=0),
    "y_mean": train_y.mean(dim=0),
    "y_std": train_y.std(dim=0),
}


def normalize_graph(graph, stats):
    # Clone to avoid overwriting original data in memory if needed
    g = graph.clone()
    g["node"].x = (g["node"].x - stats["x_mean"]) / stats["x_std"]
    g["node"].y = (g["node"].y - stats["y_mean"]) / stats["y_std"]
    return g


loader = DataLoader(
    [normalize_graph(g, stats) for g in graphs], batch_size=1, shuffle=False
)

Using device: cpu


In [None]:
g0 = graphs[0]
node_dim = g0.num_node_features["node"]
mesh_edge_dim = g0.num_edge_features["node", "mesh", "node"]
contact_edge_dim = g0.num_edge_features["node", "contact", "node"]
latent_dim = 128
output_dim = 1  # predicting Von Mises stress

print("Node feature dim:", node_dim)
print("Mesh edge feature dim:", mesh_edge_dim)
print("Contact edge feature dim:", contact_edge_dim)

model = EncodeProcessDecode(
    node_dim=node_dim,
    mesh_edge_dim=mesh_edge_dim,
    contact_edge_dim=contact_edge_dim,
    output_dim=output_dim,
    latent_dim=latent_dim,
    message_passing_steps=20,
    use_layer_norm=True,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

loss_history = []
num_epochs = 50
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    total_nodes = 0

    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        y_pred = model(batch)
        y_true = batch["node"].y

        loss = F.mse_loss(y_pred, y_true)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * batch.num_nodes
        total_nodes += batch.num_nodes
    avg_loss = total_loss / total_nodes
    loss_history.append(avg_loss)
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.6f}")

torch.save(
    {
        "model_state_dict": model.state_dict(),
        "stats": stats,
    },
    "models/model.pth",
)

Node feature dim: 9
Mesh edge feature dim: 4
Contact edge feature dim: 7


# Traning Loss

In [None]:
plt.plot(loss_history)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss")
plt.show()

# Visualize Results

In [None]:
mesh = trimesh.load_mesh("cantilever.stl")
g = graphs[1]
visualize(mesh, g, jupyter_backend="html")


In [None]:
model.eval()
g_pred = g.clone()
g_pred["node"].y = model(g.to(device)).detach()
visualize(mesh, g_pred, jupyter_backend="html")