In [9]:
import torch
import os
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader 
from pathlib import Path
import pandas as pd
from torch_geometric.data import Batch
from torch_geometric.loader import NeighborLoader

In [42]:
class GraphDataset(InMemoryDataset):
    """
    PyG dataset for patient graphs
    """
    def __init__(self, config,transform=None, pre_transform=None):
        super(GraphDataset, self).__init__(root=config["data_dir"], transform=transform, pre_transform=pre_transform)

        # === 1. Graph ===
        graph_path = Path(config["graph_dir"]) / f"diagnosis_graph_{config['mode']}_k{config['k']}.pt"
        print(f"==> Loading precomputed graph from {graph_path}")
        self.graph_data = torch.load(graph_path,weights_only=False)

        # === 2. Extract edge_index and edge_attr ===
        self.edge_index = self.graph_data.edge_index
        self.edge_attr = self.graph_data.edge_attr
        if isinstance(self.graph_data.patient_ids, torch.Tensor):
            self.patient_ids = self.graph_data.patient_ids.numpy()
        else:
            self.patient_ids = self.graph_data.patient_ids
        
         # === 3️. Get Flat Features from LSTMDataset ===
        
        flat_path = Path(config["data_dir"]) / "final_flat.h5"  
        print(f"==> Loading flat features from {flat_path}")
        flat_df = pd.read_hdf(flat_path)  # load flat 特征
        flat_df = flat_df.set_index("patient") 
        
        
         # === 4. Align Flat Features with Graph Patient IDs ===
        self.patient_ids = [int(pid) for pid in self.patient_ids]
        flat_df.index = flat_df.index.astype(int) 
        flat_df_aligned = flat_df.loc[self.patient_ids]
        x_flat = torch.tensor(flat_df_aligned.values, dtype=torch.float)
        
        self.graph_data.x = x_flat
        self.graph_data.num_nodes = x_flat.shape[0]
        
        # === 5. PyG `Data`  ===
        self.graph_data.x = x_flat
        self.graph_data.edge_index = self.graph_data.edge_index.contiguous()  # 确保连续
        self.graph_data.num_nodes = x_flat.shape[0]
        self.graph_data.patient_ids = torch.tensor(self.patient_ids, dtype=torch.long)

    def __len__(self):
        return 1  

    def __getitem__(self, idx):
        return self.graph_data



In [43]:
config = {  
    "data_dir": "/home/mei/nas/docker/thesis/data/hdf",
    "graph_dir": "/home/mei/nas/docker/thesis/data/graphs",
    "mode": "k_closest",
    "k": 3
          
}

In [45]:
graph_dataset = GraphDataset(config)

# 检查图数据
print("Graph data:", graph_dataset.graph_data)
print(f"Number of nodes in graph: {graph_dataset.graph_data.num_nodes}")
print(f"Number of edges in graph: {graph_dataset.graph_data.edge_index.size(1)}")

# 使用 DataLoader 将整张图输入模型
loader = DataLoader(graph_dataset, batch_size=1, shuffle=False)

# 遍历一个 batch
for batch in loader:
    print("Batch x shape:", batch.x.shape)             # 节点特征形状
    print("Batch edge_index shape:", batch.edge_index.shape)
    # 如果 batch 对象中有 batch 属性
    if hasattr(batch, 'batch') and batch.batch is not None:
        print("Batch batch shape:", batch.batch.shape)
    break

==> Loading precomputed graph from /home/mei/nas/docker/thesis/data/graphs/diagnosis_graph_k_closest_k3.pt
==> Loading flat features from /home/mei/nas/docker/thesis/data/hdf/final_flat.h5
Graph data: Data(edge_index=[2, 35094], edge_attr=[35094], num_nodes=11698, patient_ids=[11698], x=[11698, 104])
Number of nodes in graph: 11698
Number of edges in graph: 35094
Batch x shape: torch.Size([11698, 104])
Batch edge_index shape: torch.Size([2, 35094])
Batch batch shape: torch.Size([11698])
