### Raw STL Conversion
Data preprocessing for use with MeshGraphNets


In [1]:
!pip install trimesh

Collecting trimesh
  Downloading trimesh-4.5.2-py3-none-any.whl.metadata (18 kB)
Downloading trimesh-4.5.2-py3-none-any.whl (704 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m704.4/704.4 kB[0m [31m15.9 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: trimesh
Successfully installed trimesh-4.5.2


In [3]:
import os
import itertools
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Data
import torch.nn.functional as nn
import trimesh
from trimesh import load as load_stl

In [4]:
def load_stl(file_path):
    mesh = trimesh.load(file_path)
    vertices = torch.tensor(mesh.vertices, dtype=torch.float)  # Node positions
    faces = torch.tensor(mesh.faces, dtype=torch.long)         # Face (triangle) indices
    return vertices, faces

In [None]:
def create_dataset(data_file, quality_scores):
    data_list = []
    file_name_list = []
    file_list = os.listdir(data_file)
    j = 1
    
    for item in file_list:
        design_name = item.rpartition('_')[0]
        print(f"Processing file {j}: {design_name}")
        
        if design_name not in file_name_list and design_name != ".DS":
            file_name_list.append(design_name)

            # Load node positions and elements (faces)
            design_node = os.path.join(data_file, f"{design_name}_nodes.csv")
            design_element = os.path.join(data_file, f"{design_name}_elements.csv")

            df_nodes = pd.read_csv(design_node)
            df_elements = pd.read_csv(design_element)

            # Convert four-node 3D elements into triangles
            df_triangles = pd.DataFrame(columns=['elem1', 'elem2', 'elem3'])
            for i in range(df_elements.shape[0]):
                four_nodes = df_elements.iloc[i, -4:].tolist()
                triangles = list(itertools.combinations(four_nodes, 3))
                df_triangles = df_triangles.append(pd.DataFrame(triangles, columns=['elem1', 'elem2', 'elem3']), ignore_index=True)

            # Get cell indices and node positions
            cells_index = np.vstack((df_triangles["elem1"].to_numpy(),
                                     df_triangles["elem2"].to_numpy(),
                                     df_triangles["elem3"].to_numpy())).T.astype('int32')
            cells = torch.tensor(cells_index)

            node_position = np.vstack((df_nodes["x"].to_numpy(), df_nodes["y"].to_numpy(), df_nodes["z"].to_numpy())).T.astype('float32')
            mesh_pos = torch.tensor(node_position)

            # Calculate edge indices using triangles_to_edges
            edges = triangles_to_edges(cells)
            edge_index = torch.cat((torch.tensor(edges[0].numpy()).unsqueeze(0),
                                    torch.tensor(edges[1].numpy()).unsqueeze(0)), dim=0).type(torch.long)

            # Calculate edge features
            u_i = mesh_pos[edge_index[0]]
            u_j = mesh_pos[edge_index[1]]
            u_ij = u_i - u_j
            u_ij_norm = torch.norm(u_ij, p=2, dim=1, keepdim=True)
            edge_attr = torch.cat((u_ij, u_ij_norm), dim=-1).type(torch.float)

            # Node type and additional features
            node_type_info = df_nodes["nodetype"].to_numpy().astype('int32')
            node_type = nn.one_hot(torch.tensor(node_type_info).to(torch.long), num_classes=4)  # Assuming 4 types; adjust as needed
            x = torch.cat((mesh_pos, node_type), dim=-1).type(torch.float)

            # Add quality score as the target
            if f"{design_name}.stl" in quality_scores:
                target = torch.tensor([float(quality_scores[f"{design_name}.stl"])], dtype=torch.float)
            else:
                print(f"Warning: Quality score not found for {design_name}")
                continue

            # Append to data list
            data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=target, mesh_pos=mesh_pos, cells=cells)
            data_list.append(data)

        j += 1
    print("Done collecting data!")
    
    # Define the save path
    save_dir = os.path.join(data_file, 'data_preprocessed')
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, '3D_data_sample.pt')
    
    # Save dataset
    torch.save(data_list, save_path)
    print(f"Dataset saved to: {save_path}")
    return save_path

FileNotFoundError: [Errno 2] No such file or directory: '/path/to/your/data'

### Possible Usage

In [None]:
# Usage example
data_file = "/path/to/your/data"  # Path where STL-related CSVs are located
quality_scores = {
    "/path/to/file1.stl": "1.5",
    "/path/to/file2.stl": "2.0",
    # Add paths and corresponding quality scores here
}
save_path = create_dataset(data_file, quality_scores)
print(f"Dataset saved to: {save_path}")