In [None]:
import numpy as np
np.random.seed(1234556)

from graspologic.inference import LatentPositionTest, lpt_function
from graspologic.embed import AdjacencySpectralEmbed
from graspologic.simulations import sbm, rdpg
from graspologic.utils import symmetrize
from graspologic.plot import heatmap, pairplot
from matplotlib import pyplot as plt
import pandas as pd

%matplotlib inline

In [None]:
n_components = 4 # the number of embedding dimensions for ASE
P = np.array([[0.9, 0.11, 0.13, 0.2],
              [0, 0.7, 0.1, 0.1],
              [0, 0, 0.8, 0.1],
              [0, 0, 0, 0.85]])

P = symmetrize(P)
p_vals_class = []
p_vals_function = []
p_val_diff = []
n_verts_list = []

In [None]:
for n in range(50, 201, 10):
    n_verts_list.append(n)
    p_val_class = 0
    p_val_function = 0
    for _ in range(100):
        
        A = sbm([n]*4, P)
        X = AdjacencySpectralEmbed(n_components=n_components).fit_transform(A)
        A1 = rdpg(X,
              loops=False,
              rescale=False,
              directed=False)
        A2 = rdpg(X,
              loops=False,
              rescale=False,
              directed=False)
        lpt_class = LatentPositionTest(n_bootstraps=150, n_components=n_components)
        lpt_class.fit(A1, A2)
        p_val_class += lpt_class.p_value_

        p_val, _, _ = lpt_function(A1, A2, n_bootstraps=150, n_components=n_components)
        p_val_function += p_val
    
    p_val_class /= 100 
    p_val_function /= 100
    print("c: {}".format(p_val_class))
    print("f: {}".format(p_val_function))
    p_vals_class.append(p_val_class)
    p_vals_function.append(p_val_function)
    p_val_diff.append(p_val_class-p_val_function)
    
p_vals_dict = {"p-values class": p_vals_class, "p-values function": p_vals_function}
df = pd.DataFrame(p_vals_dict)
df.to_csv('p_values_lpt.csv')

In [None]:
print(len(p_val_diff))
print(len(n_verts_list))

In [None]:
print(n_verts_list)
print(p_val_diff)

In [None]:
plt.plot(n_verts_list, p_val_diff)
plt.xlabel("n_verts")
plt.ylabel("p-value difference")
plt.title("lpt class vs function")
plt.show()