In [2]:
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F

from pathlib import Path

from sklearn.preprocessing import StandardScaler, PowerTransformer, \
    RobustScaler, OneHotEncoder
from tqdm import tqdm
import joblib

pd.set_option('display.float_format', '{:.8f}'.format)

In [15]:
DATASET_PATH = Path(r"E:\gnn_data\pyg_data_v2")
DATASET_PATH_SCALED = Path(r"E:\gnn_data\pyg_data_v2_scaled")
DATASET_PATH_SCALED.mkdir(exist_ok=True)

In [16]:
# Concat tensor to scale features
all_node_features = []
all_edge_features = []
all_global_features = []
all_files = list(Path(DATASET_PATH).glob("*.pt"))
for file in tqdm(all_files):
    try:
        data = torch.load(file, weights_only=False)
        if data.x is not None:
            all_node_features.append(data.x.numpy())
        if data.edge_attr is not None:
            all_edge_features.append(data.edge_attr.numpy())
        if data.global_features is not None:
            all_global_features.append(data.global_features.numpy())
    except Exception as e:
        print(f"Error loading {file}: {e}")
        continue

100%|██████████| 62198/62198 [00:19<00:00, 3232.10it/s]


In [17]:
concatenated_node_features = np.vstack(all_node_features)
concatenated_edge_features = np.vstack(all_edge_features)
concatenated_global_features = np.vstack(all_global_features)
edge_df = pd.DataFrame(concatenated_edge_features)
node_df = pd.DataFrame(concatenated_node_features)
global_df = pd.DataFrame(concatenated_global_features)
len(concatenated_node_features), len(concatenated_edge_features), len(
    concatenated_global_features)

(7207663, 35437906, 62198)

In [18]:
node_columns = [
    "item_id", 'node_degree', 'degree_centrality', 'average_neighbor_degree',
    'triangles', 'page_rank',
    'betweenness_centrality', 'closeness_centrality',
    'clustering_coefficient',
    'area', 'perimeter', 'edge_count', 'vertex_count',
    'compactness', 'u_span', 'v_span', 'mean_curvature',
    'orientation', 'surface_type',
]

node_columns = node_columns[:5] + node_columns[9:17] + node_columns[5:9] + node_columns[17:]

node_df.columns = node_columns
node_df.describe()

Unnamed: 0,item_id,node_degree,degree_centrality,average_neighbor_degree,triangles,area,perimeter,edge_count,vertex_count,compactness,u_span,v_span,mean_curvature,page_rank,betweenness_centrality,closeness_centrality,clustering_coefficient,orientation,surface_type
count,7207663.0,7207663.0,7207663.0,7207663.0,7207663.0,7207663.0,7207663.0,7207663.0,7207663.0,7207663.0,7207663.0,7207663.0,7207663.0,7207663.0,7207663.0,7207663.0,7207663.0,7207663.0,7207663.0
mean,73581.3984375,4.91670036,0.04149254,40.93860245,3.43363309,826.40307617,93.14465332,5.09715176,4.98209858,0.46362433,12.68271065,18.15581131,-0.23450254,0.00862942,0.0202569,0.26754567,0.40480793,0.55952114,1.49751234
std,29072.43945312,11.80577755,0.08558629,94.1300354,10.6616497,14092.12109375,331.69641113,12.01191139,12.02399921,0.30184901,57.45454788,79.25156403,0.68544483,0.01742876,0.06342044,0.13665757,0.26747343,0.51759177,1.97586739
min,23481.0,0.0,0.0,0.0,0.0,0.0,0.00011846,1.0,1.0,0.0,0.0,0.0,-10.0,6.381e-05,0.0,0.0,0.0,0.0,0.0
25%,48870.0,4.0,0.00535332,4.75,1.0,2.6964817,9.3349781,4.0,4.0,0.20388903,1.0,1.0,-0.22222222,0.00110429,0.00017593,0.17065556,0.16666667,0.0,0.0
50%,71507.0,4.0,0.01339286,8.75,3.0,20.02379036,27.22178268,4.0,4.0,0.47086021,3.0,3.0,-0.01811594,0.00279895,0.00148221,0.25078699,0.33333334,1.0,1.0
75%,98395.0,4.0,0.03846154,28.75,4.0,112.31193542,66.0,4.0,4.0,0.70534748,6.28318548,11.62790966,0.0,0.00817791,0.00870273,0.35738832,0.66666669,1.0,2.0
max,125515.0,1825.0,2.0,1182.33337402,1824.0,2886913.0,50747.9921875,1825.0,1825.0,16.05916786,4446.96679688,15581.2265625,10.0,0.57446861,1.0,1.0,1.0,1.0,9.0


In [31]:
node_df.iloc[:, 13:18]

Unnamed: 0,page_rank,betweenness_centrality,closeness_centrality,clustering_coefficient,orientation
0,0.00328070,0.00001904,0.46943232,0.66666669,0.00000000
1,0.00328070,0.00001904,0.46943232,0.66666669,0.00000000
2,0.00328070,0.00001904,0.46943232,0.66666669,0.00000000
3,0.00328070,0.00001904,0.46943232,0.66666669,0.00000000
4,0.00328070,0.00001904,0.46943232,0.66666669,0.00000000
...,...,...,...,...,...
7207658,0.01612903,0.44808742,0.05695612,0.00000000,0.00000000
7207659,0.01612903,0.45901638,0.05803996,0.00000000,0.00000000
7207660,0.01612903,0.46885246,0.05905131,0.00000000,0.00000000
7207661,0.01612903,0.47759563,0.05998033,0.00000000,0.00000000


In [61]:
to_remove_id = node_df[(node_df.mean_curvature < -10) | (node_df.mean_curvature > 10)].item_id.values
to_remove_id = set([str(int(item)) for item in to_remove_id])
len(to_remove_id)

0

In [62]:
step_to_check = []
step_broken_dir = Path(r"E:\gnn_data\step_broken")
file_moved = 0
for file in Path(r"E:\gnn_data\step_files").glob("*.*"):
    if file.stem.split("_")[0] in to_remove_id:
        file.rename(step_broken_dir / file.name)
        file_moved += 1
print(f"Moved {file_moved} files to {step_broken_dir}")

Moved 0 files to E:\gnn_data\step_broken


In [19]:
edge_columns = [
    "shared_face_count", "curve_type", "length", "chord_length", "is_closed", "orientation",

]
edge_df.columns = edge_columns
edge_df.describe()

Unnamed: 0,shared_face_count,curve_type,length,chord_length,is_closed,orientation
count,35437906.0,35437906.0,35437906.0,35437906.0,35437906.0,35437906.0
mean,2.0003674,1.44088829,18.45617104,14.39870262,0.46545297,0.04799934
std,0.97307545,2.04126692,78.5568161,62.20270538,0.48653147,0.20872761
min,2.0,0.0,0.0,0.0,0.0,0.0
25%,2.0,0.0,1.14261234,0.77115077,0.0,0.0
50%,2.0,1.0,4.00014973,3.0,0.0,0.0
75%,2.0,1.0,13.13338947,9.53872967,1.0,0.0
max,4.0,6.0,24738.07617188,7083.93994141,1.0,1.0


In [28]:
edge_df.iloc[:, 2:4]

Unnamed: 0,length,chord_length
0,13.70161915,13.69951153
1,13.70161915,13.69951153
2,0.31999999,0.31999999
3,0.31999999,0.31999999
4,13.70161915,13.69951153
...,...,...
35437901,0.13346587,0.13344844
35437902,5.34398365,3.40208554
35437903,5.34398365,3.40208554
35437904,0.58311963,0.58311963


In [20]:
global_columns = [
    "faces", "edges", "vertices", "quantity",
    "height", "width", "depth", "volume", "area",
    "bbox_height", "bbox_width", "bbox_depth", "bbox_volume",
    "bbox_area",
]
global_df.columns = global_columns
global_df.describe()

Unnamed: 0,faces,edges,vertices,quantity,height,width,depth,volume,area,bbox_height,bbox_width,bbox_depth,bbox_volume,bbox_area
count,62198.0,62198.0,62198.0,62198.0,62198.0,62198.0,62198.0,62198.0,62198.0,62198.0,62198.0,62198.0,62198.0,62198.0
mean,115.87760162,303.86050415,193.14466858,50.61309052,90.4812851,120.32733917,112.4785614,504756.40625,92888.3125,94.39996338,167.32948303,59.50634766,2964243.25,59.50634766
std,219.37226868,577.70739746,365.66427612,333.13510132,158.79827881,207.23075867,186.56289673,2791833.5,354723.125,141.77568054,262.35568237,93.30776215,42696920.0,93.30776215
min,2.0,3.0,2.0,1.0,0.1,0.1,0.07,8.384e-05,0.06177394,0.07,0.1,0.1,0.004942,0.1
25%,23.0,57.0,36.0,1.0,20.0,23.60000038,22.0,9383.05688477,6022.57519531,21.9734993,37.95422745,15.0,26250.0,15.0
50%,44.0,115.0,74.0,2.0,40.0,60.00001144,53.0308075,37802.359375,16901.42773438,50.0,90.0,29.49941254,128697.296875,29.49941254
75%,106.0,280.0,179.0,6.0,99.0487442,134.91200256,123.59999847,202855.609375,61179.8984375,105.0,180.0,68.50379944,671994.078125,68.50379944
max,2900.0,7778.0,5114.0,9600.0,3104.0,5110.0,5000.0,99667912.0,27595844.0,2782.265625,6081.52636719,2369.0,5282931712.0,2369.0


In [21]:
pt_scaler = PowerTransformer()

pt_node_df = pd.DataFrame(pt_scaler.fit_transform(node_df.iloc[:, 1:13]))
pt_node_df.describe()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11
count,7207663.0,7207663.0,7207663.0,7207663.0,7207663.0,7207663.0,7207663.0,7207663.0,7207663.0,7207663.0,7207663.0,7207663.0
mean,-2e-08,-1e-08,-0.0,4e-08,0.0,-1e-08,0.0,-1e-08,1e-08,1e-08,1e-08,-0.0
std,1.00189829,0.99553233,0.99545538,0.99967074,0.99632168,0.99554378,0.99952495,0.99901342,0.99546027,0.99488014,0.99424034,0.98431462
min,-12.80566025,-1.1993978,-4.6236062,-1.60452652,-1.61600995,-2.59231758,-6.9834342,-5.10011673,-1.77486718,-1.91054368,-1.79620552,-8.46036148
25%,0.01074957,-0.83189958,-0.82164377,-0.64952034,-0.83312541,-0.67004597,-0.09646273,0.0230778,-0.86197013,-0.74014175,-0.85472673,-0.06836542
50%,0.01074957,-0.35716188,-0.15229301,0.29607743,0.04835479,0.05807765,-0.09646273,0.0230778,0.12488431,0.12775521,-0.10508639,0.27965906
75%,0.01074957,0.67770714,0.83036113,0.59850842,0.75640059,0.64201963,-0.09646273,0.0230778,0.85474926,0.69502139,0.8168065,0.3120499
max,4.35718393,2.1535871,2.04333687,8.25516605,3.09980631,4.05105877,3.75859046,3.86764359,12.33006287,2.49697042,2.63085961,40.06744003


In [22]:
edge_pt_scaler = PowerTransformer()
edge_one_hot = OneHotEncoder(sparse_output=False)
edge_df_pt_scaled = pd.DataFrame(edge_pt_scaler.fit_transform(edge_df.iloc[:, 2:5]))
edge_df_one_hot = pd.DataFrame(edge_one_hot.fit_transform(edge_df.iloc[:, 0:2]))

Unnamed: 0,0,1,2,0.1,is_closed,orientation
0,0.81823951,1.00712299,1.07165515,"(0, 0)\t1.0\n (0, 5)\t1.0",1.00000000,0.00000000
1,0.81823951,1.00712299,1.07165515,"(0, 0)\t1.0\n (0, 5)\t1.0",1.00000000,0.00000000
2,-1.42752099,-1.24405205,1.07165515,"(0, 0)\t1.0\n (0, 3)\t1.0",1.00000000,0.00000000
3,-1.42752099,-1.24405205,1.07165515,"(0, 0)\t1.0\n (0, 3)\t1.0",1.00000000,0.00000000
4,0.81823951,1.00712299,-0.93313593,"(0, 0)\t1.0\n (0, 5)\t1.0",0.00000000,0.00000000
...,...,...,...,...,...,...
35437901,-1.62497962,-1.45303595,-0.93313593,"(0, 0)\t1.0\n (0, 8)\t1.0",0.00000000,0.00000000
35437902,0.19521198,0.09461578,1.07165515,"(0, 0)\t1.0\n (0, 4)\t1.0",1.00000000,0.00000000
35437903,0.19521198,0.09461578,1.07165515,"(0, 0)\t1.0\n (0, 4)\t1.0",1.00000000,0.00000000
35437904,-1.20237553,-1.00767088,-0.93313593,"(0, 0)\t1.0\n (0, 3)\t1.0",0.00000000,0.00000000


In [23]:
edge_one_hot = OneHotEncoder(sparse_output=False)
edge_df_one_hot = pd.DataFrame(edge_one_hot.fit_transform(edge_df.iloc[:, 0:2]),
                              columns=edge_one_hot.get_feature_names_out(edge_df.columns[:2]))
edge_df_one_hot


Unnamed: 0,shared_face_count_2.0,shared_face_count_3.0,shared_face_count_4.0,curve_type_0.0,curve_type_1.0,curve_type_2.0,curve_type_3.0,curve_type_4.0,curve_type_6.0
0,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000
1,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000
2,1.00000000,0.00000000,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000
3,1.00000000,0.00000000,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000
4,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000
...,...,...,...,...,...,...,...,...,...
35437901,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,1.00000000
35437902,1.00000000,0.00000000,0.00000000,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000
35437903,1.00000000,0.00000000,0.00000000,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000
35437904,1.00000000,0.00000000,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000


In [24]:
edge_scaled_df = pd.concat([edge_df_pt_scaled, edge_df_one_hot, edge_df.iloc[:, 4:]], axis=1)
edge_scaled_df

Unnamed: 0,0,1,2,shared_face_count_2.0,shared_face_count_3.0,shared_face_count_4.0,curve_type_0.0,curve_type_1.0,curve_type_2.0,curve_type_3.0,curve_type_4.0,curve_type_6.0,is_closed,orientation
0,0.81823951,1.00712299,1.07165515,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000,1.00000000,0.00000000
1,0.81823951,1.00712299,1.07165515,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000,1.00000000,0.00000000
2,-1.42752099,-1.24405205,1.07165515,1.00000000,0.00000000,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,1.00000000,0.00000000
3,-1.42752099,-1.24405205,1.07165515,1.00000000,0.00000000,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,1.00000000,0.00000000
4,0.81823951,1.00712299,-0.93313593,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
35437901,-1.62497962,-1.45303595,-0.93313593,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,1.00000000,0.00000000,0.00000000
35437902,0.19521198,0.09461578,1.07165515,1.00000000,0.00000000,0.00000000,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,1.00000000,0.00000000
35437903,0.19521198,0.09461578,1.07165515,1.00000000,0.00000000,0.00000000,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,1.00000000,0.00000000
35437904,-1.20237553,-1.00767088,-0.93313593,1.00000000,0.00000000,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000


In [26]:
global_features_df = pd.DataFrame(np.log1p(global_df), columns=global_df.columns)
global_features_df.describe()

Unnamed: 0,faces,edges,vertices,quantity,height,width,depth,volume,area,bbox_height,bbox_width,bbox_depth,bbox_volume,bbox_area
count,62198.0,62198.0,62198.0,62198.0,62198.0,62198.0,62198.0,62198.0,62198.0,62198.0,62198.0,62198.0,62198.0,62198.0
mean,3.96517348,4.86377573,4.43062305,1.60885537,3.7769146,4.06244373,3.99959993,10.57654381,9.79553223,3.92598963,4.41646433,3.46486783,11.67086697,3.46486783
std,1.16525197,1.24741495,1.23061299,1.43991017,1.20675457,1.23305404,1.22032368,2.45987391,1.86330795,1.12492836,1.23818839,1.12769842,2.61950541,1.12769842
min,1.09861231,1.38629436,1.09861231,0.69314718,0.09531018,0.09531018,0.06765865,8.383e-05,0.05994104,0.06765865,0.09531018,0.09531018,0.00492983,0.09531018
25%,3.17805386,4.06044292,3.61091781,0.69314718,3.04452252,3.20274639,3.13549423,9.14676738,8.7034359,3.13434124,3.66238737,2.77258873,10.17545891,2.77258873
50%,3.80666256,4.75359011,4.31748819,1.09861231,3.71357203,4.11087418,3.98955441,10.5401535,9.73521233,3.93182564,4.51085949,3.41770744,11.76522636,3.41770744
75%,4.67282867,5.63835478,5.19295692,1.9459101,4.60565758,4.91200781,4.82510853,12.2202549,11.02159023,4.66343927,5.19849682,4.24138165,13.41800666,4.24138165
max,7.97281075,8.95918274,8.53993225,9.16962242,8.04076862,8.53915024,8.51739311,18.41735458,17.1331749,7.93138027,8.71317577,7.77064514,22.38774681,7.77064514


In [53]:
edge_df.describe()

Unnamed: 0,shared_face_count,length,chord_length,is_closed,orientation,curve_type
count,35437906.0,35437906.0,35437906.0,35437906.0,35437906.0,35437906.0
mean,2.0003674,18.45617104,14.39870262,0.46545297,0.04799934,1.44088829
std,0.97307545,78.5568161,62.20270538,0.48653147,0.20872761,2.04126692
min,2.0,0.0,0.0,0.0,0.0,0.0
25%,2.0,1.14261234,0.77115077,0.0,0.0,0.0
50%,2.0,4.00014973,3.0,0.0,0.0,1.0
75%,2.0,13.13338947,9.53872967,1.0,0.0,1.0
max,4.0,24738.07617188,7083.93994141,1.0,1.0,6.0


In [28]:
import os

with open(r"E:\gnn_data\pyg_data_v2\dataset_mapping.pkl", "rb") as f:
    dataset_mapping = joblib.load(f)

In [31]:
dataset_mapping["processed_files"] = [
    item for item in dataset_mapping["processed_files"]
    if "Outlet_Coolant" in item["original_path"]
]

{'original_path': 'E:\\gnn_data\\graphml_files\\95332_EM50_15S7P_Outlet_Coolant_Cup.graphml', 'processed_path': 'E:\\gnn_data\\pyg_data_v2\\95332_EM50_15S7P_Outlet_Coolant_Cup.pt', 'label': 1, 'index': 60185}
{'original_path': 'E:\\gnn_data\\graphml_files\\95335_EM50_15S7P_Outlet_Coolant_Cup.graphml', 'processed_path': 'E:\\gnn_data\\pyg_data_v2\\95335_EM50_15S7P_Outlet_Coolant_Cup.pt', 'label': 1, 'index': 60188}


In [11]:
unscaled = torch.load(r"E:\gnn_data\pyg_data_v2\100062_CEL_03_0089_V4_2025_01_21.pt", weights_only=False)
scaled = torch.load(r"E:\gnn_data\pyg_data_v2_scaled\100062_CEL_03_0089_V4_2025_01_21.pt", weights_only=False)

In [9]:
unscaled["y"]

tensor([1])

In [5]:
pd.DataFrame(unscaled.x)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18
0,100062.00000000,2.00000000,0.00448431,3.00000000,0.00000000,5.61216450,9.58100414,4.00000000,3.00000000,0.76827699,3.14159274,2.04160833,-0.29430747,0.00142426,0.00001041,0.14513505,0.00000000,1.00000000,2.00000000
1,100062.00000000,4.00000000,0.00896861,4.00000000,2.00000000,54.61496353,31.00917244,6.00000000,6.00000000,0.71374124,3.14159274,10.00000000,-0.28571430,0.00220838,0.00448750,0.16958176,0.33333334,1.00000000,1.00000000
2,100062.00000000,2.00000000,0.00448431,2.50000000,0.00000000,5.61216450,9.58100414,4.00000000,3.00000000,0.76827699,3.14159274,2.04160833,-0.29430747,0.00145827,0.00000504,0.14504065,0.00000000,1.00000000,2.00000000
3,100062.00000000,3.00000000,0.00672646,4.00000000,1.00000000,54.97787094,30.99557495,4.00000000,4.00000000,0.71911448,3.14159274,10.00000000,-0.28571430,0.00180408,0.00446381,0.16945289,0.33333334,1.00000000,1.00000000
4,100062.00000000,4.00000000,0.00896861,72.75000000,4.00000000,2.74889350,6.64159250,4.00000000,4.00000000,0.78311032,1.57079637,1.75000000,-0.50000000,0.00155689,0.00057012,0.33184522,0.66666669,0.00000000,1.00000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
442,100062.00000000,4.00000000,0.00896861,16.00000000,4.00000000,9.42477798,15.14159298,4.00000000,4.00000000,0.51657993,1.57079637,6.00000000,-0.50000000,0.00167501,0.00015388,0.20281947,0.66666669,0.00000000,1.00000000
443,100062.00000000,4.00000000,0.00896861,16.00000000,4.00000000,9.42477798,15.14159298,4.00000000,4.00000000,0.51657993,1.57079637,6.00000000,-0.50000000,0.00167501,0.00015388,0.20281947,0.66666669,0.00000000,1.00000000
444,100062.00000000,4.00000000,0.00896861,16.50000000,4.00000000,9.42477798,15.14159298,4.00000000,4.00000000,0.51657993,1.57079637,6.00000000,-0.50000000,0.00170022,0.00016463,0.20318906,0.66666669,0.00000000,1.00000000
445,100062.00000000,185.00000000,0.41479820,4.76216221,185.00000000,4909.73925781,8318.54003906,185.00000000,185.00000000,0.00089161,154.00000000,204.00000000,0.00000000,0.05984786,0.46682021,0.45837617,0.01086957,0.00000000,0.00000000


In [6]:
pd.DataFrame(scaled.x)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,17,18,19,20,21,22,23,24,25,26
0,-2.36917782,-0.88856727,-1.38376331,-1.60452652,-0.51935393,-0.65235418,-0.09646273,-0.85859513,1.03361976,0.16484381,...,0.00000000,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000
1,0.01074957,-0.60774773,-1.02726340,-0.09524558,0.47311378,0.14551882,1.11332548,1.05263269,0.87898809,0.16484381,...,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000
2,-2.36917782,-0.88856727,-1.61497760,-1.60452652,-0.51935393,-0.65235418,-0.09646273,-0.85859513,1.03361976,0.16484381,...,0.00000000,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000
3,-0.90337020,-0.74459845,-1.02726340,-0.64952034,0.47580427,0.14522472,-0.09646273,0.02307780,0.89439261,0.16484381,...,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000
4,0.01074957,-0.60774773,1.33431304,0.59850842,-0.82529002,-0.89990401,-0.09646273,0.02307780,1.07494235,-0.39543951,...,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
442,0.01074957,-0.60774773,0.39916083,0.59850842,-0.28728825,-0.34004918,-0.09646273,0.02307780,0.27586567,-0.39543951,...,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000
443,0.01074957,-0.60774773,0.39916083,0.59850842,-0.28728825,-0.34004918,-0.09646273,0.02307780,0.27586567,-0.39543951,...,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000
444,0.01074957,-0.60774773,0.42428070,0.59850842,-0.28728825,-0.34004918,-0.09646273,0.02307780,0.27586567,-0.39543951,...,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000
445,4.18402863,2.15181351,-0.81863403,5.36783838,1.93807256,3.28344464,3.69601321,3.74662662,-1.77048659,2.10372281,...,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000


In [89]:
pd.DataFrame(unscaled.edge_attr, columns=edge_df.columns)

Unnamed: 0,shared_face_count,curve_type,length,chord_length,is_closed,orientation
0,2.00000000,1.00000000,5.49778700,3.50000000,0.00000000,0.00000000
1,2.00000000,1.00000000,5.49778700,3.50000000,0.00000000,0.00000000
2,2.00000000,0.00000000,2.04160833,2.04160833,1.00000000,0.00000000
3,2.00000000,0.00000000,2.04160833,2.04160833,1.00000000,0.00000000
4,2.00000000,0.00000000,10.00000000,10.00000000,0.00000000,0.00000000
...,...,...,...,...,...,...
2647,2.00000000,1.00000000,3.14159274,2.00000000,1.00000000,0.00000000
2648,2.00000000,1.00000000,5.49778700,3.50000000,1.00000000,0.00000000
2649,2.00000000,1.00000000,5.49778700,3.50000000,1.00000000,0.00000000
2650,2.00000000,1.00000000,5.49778700,3.50000000,1.00000000,0.00000000


In [87]:
pd.DataFrame(scaled.edge_attr)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11
0,0.21500720,0.11460023,2.00000000,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000
1,0.21500720,0.11460023,2.00000000,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000
2,-0.47929421,-0.26307076,2.00000000,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,1.00000000,0.00000000
3,-0.47929421,-0.26307076,2.00000000,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,1.00000000,0.00000000
4,0.61907071,0.81861818,2.00000000,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000
...,...,...,...,...,...,...,...,...,...,...,...,...
2647,-0.17958987,-0.27722991,2.00000000,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,1.00000000,0.00000000
2648,0.21500720,0.11460023,2.00000000,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,1.00000000,0.00000000
2649,0.21500720,0.11460023,2.00000000,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,1.00000000,0.00000000
2650,0.21500720,0.11460023,2.00000000,0.00000000,1.00000000,0.00000000,0.00000000,0.00000000,0.00000000,0.00000000,1.00000000,0.00000000


In [90]:
pd.DataFrame(unscaled.global_features)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13
0,447.0,1368.0,904.0,10.0,230.00056458,175.00056458,9.50056076,130892.1015625,83646.1015625,9.50056076,175.00056458,230.00056458,382399.71875,230.00056458


In [91]:
pd.DataFrame(scaled.global_features)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13
0,1.50939655,1.84195864,1.94404793,-0.12192649,0.87862372,0.26382789,-0.55197787,-0.13390332,-0.02605509,-0.59879136,0.02923957,1.82732224,-0.06047219,1.82732224


In [5]:
sync_dataset = pd.read_csv(r".\data\synced_dataset_final.csv")

In [23]:
import shutil
old_pt_dir = Path(r"E:\gnn_data\pyg_data_v2_scaled")
for index, row in tqdm(sync_dataset.iterrows()):
    file_path = Path(row["step_file"])
    fold_id = int(row["binary_fold"])
    new_pt_dir = old_pt_dir / f"fold_{str(fold_id).zfill(2)}"
    new_pt_dir.mkdir(exist_ok=True, parents=True)
    old_pt_path = old_pt_dir / file_path.with_suffix(".pt").name
    new_pt_path = new_pt_dir / file_path.with_suffix(".pt").name
    shutil.copy(old_pt_path, new_pt_path)

62198it [09:40, 107.13it/s]


{'node_power_transformer': PowerTransformer(),
 'node_minmax_scaler': MinMaxScaler(),
 'node_onehot_encoder': OneHotEncoder(categories=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], sparse_output=False),
 'edge_onehot_encoder_0': OneHotEncoder(categories=[[2, 3, 4]], sparse_output=False),
 'edge_onehot_encoder_1': OneHotEncoder(categories=[[0, 1, 2, 3, 4, 5, 6]], sparse_output=False),
 'edge_minmax_scaler': MinMaxScaler(),
 'global_minmax_scaler': MinMaxScaler()}

In [3]:
UNSCALED_DATA_DIR = Path(r"E:\gnn_data\pyg_data_v2")
for vlad_fold in range(1, 10):
    SCALED_DATA_DIR = Path(rf"E:\gnn_data\pyg_data_v2_scaled_validation_fold_{str(vlad_fold).zfill(2)}")
    SCALED_DATA_DIR.mkdir(exist_ok=True, parents=True)
    with open(fr"E:\gnn_data\pyg_data_v2\scalers_fold_{str(vlad_fold).zfill(2)}.pkl", "rb") as f:
            scalers = joblib.load(f)
    for fold_id in range(10):
        fold_dir = UNSCALED_DATA_DIR / f"fold_{str(fold_id).zfill(2)}"
        for file in tqdm(list(fold_dir.glob("*.pt")), desc=f"Processing fold {fold_id}"):
            data = torch.load(file, weights_only=False)
            if data.x is not None:
                x = data.x.numpy()

                # Discard x[:, 0]
                # Apply PowerTransformer + MinMaxScaler on x[:, 1:13]
                x_power_part = x[:, 1:13]
                x_power_transformed = scalers["node_power_transformer"].transform(
                    x_power_part)
                x_power_scaled = scalers["node_minmax_scaler"].transform(
                    x_power_transformed)

                # Keep x[:, 13:17] the same (assuming you meant 13:17 based on your indexing)
                x_unchanged = x[:, 13:18]

                # One-hot encode x[:, 17] (assuming 0-based indexing, so column 17 is what you called 18)
                x_onehot_part = x[:, 18:19]  # Keep as 2D
                x_onehot_encoded = scalers["node_onehot_encoder"].transform(
                    x_onehot_part)

                # Concatenate all parts
                new_x = np.concatenate([
                    x_power_scaled,  # columns 1:13 -> power + minmax scaled
                    x_unchanged,  # columns 13:17 -> unchanged
                    x_onehot_encoded
                    # column 17 -> one-hot encoded (10 categories)
                ], axis=1)

                data.x = torch.tensor(new_x, dtype=torch.float32)

            # Process edge features (Data.edge_attr)
            if data.edge_attr is not None:
                edge_attr = data.edge_attr.numpy()

                # One-hot encode edge_attr[:, 0] with 3 categories (2, 3, 4)
                edge_onehot_0 = scalers["edge_onehot_encoder_0"].transform(
                    edge_attr[:, 0:1])

                # One-hot encode edge_attr[:, 1] with 7 categories
                edge_onehot_1 = scalers["edge_onehot_encoder_1"].transform(
                    edge_attr[:, 1:2])

                # Apply PowerTransformer + MinMaxScaler on edge_attr[:, 2:4]
                edge_power_part = edge_attr[:, 2:4]
                edge_log_transformed = np.log1p(edge_power_part)
                edge_scaled = scalers["edge_minmax_scaler"].transform(
                    edge_log_transformed)

                # Keep edge_attr[:, 4:6] the same
                edge_unchanged = edge_attr[:, 4:6]

                # Concatenate all parts
                new_edge_attr = np.concatenate([
                    edge_onehot_0, # column 0 -> one-hot encoded (3 categories)
                    edge_onehot_1, # column 1 -> one-hot encoded (7 categories)
                    edge_scaled,  # columns 2:4 -> power + minmax scaled
                    edge_unchanged  # columns 4:6 -> unchanged
                ], axis=1)

                data.edge_attr = torch.tensor(new_edge_attr,
                                              dtype=torch.float32)

            # Process global features (Data.global_features)
            if data.global_features is not None:
                global_features = data.global_features.numpy()

                # Apply PowerTransformer + MinMaxScaler on all global features
                global_log_transformed = np.log1p(global_features)
                global_scaled = scalers["global_minmax_scaler"].transform(
                    global_log_transformed)

                data.global_features = torch.tensor(global_scaled,
                                                    dtype=torch.float32)
            new_pt_path = SCALED_DATA_DIR / f"fold_{str(fold_id).zfill(2)}" / file.name
            new_pt_path.parent.mkdir(parents=True, exist_ok=True)
            torch.save(data, new_pt_path)

Processing fold 0: 100%|██████████| 6220/6220 [01:00<00:00, 102.62it/s]
Processing fold 1: 100%|██████████| 6220/6220 [00:58<00:00, 105.46it/s]
Processing fold 2: 100%|██████████| 6220/6220 [00:59<00:00, 105.21it/s]
Processing fold 3: 100%|██████████| 6220/6220 [00:58<00:00, 105.46it/s]
Processing fold 4: 100%|██████████| 6220/6220 [00:59<00:00, 104.51it/s]
Processing fold 5: 100%|██████████| 6220/6220 [00:58<00:00, 105.93it/s]
Processing fold 6: 100%|██████████| 6220/6220 [00:58<00:00, 106.60it/s]
Processing fold 7: 100%|██████████| 6220/6220 [00:59<00:00, 104.64it/s]
Processing fold 8: 100%|██████████| 6219/6219 [00:59<00:00, 104.69it/s]
Processing fold 9: 100%|██████████| 6219/6219 [00:59<00:00, 105.28it/s]
Processing fold 0: 100%|██████████| 6220/6220 [00:17<00:00, 354.62it/s]
Processing fold 1: 100%|██████████| 6220/6220 [00:17<00:00, 358.81it/s]
Processing fold 2: 100%|██████████| 6220/6220 [00:17<00:00, 360.28it/s]
Processing fold 3: 100%|██████████| 6220/6220 [00:17<00:00, 360.