In [None]:
%load_ext autoreload
%autoreload 2

from collections import Counter

import networkx as nx
import numpy as np
import torch
from torch_geometric.utils import from_networkx

from tadcsbm import (
    tadcsbm_simulator,
    generate_block_matrix,
    generate_transition_matrix,
    generate_degree_vector,
    generate_community_vector,
    gt_to_nx,
)

In [None]:
k = 8       # Number of communities.
t = 8       # Number of snapshots.
n = 1024    # Number of vertices.
e = 10240   # Number of edges.
eta = 1.0   # Community stability rate.
gamma = 1   # Fix transition probabilities.
beta = 1.0  # Edge sampling rate.

In [None]:
mat = generate_block_matrix(k)
mat

In [None]:
tau = generate_transition_matrix(k, eta, uniform_all=False)
tau

In [None]:
z = generate_community_vector(n, k, shuffle=False)
z

In [None]:
sbm = tadcsbm_simulator(
    snapshots=t,
    num_vertices=n,
    num_edges=e,
    pi=[v/len(z) for k, v in Counter(z).items()],
    prop_mat=mat,
    tau_mat=tau,
    num_feature_groups=k,
    feature_dim=32,
    feature_center_distance=6.0,
    feature_cluster_variance=1.0,
    edge_feature_dim=32,
    edge_center_distance=6.0,
    edge_cluster_variance=1.0,
    fixed_probabilities=False,
    reverse_snapshot_order=True,
    edge_sampling_rate=1.0,
)

In [None]:
# Compose graph-tool graphs as a single NetworkX multigraph.
# list(graph.save(f"output/snapshot_t={t}.graphml") for t, graph in enumerate(sbm.graph))
G = nx.compose_all([gt_to_nx(graph, time=t) for t, graph in enumerate(sbm.graph)])
nx.set_node_attributes(G, {v: y for v, y in zip(G.nodes(), sbm.graph_memberships)}, "y")
nx.write_graphml(G, "output/graph.graphml")

# Save node and edge features as NumPy arrays.
np.save("output/features_node.npy", sbm.node_features1)
np.save("output/features_edge.npy", sbm.edge_features)

# Set node and edge attributes in the NetworkX graph.
nx.set_node_attributes(G, {v: x for v, x in zip(G.nodes(), sbm.node_features1)}, "x")
nx.set_node_attributes(G, {e: x for e, x in zip(G.edges(), sbm.edge_features)}, "edge_attr")

# Save as PyTorch Geometric data object.
data = from_networkx(G)
torch.save(data, "output/data.pt")

print(G)
print(data)

___

In [None]:
# from tadcsbm.simulations.sbm_simulator import _TransitionNodeMemberships
# _TransitionNodeMemberships(sbm.graph_memberships, tau)