# Run Conditional DCorr

Conditional DCorr is a version of conditional independence testing that has also shown properties of "uniform" consistency for hypothesis testing.

We want to compute results here to compare against CoMIGHT

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import math
from collections import defaultdict
from itertools import product

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from hyppo.conditional import ConditionalDcorr
from joblib import Parallel, delayed
from sklearn.model_selection import StratifiedShuffleSplit

from sktree.stats.utils import (
    METRIC_FUNCTIONS,
    POSITIVE_METRICS,
    POSTERIOR_FUNCTIONS,
    REGRESSOR_METRICS,
    _compute_null_distribution_coleman,
    _non_nan_samples,
)

seed = 12345
rng = np.random.default_rng(seed)

# Parallelizable function for any simulation

In [3]:
def _run_parallel_posterior_sim(
    idx,
    n_samples,
    n_features,
    class_probs,
    seed,
    n_features_2,
    test_size,
    max_fpr,
    sim_type,
):
    n_jobs = 1
    results = defaultdict(list)
    rng = np.random.default_rng(seed)
    n_features_ends = [100, None]

    if sim_type == "confounder":
        npy_data = np.load(
            f"/Users/adam2392/Desktop/cancer/confounder/confounder_{idx}.npz"
        )
    elif sim_type == "collider":
        npy_data = np.load(
            f"/Users/adam2392/Desktop/cancer/collider/collider_{idx}.npz"
        )
    elif sim_type == "mediator":
        npy_data = np.load(
            f"/Users/adam2392/Desktop/cancer/mediator/mediator_{idx}.npz"
        )
    elif sim_type == "direct-indirect":
        npy_data = np.load(
            f"/Users/adam2392/Desktop/cancer/direct-indirect/direct-indirect_{idx}.npz"
        )
    elif sim_type == "independent":
        npy_data = np.load(
            f"/Users/adam2392/Desktop/cancer/independent/independent_{idx}.npz"
        )

    X = npy_data["X"]
    y = npy_data["y"]
    # print(X.shape, y.shape)

    X = X[:, : 100 + n_features_2]
    if n_samples < X.shape[0]:
        cv = StratifiedShuffleSplit(n_splits=1, train_size=n_samples)
        for train_idx, _ in cv.split(X, y):
            continue
        X = X[train_idx, :]
        y = y[train_idx, ...].squeeze()
    assert len(X) == len(y)
    assert len(y) == n_samples
    n_features_ends[1] = X.shape[1]

    posteriors_dict = dict()

    # now compute the pvalue when shuffling X2
    covariate_index = np.arange(n_features_ends[0], n_features_ends[1])

    # estimate (conditional) mutual information using KSG
    Z = X[:, covariate_index]
    mask_array = np.ones(X.shape[1])
    mask_array[covariate_index] = 0
    mask_array = mask_array.astype(bool)
    X_minus_Z = X[:, mask_array]

    cdcorr = ConditionalDcorr(bandwidth="scott")

    print(Z, np.var(Z))
    # print(X_minus_Z, np.var(X_minus_Z))
    print(Z.shape, X_minus_Z.shape, y.shape)
    try:
        cdcorr_stat, cdcorr_pvalue = cdcorr.test(
            x=X_minus_Z.copy().astype(np.float64),
            y=y.copy().astype(np.float64),
            z=Z.copy().astype(np.float64),
        )
    except Exception as e:
        errmsg = f"{idx, n_samples, n_features, n_features_2, np.var(Z), X_minus_Z.shape, y.shape, Z.shape}"
        e.args += (errmsg,)
        raise (e)

    np.savez(
        f"./varying-samples/{sim_type}/conddcorr_{n_samples}_{n_features_2}_{idx}.npz",
        n_samples=n_samples,
        n_features_2=n_features_2,
        y_true=y,
        cdcorr_state=cdcorr_stat,
        cdcorr_pvalue=cdcorr_pvalue,
    )
    # results["cdcorr_pvalue_x2"].append(cdcorr_pvalue)
    # results["cdcorr_stat_x2"].append(cdcorr_stat)

    # results["mvrf_posteriors"].append(comight_posteriors_x2)
    # results["mvrf_null_posteriors"].append(comight_null_posteriors_x2)
    return results

# PValue Computations Varying Dimensionality of X2

In [4]:
# NOTE: increase this when running for full figure to 100
n_repeats = 100

In [5]:
# number of features in the first view
n_features = 10
noise_dims = 90

n_samples = 256
max_features = 0.3
n_jobs = -1
test_size = 0.2

max_fpr = 0.1

# number of features in the second view
pows = np.arange(2, 13, dtype=int)
n_features_2_list = [0] + [2**pow for pow in pows]

n_features_2_list = [2**pow for pow in pows]
print(n_features_2_list)

class_probs = [0.5, 0.5]

[4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]


In [6]:
print(n_features_2_list)

[4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]


## Independent Results

In [7]:
_results_ind = Parallel(n_jobs=-1)(
    delayed(_run_parallel_posterior_sim)(
        idx_,
        n_samples,
        n_features,
        class_probs,
        seed,
        n_features_2_,
        test_size,
        max_fpr,
        "independent",
    )
    for (idx_, n_features_2_) in product(range(n_repeats), n_features_2_list)
)

# {idx, n_samples, n_features, n_features_2}

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(

[[ 2.20954603  0.12278706 -1.64224312 ...  0.4963114  -1.05766931
   1.13484185]
 [-0.02270628 -0.22531537  0.32086887 ...  0.48166069  1.60464735
  -0.76164337]
 [-0.6041721  -0.9451815  -1.0658785  ... -0.41822845 -0.16688235
   0.88128755]
 ...
 [-0.27811215  1.42828143 -0.19839126 ...  0.1169345   1.09908561
   0.23387404]
 [ 0.28990372 -0.48935689  0.55366137 ... -1.89206814  1.54556472
  -0.23755623]
 [-0.12557662  0.52325242 -1.0306507  ... -0.1511866  -0.09191958
  -1.02774067]] 0.9911961061658242
(256, 128) (256, 100) (256,)
[[ 0.974303   -0.47718166  0.21436605 ...  0.6261333   2.13917784
  -0.08569217]
 [ 0.78196064 -0.04991388 -0.06910147 ... -0.99782345 -0.52851396
  -0.08443322]
 [-0.27085029  1.40577703  1.4399205  ...  0.17637968  2.77876569
   0.95798701]
 ...
 [ 0.63364257 -1.45473417  0.50738316 ... -0.21697808 -0.89674537
  -1.39385544]
 [ 0.15138162  0.83088596  1.70119977 ... -0.53403395  0.15789014
  -2.17717262]
 [ 0.81729085  1.85496692 -0.75116705 ...  0.58246

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)


[[-0.52657324 -0.60340806  1.15873376 ... -1.23968685  0.03564921
  -0.0702338 ]
 [ 1.28978219  1.01647865 -0.49440615 ... -0.22367843  1.16498566
  -1.36084559]
 [ 0.40050159 -1.06783523 -0.69241061 ...  0.05619994 -1.22922165
   0.33236605]
 ...
 [-0.40570296 -1.40627019 -0.36260142 ... -1.70264126  0.68237504
   0.33724757]
 [-1.51455633  0.24754675  0.70503185 ... -0.75411635  0.22074858
   1.01038625]
 [ 1.30678646 -0.41658768 -1.38859262 ... -1.84110569  1.50375787
   1.08141084]] 0.9958593932309547
(256, 256) (256, 100) (256,)
[[ 0.47924264  0.82558499 -0.44337661 ...  1.92240857  0.61349553
   1.87373436]
 [-0.58713263  0.5688483   1.21519488 ...  0.20106203 -0.61591228
  -0.49989105]
 [-1.46017309 -1.2383824  -0.11400582 ...  1.37978389 -0.23280354
   0.18083541]
 ...
 [ 0.60781986  1.29105276 -1.53318952 ... -0.85261942  0.33231265
   0.18351325]
 [-0.84282511 -1.18383023  0.74461981 ... -0.05004345  1.04261211
   0.10552481]
 [ 0.64565908 -0.13972369 -0.6941621  ...  0.95947

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)


[[ 0.03202824  0.49954425 -0.41527446 -0.52009728]
 [-0.3253585   0.95940512 -0.22586575  1.69991204]
 [ 1.12381774 -0.86115089  0.41200143  0.43105298]
 ...
 [-0.16140014 -0.71059317 -0.699566    1.06338401]
 [ 1.21619321 -1.52359595  0.12514081  1.70941686]
 [-1.40494011 -1.56513408  1.70339027 -2.11104161]] 1.0351366531486683
(256, 4) (256, 100) (256,)
[[ 0.47384709  2.18659481 -1.92035681 ...  0.19398447  0.72016551
  -0.19607934]
 [-0.15385189  0.30421959 -1.04775129 ... -1.20983182  0.38658808
   0.52344959]
 [ 0.00901866 -0.48340307  0.59012501 ...  0.8590457   0.76233401
  -0.21579118]
 ...
 [-0.8783908   0.36904306  2.3959627  ... -0.36667321 -1.28568002
  -1.09899574]
 [-1.56181751 -0.65731141 -0.15840005 ...  0.52919125 -0.12775474
   0.58308711]
 [ 0.853182   -0.13509859  0.78497227 ... -1.32637399 -1.47569952
  -0.06600445]] 1.0076989098083424
(256, 32) (256, 100) (256,)
[[ 0.98807588 -0.30928775 -0.55286761 ... -1.51179992  0.35231145
  -1.47137597]
 [-0.63806814 -0.58461

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)


[[ 0.57851068 -0.42099544  0.49534829 ...  0.27515588  0.11104906
   0.34621397]
 [-0.42061442 -0.21747371 -0.183551   ...  0.60000406 -1.71965103
  -0.43258557]
 [-1.75970605  1.45839022  0.24105405 ...  1.98456408 -0.42723049
   0.71691344]
 ...
 [-0.56414439 -0.07275106 -1.87393418 ...  0.35934901  0.16607828
   0.21797258]
 [ 0.37974436 -1.31846264 -1.16070044 ...  0.49545104  0.24158005
  -0.11701042]
 [-1.42382504  1.26372846 -0.87066174 ...  0.90219827 -0.46695317
  -0.06068952]] 1.013107101600468
(256, 16) (256, 100) (256,)
[[-1.39983224 -0.33000438 -0.50648614 ...  0.86403208 -0.24474865
   0.82221394]
 [-1.29323468 -0.52698517  0.96718806 ...  1.06903764 -0.31327008
  -1.73775152]
 [ 2.75475111 -1.36948414 -1.41767328 ...  0.04210493 -0.97493629
  -0.76580618]
 ...
 [ 0.72317966  0.55570912 -0.41157675 ...  0.62753949 -0.14055474
   0.45508487]
 [ 0.86606888  0.09335925 -0.35685445 ... -0.16900724  0.20274247
   0.4337858 ]
 [-0.04788449  0.04845545 -1.18600485 ...  1.0773098

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)


[[ 2.23428609 -0.91018407  0.84682019 ...  0.02155053  1.57103592
   0.04091409]
 [-0.78454631 -1.70921836  0.25050423 ... -0.18365926  0.07445918
   1.17094809]
 [ 1.28978219  1.01647865 -0.49440615 ... -0.55581053 -0.48114026
  -0.09646797]
 ...
 [ 0.71288568 -2.11019064  0.45814698 ... -0.04910715 -0.14466469
   0.42274075]
 [ 1.49591287 -0.18154042 -1.09710796 ...  0.35989265 -1.9839797
  -1.51936303]
 [-0.02116609  1.73292901 -0.60968967 ... -1.21705086  1.72345889
  -0.75294199]] 0.9957048166388867
(256, 32) (256, 100) (256,)
[[ 0.99668841 -0.10034804 -0.06551609 -0.14367686]
 [ 1.02897562 -0.83057076 -0.96073013  1.48817083]
 [ 0.87494752  1.10374609  0.20043972  1.51349296]
 ...
 [-0.68210334 -1.1225833   0.2736558   0.45389271]
 [ 0.32733157  1.05487269  1.33754702 -0.09139602]
 [-1.11739611  1.38954602  1.1251978   0.72937394]] 0.9626155408289825
(256, 4) (256, 100) (256,)
[[ 0.62641659  0.41450868 -0.08946015 ...  0.78430355 -0.88391831
  -0.43354934]
 [-2.05999906 -0.246237

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(

[[ 0.24558676 -1.7465076  -0.60391022 ...  0.92733902  3.28937063
   1.10205968]
 [ 2.09388432  0.99084626 -0.84821706 ... -2.43782435  2.25099397
  -0.21040118]
 [-1.13238219  0.34038639 -1.53898441 ...  0.38585389 -1.17059973
  -0.11633221]
 ...
 [-0.69673035  1.19354291 -0.28156414 ... -1.15787238 -0.40552997
   2.12859349]
 [ 0.31354181  0.7050524  -0.55930473 ...  0.83107349 -0.88821385
  -1.30420352]
 [ 0.73148403 -0.03075807 -0.15759756 ... -0.29157372 -1.28426901
  -0.53472153]] 0.9994845786256527
(256, 512) (256, 100) (256,)
[[-0.12442506  1.21545346  2.44626703 ... -1.10277481 -0.48097182
   0.72087654]
 [ 0.18146442  0.10379128  1.09457003 ...  1.20947912 -0.74354233
  -0.01299742]
 [-0.05714025 -0.20102763 -0.60339863 ... -1.68126843  0.13012827
  -1.57917088]
 ...
 [ 0.94687678  0.22856079  0.80966336 ...  1.21132065 -0.31288694
  -1.88281833]
 [ 2.01317751 -2.25787637  1.73054237 ...  0.31088898  0.48285485
  -0.26297062]
 [-0.39721575  1.76340397 -1.79895635 ... -0.49336

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)


[[-0.07494483 -0.33809001 -0.86880219 ... -0.45657594  1.72293875
  -1.96205461]
 [-2.16577352 -1.14078859 -0.04204548 ... -0.43228952 -0.8507396
   0.36328865]
 [-0.28167526 -1.23330815 -1.01349889 ... -0.48740697 -0.73683938
   1.02925449]
 ...
 [-0.71016226 -0.37836713  1.91882648 ... -0.87358835  1.34942386
  -0.17440306]
 [ 0.18204152 -0.97279068 -0.22251227 ... -0.02949146  1.1018466
   1.15246869]
 [-2.99838452  1.10073026 -0.42876926 ...  0.03473719 -0.57608131
   1.73554487]] 0.9873054615968113
(256, 8) (256, 100) (256,)
[[ 1.63936922 -1.19989647  1.0587036  ... -0.18498242  1.53810952
  -1.64080481]
 [-0.44685892  0.46994551  1.67869352 ...  2.96574128 -0.05954982
  -1.4181018 ]
 [-0.1411137  -0.23349026 -0.67798372 ...  0.50457844 -0.61659106
   0.30776325]
 ...
 [-0.36701732  0.78019548 -1.15599669 ...  0.38707751 -1.09117095
   0.6341431 ]
 [-0.82579148 -0.36856617  0.0535161  ... -0.837811   -1.87211993
  -0.95259269]
 [-1.09886424  1.21012488 -1.23355153 ... -0.54797622 

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)


1.03112832589462
(256, 32) (256, 100) (256,)
[[-2.76133475e-01  2.52698891e+00 -2.91663552e-02 ...  6.04959650e-01
   2.15703155e+00  2.13367636e-01]
 [-7.59435722e-01  1.73291782e+00 -2.29077027e-03 ... -1.09326419e-01
  -9.66994576e-02  2.28479368e-01]
 [-5.56285512e-01 -4.19391307e-02  3.97996602e-01 ...  1.55517426e-01
   1.40266382e+00 -1.45301363e+00]
 ...
 [-3.63728984e-01 -1.56227493e+00  4.40716973e-01 ...  1.86625266e+00
   2.13708973e-01 -2.18995821e-01]
 [ 3.28129892e-01  7.90503261e-01 -3.48553353e-01 ... -6.36178979e-01
  -2.73806570e-01  1.05542069e-01]
 [-3.70029706e-01 -6.96327555e-01 -1.06380981e+00 ... -1.10613566e-01
   1.87148055e+00  8.76671707e-01]] 0.9976641142288315
(256, 4096) (256, 100) (256,)
[[-1.30962143 -0.3299232  -0.6410983  ... -0.85166963 -0.83392631
  -0.52848186]
 [ 0.56074345 -0.90152734  0.03101604 ... -0.29700335 -0.81043901
  -1.02396483]
 [-0.00464503  0.56818195 -1.54620437 ... -0.28654226 -0.99697736
   1.29161532]
 ...
 [ 1.09619046  0.77795

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)


[[ 0.67861508 -1.44271569 -1.38513482 ... -0.12403111  0.36580899
   1.27006058]
 [ 0.84387565 -0.34292127 -1.68729046 ... -1.18669877 -1.0270327
  -0.01871357]
 [-0.45704027  1.72077328 -0.65197972 ...  0.25783478 -0.93780098
  -1.04235503]
 ...
 [-0.01109558  1.0246624   0.24882757 ...  1.01745337  1.00824681
  -0.40765016]
 [ 1.44622008 -0.6088446   0.09276402 ...  0.62996063  1.17932906
   0.25265707]
 [ 0.76264936 -0.76084462 -0.76318258 ...  0.11688334 -1.0996708
   0.04312617]] 0.9840154239593384
(256, 16) (256, 100) (256,)
[[-0.64965763  0.33927687 -0.94576125 ... -1.76065996 -1.15337422
  -2.24533339]
 [ 0.23359008  1.28561319  1.05891247 ...  0.35416855  1.38610651
  -0.43359605]
 [ 0.18146442  0.10379128  1.09457003 ...  1.26786397 -0.33022231
  -1.82744571]
 ...
 [-0.43059921 -1.34726557 -0.18177917 ... -0.41355585  2.46292999
  -0.99116712]
 [ 0.84679707 -0.072563   -0.97113237 ... -0.57767378 -0.13352165
   0.29084914]
 [-2.27056409 -0.38908663  0.32472262 ...  0.6940421 

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(

[[-1.8190358  -0.16942588  0.80771286 ... -0.22875598 -0.9601202
  -0.69907179]
 [-1.01388947  0.86585726  1.7647305  ... -1.57141583 -1.08137147
   0.63055094]
 [ 0.25321039 -0.73719154  1.64265089 ...  0.16545064  0.35221796
   0.0880718 ]
 ...
 [ 1.62288144 -0.43073376 -0.1914233  ... -0.67774119  0.02243161
  -0.81103955]
 [ 0.99415297  0.1363985   1.2652584  ... -0.45040922  1.75662163
  -0.46540856]
 [-0.95539372  0.66099836  0.83377311 ...  1.02087608 -1.63336698
  -0.48799483]] 1.0191735713982069
(256, 32) (256, 100) (256,)
[[ 0.4079352  -3.0108422  -0.63613032  0.54834372]
 [ 0.27913268  0.96492297 -1.86865442  0.08396707]
 [-0.71066202 -0.57548234 -0.5961355  -1.27680238]
 ...
 [ 1.60732163 -1.55640213 -2.29767513 -0.66118041]
 [ 0.14145026  1.30828199  0.7349177  -1.01158315]
 [ 1.62930105 -0.58009631  0.0849532  -0.54553416]] 1.0350313511118996
(256, 4) (256, 100) (256,)
[[ 0.27331931  1.09759463 -0.20871191 ... -1.15732211  1.57043058
   0.51630116]
 [-1.67621781 -0.668044

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)


[[-0.58205918  0.12185123 -0.33759977 ... -1.48665911  0.49927268
   0.42045812]
 [-1.30417177  1.12155914  2.50174375 ... -0.70751371  1.04571521
  -0.1366694 ]
 [-0.46818159  0.31495908  0.85878947 ... -0.16070959 -0.75552775
   1.36202679]
 ...
 [-0.48781171 -0.21379163  2.13554822 ... -0.52125269  0.22248283
  -0.202819  ]
 [-0.47330136  0.86595058  0.05622119 ...  1.55245623 -0.37472548
   1.11464971]
 [-0.56318468  2.00644808  1.798839   ...  0.64047671  0.57829316
  -2.02492138]] 1.0118473373138566
(256, 64) (256, 100) (256,)
[[-0.0563896   1.02929723  0.50662081 ...  1.07930864 -0.84529692
   0.0710934 ]
 [-0.80322546 -0.02546051 -0.74936353 ... -0.73319675  0.79029401
   0.64070356]
 [-0.15668735  1.23213473  0.85875138 ... -0.67096111 -0.27546214
  -1.17283489]
 ...
 [-1.1159663  -0.28510709 -0.69720285 ... -0.57738102 -0.42194442
   0.49020968]
 [ 1.31380975  0.95827106 -0.29898074 ... -0.16596157 -0.10513608
   0.86324596]
 [-0.49982498  0.30138849 -1.2564726  ...  0.617742

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)


[[ 0.87883389 -0.32202884 -0.23637812 ...  1.18766655 -2.1854413
  -0.05405194]
 [ 1.16770867  0.94142513 -0.86714477 ... -0.46804473  0.96943349
   0.80231635]
 [ 0.19752036 -0.48232164 -0.58397538 ... -0.92619159  1.2980759
   1.04496678]
 ...
 [-0.40919012 -0.42829989  0.75265189 ... -1.01575425 -1.37593352
   1.31652675]
 [ 0.96155359 -1.18969901  1.40664131 ... -0.82013541  0.95146433
   0.39642425]
 [-1.77227291  0.11260199 -0.11988912 ...  1.2084463  -0.36066588
  -0.43923161]] 0.9857807742765019
(256, 32) (256, 100) (256,)
[[ 0.33093176  0.17053393  0.12112408 -1.24770593]
 [-0.36799416  1.05977207  1.31469632 -0.57755129]
 [ 0.35797602  0.1317688   1.64258293 -0.46999911]
 ...
 [ 0.87280434 -0.40706287 -0.73059501  0.41925062]
 [ 0.78319796 -1.74560334  0.17810453  0.4255604 ]
 [-0.5413807   2.20215923  0.88798282 -0.78441246]] 1.0535426010871751
(256, 4) (256, 100) (256,)
[[-0.94927428 -0.05560572 -0.50701319 ... -0.7007765   0.39268951
   1.12107749]
 [-0.60797163  0.8972824

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)


[[ 0.42520563 -1.09923389 -0.50813954 ...  0.12567785 -1.35376997
   1.19414103]
 [ 0.82414953 -0.1423708  -1.16032737 ...  0.44259823  2.12874733
  -0.17571477]
 [ 0.53895966  0.2173382   2.4106186  ... -0.12732219 -0.67140732
  -0.12601669]
 ...
 [-0.47946248  0.85089557 -0.14462426 ... -0.04563605 -0.38582957
   2.02457639]
 [ 0.72594131 -0.54169491  0.80603246 ...  1.06998408  1.04149898
   0.47648637]
 [-1.22324073  0.01812907 -1.15005207 ...  0.8387166  -0.35694968
  -0.83050626]] 1.0053369273997927
(256, 16) (256, 100) (256,)
[[ 0.59032837  0.84059939  0.14194434 ... -0.47012534 -0.29937797
   0.555616  ]
 [ 0.08854644 -0.40857374  0.41336605 ... -1.16288258 -1.54876198
  -1.99816158]
 [ 0.1492627  -1.48266204 -0.2136378  ... -0.26428817 -1.04482675
   0.02426919]
 ...
 [-0.15611991 -0.22494149  0.41462936 ...  0.31923336 -0.95855022
   2.11148448]
 [-1.3563004   2.4674929  -0.56820168 ... -0.69645304 -0.22752065
   0.33762567]
 [ 1.13169123 -0.66241824 -0.86572628 ...  0.669432

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(

[[-1.32341915  0.89530218 -0.71641604 ... -0.60560699  1.41545059
  -0.18875367]
 [-1.70307409  1.07289269 -0.21749624 ... -0.62914015 -0.53937106
   0.1582927 ]
 [-0.94632861 -0.12272664 -0.90901714 ... -2.04560527  1.51003058
  -1.47934772]
 ...
 [-0.74781772  0.42477909  0.44868205 ...  0.6815098  -1.10147372
   0.76269454]
 [ 0.55817281  0.56325571  0.38291749 ...  0.23694759 -0.59014963
   1.04417217]
 [ 0.17125479 -0.1349695   0.61021966 ... -1.26247805  1.48074828
   0.30661038]] 1.0194154698560067
(256, 32) (256, 100) (256,)
[[-0.21925574 -0.83918836 -0.57075079 ...  1.84480641  0.86588463
  -0.38910179]
 [ 0.28060352  1.96174599  0.452693   ... -0.83186769 -0.63459139
   0.41645657]
 [-1.21871709  1.89595903  2.42180752 ... -2.4451487  -1.37858394
   0.0230104 ]
 ...
 [ 0.99777322 -0.33099858  0.26326733 ...  0.50055199  0.53392906
   0.88759695]
 [ 0.34739327  1.36725395  0.24018456 ... -0.2180689  -1.27734542
  -1.9286867 ]
 [ 0.01021067  0.83180229  0.64306154 ... -0.518583

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)


[[ 0.26959054 -1.70250773  0.12055811 ... -0.02327351 -1.16925448
  -1.60010689]
 [-0.40985523 -1.98246413  0.09443079 ...  2.69179621  1.45927563
  -0.10542513]
 [ 0.02971244 -0.38693027 -0.31735919 ... -0.19828567  0.15667692
  -0.11136844]
 ...
 [-0.17709083  0.16903973 -0.46931153 ... -2.10690083 -0.06356941
  -1.26317904]
 [ 1.68875686  1.08338614 -0.50962533 ... -0.18309573 -0.249074
   0.24088051]
 [ 1.31544476 -1.20442785  0.60382206 ... -0.48264503  1.19656931
   0.30513807]] 1.0014721149712948
(256, 128) (256, 100) (256,)
[[-2.20253874 -0.56403776 -0.61446397 ... -0.2198107   1.58599883
  -0.72981591]
 [-0.27642956 -0.63284408  0.1872734  ... -1.6791488   0.36370086
   0.08760503]
 [-0.06443925 -0.64106567 -0.91582347 ... -0.81557311 -0.75702263
   0.55393525]
 ...
 [ 1.66599698 -0.14843858  0.04500515 ...  0.74217297 -0.05388704
   0.64784273]
 [ 1.34951425  0.02115402 -0.58377283 ...  0.80597328 -1.0812892
   1.29859541]
 [-0.20445753  0.16129563 -1.27557359 ...  0.17179883

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)


[[ 0.55357081  0.67206806  0.00940884 ... -0.6970075  -0.7911303
   1.7320029 ]
 [-0.23900507  0.61443475  1.36538018 ...  0.86398684 -0.42356679
  -1.29194764]
 [ 1.10732644 -0.22330036 -0.43552134 ... -1.13053425 -0.83380431
   0.2342036 ]
 ...
 [ 0.88681979 -0.23030201  0.91933864 ...  0.02771491  0.18109428
  -0.43109663]
 [-0.93779707  1.08946084 -1.28453487 ...  1.51000194 -0.20732916
  -0.12235468]
 [-1.864201    0.75772345  0.22131766 ... -1.61846072 -0.20573482
   0.76419935]] 0.9929254294359116
(256, 256) (256, 100) (256,)
[[ 1.46179824 -0.13230842 -0.14582547 ... -0.35418538  1.43165889
   0.91802078]
 [-0.09902176  0.30089443  0.96625314 ...  0.44176402 -0.846547
  -1.74872377]
 [ 0.16891763  0.51909523 -1.83318035 ...  0.78714    -0.52402386
  -0.03543736]
 ...
 [ 0.83002812 -0.72635547  2.37799088 ... -0.32369495 -1.45105493
   2.53910682]
 [-0.88410912  0.77652301 -1.0948836  ...  1.04432278  0.36443406
  -2.18184288]
 [ 1.7721003   0.88438483  1.2739733  ... -0.29346127

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)


## Confounder Results

In [9]:
print("done")

done


In [10]:
_results_confounder = Parallel(n_jobs=-1)(
    delayed(_run_parallel_posterior_sim)(
        idx_,
        n_samples,
        n_features,
        class_probs,
        seed,
        n_features_2_,
        test_size,
        max_fpr,
        "confounder",
    )
    for (idx_, n_features_2_) in product(range(n_repeats), n_features_2_list)
)
# {idx, n_samples, n_features, n_features_2}

[[ 0.69860219  0.37501992 -1.92950143 ... -0.905097    0.25471235
  -1.76312482]
 [-0.48317953  1.39613195  1.49716422 ... -1.7462176  -0.18314323
   1.095924  ]
 [ 0.84364769  1.32447714 -0.13134295 ... -0.6705304   0.3327977
   2.03239257]
 ...
 [-0.22812356  0.17156737 -1.25165195 ... -0.03267069 -0.02128367
   0.06373303]
 [ 0.52815833 -1.80215407  1.02348332 ... -0.94692411  0.01531964
   1.54460441]
 [-0.59239972  1.40809558 -0.72562595 ... -1.00971275 -0.62037055
   0.31137712]] 0.9688006458993984
(256, 64) (256, 100) (256,)
[[-1.4822823  -0.27201538 -0.22809511 ...  0.4197406  -0.68726661
  -1.35602443]
 [ 0.57614976 -0.41344456 -1.44199079 ...  0.20370536  0.7848533
   1.98807818]
 [-0.32234524 -0.76291492  0.5066942  ... -0.30179647  0.63739647
  -0.91757485]
 ...
 [ 0.64993265  0.66619659 -0.77186554 ...  0.22657242 -1.86320856
  -0.82094898]
 [ 0.18254134 -0.91232347  1.02365771 ... -1.73276923 -0.24501561
  -0.3433405 ]
 [-0.17387559 -1.88769903 -0.24186695 ... -0.0133312 

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(

[[ 1.24288881  2.14500768 -0.25738752 ... -0.67193456  2.46090409
   0.5341129 ]
 [ 0.55392777  0.00801619 -0.38727234 ...  0.87621257 -1.10143492
  -0.93630161]
 [ 0.09359257 -1.8026496  -0.43908896 ... -1.08857626 -1.20854789
  -1.45292111]
 ...
 [-2.99554475 -0.45830399  0.34012481 ...  1.10467049  1.13760105
  -0.85430107]
 [ 0.40200619 -1.97344679  1.24503375 ... -0.66211788 -1.45304489
  -0.07268838]
 [-0.24427351  0.0565823  -0.39322393 ... -0.66206438 -0.22814404
   1.4708737 ]] 1.0014850631696746
(256, 128) (256, 100) (256,)
[[ 0.34536926 -0.7260329  -1.66904171 ...  0.86770721  0.85074271
   1.62288334]
 [ 0.24879151 -0.84362303 -0.91340673 ...  0.05098101  1.64894311
   0.68730976]
 [-0.58937419  1.60375299  1.7387922  ...  0.64722149  1.18573077
  -2.60252626]
 ...
 [ 1.8323723   1.05669668 -2.2120196  ...  0.40395701 -1.51259429
   0.38523577]
 [ 0.35948763  0.43619202  0.71778975 ... -0.7274238  -0.64555638
   0.01008674]
 [-0.52756424 -1.32409429  0.4783707  ... -0.84568

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)


[[ 0.30007916 -0.47797042 -1.06110286 ... -0.14416984 -1.11997555
   0.01768794]
 [ 0.51499462  0.84066717 -0.43561975 ... -1.65810715 -1.0409949
   0.23724737]
 [ 0.09480095 -2.02538072 -0.74197671 ... -0.6001155   0.29422337
   1.43901062]
 ...
 [-0.42279735 -0.44827256  1.1519032  ...  0.38567797 -0.16968865
  -0.78354119]
 [-0.3932423   0.79171432 -0.81357235 ...  1.15651946  1.69260817
  -0.22318245]
 [ 0.01007324  1.32215141 -0.07058638 ... -0.11459466 -0.64151624
   1.64314177]] 0.9990048235121081
(256, 2048) (256, 100) (256,)
[[ 0.80497292  0.65354865 -0.70234682 ... -1.14551644  0.51532891
   1.34345974]
 [ 0.0659526   0.24798963 -1.17393409 ...  0.30940347 -0.71315227
   0.04958806]
 [ 0.77446647  0.26092312  1.09605871 ...  0.33929347 -0.56112621
  -0.08375187]
 ...
 [-1.17810996  0.61422967 -0.43464061 ... -1.1297454  -0.08006143
  -2.11135906]
 [ 0.26395729  0.11404312  0.91923839 ...  0.09129608  1.03413488
   0.77289437]
 [ 1.67657006 -1.1311996  -0.40744521 ... -0.62376

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)



[[ 0.19858892 -0.05452174 -1.57300447 -0.77656984]
 [ 0.54301009 -0.98428023 -0.18947216 -0.16894951]
 [ 0.34798764  0.9016335  -0.64275586  1.58699107]
 ...
 [-0.99979453 -0.91478988 -0.30741761 -1.59733027]
 [-0.70024736 -1.03762012  0.14705246 -1.78855738]
 [-0.91224885 -0.50993288 -0.13757106  0.9058283 ]] 0.9784065084220486
(256, 4) (256, 100) (256,)
[[ 1.17875718  0.97779733 -0.18059651 ...  0.89271876 -0.96682632
  -0.15346445]
 [-0.71970551  1.3255226   0.19388634 ...  1.2224361  -0.37702357
  -0.75708151]
 [ 1.02470274 -0.81335615 -0.34185602 ... -0.75648479 -0.10400736
  -0.5466936 ]
 ...
 [ 0.05999543  0.8210906   0.13688858 ... -0.29026948  0.49780937
  -1.08955128]
 [ 0.56224298  1.13345377 -0.19432334 ...  0.04759271 -0.37753136
  -1.47584204]
 [ 0.98895827 -2.13478942  0.26307323 ... -1.57238006  0.53394003
   0.52640363]] 1.0015703462552707
(256, 1024) (256, 100) (256,)
[[ 4.67137747e-01  3.39216828e-01  1.15087696e+00 ... -1.52532950e-01
   1.36428574e+00  4.80277703e

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)


[[-1.11986766 -1.97957415  1.14244632 ...  0.31961337  0.1729181
   0.43685864]
 [ 0.32273087  1.03664631 -2.15182771 ...  0.20076077  0.07991751
  -0.67553163]
 [ 0.43379816  0.45318195  0.55665006 ...  0.34175161 -0.15141501
   0.43452399]
 ...
 [ 0.87717907  0.68786461 -0.5074671  ... -0.95659494  1.38355431
  -0.31984699]
 [-0.34188929  0.29705935  1.70165906 ... -0.08237464 -0.87271956
  -1.30060569]
 [ 0.56047158 -0.91956654 -1.75938233 ...  0.47330389 -0.0097634
  -0.47559842]] 1.0239623082536429
(256, 32) (256, 100) (256,)
[[ 0.36817108  0.47757109  0.17916922  0.49750814]
 [ 0.48935115  0.43040737  0.31535646  0.31592796]
 [-0.1719372   1.12500743 -0.30911864  0.54274353]
 ...
 [ 0.19032475 -1.08678248 -0.73551416  1.65536932]
 [-2.06903759 -0.25439902  0.70454564 -0.95050079]
 [ 1.52083531  0.80563547  0.56034655 -1.33530812]] 0.9295638497761356
(256, 4) (256, 100) (256,)
[[ 2.58976752 -0.309706    0.73601671 ...  0.92136412 -0.28880108
   1.33955303]
 [-1.38529929  2.8557555

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)


[[-1.02062884  1.13324335  0.78604008 ... -1.73937366 -0.56349565
  -0.11581192]
 [-1.65251125  0.24836928  1.10046895 ... -1.9959649  -0.15522951
   0.0194566 ]
 [-0.20446664  1.83971935 -1.12801196 ... -0.39033464  0.86926508
  -0.10627839]
 ...
 [-1.51452808 -0.33498719 -1.25933421 ...  0.14015171  1.09367071
   0.89632877]
 [-0.48282178 -0.13745226 -0.98768619 ...  1.036101   -0.89502993
  -1.90272699]
 [ 0.41389778 -1.88771972 -0.22749348 ...  1.33191836 -1.03568438
  -0.07959432]] 0.9946742696883755
(256, 256) (256, 100) (256,)
[[-0.96219513 -0.83186026 -0.75466557 ...  0.87585262 -1.01891973
   0.56710471]
 [-0.36801877  0.42025706  0.80192743 ... -0.41959831 -0.24426961
   0.22875349]
 [-0.18342631  0.85949368 -0.66517974 ... -0.91097592  1.81999721
   0.52256986]
 ...
 [ 1.24385707 -1.44568715 -1.66747501 ...  0.69885843 -1.72108708
   1.67605512]
 [ 1.81978349 -0.27929121  0.28159059 ... -0.26088727  1.1964293
   2.09629782]
 [-0.45026274  1.3541502   0.61927164 ... -0.592101

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(

[[-1.06577778  7.8060824   1.07902084 ... -0.60760119 -1.34806109
   0.76465339]
 [ 0.81031436  7.69266471  1.98867367 ... -0.09657931  1.24758458
   0.1527515 ]
 [ 2.42585296  1.88697483  6.14426035 ... -1.67181593  2.19619952
  -1.38975731]
 ...
 [ 2.27264232 -0.62461654  1.95922929 ... 11.28554468  3.35882179
   0.7438451 ]
 [ 1.71607376  5.58205095  2.19639576 ...  7.06542991  0.37299338
   2.78830646]
 [-4.13812906  5.41680769 -1.5061572  ... -0.26314059 -3.16892538
   3.81381298]] 1257.0502336410607
(256, 32) (256, 100) (256,)
[[-0.54522308  1.45711485  0.67879557  3.34801566]
 [ 1.021638   -1.41324322  1.74721104 12.25233769]
 [ 1.02354633 -1.41516362  1.55078251  8.78742295]
 ...
 [ 0.80111366  0.02891896 -2.528302    0.32018204]
 [ 0.248562   -1.12265343 -1.00749067  0.23113896]
 [ 0.7935672   0.01718215  0.68772739  0.83165592]] 101.14226486705974
(256, 4) (256, 100) (256,)
[[ 0.15304286  0.57165746 12.51574006 ... -1.31513043  0.25328397
   0.68849651]
 [ 3.17979485 -0.97367

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)


[[-0.80740034  4.01410421  0.44714506 ... -0.467064   -0.28761232
   0.28956884]
 [-0.70722048  1.61883837  7.90754998 ...  0.81546125  1.1553518
  -0.33580022]
 [ 6.27786639  1.35533746  2.0935221  ... -0.73915276 -0.88620839
   1.39154246]
 ...
 [ 0.35513405  9.92280379 -0.54458574 ...  0.25425381 -1.04138279
   1.77515366]
 [ 3.28679411  3.93648937  1.89396257 ... -1.40042455  0.75173317
   0.66653386]
 [ 0.54138166  3.90243436  0.77094577 ... -0.84530097 -1.71245533
   0.64608492]] 32.45643093692279
(256, 128) (256, 100) (256,)
[[ 6.12175342e-01  1.72573894e+00  4.83746441e+00 ...  1.98765973e-01
   3.16277289e+00  2.97635519e-01]
 [ 2.20229918e+00 -1.58385996e+00  1.47134127e+00 ...  9.79328749e-01
   2.09125490e+00  8.50065666e-01]
 [ 3.53324624e+00  4.89191741e+00 -2.23978537e-01 ...  2.24481012e+00
   1.84481795e+00  1.03626903e-01]
 ...
 [ 1.30646080e+00  3.10375226e+01  3.29045818e+00 ...  1.65747187e+00
   3.68032428e+00  2.77106708e+00]
 [ 4.82573589e-01 -2.14598058e-03 -4.

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)


[[-1.21517263  9.30327951  3.72030831 ...  1.24684448 -0.18305691
   0.20789462]
 [ 3.43008638 20.08840662  9.8957743  ...  0.04790942  0.03219239
  -0.69267813]
 [-0.7300845  -0.86397638  6.77213623 ... -0.6352925   1.7407213
  -0.3742863 ]
 ...
 [-3.03910363  8.62087049 14.03117739 ...  0.61120523  0.31690562
   0.46103685]
 [ 5.30032594  3.25339683  1.0359435  ... -0.82006209  0.04370711
  -0.35602697]
 [ 5.95190209  1.44019385  4.16359167 ...  1.63769184 -0.17922528
   0.45430685]] 2.5840816480324964
(256, 4096) (256, 100) (256,)
[[ 9.98343291e+00  5.42800688e-01 -1.35529468e+00 ... -2.28287831e-01
   4.89016159e-02 -1.02696090e+00]
 [ 5.18973767e+00  1.74608093e+00 -1.57662246e+00 ...  1.36120954e+00
  -3.15918780e-01  9.00939641e-03]
 [ 4.85121078e+00  3.28359255e+00 -7.43536975e-01 ...  1.48054085e+00
  -2.23792382e+00  1.21805042e+00]
 ...
 [ 7.42738993e+00 -5.12864011e-01 -7.14622460e-01 ... -5.10572522e-01
   6.82888095e-03 -9.44972695e-01]
 [ 2.29838874e+00  2.19736074e+00  

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)


[[ 1.48399987 -0.22824062 -1.18969606 ...  0.40562913  1.1793206
   0.05539156]
 [ 2.32019202  2.82974354 -1.28525093 ...  0.3552684  -0.15406672
   0.69894721]
 [ 8.5414744   0.37212837 -2.23570603 ...  1.04772908 -0.07246573
   0.01660459]
 ...
 [ 4.55831058  0.25336082  0.73367745 ...  0.78695739  0.18880076
  -0.14215645]
 [ 6.36725965  1.88290388  1.55971749 ... -0.22585068 -0.62725468
  -1.33200054]
 [ 3.73803369  1.34296015 -1.71151044 ...  0.85381439 -0.08895704
  -1.34431626]] 191.79083492496906
(256, 4096) (256, 100) (256,)
[[-0.5265175   5.27404853 12.21442486 ...  0.07146164 -0.20283102
   1.65652212]
 [-2.08822507  4.06139753  2.63455195 ...  1.12363695 -0.10132579
  -1.86913815]
 [-1.7481132   1.98914002  0.13383669 ... -0.39242893  0.24273869
  -0.86946171]
 ...
 [-1.11026688  1.45943013 -2.65518588 ... -0.97686283  0.43865475
  -0.98690851]
 [-0.25987707  1.45456372  2.12035561 ...  0.35890434  1.87774515
  -1.86998201]
 [ 2.97793704  1.07662102  4.43469327 ...  0.78617

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(

[[-0.57658078  2.37680161  6.46343285 ... -1.13060874 -0.19398408
   0.69475093]
 [ 1.23573256  0.79587776  2.38024089 ... -0.66788912 -0.45685497
   0.69498575]
 [ 4.00286075  0.89385997  5.66797926 ... -0.26358541 -0.57836736
   0.97867653]
 ...
 [ 7.20996693  1.85759656  5.70022145 ...  0.6292048   1.72202494
   0.415938  ]
 [ 5.39102202 11.04578786 11.75157781 ... -1.02891473 -0.8789437
   1.26430144]
 [16.1898468   4.97669923  4.61444066 ...  0.70491886  1.46190582
  -1.30516221]] 2.176371764814375
(256, 2048) (256, 100) (256,)
[[ 3.37264856e+01 -2.48871292e-01  2.22513525e+01 ...  1.64623643e+00
  -6.80562399e-01  6.00115646e-01]
 [ 6.80224369e+00  9.64618492e-02  2.04453297e+00 ... -2.66737502e-01
   1.07422567e+00 -4.98095486e-01]
 [ 8.82937276e-03 -2.43171577e-01 -7.51841623e-01 ... -8.19304121e-02
   5.11893708e-01  2.09027671e-01]
 ...
 [ 6.62758992e-01 -1.24631720e-01 -1.35805372e+00 ... -1.19723164e+00
   1.12069716e+00  3.99949089e-01]
 [ 1.36024165e+00  1.10205977e+00  7

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)


[[ 5.93228277 -0.86680079 -1.56533532 ... -0.94952226  0.98743292
  -0.05574585]
 [ 1.07284399  0.08858542  0.53855492 ...  0.81305598 -1.1648704
   0.71911403]
 [ 6.05072613 -0.53316204  3.81581371 ... -0.1984517   0.95129593
   0.24623949]
 ...
 [ 1.75987782 35.96791519 -1.12545196 ... -0.73057127 -0.25452416
  -2.56062225]
 [ 0.59186946 -0.49570671  0.27964263 ...  0.58460698 -0.20188173
  -0.76722711]
 [ 6.71783438  1.75129263 -0.27172488 ...  1.02577082  0.56411765
   0.51787056]] 5.513229127918654
(256, 1024) (256, 100) (256,)
[[ 0.13416201 -1.64431317 -3.72111607 ...  0.42627778  1.08036889
  -1.74127941]
 [ 1.15816433 -2.88583968 -2.43705099 ... -1.58294663 -0.18838031
   0.62256296]
 [-0.70550851 -4.68326438 -3.9829565  ... -0.93325269  0.09613827
  -0.92601571]
 ...
 [ 2.55067639  0.80261964  2.14613929 ...  0.46423553  0.09554873
   0.33614302]
 [ 0.92586998  7.31901415 -1.03728184 ...  0.25858072 -0.74226135
   1.47076544]
 [ 0.40414353 -3.35209602  0.32727658 ... -1.260709

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)


[[11.03222587  1.08443111  3.22116306 ... -1.15472494  1.2597136
  -0.37581938]
 [ 9.82213502 -0.40817138 10.0035709  ...  2.66258932  1.05332535
   0.36863979]
 [ 1.30739605  7.11691513  6.50943786 ...  0.98548562 -2.1326147
  -1.60452849]
 ...
 [-1.69509481  2.9445991   4.28546251 ... -1.28784149  0.13771963
  -0.93729537]
 [ 6.64162762 -0.47590889 -1.35918099 ... -1.88064146 -0.90150912
   0.45037179]
 [-0.4829292   2.19178569  3.85383284 ...  0.09095781  2.36026788
   1.80348439]] 3.1895421333779934
(256, 1024) (256, 100) (256,)
[[ 2.26374680e+00  4.19944485e-01 -6.94177800e-01 ...  1.14605147e+00
   6.04795059e-01  3.11519494e-01]
 [ 1.11529879e+00 -5.28197436e-01 -9.89274746e-01 ... -2.07395800e-02
   7.29037939e-01  4.70750458e-01]
 [ 1.43278844e+00  1.74504655e+00 -2.57757967e+00 ... -7.44686300e-01
  -1.51322438e-01  9.02136661e-01]
 ...
 [ 9.90627989e+00  1.90119231e+00 -7.44637614e-03 ... -9.91091481e-01
  -4.56291884e-01  5.35422545e-01]
 [ 3.81825980e-01 -1.36647383e-01 -1

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)


[[ 0.9572535  -1.97010373 16.15516069  2.85559876]
 [ 4.50403159  0.05143647  2.49659599 -0.42476834]
 [ 5.74865684  7.11151776  1.64863685  3.7951555 ]
 ...
 [ 0.72761034 15.20830171 -0.55556931  0.58821168]
 [ 2.42968419 11.38195271  0.44427591 -0.03443369]
 [ 2.00106788 -1.39331872 -0.76360827 33.32114728]] 292.5561349018079
(256, 4) (256, 100) (256,)
[[ 2.72752608 19.44020641  8.36843115 ...  0.73575521 -0.70452644
   0.32448982]
 [ 1.51716076  0.82386742  0.20256288 ...  1.56879881 -0.23811311
  -1.88388859]
 [ 4.04450659  7.14444946  7.4634722  ...  0.55355044 -0.74598946
   1.65161927]
 ...
 [ 4.95209254  4.23784312 -1.54687317 ...  1.43613027  2.7202545
   1.78568224]
 [ 0.62068716  0.80781737 -0.51172433 ...  1.88045327  0.69368872
   0.24207917]
 [53.25621949  0.69155029  1.69500501 ... -0.42025194 -0.09982318
   1.1820762 ]] 155.84506389966424
(256, 1024) (256, 100) (256,)
[[-7.40554934e-01  1.14079260e-01  6.80900881e+00 ...  1.11726848e+00
  -5.67095647e-01 -2.28472173e-01

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)


[[ 4.09658216e+00  1.97665422e+00 -1.89203775e+00 ...  2.07086987e+00
   1.84560040e-02  1.22426693e+00]
 [ 1.44969720e+00  2.77627757e+00  5.59820828e-02 ... -4.09554317e-01
  -8.14204585e-01 -2.81336224e-01]
 [-1.79040121e+00 -4.18858794e+00 -3.46247809e+00 ...  7.36739945e-01
   1.03408223e+00  2.52853004e-01]
 ...
 [ 2.35649463e+00  1.14769358e+01 -1.95030369e+00 ...  1.12071823e+00
   1.02428323e+00 -3.81470583e-02]
 [ 5.77384077e+00  6.89306581e+00  2.10015315e+01 ...  2.75172557e+00
   2.76478304e+00  7.57807661e-01]
 [ 5.83866224e+00  2.36319188e+00  2.42907388e+00 ...  6.23062455e-01
   6.35680648e-01 -7.16377800e-01]] 32489.82434910617
(256, 256) (256, 100) (256,)
[[ 2.28879975e+00  2.08607317e+00  1.53624784e+01 ... -2.89045149e+00
  -3.37536226e-01  5.57971849e-01]
 [-1.31463806e+00  1.79642927e+00  1.22182668e+01 ... -3.20157501e-02
   1.31572030e+00  2.47396355e+00]
 [ 5.88021973e+00  2.22479192e-01  1.20088117e+00 ...  8.05056435e-01
   3.01211835e+00 -1.67878705e+00]
 .

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)
  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)


[[ 2.57350422  2.05502221 -2.64425397 ...  0.55023187 -0.32916926
   2.10228439]
 [19.75308172 13.55959101  2.5953823  ...  0.2082325   1.16196352
   1.60736469]
 [ 0.14883658  1.58542642 -0.03053576 ...  1.39483345 -0.98346761
   0.1539194 ]
 ...
 [ 1.55963853 -1.22376299 -1.65276581 ...  0.55546191 -0.23632313
   0.24065185]
 [-0.10151901  0.31093006  1.54943578 ... -1.58056168 -0.44585024
  -0.28332815]
 [ 2.46641299  0.2243282   1.36941829 ...  0.28241915  0.41786479
  -0.43984693]] 13.422970025565496
(256, 256) (256, 100) (256,)
[[13.49599847  0.05832109  1.43810677 ...  1.81570996 -2.27528409
   1.83237833]
 [13.79467694  2.68124973  2.0040745  ... -1.42712618  0.13469053
  -1.08397271]
 [ 3.44892348 -0.82320511  4.0590927  ...  0.12231475  0.35949197
   0.45211147]
 ...
 [ 1.67571918  2.76052305  2.33521542 ... -1.00200829  1.17928359
  -0.60026387]
 [ 1.22870438 -0.9984692   2.99747423 ... -2.25784101 -0.69198212
  -0.69038547]
 [ 3.54159832 -1.37887435 23.3407398  ...  0.60988

  denom = np.power(2 * np.pi, d / 2) * np.power(bandwidth_.prod(), 0.5)


KeyboardInterrupt: 

## Collider Results

In [None]:
_results_collider = Parallel(n_jobs=-1)(
    delayed(_run_parallel_posterior_sim)(
        idx_,
        n_samples,
        n_features,
        class_probs,
        seed,
        n_features_2_,
        test_size,
        max_fpr,
        "collider",
    )
    for (idx_, n_features_2_) in product(range(n_repeats), n_features_2_list)
)

## Direct-Indirect Results

In [None]:
print(n_features_2_list)

In [None]:
_results_directindirect = Parallel(n_jobs=-1)(
    delayed(_run_parallel_posterior_sim)(
        idx_,
        n_samples,
        n_features,
        class_probs,
        seed,
        n_features_2_,
        test_size,
        max_fpr,
        "direct-indirect",
    )
    for (idx_, n_features_2_) in product(range(n_repeats), n_features_2_list)
)