# 1.Import and Configuration(导入与配置)

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from torch_geometric.nn import GATConv
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    roc_auc_score, accuracy_score, f1_score, recall_score, precision_score, confusion_matrix
)
import random
from kan import KAN  # pip install kan 
import matplotlib
matplotlib.rcParams["font.sans-serif"] = ["SimHei"]
matplotlib.rcParams["axes.unicode_minus"] = False
import matplotlib.pyplot as plt
from collections import deque, defaultdict
import plotly.graph_objs as go
from glob import glob

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

FEATURE_COLUMNS = [
    'ITG', 'BTW', 'CTR', 'ETR', 'S', 'L', 'WA', 'WFR', 'WWR', 'DN',
    'WN', 'DS', 'WS', 'H', 'V', 'a', 'b', 'OH', 'IH', 'CN', 'CA'
]
INPUT_FEATURES = FEATURE_COLUMNS
#FEATURES_LAYER1 = ['ITG', 'BTW', 'CTR', 'ETR', 'S', 'H', 'V', 'a', 'b']
#FEATURES_LAYER2 = ['L', 'WA', 'WFR', 'WWR', 'DN', 'WN', 'DS', 'WS', 'OH', 'IH', 'CN', 'CA']
FEATURES_LAYER1 = ['ITG', 'BTW', 'CTR', 'ETR', 'WA', 'WFR', 'WWR', 'WS' , 'CA']
FEATURES_LAYER2 = ['S', 'H', 'V', 'a', 'b' , 'L', 'DN', 'WN', 'DS',  'OH', 'IH', 'CN']


Using device: cuda:0


# 2.Data processing & related functions(数据处理相关函数)

In [4]:
class GroupedStandardScaler:
    def __init__(self):
        self.group_scalers = {}
    def fit(self, X, groups):
        self.group_scalers = {}
        groups = pd.Series(groups)
        for g in groups.unique():
            scaler = StandardScaler()
            mask = (groups == g)
            scaler.fit(X.loc[mask])
            self.group_scalers[g] = scaler
        return self
    def transform(self, X, groups):
        X_out = X.copy()
        groups = pd.Series(groups)
        for g, scaler in self.group_scalers.items():
            mask = (groups == g)
            X_out.loc[mask] = scaler.transform(X.loc[mask])
        return X_out

def read_data(folder_path):
    data_list = []
    file_names = [f for f in os.listdir(folder_path) if f.endswith('.xlsx')]
    if not file_names:
        raise FileNotFoundError(f"No .xlsx files in {folder_path}")
    for file_name in file_names:
        file_path = os.path.join(folder_path, file_name)
        xls = pd.ExcelFile(file_path)
        for sheet_name in xls.sheet_names:
            if sheet_name == 'data':
                continue
            df = pd.read_excel(file_path, sheet_name=sheet_name)
            df['source'] = f"{file_name}_{sheet_name}"
            data_list.append(df)
    if not data_list:
        raise ValueError("No valid sheets loaded in read_data")
    return pd.concat(data_list, ignore_index=True)

def read_edges(folder_path):
    edge_list = []
    file_names = [f for f in os.listdir(folder_path) if f.endswith('.xlsx')]
    for file_name in file_names:
        file_path = os.path.join(folder_path, file_name)
        xls = pd.ExcelFile(file_path)
        for sheet_name in xls.sheet_names:
            if sheet_name == 'data':
                continue
            df = pd.read_excel(file_path, sheet_name=sheet_name)
            df['source'] = f"{file_name}_{sheet_name}"
            edge_list.append(df)
    return pd.concat(edge_list, ignore_index=True)

def get_root_layer_nodes_for_sheet(node_data, node_ids):
    df = node_data[node_data['global_id'].isin(node_ids)]
    for tag in ['EN', 'DT', 'LT']:
        lst = df[df['Class'] == tag]['global_id'].tolist()
        if lst:
            return lst
    return [node_ids[0]]

def process_data(node_data, edge_data):
    node_data = node_data.copy()
    node_data['global_id'] = range(len(node_data))
    id_map = {(row['source'], row['ID']): row['global_id'] for _, row in node_data.iterrows()}
    all_input_cols = INPUT_FEATURES
    idx_layer1 = [all_input_cols.index(col) for col in FEATURES_LAYER1]
    idx_layer2 = [all_input_cols.index(col) for col in FEATURES_LAYER2]
    group_col = 'Class'
    group_series = node_data[group_col].astype(str)
    group_features = ['BTW', 'CTR']
    other_features = [col for col in INPUT_FEATURES if col not in group_features]
    input_scaler = StandardScaler()
    X_other = input_scaler.fit_transform(node_data[other_features].fillna(0))
    group_scaler_BTW = GroupedStandardScaler()
    group_scaler_BTW.fit(node_data[['BTW']].fillna(0), group_series)
    group_scaler_CTR = GroupedStandardScaler()
    group_scaler_CTR.fit(node_data[['CTR']].fillna(0), group_series)
    X_BTW = group_scaler_BTW.transform(node_data[['BTW']].fillna(0), group_series).values
    X_CTR = group_scaler_CTR.transform(node_data[['CTR']].fillna(0), group_series).values
    X = np.zeros((len(node_data), len(all_input_cols)))
    for i, col in enumerate(all_input_cols):
        if col == 'BTW':
            X[:, i] = X_BTW[:, 0]
        elif col == 'CTR':
            X[:, i] = X_CTR[:, 0]
        else:
            X[:, i] = X_other[:, other_features.index(col)]
    edge_list = []
    for _, row in edge_data.iterrows():
        src_key = (row['source'], row['StartPointID'])
        dst_key = (row['source'], row['EndPointID'])
        if src_key in id_map and dst_key in id_map:
            edge_list.append([id_map[src_key], id_map[dst_key]])
    edge_index = torch.LongTensor(edge_list).t().contiguous()
    return {
        'x': torch.FloatTensor(X),
        'edge_index': edge_index,
        'idx_layer1': idx_layer1,
        'idx_layer2': idx_layer2,
        'node_data': node_data,
        'id_map': id_map
    }


# 3.Model structure definition(模型结构定义)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DualMLP_EdgeClassifier(nn.Module):
    def __init__(
        self,
        input_dim,
        idx_layer1,
        idx_layer2,
        mlp_dim=32,
        hidden_dim=64,
        dropout=0.2,
        heads=2,    
        concat=True,
        residual=True
    ):
        super().__init__()
        self.idx_layer1 = idx_layer1
        self.idx_layer2 = idx_layer2

        # Projection
        self.proj1 = nn.Linear(len(idx_layer1), mlp_dim)
        self.proj2 = nn.Linear(len(idx_layer2), mlp_dim)

        # Group MLPs (for each feature group independently)
        self.mlp1 = nn.Sequential(
            nn.Linear(mlp_dim, mlp_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim * 2, mlp_dim)
        )
        self.mlp2 = nn.Sequential(
            nn.Linear(mlp_dim, mlp_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim * 2, mlp_dim)
        )

        out_dim = hidden_dim * heads if concat else hidden_dim  

        # Dim align
        self.mlp_align1 = nn.Linear(mlp_dim, out_dim)
        self.mlp_align2 = nn.Linear(mlp_dim, out_dim)
        
        # Feature fusion gating
        self.fusion_gate = nn.Sequential(
            nn.Linear(out_dim * 4, out_dim * 4),
            nn.Sigmoid()
        )
        self.dropout = nn.Dropout(dropout)
        self.residual = residual
        self.classifier = None
        self.out1_dim = out_dim
        self.out2_dim = out_dim

    def build_classifier(self, feats_dim, device):
        self.classifier = nn.Sequential(
            nn.Linear(feats_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.18),
            nn.Linear(128, 32),
            nn.ReLU(),
            nn.Linear(32, 2)
        ).to(device)

    def forward(self, x, edge_index, edge_pairs):
        device = x.device
        idx_layer1 = self.idx_layer1
        idx_layer2 = self.idx_layer2
        x1 = x[:, idx_layer1]
        x2 = x[:, idx_layer2]

        # Two sets of feature independent MLP transformation
        x1_proj = self.proj1(x1)
        h1_mlp = self.mlp1(x1_proj)
        h1_aligned = self.mlp_align1(h1_mlp)
        h1_mlp = self.dropout(h1_aligned)

        x2_proj = self.proj2(x2)
        h2_mlp = self.mlp2(x2_proj)
        h2_aligned = self.mlp_align2(h2_mlp)
        h2_mlp = self.dropout(h2_aligned)

        feats_all = torch.cat([h1_mlp, h1_mlp, h2_mlp, h2_mlp], dim=1)  # [N, out_dim*4]
        gate = self.fusion_gate(feats_all)
        gate1, gate2, gate3, gate4 = torch.chunk(gate, 4, dim=1)
        fused = (gate1 * h1_mlp) + (gate2 * h1_mlp) + (gate3 * h2_mlp) + (gate4 * h2_mlp)

        # Take the concatenation of the two endpoints in the participating edge pairs as the representation of each "edge"
        edge_feat_i = fused[edge_pairs[0]]  # [n_edges, out_dim]
        edge_feat_j = fused[edge_pairs[1]]  # [n_edges, out_dim]
        edge_feat = torch.cat([edge_feat_i, edge_feat_j], dim=1) # [n_edges, out_dim * 2]
        
        if self.classifier is None:
            self.build_classifier(edge_feat.size(1), device)

        logits = self.classifier(edge_feat)
        return logits  # [n_edges, 2]



# 4.Evaluation, auxiliary, and visualization functions(评估、辅助和可视化函数)

In [6]:
def bfs_tree_layers_priority(edges, node_ids, node_data):
    root_layer_nodes = get_root_layer_nodes_for_sheet(node_data, node_ids)
    edge_set = set((min(a, b), max(a, b)) for a, b in edges)
    layers = {nid: 0 for nid in root_layer_nodes}
    q = deque(list(root_layer_nodes))
    used = set(root_layer_nodes)
    while q:
        u = q.popleft()
        for v in node_ids:
            if v not in used and (min(u, v), max(u, v)) in edge_set:
                layers[v] = layers[u] + 1
                used.add(v)
                q.append(v)
    return layers

def edge_evaluate(logits, labels):
    prob_edge = F.softmax(logits, dim=1)[:, 1]
    pred_label = (prob_edge > 0.5).cpu().numpy().astype(np.int64)
    y_true = labels.cpu().numpy()
    acc = accuracy_score(y_true, pred_label)
    f1 = f1_score(y_true, pred_label)
    recall = recall_score(y_true, pred_label)
    precision = precision_score(y_true, pred_label)
    roc_auc = roc_auc_score(y_true, prob_edge.detach().cpu().numpy()) if np.sum(y_true) > 0 and np.sum(1 - y_true) > 0 else 0.0
    loss = F.cross_entropy(logits, labels.long()).item()
    cm = confusion_matrix(y_true, pred_label)
    tn, fp, fn, tp = (cm.ravel().tolist() if cm.size == 4 else (0, 0, 0, 0))
    return {
        'loss': loss, 'acc': acc, 'f1': f1, 'recall': recall, 'precision': precision, 'roc_auc': roc_auc,
        'tn': tn, 'fp': fp, 'fn': fn, 'tp': tp,
        'prob': prob_edge.detach().cpu().numpy(), 'pred_label': pred_label, 'y_true': y_true
    }

def plot_metrics_curve(history, tag, outdir):
    os.makedirs(outdir, exist_ok=True)
    epochs = np.arange(len(history['train']))
    for metric in ['loss', 'f1', 'roc_auc', 'recall', 'acc']:
        plt.figure(figsize=(8, 5))
        for phase in ['train', 'val', 'test']:
            plt.plot(epochs, [h[metric] for h in history[phase]], label=phase)
        plt.xlabel('Epoch')
        plt.ylabel(metric)
        plt.title(f'{tag} {metric} Curve')
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(outdir, f"{tag}_curve_{metric}.png"), dpi=600)
        plt.close()

def infer_sheet_prob_matrix(model, features, edge_index, node_ids, batchsize=2048):
    pairs = []
    for i in node_ids:
        for j in node_ids:
            if i < j:
                pairs.append((i, j))
    pairs = np.array(pairs).T
    with torch.no_grad():
        probs = []
        for start in range(0, pairs.shape[1], batchsize):
            batch_pairs = pairs[:, start:start + batchsize]
            logits = model(features, edge_index, torch.LongTensor(batch_pairs).to(features.device))
            proba = F.softmax(logits, dim=1)[:, 1].cpu().numpy()
            probs.append(proba)
        all_probs = np.concatenate(probs)
    N = len(node_ids)
    mat = np.zeros((N, N))
    idxmap = {nid: idx for idx, nid in enumerate(node_ids)}
    k = 0
    for a, b in zip(pairs[0], pairs[1]):
        ia, ib = idxmap[a], idxmap[b]
        mat[ia, ib] = all_probs[k]
        mat[ib, ia] = all_probs[k]
        k += 1
    np.fill_diagonal(mat, 0)
    return mat, idxmap

def plot_depth_graph(df, edge_list, node_depths, id_col, name_col, save_path, color='#f7941d', title="结构图"):
    nodes = list(df['global_id'].values)
    names = {row['global_id']: str(row[name_col]) for _, row in df.iterrows()}
    depths = [node_depths[n] for n in nodes if n in node_depths]
    max_depth = max(depths) if depths else 1
    depth_layers = {}
    for n in nodes:
        d = node_depths.get(n, 0)
        depth_layers.setdefault(d, []).append(n)
    ygap = 2.7 if len(depth_layers) <= 10 else 1.5
    pos = {}
    for d, layer in depth_layers.items():
        x_gap = 1.1 if len(layer) <= 10 else 0.65
        for i, nid in enumerate(sorted(layer, key=lambda v: names[v])):
            pos[nid] = (i * x_gap, d * ygap)
    plt.figure(figsize=(max(12, 0.55 * len(nodes)), (max_depth + 3) * 1.3), dpi=600)
    for n in nodes:
        x, y = pos[n]
        plt.scatter(x, y, s=1000, color='deepskyblue', edgecolors='k', zorder=20)
        plt.text(x, y, names[n], fontsize=10, ha='center', va='center', color='black', weight='bold', zorder=25)
    for (a, b) in edge_list:
        x1, y1 = pos[a]; x2, y2 = pos[b]
        plt.plot([x1, x2], [y1, y2], c=color, lw=3, alpha=0.76, zorder=10)
    for d in range(max_depth + 1):
        plt.axhline(y=d * ygap, color='gray', linewidth=0.7, linestyle='dashed', alpha=0.17)
        plt.text(-2, d * ygap, f'Depth: {d}', fontsize=13, color='k', weight='bold', va='center')
    plt.axis('off')
    plt.title(title)
    plt.tight_layout()
    plt.savefig(save_path, dpi=600)
    plt.close()

def get_true_layers(node_ids, edges, node_data):
    roots = get_root_layer_nodes_for_sheet(node_data, node_ids)
    adj = defaultdict(list)
    for u, v in edges:
        adj[u].append(v)
        adj[v].append(u)
    depths = {n: 0 for n in roots}
    q = deque(roots)
    visited = set(roots)
    while q:
        u = q.popleft()
        for v in adj[u]:
            if v not in visited:
                depths[v] = depths[u] + 1
                q.append(v)
                visited.add(v)
    max_depth = max(depths.values()) if depths else 0
    layers = [[] for _ in range(max_depth + 1)]
    for n, d in depths.items():
        layers[d].append(n)
    return layers

def pred_structure_deep_layers(prob_mat, node_ids, node_data, layer_sizes, root_nodes):
    N = len(node_ids)
    idxmap = {nid: i for i, nid in enumerate(node_ids)}
    used = set(root_nodes)
    alloc = []
    alloc.append(list(root_nodes))
    left = set(node_ids) - set(root_nodes)
    cur_layer = list(root_nodes)
    depth = 1
    while depth < len(layer_sizes):
        scores = []
        for nid in left:
            s = np.mean([prob_mat[idxmap[nid], idxmap[u]] for u in cur_layer])
            scores.append((s, nid))
        scores.sort(reverse=True)
        need_num = max(1, len(layer_sizes[depth]))
        layer_nodes = [nid for _, nid in scores[:need_num]]
        alloc.append(layer_nodes)
        left -= set(layer_nodes)
        cur_layer = layer_nodes
        depth += 1
    if left:
        alloc[-1].extend(left)
    return alloc

def make_pred_edges_by_layers(prob_mat, alloc_layers, pred_prob_thresh=0.5):
    layers = alloc_layers
    idxmap = {}
    node_ids = []
    for layer in layers:
        node_ids.extend(layer)
    for idx, nid in enumerate(node_ids):
        idxmap[nid] = idx
    pred_edges = []
    for dep in range(1, len(layers)):
        for v in layers[dep]:
            bestu = None
            bestscore = -1
            for u in layers[dep - 1]:
                s = prob_mat[idxmap[u], idxmap[v]]
                if s > bestscore:
                    bestscore = s
                    bestu = u
            if bestu is not None:
                pred_edges.append(tuple(sorted([v, bestu])))
    return pred_edges

def structure_layer_loss(pred_edges, node_ids, node_data, real_layers):
    pred_layers = bfs_tree_layers_priority(pred_edges, node_ids, node_data)
    real_layers_map = {}
    for d, l in enumerate(real_layers):
        for n in l:
            real_layers_map[n] = d
    loss = 0
    for n in node_ids:
        loss += (pred_layers.get(n, 999) - real_layers_map.get(n, 999)) ** 2
    return loss

def depth_sa_optimize(prob_mat, node_ids, node_data, real_layers, pred_edges):
    best_edges = pred_edges[:]
    best_score = structure_layer_loss(best_edges, node_ids, node_data, real_layers)
    cur_edges = pred_edges[:]
    for epoch in range(3000):
        nonroot = []
        for i, layer in enumerate(real_layers):
            if i == 0: continue
            nonroot.extend(layer)
        if not nonroot: break
        tgt = random.choice(nonroot)
        candidate_p = []
        prev_layer = []
        for li, ly in enumerate(real_layers):
            if tgt in ly and li > 0:
                prev_layer = real_layers[li - 1]
                break
        for u in prev_layer:
            if tuple(sorted([u, tgt])) not in cur_edges:
                candidate_p.append(u)
        if candidate_p:
            u_new = random.choice(candidate_p)
            edges_new = [e for e in cur_edges if tgt not in e] + [tuple(sorted([u_new, tgt]))]
        else:
            continue
        score_new = structure_layer_loss(edges_new, node_ids, node_data, real_layers)
        dE = best_score - score_new
        if score_new < best_score or random.random() < np.exp(dE / 2.4):
            cur_edges = edges_new
            if score_new < best_score:
                best_edges = edges_new
                best_score = score_new
    return best_edges

def save_edges(edge_list, nd, outcsv):
    id2name = dict(zip(nd['global_id'], nd['Name']))
    edge_rows = []
    for a, b in edge_list:
        edge_rows.append({'src_id': a, 'src_name': id2name.get(a, ''),'dst_id': b, 'dst_name': id2name.get(b, '')})
    pd.DataFrame(edge_rows).to_csv(outcsv, index=False)



# 5.Data loading and feature processing(数据加载与特征处理)

In [None]:
# Note the folder path
node_data = read_data('nodedata')
edge_data = read_edges('edgedata')
data_info = process_data(node_data, edge_data)
features = data_info['x'].to(device)
edge_index = data_info['edge_index'].to(device)
idx_layer1 = data_info['idx_layer1']
idx_layer2 = data_info['idx_layer2']
node_data_pd = data_info['node_data']
id_map = data_info['id_map']
num_nodes = features.shape[0]
print("节点数:", features.shape[0], "特征数:", features.shape[1])
print("边索引 shape:", edge_index.shape)


节点数: 2722 特征数: 21
边索引 shape: torch.Size([2, 3023])


# 6.Construction and division of positive and negative samples(正负样本构建与划分)

In [8]:
pos_edges = edge_index.cpu().numpy().T
exists = set(tuple(sorted([i, j])) for i, j in pos_edges)
neg_edges = []
while len(neg_edges) < len(pos_edges):
    i, j = np.random.randint(0, num_nodes, size=2)
    if i == j: continue
    if tuple(sorted([i, j])) in exists: continue
    neg_edges.append([i, j])
    exists.add(tuple(sorted([i, j])))
neg_edges = np.array(neg_edges)
all_edges = np.concatenate([pos_edges, neg_edges], axis=0)
labels = np.array([1] * len(pos_edges) + [0] * len(neg_edges))
idx = np.random.permutation(len(labels))
all_edges = all_edges[idx]; labels = labels[idx]
n = len(labels)
n_train = int(0.6 * n); n_val = int(0.2 * n)
train_edges = all_edges[:n_train].T; train_labels = labels[:n_train]
val_edges = all_edges[n_train:n_train + n_val].T; val_labels = labels[n_train:n_train + n_val]
test_edges = all_edges[n_train + n_val:].T; test_labels = labels[n_train + n_val:]
print("Train:", train_edges.shape, train_labels.shape)
print("Val:", val_edges.shape, val_labels.shape)
print("Test:", test_edges.shape, test_labels.shape)


Train: (2, 3627) (3627,)
Val: (2, 1209) (1209,)
Test: (2, 1210) (1210,)


# 7.Training main loop(训练主循环)

In [None]:
import torch
import torch.nn.functional as F
import random
import numpy as np

# Set the random seed to ensure the experiment is reproducible
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = DualMLP_EdgeClassifier(
    input_dim=features.shape[1],
    idx_layer1=idx_layer1,
    idx_layer2=idx_layer2,
    mlp_dim=32,
    hidden_dim=64,
    dropout=0.2,
    heads=2,
    concat=True,
    residual=True
).to(device)

max_epochs = 200
patience = 50    
wait = 0
best_auc = -np.inf
best_model_state = None
history = {'train': [], 'val': [], 'test': []}

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=3e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', patience=5, factor=0.2, verbose=True, min_lr=1e-6
)

print("\n==== Start Model Training with Dynamic LR (ReduceLROnPlateau) ====")
for epoch in range(max_epochs):
    model.train()
    optimizer.zero_grad()
    logits = model(features, edge_index, torch.LongTensor(train_edges).to(device))
    train_labels_tensor = torch.LongTensor(train_labels).to(device)
    loss = F.cross_entropy(logits, train_labels_tensor)
    loss.backward()
    optimizer.step()

    # Evaluate train/val/test sets
    for phase, edge_set, lab_set in zip(
        ['train', 'val', 'test'],
        [train_edges, val_edges, test_edges],
        [train_labels, val_labels, test_labels]
    ):
        logits_eval = model(features, edge_index, torch.LongTensor(edge_set).to(device))
        labtensor_eval = torch.LongTensor(lab_set).to(device)
        res = edge_evaluate(logits_eval, labtensor_eval)
        history[phase].append(res)

    print(
        f"\nEpoch {epoch + 1:03}/{max_epochs} |"
        f" Loss: {history['train'][-1]['loss']:.4f}/{history['val'][-1]['loss']:.4f}/{history['test'][-1]['loss']:.4f}"
        f" | EdgeF1: {history['train'][-1]['f1']:.3f}/{history['val'][-1]['f1']:.3f}/{history['test'][-1]['f1']:.3f}"
        f" | EdgeAP: {history['train'][-1]['acc']:.3f}/{history['val'][-1]['acc']:.3f}/{history['test'][-1]['acc']:.3f}"
        f" | Recall: {history['train'][-1]['recall']:.3f}/{history['val'][-1]['recall']:.3f}/{history['test'][-1]['recall']:.3f}"
        f" | ROC:    {history['train'][-1]['roc_auc']:.3f}/{history['val'][-1]['roc_auc']:.3f}/{history['test'][-1]['roc_auc']:.3f}"
    )

    # Dynamically adjust the learning rate (for val set AUC)
    scheduler.step(history['val'][-1]['roc_auc'])

    if history['val'][-1]['roc_auc'] > best_auc:
        best_auc = history['val'][-1]['roc_auc']
        wait = 0
        best_model_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
    else:
        wait += 1
        if wait >= patience:
            print(f"Early Stopping at epoch {epoch + 1}")
            break


# Restore optimal weights
model.load_state_dict(best_model_state)
print("训练完成，最佳模型加载完毕！")

plot_metrics_curve(history, "DualGAT-MLP", "outputs_DualGAT_MLP")
print("指标曲线已保存。")



==== Start Model Training with Dynamic LR (ReduceLROnPlateau) ====

Epoch 001/200 | Loss: 0.6917/0.6923/0.6915 | EdgeF1: 0.564/0.561/0.584 | EdgeAP: 0.547/0.529/0.573 | Recall: 0.594/0.577/0.599 | ROC:    0.574/0.551/0.590

Epoch 002/200 | Loss: 0.6897/0.6905/0.6897 | EdgeF1: 0.630/0.618/0.642 | EdgeAP: 0.620/0.586/0.624 | Recall: 0.657/0.643/0.673 | ROC:    0.662/0.625/0.666

Epoch 003/200 | Loss: 0.6881/0.6884/0.6888 | EdgeF1: 0.663/0.652/0.644 | EdgeAP: 0.654/0.632/0.629 | Recall: 0.690/0.659/0.670 | ROC:    0.706/0.686/0.680

Epoch 004/200 | Loss: 0.6864/0.6878/0.6863 | EdgeF1: 0.686/0.668/0.679 | EdgeAP: 0.680/0.651/0.669 | Recall: 0.709/0.674/0.700 | ROC:    0.727/0.690/0.735

Epoch 005/200 | Loss: 0.6837/0.6851/0.6843 | EdgeF1: 0.699/0.696/0.696 | EdgeAP: 0.698/0.676/0.689 | Recall: 0.713/0.710/0.711 | ROC:    0.761/0.727/0.745

Epoch 006/200 | Loss: 0.6819/0.6833/0.6829 | EdgeF1: 0.701/0.702/0.705 | EdgeAP: 0.700/0.687/0.691 | Recall: 0.713/0.708/0.738 | ROC:    0.760/0.734/0.

# 8.Sheet结构预测与可视化生成

In [None]:
# When running Jupyter, ensure that the memory is released after each graph is drawn, or reduce the dpi (600→300)
import matplotlib.pyplot as plt  
from tqdm import tqdm            # Used to display a progress bar, optional
import gc                        # Clear Memory

# Make sure the output directory exists
output_dir = 'outputs_GSL_KAN_ADV'
output_dir_depth = os.path.join(output_dir, 'sheet_depthvis')
os.makedirs(output_dir_depth, exist_ok=True)

output_dir_depth_pred = os.path.join(output_dir, 'sheet_depthvis_pred')
os.makedirs(output_dir_depth_pred, exist_ok=True)

output_dir_depth_sa = os.path.join(output_dir, "sheet_depthvis_sa")
os.makedirs(output_dir_depth_sa, exist_ok=True)

os.makedirs('output_result/edge_pred', exist_ok=True)
os.makedirs('output_result/edge_sa', exist_ok=True)

# Make sure to show progress each time through the loop
all_sheets = node_data_pd['source'].unique()

# Traverse all sheets
for sheet_idx, sheet in enumerate(tqdm(all_sheets, desc="Processing Sheets")):
    print(f"[INFO] 开始处理 Sheet: {sheet} ({sheet_idx + 1}/{len(all_sheets)})")

    # Get the nodes and corresponding edges of the current sheet
    nd = node_data_pd[node_data_pd['source'] == sheet]
    node_ids = nd['global_id'].tolist()

    real_sheet_edges = []
    for a in node_ids:
        for b in node_ids:
            if a < b and tuple(sorted((a, b))) in all_real_edges:
                real_sheet_edges.append([a, b])
    
    # Actual levels and depth
    real_layers = get_true_layers(node_ids, real_sheet_edges, nd)
    real_depths = {n: d for d, layer in enumerate(real_layers) for n in layer}

    # 1. Save the real structure diagram
    try:
        plot_depth_graph(
            nd, real_sheet_edges, real_depths,
            id_col='ID', name_col='Name',
            save_path=os.path.join(output_dir_depth, f'{sheet.replace(".", "_").replace("/", "_")}_real.png'),
            color='#f7941d', title=f'{sheet} 实际结构'
        )
        print(f" - [INFO] 真实结构图已保存: {sheet}_real.png")
    except Exception as e:
        print(f"[ERROR] 生成真实结构图时出错: {e}")

    # Inference prediction
    prob_mat, idxmap = infer_sheet_prob_matrix(model, features, edge_index, node_ids)

    # Get the root node
    root_nodes = get_root_layer_nodes_for_sheet(nd, node_ids)

    # Generate prediction levels and depth
    pred_layers = pred_structure_deep_layers(prob_mat, node_ids, nd, real_layers, root_nodes)
    pred_depths = {n: d for d, layer in enumerate(pred_layers) for n in layer}

    # Prediction Edge
    pred_edges = make_pred_edges_by_layers(prob_mat, pred_layers, pred_prob_thresh=0.5)

    # 2. Save the prediction structure diagram
    try:
        plot_depth_graph(
            nd, pred_edges, pred_depths,
            id_col='ID', name_col='Name',
            save_path=os.path.join(output_dir_depth_pred, f'{sheet.replace(".", "_").replace("/", "_")}_pred.png'),
            color='green', title=f'{sheet} 预测深度结构'
        )
        print(f" - [INFO] 预测深度结构图已保存: {sheet}_pred.png")
    except Exception as e:
        print(f"[ERROR] 生成预测深度结构图时出错: {e}")

    # Annealing optimization
    sa_edges = depth_sa_optimize(prob_mat, node_ids, nd, real_layers, pred_edges)
    sa_pred_layers = bfs_tree_layers_priority(sa_edges, node_ids, nd)
    sa_depths = sa_pred_layers.copy()

    # 3. Save the structure diagram after annealing optimization
    try:
        plot_depth_graph(
            nd, sa_edges, sa_depths,
            id_col='ID', name_col='Name',
            save_path=os.path.join(output_dir_depth_sa, f'{sheet.replace(".", "_").replace("/", "_")}_pred_sa.png'),
            color='purple', title=f'{sheet} 退火优化结构'
        )
        print(f" - [INFO] 退火优化深度结构图已保存: {sheet}_pred_sa.png")
    except Exception as e:
        print(f"[ERROR] 生成退火优化深度结构图时出错: {e}")

    # Save edge data to CSV
    try:
        save_edges(pred_edges, nd, f'output_result/edge_pred/{sheet}.csv')
        save_edges(sa_edges, nd, f'output_result/edge_sa/{sheet}.csv')
        print(f" - [INFO] 边数据已保存: {sheet}.csv")
    except Exception as e:
        print(f"[ERROR] 保存边数据时出错: {e}")

    plt.close('all')  
    gc.collect()      

print("[OK] 所有 sheet 的真实/多层预测/分层退火优化结构图与指标曲线已生成")
print("[OK] edge_pred、edge_sa 的 CSV 文件已保存到 output_result")


NameError: name 'node_data_pd' is not defined

# 9.Visualizing graph structure and feature prediction (interactive)

In [None]:
from collections import deque, defaultdict
import plotly.graph_objs as go
import pandas as pd
import os
from glob import glob


def plotly_depth_graph_like_matplotlib_with_roots(
    df, edge_list, node_data,
    id_col='id', name_col='Name',
    feature_names=None,
    save_path="depth_pred_interactive.html",
    node_color='#21cbbb',
    edge_color='gray',
    title='预测结构交互图'
):
    def get_root_layer_nodes_for_sheet(node_data, node_ids):
        df_filtered = node_data[node_data[id_col].isin(node_ids)]
        for tag in ['EN', 'LT', 'DT']:
            filtered = df_filtered[df_filtered[name_col] == tag][id_col].tolist()
            if filtered:
                return filtered
        return [node_ids[0]]

    def bfs_tree_layers_priority(edges, node_ids, root_nodes):
        edge_set = set((min(a, b), max(a, b)) for a, b in edges)
        layers = {nid: 0 for nid in root_nodes}
        queue = deque(root_nodes)
        visited = set(root_nodes)

        while queue:
            current = queue.popleft()
            for neighbor in node_ids:
                if neighbor not in visited and (min(current, neighbor), max(current, neighbor)) in edge_set:
                    layers[neighbor] = layers[current] + 1
                    visited.add(neighbor)
                    queue.append(neighbor)

        return layers

    node_ids = list(df[id_col].values)
    root_nodes = get_root_layer_nodes_for_sheet(node_data, node_ids)
    node_depths = bfs_tree_layers_priority(edge_list, node_ids, root_nodes)

    names = dict(zip(df[id_col], df[name_col]))
    depth_layers = defaultdict(list)
    for node, depth in node_depths.items():
        depth_layers[depth].append(node)

    max_x_gap = 1.1
    max_layer_width = max(len(nodes) for nodes in depth_layers.values()) * max_x_gap

    y_gap = 2.7 if len(depth_layers) <= 10 else 1.5
    positions = {}
    for depth, layer_nodes in depth_layers.items():
        for i, node_id in enumerate(sorted(layer_nodes, key=lambda x: names[x])):
            positions[node_id] = (i * max_x_gap, depth * y_gap)

    # Create node hover information
    def format_float_safe(value):
        try:
            if pd.isna(value):
                return ''
            return f"{float(value):.2f}"
        except:
            return str(value)

    hover_texts = {}
    ZERO_EPS = 1e-6
    def safe_parse(x):
        try:
            return float(x)
        except Exception:
            return np.nan   # or None

    for _, row in df.iterrows():
        node_id = row[id_col]

        column_labels = ["Feature", "Real", "Pred", "Error"]   
        actual_values = [format_float_safe(row[f]) for f in feature_names]
        predicted_values = [format_float_safe(row[f"pred_{f}"]) for f in feature_names]

        error_values = []
        for f in feature_names:
            v_true = safe_parse(row[f])
            v_pred = safe_parse(row[f"pred_{f}"])
            if pd.isna(v_true) or pd.isna(v_pred):
                error_values.append('')
            elif abs(v_true) < ZERO_EPS:
                error_values.append(f"{abs(v_pred - v_true):.4f}*")
            else:
                error_values.append(f"{abs((v_pred - v_true) / v_true * 100):.1f}%")

        data_rows = list(zip(feature_names, actual_values, predicted_values, error_values))
        table_cols = list(zip(*([column_labels] + data_rows)))  
        col_widths = [max(len(str(item)) for item in col) for col in table_cols]

        def format_row(row):
            return " | ".join(str(val).ljust(width) for val, width in zip(row, col_widths))

        table = "<br>".join(format_row(row) for row in [column_labels] + data_rows)
        n_name = str(row[name_col])
        hover_texts[node_id] = f"<b>{n_name}</b><br><span style='font-family:monospace;white-space:pre'>{table}</span>"

    edge_traces = []
    for src, dst in edge_list:
        if src in positions and dst in positions:
            x0, y0 = positions[src]
            x1, y1 = positions[dst]
            edge_traces.append(go.Scatter(
                x=[x0, x1], y=[y0, y1],
                mode='lines',
                line=dict(color=edge_color, width=2),
                hoverinfo='skip',
                showlegend=False
            ))

    node_traces = []
    for node_id, (x, y) in positions.items():
        node_label = names[node_id]
        hover_info = hover_texts.get(node_id, node_label)
        node_traces.append(go.Scatter(
            x=[x], y=[y],
            mode='markers+text',
            text=node_label,
            textposition='top center',
            hovertext=hover_info,
            marker=dict(size=35, color=node_color, line=dict(width=2, color='#0c515b')),
            showlegend=False
        ))

    depth_lines = []
    for depth in range(len(depth_layers)):
        y = depth * y_gap
        x_min = -1
        x_max = max_layer_width + 1
        depth_lines.append(go.Scatter(
            x=[x_min, x_max], y=[y, y],
            mode='lines',
            line=dict(color='lightgray', dash='dot'),
            hoverinfo='skip',
            showlegend=False
        ))
        depth_lines.append(go.Scatter(
            x=[x_min - 1], y=[y],
            mode='text',
            text=[f"Depth: {depth}"],
            textfont=dict(size=12, color='gray'),
            hoverinfo='skip',
            showlegend=False
        ))

    fig = go.Figure(data=depth_lines + edge_traces + node_traces)
    fig.update_layout(
        title=title,
        xaxis=dict(visible=False, zeroline=False, showgrid=False),
        yaxis=dict(visible=False, zeroline=False, showgrid=False),
        plot_bgcolor='white',
        margin=dict(l=20, r=20, t=50, b=20)
    )

    fig.write_html(save_path)
    print(f"[OK] 绘图已保存为 {save_path}")

# The main code traverses the node and edge data and draws
node_dir = "output_result/node_pred"  # Node file path
edge_dir = "output_result/edge_pred"  # Edge file path
save_dir = "output_result/vis_html_pred"  # Save the HTML file path of the drawing

os.makedirs(save_dir, exist_ok=True)

node_files = sorted(glob(os.path.join(node_dir, "*.csv")))  # Node file
edge_files = sorted(glob(os.path.join(edge_dir, "*.csv")))  # Edge file

# Make sure the number of files matches
if len(node_files) != len(edge_files):
    raise ValueError(f"[Error] 节点文件数 ({len(node_files)}) 和边文件数 ({len(edge_files)}) 不匹配！")

for node_file, edge_file in zip(node_files, edge_files):
    sheet_name = os.path.splitext(os.path.basename(node_file))[0]  # The name of the current sheet
    node_data = pd.read_csv(node_file)
    edge_data = pd.read_csv(edge_file)

    if not set(['src_id', 'dst_id']).issubset(edge_data.columns):
        raise ValueError(f"[Error] 边文件 {edge_file} 缺少必要列 'src_id', 'dst_id'")
    edge_list = list(zip(edge_data['src_id'], edge_data['dst_id']))

    # Automatically filter all feature names
    feature_names = [col for col in node_data.columns if col in [
        'ITG', 'BTW', 'CTR', 'ETR', 'S', 'L', 'WA', 'WFR', 'WWR', 'DN',
        'WN', 'DS', 'WS', 'H', 'V', 'a', 'b', 'OH', 'IH', 'CN', 'CA'
    ]]
    
    save_path = os.path.join(save_dir, f"{sheet_name}_pred_plotly.html")
    plotly_depth_graph_like_matplotlib_with_roots(
        df=node_data,
        edge_list=edge_list,
        node_data=node_data,
        id_col='id',         
        name_col='name',
        feature_names=feature_names,
        save_path=save_path,
        title=f'{sheet_name} 图结构及特征预测'
    )

print(f"[OK] 所有绘图已生成并保存在：{save_dir}")


[OK] 绘图已保存为 output_result/vis_html_pred\化工学院.xlsx_1f_pred_plotly.html
[OK] 绘图已保存为 output_result/vis_html_pred\化工学院.xlsx_2f_pred_plotly.html
[OK] 绘图已保存为 output_result/vis_html_pred\化工学院.xlsx_3f_pred_plotly.html
[OK] 绘图已保存为 output_result/vis_html_pred\化工学院.xlsx_4f_pred_plotly.html
[OK] 绘图已保存为 output_result/vis_html_pred\化工学院.xlsx_5f_pred_plotly.html
[OK] 绘图已保存为 output_result/vis_html_pred\土木学院.xlsx_1f_pred_plotly.html
[OK] 绘图已保存为 output_result/vis_html_pred\土木学院.xlsx_2f_pred_plotly.html
[OK] 绘图已保存为 output_result/vis_html_pred\土木学院.xlsx_3f_pred_plotly.html
[OK] 绘图已保存为 output_result/vis_html_pred\土木学院.xlsx_4f_pred_plotly.html
[OK] 绘图已保存为 output_result/vis_html_pred\土木学院.xlsx_5f_pred_plotly.html
[OK] 绘图已保存为 output_result/vis_html_pred\建筑与设计学院.xlsx_-1f_pred_plotly.html
[OK] 绘图已保存为 output_result/vis_html_pred\建筑与设计学院.xlsx_1f_pred_plotly.html
[OK] 绘图已保存为 output_result/vis_html_pred\建筑与设计学院.xlsx_2f_pred_plotly.html
[OK] 绘图已保存为 output_result/vis_html_pred\建筑与设计学院.xlsx_3f_pred_plotly.html
[OK] 绘图

# 10.Visualizing graph structure and feature prediction （SA）

In [None]:
from collections import deque, defaultdict
import plotly.graph_objs as go
import pandas as pd
import os
import numpy as np
from glob import glob

def get_layer_groups(edges, node_ids, root_nodes):
    edge_set = set((min(a, b), max(a, b)) for a, b in edges)
    layers = {nid: 0 for nid in root_nodes}
    queue = deque(root_nodes)
    visited = set(root_nodes)
    while queue:
        current = queue.popleft()
        for neighbor in node_ids:
            if neighbor not in visited and (min(current, neighbor), max(current, neighbor)) in edge_set:
                layers[neighbor] = layers[current] + 1
                visited.add(neighbor)
                queue.append(neighbor)
    layer_to_nodes = defaultdict(list)
    for nid, d in layers.items():
        layer_to_nodes[d].append(nid)
    return layers, layer_to_nodes

# Calculate the overlap ratio of two set edges (the larger the better)
def overlap_ratio(e_new, e_orig):
    if not e_new:
        return 0
    intersect = len(e_new.intersection(e_orig))
    return intersect / len(e_new)

def sa_modify_edges(edges_cur, pre_nodes, nxt_nodes, orig_edges_set):
    edges_list = list(edges_cur)
    if np.random.rand() < 0.5 and edges_list:
        idx = np.random.randint(len(edges_list))
        u, v = edges_list[idx]
        count_u = sum(1 for e in edges_cur if u in e)
        count_v = sum(1 for e in edges_cur if v in e)
        if (count_u > 1) and (count_v > 1):
            new_edges = edges_cur.copy()
            new_edges.remove((u, v))
            return new_edges

    candidate_edges = []
    for u in pre_nodes:
        for v in nxt_nodes:
            e = (min(u, v), max(u, v))
            if e not in edges_cur:
                candidate_edges.append(e)
    if candidate_edges:
        new_edge = candidate_edges[np.random.randint(len(candidate_edges))]
        new_edges = edges_cur.copy()
        new_edges.add(new_edge)
        return new_edges
    return edges_cur

# SA core optimization, optimizing the overlap and connectivity between the two layers
def sa_optimize_layer_edges(pre_nodes, nxt_nodes, orig_edges_set, max_iter=3000, init_temp=1.0, final_temp=0.01):
    edges_cur = set(e for e in orig_edges_set if (e[0] in pre_nodes and e[1] in nxt_nodes) or (e[1] in pre_nodes and e[0] in nxt_nodes))
    def has_isolated(edges, pre_n, nxt_n):
        linked_pre = set()
        linked_nxt = set()
        for u,v in edges:
            if u in pre_n:
                linked_pre.add(u)
            if v in pre_n:
                linked_pre.add(v)
            if u in nxt_n:
                linked_nxt.add(u)
            if v in nxt_n:
                linked_nxt.add(v)
        isolated_pre = set(pre_n) - linked_pre
        isolated_nxt = set(nxt_n) - linked_nxt
        return isolated_pre, isolated_nxt
    isolated_pre, isolated_nxt = has_isolated(edges_cur, pre_nodes, nxt_nodes)
    for u in isolated_pre:
        edges_cur.add((min(u, nxt_nodes[0]), max(u, nxt_nodes[0])))
    for v in isolated_nxt:
        edges_cur.add((min(pre_nodes[0], v), max(pre_nodes[0], v)))
    temp = init_temp
    alpha = (final_temp/init_temp)**(1/(max_iter-1))
    best_edges = edges_cur.copy()
    best_score = overlap_ratio(edges_cur, orig_edges_set)
    for i in range(max_iter):
        isolated_pre, isolated_nxt = has_isolated(edges_new, pre_nodes, nxt_nodes)
        if isolated_pre or isolated_nxt:
            continue
        score_new = overlap_ratio(edges_new, orig_edges_set)
        score_diff = score_new - overlap_ratio(edges_cur, orig_edges_set)
        accept = False
        if score_diff >= 0:
            accept = True
        else:
            p = np.exp(score_diff / temp)
            if np.random.rand() < p:
                accept = True
        if accept:
            edges_cur = edges_new
            if score_new > best_score:
                best_score = score_new
                best_edges = edges_new
        temp *= alpha
    return best_edges


def full_depth_graph_with_sa_edges(
    df, orig_edge_list, node_data, id_col='id', name_col='Name',
    feature_names=None, save_path="depth_pred_interactive_sa.html",
    node_color='#21cbbb', edge_color='gray', title='预测结构交互图',
    rand_seed=42, sa_iter=3000
):
    np.random.seed(rand_seed)

    def get_root_layer_nodes_for_sheet(node_data, node_ids):
        for tag in ['EN', 'DT', 'LT']:
            filtered = node_data[
                node_data[id_col].isin(node_ids) & node_data[name_col].str.contains(tag, na=False)
            ][id_col].tolist()
            if filtered:
                return filtered
        return [node_ids[0]]

    node_ids = list(df[id_col].values)
    root_nodes = get_root_layer_nodes_for_sheet(node_data, node_ids)
    layers, layer_to_nodes = get_layer_groups(orig_edge_list, node_ids, root_nodes)
    orig_edges_set = set((min(u,v), max(u,v)) for u,v in orig_edge_list)

    new_edges_set = set()
    depth_list = sorted(layer_to_nodes.keys())
    # For each pair of adjacent layers, use SA to optimize the edge set
    for d, dn in zip(depth_list[:-1], depth_list[1:]):
        pre_n = layer_to_nodes[d]
        nxt_n = layer_to_nodes[dn]
        if pre_n and nxt_n:
            es = sa_optimize_layer_edges(pre_n, nxt_n, orig_edges_set, max_iter=sa_iter)
            new_edges_set |= es

    # Add non-adjacent layer original edges to ensure the overall structure
    for u, v in orig_edges_set:
        depth_u = layers.get(u, -1)
        depth_v = layers.get(v, -1)
        if abs(depth_u - depth_v) != 1:
            new_edges_set.add((min(u,v), max(u,v)))

    linked_nodes = set()
    for u,v in new_edges_set:
        linked_nodes.add(u); linked_nodes.add(v)
    for nid in node_ids:
        if nid not in linked_nodes and nid != root_nodes[0]:
            new_edges_set.add((min(nid, node_ids[0]), max(nid, node_ids[0])))

    # Node location
    max_x_gap = 1.1
    y_gap = 2.7 if len(layer_to_nodes) <= 10 else 1.5
    max_layer_width = max(len(nodes) for nodes in layer_to_nodes.values()) * max_x_gap
    positions = {}
    names = dict(zip(df[id_col], df[name_col]))
    for depth in sorted(layer_to_nodes.keys()):
        layer_nodes = sorted(layer_to_nodes[depth], key=lambda x: (str(names.get(x, '')), x))
        for i, node_id in enumerate(layer_nodes):
            positions[node_id] = (i * max_x_gap, depth * y_gap)
    # hover TXT
    def format_float_safe(value):
        try:
            if pd.isna(value): return ''
            return f"{float(value):.2f}"
        except:
            return str(value)
    hover_texts = {}
    ZERO_EPS = 1e-6
    def safe_parse(x):
        try: return float(x)
        except: return np.nan
    for _, row in df.iterrows():
        node_id = row[id_col]
        column_labels = ["Feature", "Real", "Pred", "Error"]
        actual_values = [format_float_safe(row[f]) for f in feature_names]
        predicted_values = [format_float_safe(row[f"pred_{f}"]) for f in feature_names]
        error_values = []
        for f in feature_names:
            v_true = safe_parse(row[f])
            v_pred = safe_parse(row[f"pred_{f}"])
            if pd.isna(v_true) or pd.isna(v_pred): error_values.append('')
            elif abs(v_true) < ZERO_EPS: error_values.append(f"{abs(v_pred - v_true):.4f}*")
            else: error_values.append(f"{abs((v_pred - v_true) / v_true * 100):.1f}%")
        data_rows = list(zip(feature_names, actual_values, predicted_values, error_values))
        table_cols = list(zip(*([column_labels] + data_rows)))
        col_widths = [max(len(str(item)) for item in col) for col in table_cols]
        def format_row(row):
            return " | ".join(str(val).ljust(width) for val, width in zip(row, col_widths))
        table = "<br>".join(format_row(row) for row in [column_labels] + data_rows)
        n_name = str(row[name_col])
        hover_texts[node_id] = f"<b>{n_name}</b><br><span style='font-family:monospace;white-space:pre'>{table}</span>"
    # Draw Edges
    edge_traces = []
    for u,v in new_edges_set:
        if u in positions and v in positions:
            x0,y0 = positions[u]
            x1,y1 = positions[v]
            edge_traces.append(go.Scatter(
                x=[x0,x1], y=[y0,y1], mode='lines',
                line=dict(color=edge_color, width=2),
                hoverinfo='skip', showlegend=False))
    # Drawing Nodes
    node_traces = []
    for node_id,(x,y) in positions.items():
        node_label = names[node_id]
        hover_info = hover_texts.get(node_id, node_label)
        node_traces.append(go.Scatter(
            x=[x], y=[y],mode='markers+text',
            text=node_label, textposition='top center',
            hovertext=hover_info,
            marker=dict(size=35, color=node_color, line=dict(width=2, color='#0c515b')),
            showlegend=False))
    # Horizontal depth mark
    depth_lines = []
    for depth in sorted(layer_to_nodes.keys()):
        y = depth * y_gap
        x_min = -1
        x_max = max_layer_width + 1
        depth_lines.append(go.Scatter(
            x=[x_min, x_max], y=[y,y], mode='lines',
            line=dict(color='lightgray', dash='dot'),
            hoverinfo='skip', showlegend=False
        ))
        depth_lines.append(go.Scatter(
            x=[x_min-1], y=[y], mode='text', 
            text=[f"Depth: {depth}"], 
            textfont=dict(size=12, color='gray'),
            hoverinfo='skip', showlegend=False
        ))
    fig = go.Figure(data=depth_lines + edge_traces + node_traces)
    fig.update_layout(
        title=title,
        xaxis=dict(visible=False, zeroline=False, showgrid=False),
        yaxis=dict(visible=False, zeroline=False, showgrid=False),
        plot_bgcolor='white', margin=dict(l=20, r=20, t=50, b=20)
    )
    fig.write_html(save_path)
    print(f"[SA] 绘图已保存 {save_path}")

# ===== Main program =====
node_dir = "output_result/node_pred"
edge_dir = "output_result/edge_pred"
save_dir = "output_result/vis_html_pred_sa"
os.makedirs(save_dir, exist_ok=True)
node_files = sorted(glob(os.path.join(node_dir, "*.csv")))
edge_files = sorted(glob(os.path.join(edge_dir, "*.csv")))

if len(node_files) != len(edge_files):
    raise ValueError("节点文件和边文件数量不匹配！")

for node_file, edge_file in zip(node_files, edge_files):
    sheet_name = os.path.splitext(os.path.basename(node_file))[0]
    node_data = pd.read_csv(node_file)
    edge_data = pd.read_csv(edge_file)
    if not set(['src_id', 'dst_id']).issubset(edge_data.columns):
        raise ValueError(f"边文件{edge_file}缺少'src_id', 'dst_id'列！")

    orig_edges = list(zip(edge_data['src_id'], edge_data['dst_id']))
    feature_names = [col for col in node_data.columns if col in [
        'ITG', 'BTW', 'CTR', 'ETR', 'S', 'L', 'WA', 'WFR', 'WWR', 'DN',
        'WN', 'DS', 'WS', 'H', 'V', 'a', 'b', 'OH', 'IH', 'CN', 'CA'
    ]]

    save_path = os.path.join(save_dir, f"{sheet_name}_sa_plotly.html")
    full_depth_graph_with_sa_edges(
        df=node_data,
        orig_edge_list=orig_edges,
        node_data=node_data,
        id_col='id',
        name_col='name',
        feature_names=feature_names,
        save_path=save_path,
        title=f"{sheet_name} SA 优化结构预测图",
        rand_seed=42,
        sa_iter=3000 # Adjustable number of simulated annealing iterations
    )
print(f"[SA] 所有绘图已生成保存在：{save_dir}")


[SA] 绘图已保存 output_result/vis_html_pred_sa\化工学院.xlsx_1f_sa_plotly.html
[SA] 绘图已保存 output_result/vis_html_pred_sa\化工学院.xlsx_2f_sa_plotly.html
[SA] 绘图已保存 output_result/vis_html_pred_sa\化工学院.xlsx_3f_sa_plotly.html
[SA] 绘图已保存 output_result/vis_html_pred_sa\化工学院.xlsx_4f_sa_plotly.html
[SA] 绘图已保存 output_result/vis_html_pred_sa\化工学院.xlsx_5f_sa_plotly.html
[SA] 绘图已保存 output_result/vis_html_pred_sa\土木学院.xlsx_1f_sa_plotly.html
[SA] 绘图已保存 output_result/vis_html_pred_sa\土木学院.xlsx_2f_sa_plotly.html
[SA] 绘图已保存 output_result/vis_html_pred_sa\土木学院.xlsx_3f_sa_plotly.html
[SA] 绘图已保存 output_result/vis_html_pred_sa\土木学院.xlsx_4f_sa_plotly.html
[SA] 绘图已保存 output_result/vis_html_pred_sa\土木学院.xlsx_5f_sa_plotly.html
[SA] 绘图已保存 output_result/vis_html_pred_sa\建筑与设计学院.xlsx_-1f_sa_plotly.html
[SA] 绘图已保存 output_result/vis_html_pred_sa\建筑与设计学院.xlsx_1f_sa_plotly.html
[SA] 绘图已保存 output_result/vis_html_pred_sa\建筑与设计学院.xlsx_2f_sa_plotly.html
[SA] 绘图已保存 output_result/vis_html_pred_sa\建筑与设计学院.xlsx_3f_sa_plotly.html
[SA] 绘图