# Generating graph data

In [192]:
import networkx as nx
import numpy as np
from matplotlib import pyplot as plt
from ipywidgets import interactive, fixed

## Graph generator

In [202]:
def generate_graph(num_nodes, num_clusters, num_features, dist_params, p, q):
    """
    Method to generate a graph with clusters, where different distributions generate node attributes
    
    Params:
        num_nodes: Total nodes in the graph (int)
        num_clusters: Number of clusters in the graph (int)
        num_features: Length/dimensionality of node attribute vectors (int)
        dist_params: The parameters of the distributions for each clsuter. List of length num_clusters with (μ, σ) tuple
        p: Intra-cluster edge probability
        q: Inter-cluster edge probability
    """
    # init
    g = nx.empty_graph()
    A = np.zeros((num_nodes, num_nodes))

    # randomly assign clusters
    clusters = {
        i : np.random.choice(np.arange(num_clusters)) 
        for i in range(num_nodes)
    }

    # use corresponding distributions to generate attributes
    node_attrs = {
        i: np.random.normal(loc=dist_params[c][0], scale=dist_params[c][1], size=num_features)
        for i, c in clusters.items()
    }

    for i in range(num_nodes):
        for j in range(num_nodes):
            if i!=j:
                # same cluster
                if clusters[i] == clusters[j]:
                    A[i][j] = np.random.normal(p, 0.2)
                # different cluster
                else: 
                    A[i][j] = np.random.normal(q, 0.2)

    A = np.int32(A > 0.5) # threshold adjacency

    g = nx.from_numpy_array(A)
    nx.set_node_attributes(g, node_attrs, 'x')

    return g, clusters


def generate_and_draw(num_nodes, num_clusters, num_features, dist_params, p, q):
    G, clusters = generate_graph(num_nodes, num_clusters, num_features, dist_params, p, q)

    colours = ['red', 'blue', 'green', 'orange', 'pink', 'yellow', 'gray']
    nx.draw(G, node_color=[colours[c] for c in clusters.values()])
    plt.show()

In [204]:
dist_params = [
    (5, 0.9),
    (3, 0.4),
    (-1, 0.5),
    (0,0.1)
]

interactive_plot = interactive(generate_and_draw, 
                        num_nodes=(10,35),
                        num_clusters=fixed(4),
                        num_features=fixed(2),
                        dist_params=fixed(dist_params),
                        p=(0.5,1.0),
                        q=(0.0,0.3)
                    )
interactive_plot

interactive(children=(IntSlider(value=22, description='num_nodes', max=35, min=10), FloatSlider(value=0.75, de…