In [2]:
import torch
from torch_geometric.data import InMemoryDataset
from pathlib import Path
from torch_geometric.loader import NeighborSampler
import pandas as pd
import numpy as np


In [3]:
class GraphDataset(InMemoryDataset):
    """
    PyG dataset for patient graphs (loads entire graph).
    """
    def __init__(self, config, transform=None, pre_transform=None):
        super(GraphDataset, self).__init__(root=config["data_dir"], transform=transform, pre_transform=pre_transform)

        # === 1. Load Entire Graph ===
        graph_path = Path(config["graph_dir"]) / f"diagnosis_graph_{config['mode']}_k{config['k']}.pt"
        print(f"==> Loading precomputed graph from {graph_path}")
        self.graph_data = torch.load(graph_path, weights_only=False)

        # === 2. Extract edge_index and edge_attr ===
        self.edge_index = self.graph_data.edge_index
        self.edge_attr = self.graph_data.edge_attr

        # === 3. Load Flat Features ===
        flat_path = Path(config["data_dir"]) / "final_flat.h5"
        print(f"==> Loading flat features from {flat_path}")
        flat_df = pd.read_hdf(flat_path).set_index("patient")
        flat_df.index = flat_df.index.astype(int)

        # Align patient IDs with graph
        self.graph_data.patient_ids = self.graph_data.patient_ids.clone().detach().long()
        self.graph_data.x = torch.tensor(flat_df.loc[self.graph_data.patient_ids.numpy()].values, dtype=torch.float)
        self.graph_data.num_nodes = self.graph_data.x.shape[0]

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        return self.graph_data

In [12]:
def graph_loader(graph_dataset, patient_ids, sizes, batch_size, shuffle):
    """
    只采样 LSTM batch 的 `patient_ids`，防止 `NeighborSampler` 额外加入邻居节点
    """
    # **获取 `patient_ids` 在 `graph_dataset` 中的索引**
    graph_patient_ids = graph_dataset.graph_data.patient_ids
    
    # 只选择当前 batch `patient_ids` 在 Graph 数据中的索引
    node_idx = torch.tensor([torch.where(graph_patient_ids == pid)[0][0] for pid in patient_ids if pid in graph_patient_ids], dtype=torch.long)
    
    # **创建 NeighborSampler**
    loader = NeighborSampler(
        edge_index=graph_dataset.graph_data.edge_index,
        node_idx=node_idx,  # 仅采样 `patient_ids` 的图节点
        sizes=sizes,
        batch_size=batch_size,  # `batch_size` 必须等于 LSTM batch 大小
        shuffle=shuffle
    )

    return loader

In [7]:
config = {  
    "data_dir": "/home/mei/nas/docker/thesis/data/hdf",
    "graph_dir": "/home/mei/nas/docker/thesis/data/graphs",
    "mode": "k_closest",
    "k": 3
          
}

In [4]:
patient_ids=[4854, 7605, 7343, 4752, 4989, 1690, 1590, 1266, 1829, 5949, 7599, 962, 4316, 7313, 7685, 4486, 4769, 3295, 2614, 5275, 5470, 3914, 104, 6128, 5850, 3608, 2876, 330, 6051, 1228, 7874]

In [8]:
graph_dataset = GraphDataset(config)

==> Loading precomputed graph from /home/mei/nas/docker/thesis/data/graphs/diagnosis_graph_k_closest_k3.pt
==> Loading flat features from /home/mei/nas/docker/thesis/data/hdf/final_flat.h5


In [14]:
import torch
import networkx as nx
from torch_geometric.utils import to_networkx

# **将 PyG 图转换为 NetworkX 图**
graph_nx = to_networkx(graph_dataset.graph_data, to_undirected=True)

# **检查是否是连通图**
is_connected = nx.is_connected(graph_nx)
print(f"✅ The graph is {'Connected' if is_connected else 'NOT Connected'}.")

# **如果图不是连通的，检查连通分量**
connected_components = list(nx.connected_components(graph_nx))
print(f"🔹 Number of connected components: {len(connected_components)}")

# **打印最大连通分量**
largest_cc = max(connected_components, key=len)
print(f"🔹 Largest component size: {len(largest_cc)}")

✅ The graph is Connected.
🔹 Number of connected components: 1
🔹 Largest component size: 11698


In [13]:
patient_ids_tensor = torch.tensor([int(pid) for pid in patient_ids], dtype=torch.long)

graph_batch = graph_loader(
    graph_dataset, 
    patient_ids_tensor,  # 传入 LSTM batch 里的 patient_ids
    sizes=[10, 10], 
    batch_size=len(patient_ids_tensor),  # **确保 batch_size 与 LSTM 一致**
    shuffle=True

)
batch_size, n_id, adjs = next(iter(graph_batch))
print(f"sampled node ids: {n_id.tolist()}")
print(f"Edge index shape: {adjs[0][0].shape}")

RuntimeError: max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.

In [4]:
import torch
torch.cuda.empty_cache()

In [None]:
from torch_geometric.loader import NeighborSampler

graph_dataset = GraphDataset(config)

sample_sizes = [10, 10]  # 每层采样邻居数量
batch_size = 32

input_nodes = torch.arange(0, graph_dataset.graph_data.num_nodes)  
loader =  NeighborSampler(
    graph_dataset.graph_data.edge_index,
    node_idx=graph_dataset.graph_data.val_mask,
    sizes=sample_sizes,
    batch_size=batch_size,
    shuffle=True
)

# 遍历一个 batch
for batch_size, n_id, adjs in loader:
    print("Batch size:", batch_size)
    print("Node IDs:", n_id)
    
    # 提取节点属性
    node_features = graph_dataset.graph_data.x[n_id]
    print("Node features shape:", node_features.shape)
    print("Node features:", node_features)
    
    for edge_index, e_id, size in adjs:
        print("Edge index shape:", edge_index.shape)
        print("Edge IDs:", e_id)
        print("Size:", size)
    
    

In [None]:
class GraphDataset(InMemoryDataset):
    """
    PyG dataset for patient graphs aligned with MultiModalDataset patient IDs.
    """
    def __init__(self, config, multimodal_patient_ids, transform=None, pre_transform=None):
        super(GraphDataset, self).__init__(root=config["data_dir"], transform=transform, pre_transform=pre_transform)

        # === 1. Load Graph ===
        graph_path = Path(config["graph_dir"]) / f"diagnosis_graph_{config['mode']}_k{config['k']}.pt"
        print(f"==> Loading precomputed graph from {graph_path}")
        self.graph_data = torch.load(graph_path, weights_only=False)

        # === 2. Extract edge_index and edge_attr ===
        self.edge_index = self.graph_data.edge_index
        self.edge_attr = self.graph_data.edge_attr

        # === 3. Align Graph Patient IDs with MultiModalDataset ===
        if isinstance(self.graph_data.patient_ids, torch.Tensor):
            graph_patient_ids = self.graph_data.patient_ids.numpy()
        else:
            graph_patient_ids = np.array(self.graph_data.patient_ids)

        # 将 MultiModalDataset 中存在的 patient_ids 作为筛选条件
        self.patient_ids = np.array([pid for pid in multimodal_patient_ids if pid in graph_patient_ids])

        # 找到对应索引位置（用 numpy 加速）
        patient_id_indices = np.where(np.isin(graph_patient_ids, self.patient_ids))[0]

        # === 4. Load Flat Features and Align ===
        flat_path = Path(config["data_dir"]) / "final_flat.h5"
        print(f"==> Loading flat features from {flat_path}")
        flat_df = pd.read_hdf(flat_path).set_index("patient")
        flat_df.index = flat_df.index.astype(int)

        # 对齐 flat 特征，确保 patient_ids 存在于 flat_df
        flat_df_aligned = flat_df.loc[self.patient_ids]
        x_flat = torch.tensor(flat_df_aligned.values, dtype=torch.float)
        self.graph_data.x = x_flat
        self.graph_data.num_nodes = len(self.patient_ids)

        # === 5. 更新 edge_index 以反映新的节点索引映射 ===
        old_to_new_idx = {old_idx: new_idx for new_idx, old_idx in enumerate(patient_id_indices)}

        # 使用张量筛选有效边（两端节点都在 patient_id_indices 中）
        src_nodes = self.edge_index[0].numpy()
        dst_nodes = self.edge_index[1].numpy()

        valid_src_mask = np.isin(src_nodes, patient_id_indices)
        valid_dst_mask = np.isin(dst_nodes, patient_id_indices)
        valid_edges_mask = valid_src_mask & valid_dst_mask

        # 筛选有效的 edge_index 和 edge_attr
        filtered_edge_index = self.edge_index[:, valid_edges_mask]
        filtered_edge_attr = self.edge_attr[valid_edges_mask]

        # 重映射边的索引
        remapped_src = np.array([old_to_new_idx[src] for src in filtered_edge_index[0].numpy()])
        remapped_dst = np.array([old_to_new_idx[dst] for dst in filtered_edge_index[1].numpy()])

        remapped_edge_index = torch.tensor([remapped_src, remapped_dst], dtype=torch.long)

        # 更新 graph_data
        self.graph_data.edge_index = remapped_edge_index
        self.graph_data.edge_attr = filtered_edge_attr
        self.graph_data.patient_ids = torch.tensor(self.patient_ids, dtype=torch.long)

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        return self.graph_data


In [None]:
def graph_loader (graph_dataset,sizes,batch_size,shuffle):
    
    loader = NeighborSampler(
        graph_dataset.graph_data.edge_index,
        node_idx=torch.arange(graph_dataset.graph_data.num_nodes),
        sizes=sizes,
        batch_size=batch_size,
        shuffle=shuffle
    )
    return loader

In [None]:
from torch_geometric.loader import NeighborSampler

graph_dataset = GraphDataset(config, multimodal_patient_ids=multi_modal_patient_ids)

train_loader = NeighborSampler(
    graph_dataset.graph_data.edge_index,
    node_idx=torch.arange(graph_dataset.graph_data.num_nodes),  # 所有与 MultiModalDataset 对齐的节点
    sizes=[10, 10],  # 采样邻居数量
    batch_size=32,
    shuffle=True
)
