In [2]:
import numpy as np

from graspologic.inference import LatentPositionTest, lpt_function
from graspologic.embed import AdjacencySpectralEmbed
from graspologic.simulations import sbm, rdpg
from graspologic.utils import symmetrize
from graspologic.plot import heatmap, pairplot
from matplotlib import pyplot as plt
import pandas as pd
import seaborn as sns

%matplotlib inline

In [3]:
n_components = 4 # the number of embedding dimensions for ASE
P = np.array([[0.9, 0.11, 0.13, 0.2],
              [0, 0.7, 0.1, 0.1],
              [0, 0, 0.8, 0.1],
              [0, 0, 0, 0.85]])

P = symmetrize(P)
n_verts = []
p_vals_function = []
p_vals_class = []
p_vals_diff = []

In [None]:
for n in range(50, 201, 10):
    for _ in range(1):
        A1 = sbm([n]*4, P)
        A2 = sbm([n]*4, P)
        
        np.random.seed(200)
        lpt_class = LatentPositionTest(n_bootstraps=500, n_components=n_components)
        lpt_class.fit(A1, A2)
        p_val_class = lpt_class.p_value_
        
        np.random.seed(200)
        p_val_func, _, _ = lpt_function(A1, A2, n_bootstraps=500, n_components=n_components)

        p_vals_function.append(p_val_func)
        p_vals_class.append(p_val_class)
        p_vals_diff.append(p_val_func - p_val_class)
        n_verts.append(n)
        
        print(lpt_class.p_value_)
        
    
p_vals_dict = {"p-values class": p_vals_class, "p-values function": p_vals_function}
df = pd.DataFrame(data = p_vals_dict, index = n_verts)
df.to_csv('p_values_lpt.csv')

p_vals_diff_dict = {"p_value diff": p_vals_diff, "n_verts": n_verts}
df_diff = pd.DataFrame(data=p_vals_diff_dict)

0.6047904191616766
0.02594810379241517
0.16367265469061876
0.4530938123752495
0.26147704590818366
0.4231536926147705
0.8323353293413174
0.624750499001996
0.17365269461077845
0.8363273453093812
0.7624750499001997


In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 4))
sns.stripplot(data=df_diff,
    x="n_verts",
    y="p_value diff",
    jitter = 0.5,
    alpha = 0.5,
    size=5,
    palette="deep",
)
plt.title("LPT class vs function")
ax.axhline(0, color="grey", linestyle="--", alpha=1)