In [1]:
import torch
import os
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader 
from pathlib import Path
import pandas as pd
from torch_geometric.data import Batch


In [8]:
from torch_geometric.loader import NeighborLoader

In [29]:
class GraphDataset(InMemoryDataset):
    """
    PyG dataset for patient graphs (Load precomputed graphs from .pt or .h5).
    """
    def __init__(self, config, transform=None, pre_transform=None):
        super(GraphDataset, self).__init__(root=config["data_dir"], transform=transform, pre_transform=pre_transform)

        # === 1. Graph ===
        graph_path = Path(config["graph_dir"]) / f"diagnosis_graph_{config['mode']}_k{config['k']}.pt"
        print(f"==> Loading precomputed graph from {graph_path}")
        self.graph_data = torch.load(graph_path, weights_only=False)

        # === 2. edge_index and edge_attr ===
        self.edge_index = self.graph_data.edge_index
        self.edge_attr = self.graph_data.edge_attr

        # === 3. PyG `Data`  ===
        self.data = Data(edge_index=self.edge_index, edge_attr=self.edge_attr)

    def __repr__(self):
        return f"GraphOnlyDataset(num_edges={self.data.num_edges})"



In [30]:
# ===generate DataLoader ===
def get_graph_dataloader(config, batch_size=32, shuffle=True):
    """
    Create PyG dataloader for training.
    """
    dataset = GraphDataset(config)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

In [31]:
config = {  
    "data_dir": "/home/mei/nas/docker/thesis/data/hdf/val",
    "graph_dir": "/home/mei/nas/docker/thesis/data/graphs",
    "mode": "k_closest",
    "k": 3
          
}

In [33]:
graph_loader = get_graph_dataloader(config,batch_size=32)
for batch in graph_loader:
    print(batch)
    break

==> Loading precomputed graph from /home/mei/nas/docker/thesis/data/graphs/diagnosis_graph_k_closest_k3.pt
DataBatch(edge_index=[2, 35094], edge_attr=[35094], batch=[11698], ptr=[2])
