In [1]:
import torch
import os
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader 
from pathlib import Path
import pandas as pd
from torch_geometric.data import Batch


In [2]:
from torch_geometric.loader import NeighborLoader

In [36]:
class GraphDataset(InMemoryDataset):
    """
    PyG dataset for patient graphs (Load precomputed graphs from .pt or .h5).
    """
    def __init__(self, config,transform=None, pre_transform=None):
        super(GraphDataset, self).__init__(root=config["data_dir"], transform=transform, pre_transform=pre_transform)

        # === 1. Graph ===
        graph_path = Path(config["graph_dir"]) / f"diagnosis_graph_{config['mode']}_k{config['k']}.pt"
        print(f"==> Loading precomputed graph from {graph_path}")
        self.graph_data = torch.load(graph_path, weights_only=False)

        # === 2. Extract edge_index and edge_attr ===
        self.edge_index = self.graph_data.edge_index
        self.edge_attr = self.graph_data.edge_attr
        self.patient_ids = self.graph_data.patient_ids.numpy()
        
         # === 3️. Get Flat Features from LSTMDataset ===
        
        flat_path = Path(config["data_dir"]) / "final_flat.h5"  # 假设 flat 特征存储在 flat.h5
        print(f"==> Loading flat features from {flat_path}")
        flat_df = pd.read_hdf(flat_path)  # 读取 flat 特征
        flat_df = flat_df.set_index("patient") 
        
         # === 4. Align Flat Features with Graph Patient IDs ===
        sorted_idx = [flat_df.index.get_loc(pid) for pid in self.patient_ids]  # 获取 `patient_id` 对应的索引
        x_flat = torch.tensor(flat_df.values[sorted_idx], dtype=torch.float)  # 按 `patient_id` 重新排序
        print(f"x_flat shape: {x_flat.shape}, num_nodes: {self.graph_data.num_nodes}")
        
        # === 5. PyG `Data`  ===
        self.data = Data(x=x_flat, edge_index=self.edge_index, edge_attr=self.edge_attr)
        self.graph_data.x = x_flat
        self.graph_data.num_nodes = x_flat.shape[0]

    def __len__(self):
        return 1  

    def __getitem__(self, idx):
        return self.data 



In [13]:
config = {  
    "data_dir": "/home/mei/nas/docker/thesis/data/hdf",
    "graph_dir": "/home/mei/nas/docker/thesis/data/graphs",
    "mode": "k_closest",
    "k": 3
          
}

In [4]:
# ===generate DataLoader ===
def get_graph_dataloader(config, batch_size=32, shuffle=True):
    """
    Create PyG dataloader for training.
    """
    dataset = GraphDataset(config)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

In [33]:
graph_loader = get_graph_dataloader(config,batch_size=32)
for batch in graph_loader:
    print(batch)
    break

==> Loading precomputed graph from /home/mei/nas/docker/thesis/data/graphs/diagnosis_graph_k_closest_k3.pt
==> Loading flat features from /home/mei/nas/docker/thesis/data/hdf/final_flat.h5
x_flat shape: torch.Size([11698, 104]), num_nodes: 11698
DataBatch(x=[11698, 104], edge_index=[2, 35094], edge_attr=[35094], batch=[11698], ptr=[2])




In [37]:
from torch_geometric.loader import NeighborLoader

graph_dataset = GraphDataset(config)
graph_data = graph_dataset.graph_data  
graph_data.edge_index = graph_data.edge_index.contiguous()


# 设定 `NeighborLoader` 进行采样
graph_loader = NeighborLoader(
    graph_data,
    num_neighbors=[10, 10],  # 2-hop 采样，每层 10 个邻居
    batch_size=32,
    shuffle=True
)

print("Graph data attributes:", graph_data)

==> Loading precomputed graph from /home/mei/nas/docker/thesis/data/graphs/diagnosis_graph_k_closest_k3.pt
==> Loading flat features from /home/mei/nas/docker/thesis/data/hdf/final_flat.h5
x_flat shape: torch.Size([11698, 104]), num_nodes: 11698
Graph data attributes: Data(edge_index=[2, 35094], edge_attr=[35094], num_nodes=11698, patient_ids=[11698], x=[11698, 104])
