In [490]:
import networkx as nx
from torch_geometric.utils import from_networkx
from torch_geometric.loader import DataLoader, NeighborLoader
from torch_geometric.data import Batch
from tqdm.notebook import tqdm
import torch
import numpy as np
import os
import pickle
from functools import partial
import multiprocessing as mp
from tqdm.contrib.concurrent import process_map
from sklearn.datasets import make_spd_matrix

In [491]:
n_comm = 10
comm_size = 100
p_in = 0.9
p_out = 0.01
node_feat_dim = 5
# means = [0, 10, 20, 30, 40] * 2
means = [i * 1 for i in range(n_comm)]
variances = [1, 1, 1, 1, 1] * 2
noise_props = np.random.rand(n_comm) * 0.2

In [492]:
def make_sbm_graph(n_comm, comm_size, p_in, p_out, node_feat_dim, means, variances, noise_props, seed):
    np.random.seed(seed)
    communities = [comm_size] * n_comm
    G = nx.random_partition_graph(communities, p_in, p_out, seed=seed)
    G = G.to_directed() if not nx.is_directed(G) else G
    # Pyg doesn't like these attributes, so just delete.
    del G.graph['partition']
    del G.graph['name']
    node_feat = []
    labels = []
    for i in range(n_comm):
        cov = make_spd_matrix(node_feat_dim)
        mean = [means[i]] * node_feat_dim
        x = np.random.multivariate_normal(mean, cov, comm_size)
        # x = np.random.normal(means[i], variances[i], (comm_size, node_feat_dim))
        # y_clean = np.ones(int(comm_size/2)) * i
        # y_corrupted = np.ones(int(comm_size/2)) * i
        y = np.ones(comm_size) * i
        other_labels = [j for j in range(n_comm) if j != i]
        noise_prop = noise_props[i]
        n_noisy_labels = int(noise_prop * comm_size)
        noise_values = np.random.choice(other_labels, n_noisy_labels, replace=True)
        noise_indices = np.random.choice([i for i in range(comm_size)], n_noisy_labels, replace=False)
        y[noise_indices] = noise_values
        # y = np.concatenate([y_clean, y_corrupted])
        node_feat.append(x)
        labels.append(y)
    node_feat = np.concatenate(node_feat)
    labels = np.concatenate(labels)
    data = from_networkx(nx.Graph(G))
    data.x = torch.tensor(node_feat, dtype=torch.float32)
    data.y = torch.tensor(labels, dtype=torch.int64)
    return data

In [493]:
def fn_tv(i):
    return make_sbm_graph(n_comm, comm_size, p_in, p_out, node_feat_dim, means, variances, noise_props, i)
def fn_test(i):
    return make_sbm_graph(n_comm, comm_size, p_in, p_out, node_feat_dim, means, variances, noise_props, i)

In [494]:
with mp.Pool(8) as pool:
    train = process_map(fn_tv, [i for i in range(20)])
    val = process_map(fn_tv, [i for i in range(21, 23)])
    test = process_map(fn_test, [i for i in range(26, 27)])


  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

In [495]:
train_data = Batch.from_data_list(train)
val_data = Batch.from_data_list(val)
test_data = Batch.from_data_list(test)

In [496]:
train_data.num_classes = n_comm
val_data.num_classes = n_comm
test_data.num_classes = n_comm

In [497]:
try:
    os.mkdir('data/SBM')
except FileExistsError:
    pass

In [498]:
folder = 'data/SBM'
for name, data in {'train': train_data, 'val': val_data, 'test': test_data}.items():
    path = os.path.join(folder, f'{name}.pkl')
    with open(path, 'wb+') as f:
        pickle.dump(data, f)

In [503]:
train[0].x[0]

tensor([-1.4513,  0.2087, -1.0950,  1.9541,  1.2041])

In [504]:
val[0].x[0]

tensor([ 1.1249,  0.0096,  1.7162, -0.3796, -0.7158])

In [505]:
test[0].x[0]

tensor([-0.1655, -0.8628,  1.1961, -0.9450, -0.2817])

In [502]:
cov = make_spd_matrix(3)
mean = [0, 0, 0]
np.random.multivariate_normal(mean, cov, 10)

array([[-3.50711722,  1.74153693, -1.8102476 ],
       [-0.54209592,  0.55156439, -0.20988866],
       [-0.07083168,  2.12144096, -0.05377152],
       [ 2.23657732, -0.28146136,  0.44134555],
       [ 1.64122   , -1.37344051,  0.33691316],
       [-5.14698458,  2.38221859, -1.38946014],
       [-1.1308754 ,  1.00084772, -1.55658744],
       [ 0.31697078,  1.19494323,  0.37817609],
       [-0.29865434, -0.52432122, -0.47071862],
       [-1.22292098,  0.79991628,  0.25982859]])