In [1]:
%load_ext nb_black

<IPython.core.display.Javascript object>

In [2]:
from functools import partial
from itertools import product

import numpy as np
import pandas as pd
import seaborn as sns
from joblib import Parallel, delayed
from hyppo.ksample import Hotelling, KSample

from src import generate_binary_sbms



<IPython.core.display.Javascript object>

In [6]:
def run_experiment(
    m, block_1, block_2, p, delta, reps, tests, alpha=0.05,
):
    total_n = block_1 + block_2

    # only do it for relevant nodes
    pvals = np.zeros((reps, 2, len(tests)))

    for i in range(reps):
        X, Y, labels = generate_binary_sbms(m, block_1, block_2, p, delta)
        for idx, j in enumerate([0, 19]):
            for k, test in enumerate(tests):
                X_nodes = np.delete(X[:, j, :], j, axis=1)
                Y_nodes = np.delete(Y[:, j, :], j, axis=1)
                try:
                    res = test.test(X_nodes, Y_nodes, reps=500)
                    pval = res[1]
                    if np.isnan(res[1]):
                        pval = 1
                    pvals[i, idx, k] = res[1]
                except:
                    pvals[i, idx, k] = 1

    powers = np.nanmean(pvals <= (alpha / total_n), axis=0)
    to_append = [m, p, delta, *powers.reshape(-1)]

    return to_append

<IPython.core.display.Javascript object>

In [5]:
# Experiment Parameters
# Constants
block_1 = 5
block_2 = 15
p = 0.5
reps = 50
tests = [KSample("MGC"), Hotelling()]

# Varying
spacing = 50
deltas = np.linspace(0, 1 - p, spacing + 1)
ms = np.linspace(0, 500, spacing + 1)[1:]

args = [dict(m=m, delta=delta) for m, delta in product(ms, deltas)]
args = args[0::2]
args = sum(zip(reversed(args), args), ())[: len(args)]

partial_func = partial(
    run_experiment, block_1=block_1, block_2=block_2, p=p, reps=reps, tests=tests,
)

res = Parallel(n_jobs=-2, verbose=7)(delayed(partial_func)(**arg) for arg in args)

[Parallel(n_jobs=-2)]: Using backend LokyBackend with 63 concurrent workers.
[Parallel(n_jobs=-2)]: Done   2 tasks      | elapsed:  4.0min
[Parallel(n_jobs=-2)]: Done  74 tasks      | elapsed: 421.6min
[Parallel(n_jobs=-2)]: Done 162 tasks      | elapsed: 452.8min
[Parallel(n_jobs=-2)]: Done 266 tasks      | elapsed: 874.8min
[Parallel(n_jobs=-2)]: Done 386 tasks      | elapsed: 1287.2min
[Parallel(n_jobs=-2)]: Done 522 tasks      | elapsed: 1681.8min
[Parallel(n_jobs=-2)]: Done 674 tasks      | elapsed: 2103.2min
[Parallel(n_jobs=-2)]: Done 842 tasks      | elapsed: 2709.0min
[Parallel(n_jobs=-2)]: Done 1026 tasks      | elapsed: 3132.4min
[Parallel(n_jobs=-2)]: Done 1226 tasks      | elapsed: 3669.1min
[Parallel(n_jobs=-2)]: Done 1442 tasks      | elapsed: 4159.7min
[Parallel(n_jobs=-2)]: Done 1674 tasks      | elapsed: 4699.3min
[Parallel(n_jobs=-2)]: Done 1922 tasks      | elapsed: 5245.2min
[Parallel(n_jobs=-2)]: Done 2186 tasks      | elapsed: 5815.1min
[Parallel(n_jobs=-2)]: Don

In [12]:
new_res = []

for r in res:
    constants = r[:3]
    results = [b for a in r[5:] for b in a]
    new_res.append(constants + results)

In [17]:
len(new_res[0])

21

In [None]:
cols = ['m', 'p', 'delta', ]