In [None]:
import numpy as np
import networkx as nx
import pandas as pd
import torch
import networkx.algorithms.community as comm
import os
from tqdm import tqdm
from urllib.request import urlopen
from zipfile import ZipFile
from io import BytesIO
import pickle
from torch_geometric.io import read_npz
from torch_geometric.utils import to_networkx



os.chdir("../")
from bter.bter import BTER
os.chdir("notepad_code")

In [None]:
# os.chdir('../../dgd/')
# print(os.getcwd())

def load_original(name = None, largest_cc = True):
    print(f"Loading original {name} graph")
    if name == "cora":
        cora_url = 'https://github.com/abojchevski/graph2gauss/raw/master/data/cora_ml.npz'
        resp = urlopen(cora_url)
        out = read_npz(BytesIO(resp.read()))

        G = to_networkx(out, to_undirected=True)

        node_classes = {n: out.y[i].item() for i,n in enumerate(list(G.nodes()))}

        nx.set_node_attributes(G, node_classes, name = "target")



        G = G.to_undirected()
        CGs = [G.subgraph(c) for c in nx.connected_components(G)]
        CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
        G = CGs[0]

        G = nx.convert_node_labels_to_integers(G)


    elif name == "sbm":
        file_path = os.path.join(root_path, "sbm/raw/sbm_200.pt")

        adjs, eigvals, eigvecs, n_nodes, max_eigval, min_eigval, same_sample, n_max = torch.load(
            file_path)

        random_graph_selection = np.random.randint(len(adjs))

        adj = adjs[random_graph_selection]

        g = nx.from_numpy_array(adj.numpy())

        these_atom_types = np.ones(g.order())

        these_edges, these_types = [], []
        for edge in g.edges():
            start, end = edge[0], edge[1]
            these_edges += [[start, end]]
            these_types += [1.]

        these_nodes = [(ind, {"target": v}) for ind, v in enumerate(these_atom_types)]
        these_edges = [(edge[0], edge[1], {"type": these_types[ind]}) for ind, edge in enumerate(these_edges)]

        G = nx.Graph()
        G.add_nodes_from(these_nodes)
        G.add_edges_from(these_edges)




    else:
        resp = urlopen("https://snap.stanford.edu/data/facebook_large.zip")
        myzip = ZipFile(BytesIO(resp.read()))

        edgelist = pd.read_csv(myzip.open("facebook_large/musae_facebook_edges.csv"))
        G = nx.from_pandas_edgelist(df=edgelist, source="id_1", target="id_2")

        class_df = pd.read_csv(myzip.open("facebook_large/musae_facebook_target.csv"))

        unique_types = np.unique(class_df["page_type"])
        types = {target: i for i, target in enumerate(unique_types.tolist())}

        node_classes = {n: types[class_df.at[n, "page_type"]] for n in list(G.nodes())}

        nx.set_node_attributes(G, node_classes, name="target")


    if largest_cc:
        CGs = [G.subgraph(c) for c in nx.connected_components(G)]
        CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
        G = CGs[0]
    
    G = G.copy()
    G.remove_edges_from(nx.selfloop_edges(G))


    return G
    
def general_graph_metrics(datamodule, return_values = False, descriptor="", largest_cc=True):

    if type(datamodule) == nx.Graph:
        G = datamodule
    else:
        G = datamodule.G
        
    
    if largest_cc:
        CGs = [G.subgraph(c) for c in nx.connected_components(G)]
        CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
        G = CGs[0]
    
    communities = comm.louvain_communities(G, resolution=1)
    community_sizes = [len(part) for part in communities]
    
    # plt.hist(community_sizes)
    # plt.show()
    
    median_size = np.median(community_sizes)
    num_communities = len(communities)
    
    density  = nx.density(G)
    try:
        diameter = nx.diameter(G)
    except:
        diameter = -1.
    n_nodes  = G.number_of_nodes()
    n_edges  = G.number_of_edges()
    transitivity = nx.transitivity(G)
    clustering   = nx.average_clustering(G)

    print(f"=" * 50 +
          f"{descriptor}\n"
          f'\nN Nodes: {n_nodes}\n'
          f'N Edges: {n_edges}\n'
          f'Num Comms: {num_communities}\n'
          f'Median Comm Size: {median_size}\n'
          f'Density: {density}\n'
          f'Diameter: {diameter}\n'
          f'Transitivity: {transitivity}\n'
          f'Clustering: {clustering}\n' + "=" * 50)

    if return_values:
        return {"N Nodes": n_nodes,
               "N Edges": n_edges,
               "Density": density,
               "Diameter": diameter,
               "Transitivity": transitivity,
               "Clustering": clustering}

In [None]:
# G_baseline = load_original("cora")

# with open("../sampling_outputs/cora_sampled_forced_edges.pkl", "rb") as f:
#     G_sampled  = pickle.load(f)
    
# print(G_baseline, G_sampled)

In [None]:
resolutions = list(range(1,20))
print(resolutions)

def communities_to_edges(partitions, G):
    partition_dict = {}
    partitions_h1  = {}
    graphs = []
    for run in partitions:
        for i, p in enumerate(run):
            subg = G.subgraph(p)
            graphs.append(subg)
            

    return graphs

def make_meta_graph(partition, G):
    """
    Use community algorithm to produce meta-graph of communities and intra-links.
    Meta-graph edge weights are the number of inter-community links in original graph.
    """

    # Get community partition of form {"community_id":{"node_id1", "node_id2", "node_id3",...}, ...}
    community_to_node = {i: p for i, p in enumerate(partition)}

    # Invert community to node, ie new partition = {"node_id":"community_id", ...}
    partition = {}
    for comm in community_to_node:
        nodes = community_to_node[comm]
        for n in nodes:
            partition[n] = comm

    # self.partition = partition

    # Find unique community ids
    community_unique = set([k for k in community_to_node.keys()])

    # Produce a sub-graph for each community
    subgraphs = []
    for c in community_unique:
        subgraphs.append(nx.subgraph(G, community_to_node[c]))

    # Get nested list of edges in original graph
    G_edgelist = [[e1, e2] for (e1, e2) in nx.edges(G)]

    # Build nested list of edges, of form [["community_id1", "community_id2"], ["community_id3", "community_id4"], ...]
    community_edgelist = []
    for e in G_edgelist:
        comm1 = partition[e[0]]
        comm2 = partition[e[1]]

        community_edgelist.append((comm1, comm2))

    # Find unique edges that are inter-community
    unique_comm_edges = list(set(community_edgelist))
    out_edges = []
    for e in unique_comm_edges:
        if (e[1], e[0]) not in out_edges and e[0] != e[1]:
            out_edges.append(e)
    unique_comm_edges = out_edges

    # Build metagraph as a weighted networkx graph
    metaG = nx.Graph()
    # metaG.add_weighted_edges_from(full_description)
    metaG.add_edges_from(unique_comm_edges)

    # Set metagraph and community subgraphs as attributes
    # self.subgraphs = {i:g for i, g in enumerate(subgraphs)}
    return metaG

In [None]:
# original_h2_by_resolution = {}
# original_h1_by_resolution = {}
# original_statistics_by_resolution = {}

# for r in tqdm(resolutions):
#     partition = comm.louvain_communities(G_baseline, resolution=r)
    
#     h1_graphs = communities_to_edges([partition], G_baseline)
#     h2_graph  = make_meta_graph(partition, G_baseline)
    
#     original_h2_by_resolution[r] = h2_graph
#     original_h1_by_resolution[r] = h1_graphs
    
# original_h2_by_resolution = {}
# original_h1_by_resolution = {}
# original_statistics_by_resolution = {}

# for r in tqdm(resolutions):
#     partition = comm.louvain_communities(G_baseline, resolution=r)
    
#     h1_graphs = communities_to_edges([partition], G_baseline)
#     h2_graph  = make_meta_graph(partition, G_baseline)
    
#     original_h2_by_resolution[r] = h2_graph
#     original_h1_by_resolution[r] = h1_graphs
    
    
        

In [None]:
import matplotlib.pyplot as plt
from scipy.stats import kstest, ks_2samp, probplot

def ks_test_report(G_baseline, G_sampled, algorithm=nx.degree, hist_range=(0,100), name = "cora"):
    degrees_original = all_metric([G_baseline], algorithm)
    degrees_sampled = all_metric([G_sampled], algorithm)
    
#     excluded_values = [0., 1., 2.]
    
#     degrees_original = [d for d in degrees_original if float(d) not in excluded_values]
#     degrees_sampled = [d for d in degrees_sampled if float(d) not in excluded_values]

    report = ks_2samp(degrees_original, degrees_sampled)
    print(report)
    print(f"\nFor {name}:\n"
          f"Statistic: {report[0]}\n"
          f"P value: {report[1]}\n")
    
    return report[0], report[1]


def correlations(G_baseline, G_sampled, algorithm=nx.degree, name = "cora"):
    degrees_original = all_metric([G_baseline], algorithm, largest_cc=False)
    degrees_sampled = all_metric([G_sampled], algorithm, largest_cc=False)
    
    if algorithm == nx.degree:
        hist_range = (0,100)
    else:
        hist_range = (0., 1.)
    
    
    counts_1, bins = np.histogram(degrees_original, bins=100, range = hist_range)

    counts_2, bins = np.histogram(degrees_sampled, bins=100, range = hist_range)
    
    corr = np.corrcoef(counts_1, counts_2)
    
    print(f"Correlation coefficient of {corr[0,1]:.3} for {name}")
    
    
    
def all_metric(graphs, metric_algorithm, largest_cc = True):
    
    
    if largest_cc:
        CGs = [graphs[0].subgraph(c) for c in nx.connected_components(graphs[0])]
        CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
        graphs[0] = CGs[0].copy()
        
        # graphs[0] = graphs[0].copy().remove_edges_from(nx.selfloop_edges(graphs[0]))
    
    # for g in graphs:
    #     g.remove_edges_from(nx.selfloop_edges(g))
        
    metric_list = []
    
    for g in graphs:
        for n in list(g.nodes()):
            # Some algorithms fail to converge
            try:
                metric_list.append(metric_algorithm(g, n))
            except:
                pass
            
    return metric_list

class reformat_centrality:
    def __init__(self, G):
        self.betweeness_dict = nx.betweenness_centrality(G)
        
    def for_node(self, node):
        return self.betweeness_dict[node]


def plot_vs(degrees_original, degrees_sampled, degrees_bter, ax, hist_range = (0,100), residual=False):
    
    
    if residual:
        counts_1, bins = np.histogram(degrees_original, bins=75, range = hist_range)
        counts_1 = counts_1 / np.sum(counts_1)
        
        counts_2, bins = np.histogram(degrees_sampled, bins=75, range = hist_range)
        counts_2 = counts_2 / np.sum(counts_2)
        
        counts_3, bins = np.histogram(degrees_bter, bins=75, range = hist_range)
        counts_3 = counts_3 / np.sum(counts_3)
        # ax.errorbar(bins[:-1], counts, xerr=bins[:-1] - bins[1:], label = "Real (Original)")
        ax.stairs(counts_1 - counts_2, bins,  label = "HiGGs", color="orange")
        ax.stairs(counts_1 - counts_3, bins,  label = "BTER", color="green")
        
        # ax.scatter(bins[:-1], counts_1 - counts_2,   label = "HiGGs", color="orange", marker = "x")
        # ax.scatter(bins[:-1], counts_1 - counts_3,  label = "BTER", color="green", marker = "x")
        
        ax.axhline(linestyle='--', color='blue', alpha=0.5)

        
        
    else:
        counts, bins = np.histogram(degrees_original, bins=75, range = hist_range)
        counts = counts / np.sum(counts)
        # ax.errorbar(bins[:-1], counts, xerr=bins[:-1] - bins[1:], label = "Real (Original)")
        ax.stairs(counts, bins,  label = "Real", color = "blue")
        # ax.scatter(bins[:-1], counts,   label = "Real", color = "blue", marker = "x")


        counts, bins = np.histogram(degrees_sampled, bins=75, range = hist_range)
        counts = counts / np.sum(counts)
        # ax.errorbar(bins[:-1], counts, xerr=bins[:-1] - bins[1:], label = "Sythetic (Sampled)")
        ax.stairs(counts, bins,  label = "HiGGs", color="orange")
        # ax.scatter(bins[:-1], counts,   label = "HiGGs", color="orange", marker = "x")
        
        counts, bins = np.histogram(degrees_bter, bins=75, range = hist_range)
        counts = counts / np.sum(counts)
        # ax.errorbar(bins[:-1], counts, xerr=bins[:-1] - bins[1:], label = "Sythetic (Sampled)")
        ax.stairs(counts, bins,  label = "BTER", color="green")
        # ax.scatter(bins[:-1], counts,   label = "BTER", color="green", marker = "x")
        

def vis_big_graph(G, largest_cc=False, label = ""):
    
    if largest_cc:
        CGs = [G.subgraph(c) for c in nx.connected_components(G)]
        CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
        G = CGs[0]
    
    pos = nx.drawing.nx_agraph.graphviz_layout(G, prog="sfdp", args='-Gsmoothing')

    fig, (ax) = plt.subplots(ncols=1, figsize=(6,6))

    nx.draw_networkx_edges(G, node_size=2, pos=pos, alpha=0.5, ax = ax)
    try:
        nx.draw_networkx_nodes(G, node_size = 1, pos = pos, ax = ax,
                               node_color=[node[1]["target"] for node in G.nodes(data=True)])
    except:
        pass
    
    # ax.set_title(label)
    ax.axis('off')
    
    plt.tight_layout(h_pad=0, w_pad=0, pad=0)
    plt.savefig(f"{label}.png", dpi=600)
    plt.show()
        
def clean(data, intervals, edge_size=2):
    return data[edge_size:-edge_size]
        
    
def plot_all(G_baseline, G_sampled, G_bter, name = "cora", clean=False):

    
    fig, ((ax1, ax2, ax3),(ax4,ax5,ax6)) = plt.subplots(ncols=3, nrows=2, figsize=(12,7), sharex="col")


    degrees_original = all_metric([G_baseline], nx.degree)
    degrees_sampled = all_metric([G_sampled], nx.degree)
    degrees_bter = all_metric([G_bter], nx.degree)
    plot_vs(degrees_original, degrees_sampled, degrees_bter, ax1, hist_range=(0,75))
    # ax1.set_xticks([])
    
    degrees_original = all_metric([G_baseline], nx.clustering)
    degrees_sampled = all_metric([G_sampled], nx.clustering)
    degrees_bter = all_metric([G_bter], nx.clustering)
    plot_vs(degrees_original, degrees_sampled, degrees_bter, ax2, hist_range=(0,1))
    # ax2.yaxis.tick_right()
    # ax2.set_xticks([])
    
    degrees_original = all_metric([G_baseline], nx.eccentricity)
    degrees_sampled = all_metric([G_sampled], nx.eccentricity)
    degrees_bter = all_metric([G_bter], nx.eccentricity)
    plot_vs(degrees_original, degrees_sampled, degrees_bter, ax3, hist_range=(1,50))
    
    # ax3.yaxis.tick_right()
    # ax3.set_xticks([])

    degrees_original = all_metric([G_baseline], nx.degree)
    degrees_sampled = all_metric([G_sampled], nx.degree)
    degrees_bter = all_metric([G_bter], nx.degree)
    plot_vs(degrees_original, degrees_sampled, degrees_bter, ax4, hist_range=(0,75), residual=True)
    ax4.set_xlabel("Degree")

    degrees_original = all_metric([G_baseline], nx.clustering)
    degrees_sampled = all_metric([G_sampled], nx.clustering)
    degrees_bter = all_metric([G_bter], nx.clustering)
    plot_vs(degrees_original, degrees_sampled, degrees_bter, ax5, hist_range=(0,1), residual=True)
    ax5.set_xlabel("Clustering")
    # ax5.yaxis.tick_right()
    
    degrees_original = all_metric([G_baseline], nx.eccentricity)
    degrees_sampled = all_metric([G_sampled], nx.eccentricity)
    degrees_bter = all_metric([G_bter], nx.eccentricity)
    plot_vs(degrees_original, degrees_sampled, degrees_bter, ax6, residual=True, hist_range=(1,50))
    ax6.set_xlabel("Eccentricity")
    # ax6.yaxis.tick_right()
    # ax6.set_xticks([])
    
    # stat, prob = ks_test_report(G_baseline, G_sampled, algorithm=nx.degree, hist_range=(0,75), name=name+" degree")
    # ax3.set_title(f"KS Stat: {stat:.3}, P={prob:.3}")
    
    # stat, prob = ks_test_report(G_baseline, G_sampled, algorithm=nx.clustering, hist_range=(0,1), name=name+" clustering")
    # ax4.set_title(f"KS Stat: {stat:.3}, P={prob:.3}")
    
    correlations(G_baseline, G_sampled, algorithm=nx.degree, name=name+" degree HiGGs")
    correlations(G_baseline, G_bter, algorithm=nx.degree, name=name+" degree BTER")
    
    correlations(G_baseline, G_sampled, algorithm=nx.clustering, name=name+" clustering HiGGs")
    correlations(G_baseline, G_bter, algorithm=nx.clustering, name=name+" clustering BTER")

    ax1.set_ylabel("Proportion")
    ax4.set_ylabel("Difference (Original - Synthetic)")
    ax3.legend(shadow=True)
    # ax4.legend(shadow=True)
    plt.tight_layout(h_pad=0)
    
    plt.savefig(f"{name}_comparison.png")

    plt.show()

    
def plot_all_comms(G_baseline, G_sampled, G_bter, name = "cora", clean=False):
    
    fig, ((ax1, ax2, ax3),(ax4,ax5,ax6)) = plt.subplots(ncols=3, nrows=2, figsize=(12,7), sharex="col")


    degrees_original = all_metric(G_baseline, nx.degree, largest_cc=False)
    degrees_sampled = all_metric(G_sampled, nx.degree, largest_cc=False)
    degrees_bter = all_metric(G_bter, nx.degree, largest_cc=False)
    plot_vs(degrees_original, degrees_sampled, degrees_bter, ax1, hist_range=(0,75))
    # ax1.set_xticks([])
    
    degrees_original = all_metric(G_baseline, nx.clustering, largest_cc=False)
    degrees_sampled = all_metric(G_sampled, nx.clustering, largest_cc=False)
    degrees_bter = all_metric(G_bter, nx.clustering, largest_cc=False)
    plot_vs(degrees_original, degrees_sampled, degrees_bter, ax2, hist_range=(0,1))
    # ax2.yaxis.tick_right()
    # ax2.set_xticks([])
    
    degrees_original = all_metric(G_baseline, nx.eccentricity, largest_cc=False)
    degrees_sampled = all_metric(G_sampled, nx.eccentricity, largest_cc=False)
    degrees_bter = all_metric(G_bter, nx.eccentricity, largest_cc=False)
    plot_vs(degrees_original, degrees_sampled, degrees_bter, ax3, hist_range=(1,50))
    
    # ax3.yaxis.tick_right()
    # ax3.set_xticks([])

    degrees_original = all_metric(G_baseline, nx.degree, largest_cc=False)
    degrees_sampled = all_metric(G_sampled, nx.degree, largest_cc=False)
    degrees_bter = all_metric(G_bter, nx.degree, largest_cc=False)
    plot_vs(degrees_original, degrees_sampled, degrees_bter, ax4, hist_range=(0,75), residual=True)
    ax4.set_xlabel("Degree")

    degrees_original = all_metric(G_baseline, nx.clustering, largest_cc=False)
    degrees_sampled = all_metric(G_sampled, nx.clustering, largest_cc=False)
    degrees_bter = all_metric(G_bter, nx.clustering, largest_cc=False)
    plot_vs(degrees_original, degrees_sampled, degrees_bter, ax5, hist_range=(0,1), residual=True)
    ax5.set_xlabel("Clustering")
    # ax5.yaxis.tick_right()
    
    degrees_original = all_metric(G_baseline, nx.eccentricity, largest_cc=False)
    degrees_sampled = all_metric(G_sampled, nx.eccentricity, largest_cc=False)
    degrees_bter = all_metric(G_bter, nx.eccentricity, largest_cc=False)
    plot_vs(degrees_original, degrees_sampled, degrees_bter, ax6, residual=True, hist_range=(1,50))
    ax6.set_xlabel("Eccentricity")
    # ax6.yaxis.tick_right()
    # ax6.set_xticks([])
    
    # stat, prob = ks_test_report(G_baseline, G_sampled, algorithm=nx.degree, hist_range=(0,75), name=name+" degree")
    # ax3.set_title(f"KS Stat: {stat:.3}, P={prob:.3}")
    
    # stat, prob = ks_test_report(G_baseline, G_sampled, algorithm=nx.clustering, hist_range=(0,1), name=name+" clustering")
    # ax4.set_title(f"KS Stat: {stat:.3}, P={prob:.3}")
    
    # correlations(G_baseline, G_sampled, algorithm=nx.degree, name=name+" degree HiGGs")
    # correlations(G_baseline, G_bter, algorithm=nx.degree, name=name+" degree BTER")
    
    # correlations(G_baseline, G_sampled, algorithm=nx.clustering, name=name+" clustering HiGGs")
    # correlations(G_baseline, G_bter, algorithm=nx.clustering, name=name+" clustering BTER")

    ax1.set_ylabel("Proportion")
    ax4.set_ylabel("Difference (Original - Synthetic)")
    ax3.legend(shadow=True)
    # ax4.legend(shadow=True)
    plt.tight_layout(h_pad=0)
    
    plt.savefig(f"{name}_comparison.png")

    plt.show()
    

In [None]:
# G_baseline = load_original("fb_hierarchies")
# vis_big_graph(G_baseline, label="Real_fb")

# with open("../sampling_outputs/fb_sampled_forced_edges.pkl", "rb") as f:
#     G_sampled  = pickle.load(f)
# vis_big_graph(G_sampled, label="HiGGs_fb")
    
# bter = BTER(G_baseline)
# bter.fit(1)
# G_bter = bter.sample()
# G_bter.remove_edges_from(nx.selfloop_edges(G_bter))
# vis_big_graph(G_bter, label="BTER_fb")
    
# # print(G_baseline, G_sampled, G_bter)    
# plot_all(G_baseline, G_sampled, G_bter, name="fb")
# general_graph_metrics(G_baseline, descriptor="Real_FB")
# general_graph_metrics(G_sampled, descriptor="HiGGs_FB")
# general_graph_metrics(G_bter, descriptor="BTER_FB")

In [None]:

    
# G_baseline = load_original("cora")
# with open("../sampling_outputs/cora_sampled_forced_edges.pkl", "rb") as f:
#     G_sampled  = pickle.load(f)
    
# print(G_baseline, G_sampled)    
# ks_test_report(G_baseline, G_sampled)

# G_baseline = load_original("fb_hierarchies")
# with open("../sampling_outputs/fb_sampled_forced_edges.pkl", "rb") as f:
#     G_sampled  = pickle.load(f)
    
# print(G_baseline, G_sampled)    
# ks_test_report(G_baseline, G_sampled, algorithm=nx.degree, name="fb")
    

In [None]:
import numbers
# Function from https://stats.stackexchange.com/questions/403652/two-sample-quantile-quantile-plot-in-python
def qqplot(x, y, quantiles=None, interpolation='nearest', ax=None, rug=False, color="black",
           rug_length=0.05, rug_kwargs=None, label = "", **kwargs):
    """Draw a quantile-quantile plot for `x` versus `y`.

    Parameters
    ----------
    x, y : array-like
        One-dimensional numeric arrays.

    ax : matplotlib.axes.Axes, optional
        Axes on which to plot. If not provided, the current axes will be used.

    quantiles : int or array-like, optional
        Quantiles to include in the plot. This can be an array of quantiles, in
        which case only the specified quantiles of `x` and `y` will be plotted.
        If this is an int `n`, then the quantiles will be `n` evenly spaced
        points between 0 and 1. If this is None, then `min(len(x), len(y))`
        evenly spaced quantiles between 0 and 1 will be computed.

    interpolation : {‘linear’, ‘lower’, ‘higher’, ‘midpoint’, ‘nearest’}
        Specify the interpolation method used to find quantiles when `quantiles`
        is an int or None. See the documentation for numpy.quantile().

    rug : bool, optional
        If True, draw a rug plot representing both samples on the horizontal and
        vertical axes. If False, no rug plot is drawn.

    rug_length : float in [0, 1], optional
        Specifies the length of the rug plot lines as a fraction of the total
        vertical or horizontal length.

    rug_kwargs : dict of keyword arguments
        Keyword arguments to pass to matplotlib.axes.Axes.axvline() and
        matplotlib.axes.Axes.axhline() when drawing rug plots.

    kwargs : dict of keyword arguments
        Keyword arguments to pass to matplotlib.axes.Axes.scatter() when drawing
        the q-q plot.
    """
    # Get current axes if none are provided
    if ax is None:
        ax = plt.gca()

    if quantiles is None:
        quantiles = min(len(x), len(y))

    # Compute quantiles of the two samples
    if isinstance(quantiles, numbers.Integral):
        quantiles = np.linspace(start=0, stop=1, num=int(quantiles))
    else:
        quantiles = np.atleast_1d(np.sort(quantiles))
    x_quantiles = np.quantile(x, quantiles, interpolation=interpolation)
    y_quantiles = np.quantile(y, quantiles, interpolation=interpolation)

    # Draw the rug plots if requested
    if rug:
        # Default rug plot settings
        rug_x_params = dict(ymin=0, ymax=rug_length, c='gray', alpha=0.5)
        rug_y_params = dict(xmin=0, xmax=rug_length, c='gray', alpha=0.5)

        # Override default setting by any user-specified settings
        if rug_kwargs is not None:
            rug_x_params.update(rug_kwargs)
            rug_y_params.update(rug_kwargs)

        # Draw the rug plots
        for point in x:
            ax.axvline(point, **rug_x_params)
        for point in y:
            ax.axhline(point, **rug_y_params)

    # Draw the q-q plot
    ax.scatter(x_quantiles, y_quantiles, s = 10, color=color, marker='x',**kwargs)
    ax.plot(x_quantiles, y_quantiles, alpha=0.5, color=color, label = label)
    
    return ax, x_quantiles, y_quantiles

In [None]:
def square_limits(ax):
    
    xlims = list(ax.get_xlim())
    ylims = list(ax.get_ylim())
    
    lims = xlims + ylims
    print(lims)
    max_lim = max(lims)
    min_lim = min(lims)
    
    ax.set_xlim([min_lim, max_lim])
    ax.set_ylim([min_lim, max_lim])
    
    # return ax

def qq_plots(G_baseline, G_sampled, G_bter, name = "..."):
    fig, (ax1, ax2, ax3) = plt.subplots(ncols=3, figsize=(12,4))

    ax1, degree_real, degree_higgs = qqplot(all_metric([G_baseline], nx.degree), all_metric([G_sampled], nx.degree), quantiles=100,    ax=ax1, label = "HiGGs", color="orange")
    ax2, clustering_real, clustering_higgs  = qqplot(all_metric([G_baseline], nx.clustering), all_metric([G_sampled], nx.clustering), quantiles=100,  ax=ax2, label = "HiGGs", color="orange")
    ax3, eccentricity_real, eccentricity_higgs  = qqplot(all_metric([G_baseline], nx.eccentricity), all_metric([G_sampled], nx.eccentricity), quantiles=100,  ax=ax3, label = "HiGGs", color="orange")
    
    ax1, degree_real, degree_bter  = qqplot(all_metric([G_baseline], nx.degree), all_metric([G_bter], nx.degree), quantiles=100,    ax=ax1, label = "BTER", color="green")
    ax2, clustering_real, clustering_bter  = qqplot(all_metric([G_baseline], nx.clustering), all_metric([G_bter], nx.clustering), quantiles=100,  ax=ax2, label = "BTER", color="green")
    ax3, eccentricity_real, eccentricity_bter  = qqplot(all_metric([G_baseline], nx.eccentricity), all_metric([G_bter], nx.eccentricity), quantiles=100,  ax=ax3, label = "BTER", color="green")
    
    # ax1 = qqplot(all_metric([G_baseline], nx.degree), all_metric([G_baseline], nx.degree), quantiles=100,    ax=ax1, label = "Real", color="blue")
    # ax2 = qqplot(all_metric([G_baseline], nx.clustering), all_metric([G_baseline], nx.clustering), quantiles=100,  ax=ax2, label = "Real", color="blue")
    # ax3 = qqplot(all_metric([G_baseline], nx.eccentricity), all_metric([G_baseline], nx.eccentricity), quantiles=100,  ax=ax3, label = "Real", color="blue")
    # ax1.set_xlim([0,40])
    # ax1.set_ylim([0,40])

    # ax1.plot([1, ax1.get_xlim()[1]], [1, ax1.get_ylim()[1]], color="black", alpha=0.5, linestyle='--')
    # ax2.plot([0, ax2.get_xlim()[1]], [0, ax2.get_ylim()[1]], color="black", alpha=0.5, linestyle='--')
    # ax3.plot([0, ax3.get_xlim()[1]], [0, ax3.get_ylim()[1]], color="black", alpha=0.5, linestyle='--')

    ax1.set_xlabel('Original Quantiles')
    ax1.set_ylabel('Sampled Quantiles')
    ax1.set_title('Degree')

    ax2.set_xlabel('Original Quantiles')
    ax2.set_ylabel('Sampled Quantiles')
    ax2.set_title('Clustering')
    
    ax3.set_xlabel('Original Quantiles')
    ax3.set_ylabel('Sampled Quantiles')
    ax3.set_title('Eccentricity')
    
    # square_limits(ax1)
    # square_limits(ax2)
    # ax3 = square_limits(ax3)
    
    ax1.set_xscale('log')
    ax1.set_yscale('log')
    
    ax2.legend(shadow=True)
    
    df = pd.DataFrame()
    df["Degree_quantiles_real"] = degree_real
    df["Degree_quantiles_higgs"] = degree_higgs
    df["Degree_quantiles_bter"] = degree_bter
    
    df["Clustering_quantiles_real"] = clustering_real
    df["Clustering_quantiles_higgs"] = clustering_higgs
    df["Clustering_quantiles_bter"] = clustering_bter

    df["Eccentricity_quantiles_real"] = eccentricity_real
    df["Eccentricity_quantiles_higgs"] = eccentricity_higgs
    df["Eccentricity_quantiles_bter"] = eccentricity_bter
    
    df.to_csv(f"{name}.csv")
    
    # plt.suptitle(name)
    
    plt.tight_layout()
    
    plt.savefig(f"{name}.png")
    
    plt.show()
    
    
def qq_plots_comms(G_baseline, G_sampled, G_bter, name = "..."):
    fig, (ax1, ax2, ax3) = plt.subplots(ncols=3, figsize=(12,4))

    ax1, degree_real, degree_higgs  = qqplot(all_metric(G_baseline, nx.degree, largest_cc=False), all_metric(G_sampled, nx.degree, largest_cc=False), quantiles=100,    ax=ax1, label = "HiGGs", color="orange")
    ax2, clustering_real, clustering_higgs   = qqplot(all_metric(G_baseline, nx.clustering, largest_cc=False), all_metric(G_sampled, nx.clustering, largest_cc=False), quantiles=100,  ax=ax2, label = "HiGGs", color="orange")
    ax3, eccentricity_real, eccentricity_higgs   = qqplot(all_metric(G_baseline, nx.eccentricity, largest_cc=False), all_metric(G_sampled, nx.eccentricity, largest_cc=False), quantiles=100,  ax=ax3, label = "HiGGs", color="orange")
    
    ax1, degree_real, degree_bter   = qqplot(all_metric(G_baseline, nx.degree, largest_cc=False), all_metric(G_bter, nx.degree, largest_cc=False), quantiles=100,    ax=ax1, label = "BTER", color="green")
    ax2, clustering_real, clustering_bter   = qqplot(all_metric(G_baseline, nx.clustering, largest_cc=False), all_metric(G_bter, nx.clustering, largest_cc=False), quantiles=100,  ax=ax2, label = "BTER", color="green")
    ax3, eccentricity_real, eccentricity_bter   = qqplot(all_metric(G_baseline, nx.eccentricity, largest_cc=False), all_metric(G_bter, nx.eccentricity, largest_cc=False), quantiles=100,  ax=ax3, label = "BTER", color="green")
    
    # ax1 = qqplot(all_metric(G_baseline, nx.degree, largest_cc=False), all_metric(G_baseline, nx.degree, largest_cc=False), quantiles=100,    ax=ax1, label = "Real", color="blue")
    # ax2 = qqplot(all_metric(G_baseline, nx.clustering, largest_cc=False), all_metric(G_baseline, nx.clustering, largest_cc=False), quantiles=100,  ax=ax2, label = "Real", color="blue")
    # ax3 = qqplot(all_metric(G_baseline, nx.eccentricity, largest_cc=False), all_metric(G_baseline, nx.eccentricity, largest_cc=False), quantiles=100,  ax=ax3, label = "Real", color="blue")
    # ax1.set_xlim([0,40])
    # ax1.set_ylim([0,40])

    # ax1.plot([1, ax1.get_xlim()[1]], [1, ax1.get_ylim()[1]], color="black", alpha=0.5, linestyle='--')
    # ax2.plot([0, ax2.get_xlim()[1]], [0, ax2.get_ylim()[1]], color="black", alpha=0.5, linestyle='--')
    # ax3.plot([0, ax3.get_xlim()[1]], [0, ax3.get_ylim()[1]], color="black", alpha=0.5, linestyle='--')

    ax1.set_xlabel('Original Quantiles')
    ax1.set_ylabel('Sampled Quantiles')
    ax1.set_title('Degree')

    ax2.set_xlabel('Original Quantiles')
    ax2.set_ylabel('Sampled Quantiles')
    ax2.set_title('Clustering')
    
    ax3.set_xlabel('Original Quantiles')
    ax3.set_ylabel('Sampled Quantiles')
    ax3.set_title('Eccentricity')
    
    # square_limits(ax1)
    # square_limits(ax2)
    # ax3 = square_limits(ax3)
    
    ax1.set_xscale('log')
    ax1.set_yscale('log')
    
    ax2.legend(shadow=True)
    
    
    # plt.suptitle(name)
    
    plt.tight_layout()
    
    plt.savefig(f"{name}.png")
    
    plt.show()
    # plt.close()
    
    df = pd.DataFrame()
    df["Degree_quantiles_real"] = degree_real
    df["Degree_quantiles_higgs"] = degree_higgs
    df["Degree_quantiles_bter"] = degree_bter
    
    df["Clustering_quantiles_real"] = clustering_real
    df["Clustering_quantiles_higgs"] = clustering_higgs
    df["Clustering_quantiles_bter"] = clustering_bter

    df["Eccentricity_quantiles_real"] = eccentricity_real
    df["Eccentricity_quantiles_higgs"] = eccentricity_higgs
    df["Eccentricity_quantiles_bter"] = eccentricity_bter
    
    df.to_csv(f"{name}.csv")


# G_baseline = load_original("cora")
# with open("../sampling_outputs/cora_sampled_forced_edges.pkl", "rb") as f:
#     G_sampled  = pickle.load(f)

# bter = BTER(G_baseline)
# bter.fit(5)
# G_bter = bter.sample()
# G_bter.remove_edges_from(nx.selfloop_edges(G_bter))
    


In [None]:
# G_baseline = load_original("fb_hierarchies")
# with open("../sampling_outputs/fb_sampled_forced_edges.pkl", "rb") as f:
#     G_sampled  = pickle.load(f)
    
# bter = BTER(G_baseline)
# bter.fit(1)
# G_bter = bter.sample()
# G_bter.remove_edges_from(nx.selfloop_edges(G_bter))

# qq_plots(G_baseline, G_sampled, G_bter, name="Facebook Page-Page")

In [None]:
import sys
os.chdir("../../")
pwd = os.getcwd()
dgd_dir = os.path.join(pwd, "dgd")
sys.path.insert(0, dgd_dir)
os.chdir("higgs/notepad_code")

In [None]:
from analysis.spectre_utils import FullSampleMetrics, ByClassSampleMetrics, PostProcessSampleMetrics

# print(G_baseline, G_sampled, G_bter)    

# general_graph_metrics(G_baseline, descriptor="Real_Cora")
# general_graph_metrics(G_sampled, descriptor="HiGGs_Cora")
# general_graph_metrics(G_bter, descriptor="BTER_Cora")

In [None]:
# G_baseline = load_original("cora")
# vis_big_graph(G_baseline, label="Real_cora")

# with open("/outputs/2023-03-31/09-32-43/sampling/sampled_graph.pkl", "rb") as f:
#     G_sampled  = pickle.load(f)
# vis_big_graph(G_sampled, label="HiGGs_cora")

# bter = BTER(G_baseline)
# bter.fit(1)
# G_bter = bter.sample()
# G_bter.remove_edges_from(nx.selfloop_edges(G_bter))


# vis_big_graph(G_bter, label="BTER_cora", largest_cc = False)

# CGs = [G_bter.subgraph(c) for c in nx.connected_components(G_bter)]
# CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
# G_bter = CGs[0]


# # plot_all(G_baseline, G_sampled, G_bter)

# qq_plots(G_baseline, G_sampled, G_bter, name=f"CORA_QQ")

# # general_graph_metrics(G_baseline, descriptor="Real_Cora")
# # general_graph_metrics(G_sampled, descriptor="HiGGs_Cora")
# # general_graph_metrics(G_bter, descriptor="BTER_Cora")

# # sampling_metrics = PostProcessSampleMetrics([G_baseline], use_wandb=False)
# # print("\nHiGGs\n---------------------------------------------------------------------------")
# # sampling_metrics([G_sampled], "MMD scores HiGGs CORA", 0, 0)
# # print("\nBTER\n---------------------------------------------------------------------------")
# # sampling_metrics([G_bter], "MMD scores BTER CORA", 0, 0)


In [None]:

# G_baseline = load_original("fb_hierarchies")
# vis_big_graph(G_baseline, label="Real_fb")
# with open("../sampling_outputs/fb_sampled_forced_edges.pkl", "rb") as f:
#     G_sampled  = pickle.load(f)
# vis_big_graph(G_sampled, label="HiGGs_fb")
    
# bter = BTER(G_baseline)
# bter.fit(1)
# G_bter = bter.sample()
# G_bter.remove_edges_from(nx.selfloop_edges(G_bter))



# vis_big_graph(G_bter, label="BTER_fb", largest_cc = False)

# CGs = [G_bter.subgraph(c) for c in nx.connected_components(G_bter)]
# CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
# G_bter = CGs[0]

# # plot_all(G_baseline, G_sampled, G_bter, name="fb")

# qq_plots(G_baseline, G_sampled, G_bter, name=f"FB_QQ")


# # general_graph_metrics(G_baseline, descriptor="Real_fb")
# # general_graph_metrics(G_sampled, descriptor="HiGGs_fb")
# # general_graph_metrics(G_bter, descriptor="BTER_fb")

# # sampling_metrics = PostProcessSampleMetrics([G_baseline], use_wandb=False)
# # print("\nHiGGs\n---------------------------------------------------------------------------")
# # sampling_metrics([G_sampled], "MMD scores HiGGs FB", 0, 0)
# # print("\nBTER\n---------------------------------------------------------------------------")
# # sampling_metrics([G_bter], "MMD scores BTER FB", 0, 0)


In [None]:
def vis_communities(graphs, largest_cc=False, label = ""):
    
    G = nx.Graph()
    
    for g in graphs:
        g = nx.convert_node_labels_to_integers(g, first_label = G.order())
        G = nx.compose(G, g)
    
    if largest_cc:
        CGs = [G.subgraph(c) for c in nx.connected_components(G)]
        CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
        G = CGs[0]
    
    pos = nx.drawing.nx_agraph.graphviz_layout(G, prog="sfdp", args='-Gsmoothing')

    fig, (ax) = plt.subplots(ncols=1, figsize=(6,6))

    nx.draw_networkx_edges(G, node_size=2, pos=pos, alpha=0.5, ax = ax)
    try:
        nx.draw_networkx_nodes(G, node_size = 1, pos = pos, ax = ax,
                               node_color=[node[1]["target"] for node in G.nodes(data=True)])
    except:
        pass
    
    # ax.set_title(label)
    ax.axis('off')
    
    plt.tight_layout(h_pad=0, w_pad=0, pad=0)
    plt.savefig(f"{label}.png", dpi=600)
    plt.show()

In [None]:
G_baseline = load_original("cora")
communities_baseline = comm.louvain_communities(G_baseline, resolution=1)
communities_baseline = [G_baseline.subgraph(community) for community in communities_baseline]

vis_big_graph(G_baseline, label="Real_cora")
vis_communities(communities_baseline[:5], label = "Cora_real_comms")
#================================================================================================================================================

with open("../sampling_outputs/cora-clustering/sampling/sampled_graph.pkl", "rb") as f:
    G_sampled  = pickle.load(f)
    
communities_sampled = comm.louvain_communities(G_sampled, resolution=1)
communities_sampled = [G_sampled.subgraph(community) for community in communities_sampled]
vis_big_graph(G_sampled, label="HiGGs_cora")
vis_communities(communities_sampled[:5], label = "Cora_HiGGs_comms")

#================================================================================================================================================

bter = BTER(G_baseline)
bter.fit(1)
G_bter = bter.sample()
G_bter.remove_edges_from(nx.selfloop_edges(G_bter))
vis_big_graph(G_bter, label="BTER_cora", largest_cc = False)

CGs = [G_bter.subgraph(c) for c in nx.connected_components(G_bter)]
CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
G_bter = CGs[0]

communities_bter = comm.louvain_communities(G_bter, resolution=1)
communities_bter = [G_bter.subgraph(community) for community in communities_bter]


vis_communities(communities_bter[:5], label = "Cora_BTER_comms")

#================================================================================================================================================

# plot_all(G_baseline, G_sampled, G_bter, name="Cora")

qq_plots(G_baseline, G_sampled, G_bter, name=f"CORA_QQ")

In [None]:
#================================================================================================================================================

# general_graph_metrics(G_baseline, descriptor="Real_Cora")
# general_graph_metrics(G_sampled, descriptor="HiGGs_Cora")
# general_graph_metrics(G_bter, descriptor="BTER_Cora")

# sampling_metrics = PostProcessSampleMetrics([G_baseline], use_wandb=False)
# print("\nHiGGs\n---------------------------------------------------------------------------")
# sampling_metrics([G_sampled], "MMD scores HiGGs CORA", 0, 0)
# print("\nBTER\n---------------------------------------------------------------------------")
# sampling_metrics([G_bter], "MMD scores BTER CORA", 0, 0)

# plot_all_comms(communities_baseline, communities_sampled, communities_bter, name="Cora_Comm")

qq_plots_comms(communities_baseline, communities_sampled, communities_bter, name=f"CORA_QQ_Comm")

# sampling_metrics = PostProcessSampleMetrics(communities_baseline, use_wandb=False)
# # print("\nVS Real\n---------------------------------------------------------------------------")
# # sampling_metrics([G_baseline.subgraph(community) for community in comm.louvain_communities(G_baseline, resolution=1)], "MMD scores VS Real CORA", 0, 0)
# print("\nHiGGs\n---------------------------------------------------------------------------")
# sampling_metrics(communities_sampled, "MMD scores HiGGs CORA", 0, 0)
# print("\nBTER\n---------------------------------------------------------------------------")
# sampling_metrics(communities_bter, "MMD scores BTER CORA", 0, 0)

In [None]:
G_baseline = load_original("fb_hierarchies")

communities_baseline = comm.louvain_communities(G_baseline, resolution=1)
communities_baseline = [G_baseline.subgraph(community) for community in communities_baseline]
vis_big_graph(G_baseline, label="Real_fb", )
vis_communities(communities_baseline[:5], label = "FB_real_comms")
#================================================================================================================================================

with open("../sampling_outputs/fb-clustering/sampling/fb_sampled_forced_edges.pkl", "rb") as f:
    G_sampled  = pickle.load(f)
    
communities_sampled = comm.louvain_communities(G_sampled, resolution=1)
communities_sampled = [G_sampled.subgraph(community) for community in communities_sampled]    

vis_big_graph(G_sampled, label="HiGGs_fb")
vis_communities(communities_sampled[:5], label = "FB_HiGGs_comms")
#================================================================================================================================================
    
bter = BTER(G_baseline)
bter.fit(1)
G_bter = bter.sample()
G_bter.remove_edges_from(nx.selfloop_edges(G_bter))
vis_big_graph(G_bter, label="BTER_fb", largest_cc = False)

CGs = [G_bter.subgraph(c) for c in nx.connected_components(G_bter)]
CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
G_bter = CGs[0]

communities_bter = comm.louvain_communities(G_bter, resolution=1)
communities_bter = [G_bter.subgraph(community) for community in communities_bter]


vis_communities(communities_bter[:5], label = "FB_bter_comms")

#================================================================================================================================================

# plot_all(G_baseline, G_sampled, G_bter, name="fb")

qq_plots(G_baseline, G_sampled, G_bter, name=f"FB_QQ")

#================================================================================================================================================

In [None]:
# general_graph_metrics(G_baseline, descriptor="Real_fb")
# general_graph_metrics(G_sampled, descriptor="HiGGs_fb")
# general_graph_metrics(G_bter, descriptor="BTER_fb")

# sampling_metrics = PostProcessSampleMetrics([G_baseline], use_wandb=False)
# print("\nHiGGs\n---------------------------------------------------------------------------")
# sampling_metrics([G_sampled], "MMD scores HiGGs FB", 0, 0)
# print("\nBTER\n---------------------------------------------------------------------------")
# sampling_metrics([G_bter], "MMD scores BTER FB", 0, 0)




# CGs = [G_bter.subgraph(c) for c in nx.connected_components(G_bter)]
# CGs = sorted(CGs, key=lambda x: x.number_of_nodes(), reverse=True)
# G_bter = CGs[0]

# plot_all_comms(communities_baseline, communities_sampled, communities_bter, name="FB_Comm")

qq_plots_comms(communities_baseline, communities_sampled, communities_bter, name=f"FB_QQ_Comm")


# sampling_metrics = PostProcessSampleMetrics(communities_baseline, use_wandb=False)
# # print("\nVS Real\n---------------------------------------------------------------------------")
# # sampling_metrics([G_baseline.subgraph(community) for community in comm.louvain_communities(G_baseline, resolution=1)], "MMD scores VS Real FB", 0, 0)
# print("\nHiGGs\n---------------------------------------------------------------------------")
# sampling_metrics(communities_sampled, "MMD scores HiGGs FB", 0, 0)
# print("\nBTER\n---------------------------------------------------------------------------")
# sampling_metrics(communities_bter, "MMD scores BTER FB", 0, 0)