In [5]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from collections import defaultdict
from tqdm import tqdm, trange
from torch_scatter import scatter_add
from torch_scatter import scatter_mean

class GraphEdgeBundler:
    def __init__(self, node_csv, edge_csv, layout_csv, num_control_points=5, device=None):
        """
        node_csv: path to node.csv
        edge_csv: path to edge.csv
        layout_csv: path to layout.csv (no header, each row is [x, y])
        num_control_points: number of points per edge (including endpoints)
        device: torch device
        """
        # 1. 读取数据
        self.node_df = pd.read_csv(node_csv)
        self.edge_df = pd.read_csv(edge_csv)
        self.layout = np.loadtxt(layout_csv, delimiter=',')
        self.num_nodes = len(self.node_df)
        self.dim = self.layout.shape[1]
        self.num_edges = len(self.edge_df)
        self.num_control_points = num_control_points
        # self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.device = torch.device('cpu')

        # 2. 基于edge.csv的source/target名字获取edge_list
        # 建立名字到index的映射
        name_to_idx = {name: idx for idx, name in enumerate(self.node_df['name'])}
        edge_list = [
            (name_to_idx[src], name_to_idx[tgt])
            for src, tgt in zip(self.edge_df['source'], self.edge_df['target'])
        ]
        self.edge_list = edge_list

        # 3. 节点坐标
        self.node_positions = np.asarray(self.layout)
        print("layout shape:", self.node_positions.shape)

        # 4. 边权重（如无weight列则全1）
        if 'weight' in self.edge_df.columns:
            weights = self.edge_df['weight'].to_numpy(dtype=np.float32)
        else:
            weights = np.ones(self.num_edges, dtype=np.float32)
        self.edge_weight = torch.tensor(weights, dtype=torch.float32, device=self.device)

        # 5. 初始化每条边的控制点（直线插值）
        self.init_edge_control_points()
        # 6. 构建可训练参数
        self.make_control_points_trainable()

    def init_edge_control_points(self):
        """
        For each edge, initialize control points by linear interpolation between endpoints.
        """
        control_points = []
        for src, dst in tqdm(self.edge_list, desc="Initializing edge control points"):
            p0 = self.node_positions[src]
            p1 = self.node_positions[dst]
            # Linear interpolation: shape (num_control_points, dim)
            points = np.linspace(p0, p1, self.num_control_points)
            control_points.append(points)
        # Shape: (num_edges, num_control_points, dim)
        self.edge_control_points = np.stack(control_points, axis=0)

    def make_control_points_trainable(self):
        """
        Make the internal control points (excluding endpoints) trainable torch parameters.
        Endpoints always anchored at node positions.
        """
        internal_points = self.edge_control_points[:, 1:-1, :]
        self.trainable_ctrl_pts = torch.nn.Parameter(
            torch.tensor(internal_points, dtype=torch.float32, device=self.device)
        )
        # Save for reference
        self.fixed_start_pts = torch.tensor(
            self.edge_control_points[:, 0, :], dtype=torch.float32, device=self.device
        )
        self.fixed_end_pts = torch.tensor(
            self.edge_control_points[:, -1, :], dtype=torch.float32, device=self.device
        )

    def get_full_control_points(self):
        """
        Return tensor of shape (num_edges, num_control_points, dim)
        with fixed endpoints and current trainable control points.
        """
        # Concatenate start, trainable, end
        return torch.cat([
            self.fixed_start_pts[:, None, :],
            self.trainable_ctrl_pts,
            self.fixed_end_pts[:, None, :]
        ], dim=1)

    def compute_edge_pairs(self, verbose=True):
        """
        高效生成所有共节点边对的edge_pairs，仅同序号的中间控制点互吸引
        Returns: list of (e1, seg, e2, seg, weight)
        """
        node2edges = defaultdict(list)
        for eid, (src, dst) in enumerate(self.edge_list):
            node2edges[src].append(eid)
            node2edges[dst].append(eid)

        pairs = []
        nodes_iter = tqdm(node2edges.items(), desc="Computing edge_pairs", disable=not verbose)
        for node, edges in nodes_iter:
            m = len(edges)
            for i in range(m):
                for j in range(i+1, m):
                    e1 = edges[i]
                    e2 = edges[j]
                    for seg in range(1, self.num_control_points-1):
                        pairs.append((e1, seg, e2, seg, 1.0))
        return pairs

    def build_edge2edge_index(self, verbose=True):
        """
        Build edge-edge adjacency: two edges are neighbors if they share a node.
        Returns (row, col) for message passing: edge idx pairs.
        """
        node2edges = defaultdict(list)
        for eid, (src, tgt) in enumerate(self.edge_list):
            node2edges[src].append(eid)
            node2edges[tgt].append(eid)

        row = []
        col = []
        node_items = list(node2edges.items())
        nodes_iter = tqdm(node_items, desc="Building edge2edge index", disable=not verbose)
        for node, edges in nodes_iter:
            m = len(edges)
            if m < 2:
                continue
            for i in range(m):
                for j in range(m):
                    if i != j:
                        row.append(edges[i])
                        col.append(edges[j])

        if row:
            row = torch.tensor(row, dtype=torch.long, device=self.device)
            col = torch.tensor(col, dtype=torch.long, device=self.device)
            return (row, col)
        else:
            return (torch.empty(0, dtype=torch.long, device=self.device), 
                    torch.empty(0, dtype=torch.long, device=self.device))

    @staticmethod
    def bundling_loss(control_points, edge_pairs, edge_weight, bundling_weight=1.0):
        """
        吸引力损失：让空间接近、方向相似的边段聚合
        control_points: (num_edges, num_points, dim)
        edge_pairs: list of (edge_idx1, seg_idx1, edge_idx2, seg_idx2, weight)
        edge_weight: (num_edges,) tensor
        """
        loss = 0.0
        for (e1, s1, e2, s2, w) in edge_pairs:
            p1 = control_points[e1, s1]  # (dim,)
            p2 = control_points[e2, s2]  # (dim,)
            # 捆绑loss加权，权重用边权的乘积（或其它方式）
            pair_weight = w * edge_weight[e1] * edge_weight[e2]
            loss = loss + pair_weight * ((p1 - p2)**2).sum()
        return bundling_weight * loss

    @staticmethod
    def repulsion_loss(control_points, edge_weight, repulsion_weight=1.0, min_dist=1e-2):
        """
        排斥力损失：防止不同边段重叠
        control_points: (num_edges, num_points, dim)
        edge_weight: (num_edges,) tensor
        """
        num_edges, num_points, dim = control_points.shape
        all_points = control_points.view(-1, dim)  # (num_edges * num_points, dim)
        # 每个点的边权
        point_weights = edge_weight.repeat_interleave(num_points)
        diff = all_points.unsqueeze(0) - all_points.unsqueeze(1)  # (N, N, dim)
        dist2 = (diff ** 2).sum(-1) + min_dist  # (N, N)
        mask = ~torch.eye(all_points.shape[0], dtype=torch.bool, device=control_points.device)
        # 排斥loss加权，点对权重为两点对应的边权乘积
        weight_mat = point_weights[:, None] * point_weights[None, :]  # (N, N)
        rep = weight_mat[mask] / dist2[mask]
        loss = rep.sum()
        return repulsion_weight * loss

    @staticmethod
    def smoothness_loss(control_points, edge_weight, smoothness_weight=1.0):
        """
        平滑损失：鼓励每条边的控制点序列平滑
        control_points: (num_edges, num_points, dim)
        edge_weight: (num_edges,) tensor
        """
        loss = 0.0
        for i, edge_points in enumerate(control_points):
            diff2 = edge_points[2:] - 2*edge_points[1:-1] + edge_points[:-2]
            # 平滑loss加权
            loss = loss + edge_weight[i] * (diff2**2).sum()
        return smoothness_weight * loss

    @staticmethod
    def total_loss(
        control_points,
        edge_pairs,
        edge_weight,
        weights=dict(bundling=1.0, repulsion=1.0, smoothness=0.1),
        min_dist=1e-2,
    ):
        l_bundle = GraphEdgeBundler.bundling_loss(control_points, edge_pairs, edge_weight, bundling_weight=weights['bundling'])
        l_rep = GraphEdgeBundler.repulsion_loss(control_points, edge_weight, repulsion_weight=weights['repulsion'], min_dist=min_dist)
        l_smooth = GraphEdgeBundler.smoothness_loss(control_points, edge_weight, smoothness_weight=weights['smoothness'])
        return l_bundle + l_rep + l_smooth


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter_add

class EdgeGNNLayer(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.mlp_self = nn.Sequential(
            nn.Linear(in_dim, out_dim, bias=False),
            nn.ReLU(),
            nn.Linear(out_dim, out_dim, bias=False)
        )
        self.mlp_neighbor = nn.Sequential(
            nn.Linear(in_dim, out_dim, bias=False),
            nn.ReLU(),
            nn.Linear(out_dim, out_dim, bias=False)
        )

    def forward(self, edge_feats, edge2edge_index, edge_weight):
        # edge_feats: (num_edges, num_ctrl_pts-2, dim)
        # edge_weight: (num_edges,)
        batch, nctrl, dim = edge_feats.shape
        edge_feats_flat = edge_feats.reshape(batch, -1)  # (num_edges, nctrl*dim)

        # Message passing
        row, col = edge2edge_index
        self_msgs = self.mlp_self(edge_feats_flat)
        neighbor_msgs = self.mlp_neighbor(edge_feats_flat)

        neighbor_weights = edge_weight[row]
        weighted_msgs = neighbor_msgs[row] * neighbor_weights[:, None]
        agg_msgs = scatter_add(weighted_msgs, col, dim=0, dim_size=batch)
        agg_weights = scatter_add(neighbor_weights, col, dim=0, dim_size=batch).clamp(min=1e-6)
        agg_msgs = agg_msgs / agg_weights[:, None]

        out = torch.cat([self_msgs, agg_msgs], dim=1)
        out = F.relu(out)
        return out

class EdgeGNN(nn.Module):
    def __init__(self, num_layers, dim, ctrl_dim, hidden_dim, use_weight_as_feat=True):
        super().__init__()
        self.use_weight_as_feat = use_weight_as_feat
        self.dim = dim
        self.ctrl_dim = ctrl_dim
        self.layers = nn.ModuleList()
        in_dim = (ctrl_dim-2)*dim + (1 if use_weight_as_feat else 0)
        for i in range(num_layers):
            self.layers.append(EdgeGNNLayer(in_dim, hidden_dim))
            in_dim = hidden_dim * 2
        self.out_proj = nn.Linear(hidden_dim, (ctrl_dim-2)*dim)

    def forward(self, edge_feats, edge2edge_index, edge_weight, steps=1):
        x = edge_feats
        batch, nctrl, dim = x.shape
        x = x.reshape(batch, -1)
        if self.use_weight_as_feat:
            x = torch.cat([x, edge_weight[:, None]], dim=1)
        for layer in self.layers:
            x = layer(x.reshape(batch, nctrl, dim), edge2edge_index, edge_weight)
        out = self.out_proj(x)
        out = out.reshape(batch, nctrl, dim)
        return out

In [None]:
class EdgeBundlingGNNTrainer:
    def __init__(
        self,
        bundler: GraphEdgeBundler,
        weights=dict(bundling=1.0, repulsion=1.0, smoothness=0.1),
        gnn_steps=200, fine_steps=1000, lr_gnn=1e-3, lr_fine=1e-2, min_dist=1e-2, print_every=50,
    ):
        self.bundler = bundler
        self.weights = weights
        self.gnn_steps = gnn_steps
        self.fine_steps = fine_steps
        self.lr_gnn = lr_gnn
        self.lr_fine = lr_fine
        self.min_dist = min_dist
        self.print_every = print_every
        self.device = bundler.device

        self.ctrl_dim = bundler.num_control_points
        self.hidden_dim = 64
        self.num_layers = 2

        self.edge_pairs = self.bundler.compute_edge_pairs()
        self.edge2edge_index = self.bundler.build_edge2edge_index()

        self.edge_gnn = EdgeGNN(
            num_layers=self.num_layers, 
            dim=self.bundler.dim,
            ctrl_dim=self.ctrl_dim,
            hidden_dim=self.hidden_dim
        ).to(self.device)

    def train(self):
        optimizer = torch.optim.Adam(list(self.edge_gnn.parameters()) + [self.bundler.trainable_ctrl_pts], lr=self.lr_gnn)
        for step in trange(self.gnn_steps, desc="EdgeGNN"):
            optimizer.zero_grad()
            x = self.bundler.trainable_ctrl_pts  # (num_edges, num_ctrl_pts-2, dim)
            gnn_out = self.edge_gnn(x, self.edge2edge_index, self.bundler.edge_weight)
            control_points = torch.cat([
                self.bundler.fixed_start_pts[:, None, :],
                gnn_out,
                self.bundler.fixed_end_pts[:, None, :]
            ], dim=1)
            loss = GraphEdgeBundler.total_loss(
                control_points, self.edge_pairs, self.bundler.edge_weight,
                weights=self.weights, min_dist=self.min_dist
            )
            loss.backward()
            optimizer.step()
            if step % self.print_every == 0 or step == self.gnn_steps-1:
                print(f"GNN Step {step}, loss={loss.item():.6f}")

        optimizer_fine = torch.optim.Adam([self.bundler.trainable_ctrl_pts], lr=self.lr_fine)
        for step in trange(self.fine_steps, desc="Fine-tune"):
            optimizer_fine.zero_grad()
            control_points = self.bundler.get_full_control_points()
            loss = GraphEdgeBundler.total_loss(
                control_points, self.edge_pairs, self.bundler.edge_weight,
                weights=self.weights, min_dist=self.min_dist
            )
            loss.backward()
            optimizer_fine.step()
            if step % self.print_every == 0 or step == self.fine_steps-1:
                print(f"Fine-tune Step {step}, loss={loss.item():.6f}")

        print("Training finished.")
        return self.bundler.get_full_control_points().detach().cpu().numpy()

: 

In [8]:
bundler = GraphEdgeBundler(
        node_csv="periodical-clustering/data/2010s/journal_citation_net/node_filtered.csv",
        edge_csv="periodical-clustering/data/2010s/journal_citation_net/edge_filtered.csv",
        layout_csv="periodical-clustering/data/2010s/journal_citation_net/neulay_results/GAT_filtred_3d/fdl_iter_00801.csv",
        num_control_points=5
    )
trainer = EdgeBundlingGNNTrainer(bundler)
final_ctrl_pts = trainer.train()
print("First edge control points after training:")
print(final_ctrl_pts[0])

layout shape: (20037, 3)


Initializing edge control points: 100%|█████████████████████████████| 398011/398011 [00:03<00:00, 119699.23it/s]
Computing edge_pairs: 100%|██████████████████████████████████████████████| 20037/20037 [01:09<00:00, 287.64it/s]
Building edge2edge index: 100%|█████████████████████████████████████████| 20037/20037 [00:11<00:00, 1733.88it/s]

: 

: 