In [71]:
from itertools import product
import graspy as gp
import numpy as np
from hyppo.ksample import KSample
from scipy.stats import foldnorm, truncnorm

import matplotlib.pyplot as plt
import seaborn as sns

from src import estimate_embeddings
%matplotlib inline

In [4]:
def generate_2_block_dirichlet(alpha_1, alpha_2, block_size):
    X = np.random.dirichlet(alpha_1, block_size)
    Y = np.random.dirichlet(alpha_2, block_size)
    out = np.vstack([X, Y])
    
    return out

def generate_graphs(latent_positions, num_graphs):
    pmat = gp.simulations.p_from_latent(latent_positions, loops=False)
    graphs = np.array([gp.simulations.sample_edges(pmat, loops=False) for _ in range(num_graphs)])
    return graphs

def generate_dirichlet_graphs(alpha_1, alpha_2, block_size, num_graphs, resample_latent=True):
    if resample_latent:
        X = []
        Y = []
        
        for _ in range(num_graphs):
            latent_positions_1 = generate_2_block_dirichlet(alpha_1, alpha_1, block_size)
            latent_positions_2 = generate_2_block_dirichlet(alpha_1, alpha_2, block_size)

            X.append(generate_graphs(latent_positions_1, 1))
            Y.append(generate_graphs(latent_positions_2, 1))
            
        X = np.vstack(X)
        Y = np.vstack(Y) 
    else:
        latent_positions_1 = generate_2_block_dirichlet(alpha_1, alpha_1, block_size)
        latent_positions_2 = generate_2_block_dirichlet(alpha_1, alpha_2, block_size)
        
        X = generate_graphs(latent_positions_1, num_graphs)
        Y = generate_graphs(latent_positions_2, num_graphs)

    return X, Y

In [69]:
def experiment(m, n, effect_size, reps=25):
    block_size = n // 2
    m_per_pop = m // 2
    
    alpha_1 = [1, 1]
    alpha_2 = [1, 1 + effect_size]
    
    pvals = np.zeros((reps, 2, n))

    for i in range(reps):
        X, Y = generate_dirichlet_graphs(alpha_1, alpha_2, block_size, m_per_pop)

        for j, method in enumerate(['omni', 'mase']):
            embeddings = estimate_embeddings(X, Y, method, 2, sample_space=True)
            Xhat = embeddings[:m_per_pop]
            Yhat = embeddings[m_per_pop:]
            print(Xhat.shape, Yhat.shape)
            for node in range(n):
                test = KSample("Dcorr").test(Xhat[:, node, :], Yhat[:, node, :], auto=True)
                pvals[i, j, node] = test[1]
                
    pvals = pvals.mean(axis=0)
    avg_pval_1 = pvals[:, :block_size].mean(axis=1)
    avg_pval_2 = pvals[:, block_size:].mean(axis=1)

    to_append = [m, n, effect_size, *avg_pval_1, *avg_pval_2]
    return to_append

In [84]:
ms = np.linspace(0, 200, 11)[1:]
ns = np.linspace(0, 200, 11)[1:]
effect_sizes = [0, 0.2, 10]

args = [dict(m=m, n=n, effect_size=effect_size) for m, n, effect_size in product(ms, ns, effect_sizes)]
args = sum(zip(reversed(args), args), ())[: len(args)]

In [None]:
res = Parallel(-1, 4)(delayed(experiment)(**arg) for arg in args)