In [1]:
%load_ext nb_black

<IPython.core.display.Javascript object>

In [1]:
from functools import partial
from itertools import product

import numpy as np
import pandas as pd
import seaborn as sns
from joblib import Parallel, delayed

from hyppo.ksample import Hotelling, KSample

from src import generate_binary_sbms, estimate_embeddings



In [2]:
def run_experiment(
    m, block_1, block_2, p, delta, n_components, reps, tests, alpha=0.05
):
    total_n = block_1 + block_2

    omni_corrects = np.zeros((reps, 2, len(tests)))
    mase_corrects = np.zeros((reps, 2, len(tests)))

    for i in np.arange(reps).astype(int):
        pop1, pop2, true_labels = generate_binary_sbms(
            m=m, block_1=block_1, block_2=block_2, p=p, delta=delta
        )

        for method in ["omni", "mase"]:
            embeddings = estimate_embeddings(
                pop1, pop2, method, n_components, sample_space=True
            )
            for idx, j in enumerate([0, 19]):
                for k, test in enumerate(tests):
                    X_nodes = embeddings[:m, j, :]
                    Y_nodes = embeddings[m:, j, :]
                    try:
                        res = test.test(
                            embeddings[:m, j, :], embeddings[m:, j, :], reps=500
                        )
                        pval = res[1]
                        if np.isnan(res[1]):
                            pval = 1
                    except:
                        pval = 1

                    if method == "mase":
                        mase_corrects[i, idx, k] = pval
                    else:
                        omni_corrects[i, idx, k] = pval

    omni_powers = (omni_corrects <= (alpha / total_n)).mean(axis=0)
    mase_powers = (mase_corrects <= (alpha / total_n)).mean(axis=0)

    to_append = [m, p, delta, *omni_powers.reshape(-1), *mase_powers.reshape(-1)]
    return to_append

In [4]:
spacing = 50

block_1 = 5  # different probability
block_2 = 15
p = 0.5
deltas = np.linspace(0, 1 - p, spacing + 1)
n_components = 2
reps = 50
ms = np.linspace(0, 250, spacing + 1)[1:].astype(int)
tests = [KSample("MGC"), Hotelling()]

partial_func = partial(
    run_experiment,
    block_1=block_1,
    block_2=block_2,
    p=p,
    reps=reps,
    n_components=n_components,
    tests=tests
)

In [5]:
task = 0

In [None]:
args = [dict(m=m, delta=delta) for m, delta in product(ms, deltas)]
args = args[task::4]
args = sum(zip(reversed(args), args), ())[: len(args)]

res = Parallel(n_jobs=-1, verbose=5)(delayed(partial_func)(**arg) for arg in args)

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 128 concurrent workers.


In [13]:
cols = [
    "m",
    "p",
    "delta",
    *[f"omni_power_node={i+1}" for i in [0, 19]],
    *[f"mase_power_node={i+1}" for i in [0, 19]],
]
res_df = pd.DataFrame(res, columns=cols)
res_df = res_df.sort_values(by=["m", "delta"])
res_df.to_csv(f"./results/2020401_weighted_correct_nodes_{task}.csv", index=False)

<IPython.core.display.Javascript object>

In [14]:
res_df

Unnamed: 0,m,p,delta,omni_power_node=1,omni_power_node=2,omni_power_node=3,omni_power_node=4,omni_power_node=5,omni_power_node=6,omni_power_node=7,...,mase_power_node=11,mase_power_node=12,mase_power_node=13,mase_power_node=14,mase_power_node=15,mase_power_node=16,mase_power_node=17,mase_power_node=18,mase_power_node=19,mase_power_node=20
1,10,0.5,0.00,0.04,0.16,0.00,0.04,0.08,0.20,0.00,...,0.08,0.08,0.08,0.08,0.04,0.08,0.04,0.04,0.08,0.08
3,10,0.5,0.01,0.00,0.08,0.04,0.04,0.04,0.00,0.04,...,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.00,0.04
5,10,0.5,0.02,0.00,0.08,0.00,0.00,0.00,0.04,0.08,...,0.04,0.12,0.04,0.12,0.04,0.08,0.08,0.08,0.04,0.04
7,10,0.5,0.03,0.08,0.00,0.04,0.00,0.12,0.04,0.08,...,0.08,0.00,0.12,0.08,0.04,0.12,0.12,0.08,0.08,0.12
9,10,0.5,0.04,0.08,0.04,0.08,0.08,0.08,0.04,0.04,...,0.00,0.00,0.00,0.00,0.00,0.04,0.00,0.00,0.04,0.04
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8,500,0.5,0.46,1.00,1.00,1.00,1.00,1.00,0.08,0.04,...,1.00,1.00,1.00,1.00,1.00,1.00,1.00,1.00,1.00,1.00
6,500,0.5,0.47,1.00,1.00,1.00,1.00,1.00,0.08,0.04,...,1.00,1.00,1.00,1.00,1.00,1.00,1.00,1.00,1.00,1.00
4,500,0.5,0.48,1.00,1.00,1.00,1.00,1.00,0.00,0.12,...,1.00,1.00,1.00,1.00,1.00,1.00,1.00,1.00,1.00,1.00
2,500,0.5,0.49,1.00,1.00,1.00,1.00,1.00,0.04,0.04,...,1.00,1.00,1.00,1.00,1.00,1.00,1.00,1.00,1.00,1.00


<IPython.core.display.Javascript object>