In [1]:
import os
import numpy as np
import torch
import dgl
import matplotlib.pyplot as plt
import networkx as nx
from scipy.spatial.distance import cdist
from torch.utils.data import Dataset

# 超像素图构建和图神经网络处理的工具函数
def sigma(dists, kth=8):
    knns = np.partition(dists, kth, axis=-1)[:, kth::-1]
    sigma = knns.sum(axis=1).reshape((knns.shape[0], 1)) / kth
    return sigma + 1e-8  # 避免 sigma 为 0

def compute_adjacency_matrix_images(coord, feat, use_feat=False, kth=8):
    coord = coord.reshape(-1, 2)
    c_dist = cdist(coord, coord)  # 计算节点坐标间的距离

    if use_feat:
        f_dist = cdist(feat, feat)  # 计算节点特征间的距离
        A = np.exp(- (c_dist / sigma(c_dist))**2 - (f_dist / sigma(f_dist))**2)
    else:
        A = np.exp(- (c_dist / sigma(c_dist))**2)

    A = 0.5 * (A + A.T)  # 将邻接矩阵对称化
    A[np.diag_indices_from(A)] = 0  # 对角线置零，避免自环
    return A

def compute_edges_list(A, kth=8+1):
    num_nodes = A.shape[0]
    new_kth = num_nodes - kth
    knns = np.argpartition(A, new_kth-1, axis=-1)[:, new_kth:-1]
    knns_d = np.partition(A, new_kth-1, axis=-1)[:, new_kth:-1]
    return knns, knns_d

# 血管图像数据集的类定义
class VesselSuperPix(Dataset):
    def __init__(self, data_dir, split, use_mean_px=True, use_coord=True, use_feat_for_graph_construct=False):
        self.split = split
        self.is_test = split.lower() in ['test', 'val']
        # 根据实际情况修改数据加载方式
        with open(os.path.join(data_dir, f'vessel_superpixels_{split}.pkl'), 'rb') as f:
            self.labels, self.sp_data = pickle.load(f)

        self.use_mean_px = use_mean_px
        self.use_feat_for_graph = use_feat_for_graph_construct
        self.use_coord = use_coord
        self.n_samples = len(self.labels)
        self.img_size = 512  # 根据血管图像的大小调整

    def precompute_graph_images(self):
        print(f'Precomputing data for the {self.split.upper()} set...')
        self.Adj_matrices, self.node_features, self.edges_lists = [], [], []
        for index, sample in enumerate(self.sp_data):
            mean_px, coord = sample[:2]
            coord = coord / self.img_size
            A = compute_adjacency_matrix_images(coord, mean_px, use_feat=self.use_feat_for_graph)
            edges_list, _ = compute_edges_list(A)
            N_nodes = A.shape[0]

            # 根据血管图像特征进行节点特征的设置
            x = None
            if self.use_mean_px:
                x = mean_px.reshape(N_nodes, -1)  # 使用像素强度作为特征
            if self.use_coord:
                coord = coord.reshape(N_nodes, 2)
                if self.use_mean_px:
                    x = np.concatenate((x, coord), axis=1)
                else:
                    x = coord
            if x is None:
                x = np.ones((N_nodes, 1))  # 默认特征

            # 此处可插入血管厚度、曲率等特征计算
            # 血管厚度特征计算
            # 血管曲率特征计算

            self.node_features.append(x)
            self.Adj_matrices.append(A)
            self.edges_lists.append(edges_list)

    def __len__(self):
        return self.n_samples

    def __getitem__(self, index):
        g = dgl.DGLGraph()
        g.add_nodes(self.node_features[index].shape[0])
        g.ndata['feat'] = torch.Tensor(self.node_features[index])
        for src, dsts in enumerate(self.edges_lists[index]):
            g.add_edges(src, dsts[dsts != src])
        return g, self.labels[index]

# 图像显示函数
def show_image(plt, idx, dataset, alpha=0.5):
    x, label = dataset[idx]
    plt.imshow(x.numpy()[0], cmap='gray')
    plt.axis('off')
    plt.title.set_text(f"Label: {label} | Original Image")

# 图结构可视化函数
def plot_superpixels_graph(plt, sp_data, adj_matrix, label, feat_coord, with_edges):
    x_coord = sp_data[1]
    intensities = sp_data[0].reshape(-1)
    G = nx.from_numpy_matrix(adj_matrix)
    pos = dict(zip(range(len(x_coord)), x_coord.tolist()))
    rotated_pos = {node: (y, -x) for (node, (x, y)) in pos.items()}

    edge_list = []
    for src, dsts in enumerate(compute_edges_list(adj_matrix)[0]):
        for dst in dsts:
            edge_list.append((src, dst))

    nx.draw_networkx_nodes(G, rotated_pos, node_color=intensities, cmap=plt.cm.Reds, node_size=60)
    if with_edges:
        nx.draw_networkx_edges(G, rotated_pos, edge_list, alpha=0.3)
    title = f"Label: {label} | {'Using feat and coord for knn' if feat_coord else 'Using only coord for knn'}"
    if not with_edges:
        title = f"Label: {label} | Only superpixel nodes"
    plt.title.set_text(title)

# 主流程，示例图的处理和可视化
if __name__ == "__main__":
    use_feat_for_graph_construct = False
    data_no_feat_knn = VesselSuperPix("data/superpixels", split='train', use_feat_for_graph_construct=use_feat_for_graph_construct)
    data_no_feat_knn.precompute_graph_images()

    use_feat_for_graph_construct = True
    data_with_feat_knn = VesselSuperPix("data/superpixels", split='train', use_feat_for_graph_construct=use_feat_for_graph_construct)
    data_with_feat_knn.precompute_graph_images()

    # 可视化示例
    num_samples_plot = 3
    for f_idx, idx in enumerate(np.random.choice(int(len(data_no_feat_knn) / 2), num_samples_plot, replace=False)):
        f = plt.figure(f_idx, figsize=(23, 5))
        plt1 = f.add_subplot(141)
        show_image(plt1, idx, data_no_feat_knn, alpha=0.5)

        plt2 = f.add_subplot(142)
        plot_superpixels_graph(plt2, data_no_feat_knn.sp_data[idx],
                               data_no_feat_knn.Adj_matrices[idx],
                               data_no_feat_knn[idx][1],
                               data_no_feat_knn.use_feat_for_graph,
                               with_edges=False)

        plt3 = f.add_subplot(143)
        plot_superpixels_graph(plt3, data_no_feat_knn.sp_data[idx],
                               data_no_feat_knn.Adj_matrices[idx],
                               data_no_feat_knn[idx][1],
                               data_no_feat_knn.use_feat_for_graph,
                               with_edges=True)

        plt4 = f.add_subplot(144)
        plot_superpixels_graph(plt4, data_with_feat_knn.sp_data[idx],
                               data_with_feat_knn.Adj_m

SyntaxError: unexpected EOF while parsing (1591615324.py, line 155)

In [2]:
import torch
import dgl

print("PyTorch version:", torch.__version__)
print("DGL version:", dgl.__version__)

FileNotFoundError: Cannot find DGL C++ graphbolt library at /home/pxl/miniconda3/envs/tissue/lib/python3.8/site-packages/dgl/graphbolt/libgraphbolt_pytorch_2.3.0.so