In [1]:
from functools import partial
from itertools import product

import numpy as np
import pandas as pd
from graspy.embed import OmnibusEmbed, selectSVD
from joblib import Parallel, delayed

from hyppo.ksample import KSample

from src import generate_binary_sbms



In [2]:
def estimate_omnimase(X, Y, n_components, sample_space=True):
    graphs = np.vstack([X, Y])
    n = X.shape[1]
    
    omni = OmnibusEmbed(n_components)
    omni.fit(graphs)
    
    Xhat = np.swapaxes(omni.latent_left_, 0, 1).reshape(n, -1)
    latent_left, _, _ = selectSVD(Xhat, n_components)
    
    
    if sample_space:
        scores = latent_left.T @ graphs @ latent_left
        U, D, V = np.linalg.svd(scores)
        root_scores = U @ np.stack([np.diag(np.sqrt(diag)) for diag in D]) @ V
        embeddings = latent_left @ root_scores
        
        return embeddings

    return latent_left

In [3]:
def run_experiment(m, block_1, block_2, p, delta, n_components, reps):
    corrects = np.zeros((reps, block_1 + block_2))

    for i in np.arange(reps).astype(int):
        pop1, pop2, true_labels = generate_binary_sbms(
            m=m, block_1=block_1, block_2=block_2, p=p, delta=delta
        )

        embeddings = estimate_omnimase(
            pop1, pop2, n_components, sample_space=True
        )
        for j, vert in enumerate(range(0, block_1 + block_2, block_1 + block_2 - 1)):
            test_stat, pval = KSample("Dcorr").test(
                embeddings[:m, vert, :], embeddings[m:, vert, :], reps=200
            )
            
            corrects[i, j] = pval
            
    corrects = (corrects <= 0.05).mean(axis=0)

    to_append = [m, p, delta, *corrects]
    return to_append

In [4]:
spacing = 50

block_1 = 5  # different probability
block_2 = 15
p = 0.5
deltas = np.linspace(0, 1 - p, spacing + 1)
n_components = 2
reps = 25
ms = np.linspace(0, 500, spacing + 1)[1:].astype(int)

partial_func = partial(
    run_experiment,
    block_1=block_1,
    block_2=block_2,
    p=p,
    reps=reps,
    n_components=n_components,
)

args = [dict(m=m, delta=delta) for m, delta in product(ms, deltas)]
args = sum(zip(reversed(args), args), ())[: len(args)]

In [5]:
res = Parallel(n_jobs=-1, verbose=5)(delayed(partial_func)(**arg) for arg in args)

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 128 concurrent workers.
[Parallel(n_jobs=-1)]: Done  32 tasks      | elapsed:   28.9s
[Parallel(n_jobs=-1)]: Done 194 tasks      | elapsed: 80.4min
[Parallel(n_jobs=-1)]: Done 392 tasks      | elapsed: 142.9min
[Parallel(n_jobs=-1)]: Done 626 tasks      | elapsed: 160.6min
[Parallel(n_jobs=-1)]: Done 896 tasks      | elapsed: 249.9min
[Parallel(n_jobs=-1)]: Done 1202 tasks      | elapsed: 314.2min
[Parallel(n_jobs=-1)]: Done 1544 tasks      | elapsed: 368.1min
[Parallel(n_jobs=-1)]: Done 1922 tasks      | elapsed: 433.9min
[Parallel(n_jobs=-1)]: Done 2550 out of 2550 | elapsed: 519.1min finished


In [14]:
res_arr = np.array(res)[:, :5]

In [16]:
df = pd.DataFrame(res_arr, columns=['m', 'p', 'delta', 'correct', 'incorrect'])

In [17]:
df.head()

Unnamed: 0,m,p,delta,correct,incorrect
0,500.0,0.5,0.5,1.0,1.0
1,10.0,0.5,0.0,0.04,0.04
2,500.0,0.5,0.49,1.0,1.0
3,10.0,0.5,0.01,0.16,0.12
4,500.0,0.5,0.48,1.0,1.0


In [18]:
df.to_csv("./results/20200313_omnimase.csv", index=False)