In [83]:
import torch
import dgl.transforms as T
from dgl.dataloading import NeighborSampler, DataLoader

from loguru import logger

transform = T.Compose([T.RowFeatNormalizer()])

from dgl.data import PPIDataset
dataset = (PPIDataset(mode='train', raw_dir='dataset', transform=transform), PPIDataset(mode='valid', raw_dir='dataset', transform=transform), PPIDataset(mode='test', raw_dir='dataset', transform=transform))


from dgl.data import CiteseerGraphDataset
dataset2 = CiteseerGraphDataset('dataset', transform=transform)

from torch_geometric.data import Data, Batch
from torch_geometric.loader import NeighborLoader

def merge_from_data_list(data_list):
    batch_data = Batch.from_data_list(data_list)
    data =  Data(
        x=batch_data.x,
        edge_index=batch_data.edge_index,
        y=batch_data.y
    )
    # TODO: also support edge features and edge lable.
    
    return data

from torch_geometric.datasets import PPI
dataset3 = [
    merge_from_data_list(PPI('dataset/PPI', split='train')),
    merge_from_data_list(PPI('dataset/PPI', split='val')),
    merge_from_data_list(PPI('dataset/PPI', split='test')),
]

from torch_geometric.datasets import Planetoid
dataset4 = Planetoid('dataset', 'CiteSeer', split='public', transform=transform)

  NumNodes: 3327
  NumEdges: 9228
  NumFeats: 3703
  NumClasses: 6
  NumTrainingSamples: 120
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.


In [47]:
Batch.from_data_list(PPI('dataset/PPI', split='train'))

DataBatch(x=[44906, 50], edge_index=[2, 1226368], y=[44906, 121], batch=[44906], ptr=[21])

In [41]:
dataset3[0]

Data(x=[44906, 50], edge_index=[2, 1226368], y=[44906, 121])

In [34]:
dataset4[0]

Data(x=[3327, 3703], edge_index=[2, 9104], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327])

In [53]:
dataset2[0]

Graph(num_nodes=3327, num_edges=9228,
      ndata_schemes={'feat': Scheme(shape=(3703,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool), 'train_mask': Scheme(shape=(), dtype=torch.bool)}
      edata_schemes={})

In [84]:
train_data = dataset[0]
val_data = dataset[1]
test_data = dataset[2]

# for graph in train_data:
#     graph.ndata['train_mask'] = torch.ones(graph.num_nodes(), dtype=bool)
# for graph in val_data:
#     graph.ndata['val_mask'] = torch.ones(graph.num_nodes(), dtype=bool)
# for graph in test_data:
#     graph.ndata['test_mask'] = torch.ones(graph.num_nodes(), dtype=bool)

In [75]:
train_data[0]

Graph(num_nodes=1767, num_edges=34085,
      ndata_schemes={'label': Scheme(shape=(121,), dtype=torch.float32), 'feat': Scheme(shape=(50,), dtype=torch.float32), '_ID': Scheme(shape=(), dtype=torch.int64), 'train_data': Scheme(shape=(), dtype=torch.bool), 'train_mask': Scheme(shape=(), dtype=torch.bool)}
      edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64)})

In [43]:
val_data

Dataset("ppi", num_graphs=2, save_path=dataset/ppi_4b14ad03)

In [22]:
test_data

Dataset("ppi", num_graphs=2, save_path=dataset/ppi_4b14ad03)

In [9]:
def get_loader_no_sampling(train_data, val_data, test_data):
    model_config = {
            "base_model": "GraphSAGE_DGL",
            "num_layers": 2,
            "num_neighbors": [25, 10],
            "hidden_node_channels": 256,
            "dropout": 0,
            "jk": None
        }

    logger.warning(
        "Sampling strategy is set to be None. Full graph will be used without mini-batching! Batch_size is ignored! ")
    num_neighbors = [-1] * model_config["num_layers"]
    model_config.pop("num_neighbors", None)

    general_config = general_config = {
        "framework": "transductive", # Must be transductive or inductive

        "sampling_strategy": "None",  # Must be choosen from sampling_strategy options

        # Used if sampling_strategy is SAGE; Must be choosen from SAGE_inductive_options
        "SAGE_inductive_option": "strict",

        # Enable to use sampling strategy when predicting (val, test, inference). Default: False.
        "sample_when_predict": False,

        "seed": 118010142,
        "device": "cpu",
        "tqdm": False,
        "save_model": True,
        "criterion": "loss",
        "num_epochs": 1000,
        "patience": 100,
        "num_workers": 2,
        "persistent_workers": True
    }

    logger.info(
        f"\ntrain_data={train_data}\nval_data={val_data}\ntest_data={test_data}")

    neighborsampler = NeighborSampler(num_neighbors)

    train_loader = DataLoader(
        graph=train_data,
        indices=train_data.nodes()[train_data.ndata['train_mask']],
        graph_sampler=neighborsampler,
        batch_size=512,
        num_workers=general_config["num_workers"],
        persistent_workers=general_config["persistent_workers"],
    )

    val_loader = DataLoader(
        graph=val_data,
        indices=val_data.nodes()[val_data.ndata['val_mask']],
        graph_sampler=neighborsampler,
        batch_size=512,
        num_workers=general_config["num_workers"],
        persistent_workers=general_config["persistent_workers"],
    )

    test_loader = DataLoader(
        graph=test_data,
        indices=test_data.nodes()[test_data.ndata['test_mask']],
        graph_sampler=neighborsampler,
        batch_size=512,
        num_workers=general_config["num_workers"],
        persistent_workers=general_config["persistent_workers"],
    )

    return train_loader, val_loader, test_loader

In [10]:
train_loader, val_loader, test_loader = get_loader_no_sampling(train_data, val_data, test_data)

[32m2024-06-21 09:12:08.048[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_loader_no_sampling[0m:[36m38[0m - [1m
train_data=Graph(num_nodes=2708, num_edges=10556,
      ndata_schemes={'feat': Scheme(shape=(1433,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool), 'train_mask': Scheme(shape=(), dtype=torch.bool)}
      edata_schemes={})
val_data=Graph(num_nodes=2708, num_edges=10556,
      ndata_schemes={'feat': Scheme(shape=(1433,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool), 'train_mask': Scheme(shape=(), dtype=torch.bool)}
      edata_schemes={})
test_data=Graph(num_nodes=2708, num_edges=10556,
      ndata_schemes={'feat': Scheme(shape=(1433,), dtype=torch.float32), 'label': Scheme(shape=(), dtype=torch.int64), 'val_mask': Scheme(shape=(), dt

In [47]:
train_data.nodes()

tensor([   0,    1,    2,  ..., 2705, 2706, 2707])

In [11]:
for input_nodes, output_nodes, blocks in train_loader:
    print(blocks)
    print(blocks[0].srcdata['feat'])
    

[Block(num_src_nodes=1664, num_dst_nodes=644, num_edges=3834), Block(num_src_nodes=644, num_dst_nodes=140, num_edges=638)]
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
