In [4]:
import numpy as np
import networkx as nx

from torch_geometric_temporal.signal import temporal_signal_split
from torch_geometric_temporal.signal import StaticGraphTemporalSignal
from torch_geometric_temporal.dataset import PedalMeDatasetLoader, ChickenpoxDatasetLoader

import torch
import torch.nn.functional as F
from torch_geometric_temporal.nn.recurrent import DCRNN

from tqdm import tqdm

from torch_geometric_temporal.signal import temporal_signal_split


### https://pytorch-geometric-temporal.readthedocs.io/en/latest/notes/introduction.html#epidemiological-forecasting

In [5]:
def get_edge_array(node_count, node_start):
    edges = []
    for edge in nx.gnp_random_graph(node_count, 0.1).edges():
        edges.append([edge[0]+node_start, edge[1]+node_start])
    return np.array(edges)

def generate_signal(snapshot_count, node_count, feature_count, graph_count):
    edge_indices= []
    edge_weights = []
    features = []
    targets = []
    batches = []
    for snapshot in range(snapshot_count):
        node_start = 0
        edge_indices_s = []
        edge_weights_s = []
        features_s = []
        targets_s = []
        batches_s = []
        for i in range(graph_count):
            edge_indices_s.append(get_edge_array(node_count, node_start))
            edge_weights_s.append((np.ones(edge_indices_s[-1].shape[0])))
            features_s.append(np.random.uniform(0,1, (node_count, feature_count)))
            targets_s.append(np.array([np.random.choice([0,1]) for _ in range(node_count)]))
            batches_s.append(np.array([i for _ in range(node_count)]))
            node_start = node_start + node_count
        edge_indices.append(np.concatenate(edge_indices_s).T)
        edge_weights.append(np.concatenate(edge_weights_s))
        features.append(np.concatenate(features_s))
        targets.append(np.concatenate(targets_s))
        batches.append(np.concatenate(batches_s))
    
    return edge_indices, edge_weights, features, targets, batches

In [6]:
snapshot_count = 250
n_count = 100
feature_count = 32
graph_count = 10

edge_indices, edge_weights, features, targets, batches = generate_signal(250, 100, 32, graph_count)

In [11]:
edge_indices[0].shape, edge_indices[1].shape, edge_indices[2].shape

((2, 4960), (2, 5038), (2, 4839))

In [19]:
edge_indices[:10]

[array([[  0,   0,   0, ..., 991, 992, 993],
        [  4,  19,  30, ..., 994, 994, 996]]),
 array([[  0,   0,   0, ..., 991, 991, 993],
        [  8,  11,  28, ..., 993, 998, 998]]),
 array([[  0,   0,   0, ..., 989, 990, 990],
        [ 24,  36,  61, ..., 995, 992, 998]]),
 array([[  0,   0,   0, ..., 990, 990, 993],
        [  6,  10,  24, ..., 993, 997, 996]]),
 array([[  0,   0,   0, ..., 992, 994, 998],
        [  2,  12,  16, ..., 997, 995, 999]]),
 array([[  0,   0,   0, ..., 992, 992, 998],
        [ 12,  13,  22, ..., 995, 996, 999]]),
 array([[  0,   0,   0, ..., 991, 994, 997],
        [ 12,  27,  41, ..., 998, 998, 998]]),
 array([[  0,   0,   0, ..., 989, 989, 993],
        [ 56,  62,  63, ..., 997, 999, 996]]),
 array([[  0,   0,   0, ..., 992, 993, 994],
        [  5,  20,  69, ..., 996, 995, 997]]),
 array([[  0,   0,   0, ..., 993, 997, 997],
        [ 18,  34,  39, ..., 999, 998, 999]])]

In [None]:
dataset = DynamicGraphTemporalSignalBatch(edge_indices, edge_weights, features, targets, batches)


In [None]:
 for epoch in range(2):
        for snapshot in dataset:
            assert snapshot.edge_index.shape[0] == 2
            assert snapshot.edge_index.shape[1] == snapshot.edge_attr.shape[0]
            assert snapshot.x.shape == (1000, 32)
            assert snapshot.y.shape == (1000, )
            assert snapshot.batch.shape == (1000, )