In [None]:
from src import GenerativeModels
import matplotlib.pyplot as plt
import numpy as np
import xgi
import json

In [None]:
def get_position(fname, H):
    try:
        with open(fname, "r") as file:
            pos_stored = json.loads(file.read())
        pos = {i: np.array(p) for i, p in pos_stored.items()}
    except:
        pos = xgi.pairwise_spring_layout(H)
        pos = transform_pos(pos)
        pos_store = {i: p.tolist() for i, p in pos.items()}
        datastring = json.dumps(pos_store)

        with open(fname, "w") as output_file:
            output_file.write(datastring)
        
    return pos

def transform_pos(pos):
    from sklearn.decomposition import PCA
    p = np.array(list(pos.values()))

    pca = PCA(n_components=2)
    pca.fit(p)
    t_p = p.dot(np.linalg.inv(pca.components_))
    xmin, ymin = np.min(t_p, axis=0)
    xmax, ymax = np.max(t_p, axis=0)

    # rescale so that its in [-1, 1] x [-1, 1]
    x = np.interp(t_p[:, 0], [xmin, xmax], [-1, 1])
    y = np.interp(t_p[:, 1], [ymin, ymax], [-1, 1])

    return {n : np.array([x[i], y[i]]) for i, n in enumerate(pos.keys())}

def get_hypergraph(fname, epsilon):
    try:
        H = xgi.read_json(fname)
    except:
        n = 100
        m = 3
        k = 2

        is_connected = False

        while not is_connected:
            edgelist = GenerativeModels.uniform_planted_partition_hypergraph(n, m, k, epsilon)

            cleaned_edgelist = list()
            for edge in edgelist:
                if len(edge) == len(set(edge)):
                    cleaned_edgelist.append(edge)

            H = xgi.Hypergraph(cleaned_edgelist)
            is_connected = xgi.is_connected(H)
        xgi.write_json(H, fname)
    return H

In [None]:
H1 = get_hypergraph("Data/vis/vis1.json", 0)
H2 = get_hypergraph("Data/vis/vis2.json", 0.75)
H3 = get_hypergraph("Data/vis/vis3.json", 0.95)

pos1 = get_position("Data/vis/pos1.json", H1)
pos2 = get_position("Data/vis/pos2.json", H2)
pos3 = get_position("Data/vis/pos3.json", H3)

In [None]:
plt.figure(figsize=(6, 12))
plt.subplot(311)
plt.title(r"$\epsilon_3 = 0$", fontsize=14)
xgi.draw(H1, pos1, node_size=10)
plt.subplot(312)
plt.title(r"$\epsilon_3 = 0.75$", fontsize=14)
xgi.draw(H2, pos2, node_size=10)
plt.subplot(313)
plt.title(r"$\epsilon_3 = 0.95$", fontsize=14)
xgi.draw(H3, pos3, node_size=10)
plt.tight_layout()
plt.savefig("Figures/Fig1/community_structure_visualization.png", dpi=1000)
plt.savefig("Figures/Fig1/community_structure_visualization.pdf", dpi=1000)
plt.show()