In [1]:
import sys, os
import multiprocessing as mp
from joblib import Parallel, delayed

import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

from power_ksamp import power_ksamp_dimension
from hyppo.independence import CCA, Dcorr, HHG, Hsic, RV, MGC, KMERF
from hyppo.tools import *

sys.path.append(os.path.realpath('..'))

In [2]:
import seaborn as sns
sns.set(color_codes=True, style='white', context='talk', font_scale=1.5)
PALETTE = sns.color_palette("Set1")
sns.set_palette(PALETTE[1:5] + PALETTE[6:])

These are some constants that are used in this notebook. If running these notebook, please only manipulate these constants if you are not running more tests. They define the sample sizes tested upon and the number of replications. The simulations tested over and the independence tests tested over are defined also.

In [3]:
from rpy2.robjects import Formula, numpy2ri
from rpy2.robjects.packages import importr


class Manova:
    r"""
    Wrapper of R MANOVA
    """
    def __init__(self):
        self.stats = importr('stats')
        self.r_base = importr('base')
        
        numpy2ri.activate()

        self.formula = Formula('X ~ Y')
        self.env = self.formula.environment

    def statistic(self, x, y):
        r"""
        Helper function to calculate the test statistic
        """
        self.env['Y'] = y
        self.env['X'] = x

        stat = self.r_base.summary(self.stats.manova(self.formula), test="Pillai", tol=0)[3][0, 1]

        return stat

In [17]:
MAX_DIMENSION = 10
STEP_SIZE = 1
DIMENSIONS = range(1, MAX_DIMENSION + STEP_SIZE, STEP_SIZE)
POWER_REPS = 5

In [5]:
SIMULATIONS = {
    "linear": "Linear",
    "exponential": "Exponential",
    "cubic": "Cubic",
    "joint_normal": "Joint Normal",
    "step": "Step",
    "quadratic": "Quadratic",
    "w_shaped": "W-Shaped",
    "spiral": "Spiral",
    "uncorrelated_bernoulli": "Bernoulli",
    "logarithmic": "Logarithmic",
    "fourth_root": "Fourth Root",
    "sin_four_pi": "Sine 4\u03C0",
    "sin_sixteen_pi": "Sine 16\u03C0",
    "square": "Square",
    "two_parabolas": "Two Parabolas",
    "circle": "Circle",
    "ellipse": "Ellipse",
    "diamond": "Diamond",
    "multiplicative_noise": "Multiplicative",
    "multimodal_independence": "Independence"
}

TESTS = [
    KMERF,
    MGC,
    Dcorr,
    Hsic,
    Manova,
    HHG,
    CCA,
    RV,
]

The following function calculates the estimated power ``POWER_REPS`` number off times and averages them. It does this iterating over the number of sample sizes.

**Note: We only recommend running this code if running the next 2 cells ONCE to generate the csv files used to visualize the plots. This code takes a very long time to run and if running, we recommend using a machine with many cores.**

In [6]:
def estimate_power(sim, test):
    est_power = np.array([np.mean([power_ksamp_dimension(test, rot_ksamp, sim, p=i) for _ in range(POWER_REPS)])
                          for i in DIMENSIONS])
    np.savetxt('../ksample/ksamp_vs_dimension/{}_{}.csv'.format(sim, test.__name__),
               est_power, delimiter=',')
    
    return est_power

In [7]:
outputs = Parallel(n_jobs=-1, verbose=100)(
    [delayed(estimate_power)(sim, test) for sim in SIMULATIONS.keys() for test in TESTS]
)

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 16 concurrent workers.


Unable to determine R library path: Command '('/Library/Frameworks/R.framework/Resources/bin/Rscript', '-e', 'cat(Sys.getenv("LD_LIBRARY_PATH"))')' returned non-zero exit status 2.
Unable to determine R library path: Command '('/Library/Frameworks/R.framework/Resources/bin/Rscript', '-e', 'cat(Sys.getenv("LD_LIBRARY_PATH"))')' returned non-zero exit status 2.


[Parallel(n_jobs=-1)]: Done   1 tasks      | elapsed:  1.2min
[Parallel(n_jobs=-1)]: Done   2 tasks      | elapsed:  1.5min


Unable to determine R library path: Command '('/Library/Frameworks/R.framework/Resources/bin/Rscript', '-e', 'cat(Sys.getenv("LD_LIBRARY_PATH"))')' returned non-zero exit status 2.


[Parallel(n_jobs=-1)]: Done   3 tasks      | elapsed:  2.8min
[Parallel(n_jobs=-1)]: Done   4 tasks      | elapsed:  3.0min
[Parallel(n_jobs=-1)]: Done   5 tasks      | elapsed:  4.7min
[Parallel(n_jobs=-1)]: Done   6 tasks      | elapsed:  6.0min
[Parallel(n_jobs=-1)]: Done   7 tasks      | elapsed: 11.3min
[Parallel(n_jobs=-1)]: Done   8 tasks      | elapsed: 13.0min
[Parallel(n_jobs=-1)]: Done   9 tasks      | elapsed: 13.1min


R[write to console]: 

Unable to determine R library path: Command '('/Library/Frameworks/R.framework/Resources/bin/Rscript', '-e', 'cat(Sys.getenv("LD_LIBRARY_PATH"))')' returned non-zero exit status 2.


[Parallel(n_jobs=-1)]: Done  10 tasks      | elapsed: 15.3min
[Parallel(n_jobs=-1)]: Done  11 tasks      | elapsed: 15.4min
[Parallel(n_jobs=-1)]: Done  12 tasks      | elapsed: 16.8min
[Parallel(n_jobs=-1)]: Done  13 tasks      | elapsed: 20.5min
[Parallel(n_jobs=-1)]: Done  14 tasks      | elapsed: 21.6min
[Parallel(n_jobs=-1)]: Done  15 tasks      | elapsed: 25.2min
[Parallel(n_jobs=-1)]: Done  16 tasks      | elapsed: 26.5min


Unable to determine R library path: Command '('/Library/Frameworks/R.framework/Resources/bin/Rscript', '-e', 'cat(Sys.getenv("LD_LIBRARY_PATH"))')' returned non-zero exit status 2.


[Parallel(n_jobs=-1)]: Done  17 tasks      | elapsed: 27.2min
[Parallel(n_jobs=-1)]: Done  18 tasks      | elapsed: 29.0min
[Parallel(n_jobs=-1)]: Done  19 tasks      | elapsed: 29.3min
[Parallel(n_jobs=-1)]: Done  20 tasks      | elapsed: 32.4min
[Parallel(n_jobs=-1)]: Done  21 tasks      | elapsed: 32.4min
[Parallel(n_jobs=-1)]: Done  22 tasks      | elapsed: 35.8min
[Parallel(n_jobs=-1)]: Done  23 tasks      | elapsed: 39.2min
[Parallel(n_jobs=-1)]: Done  24 tasks      | elapsed: 40.0min
[Parallel(n_jobs=-1)]: Done  25 tasks      | elapsed: 45.6min
[Parallel(n_jobs=-1)]: Done  26 tasks      | elapsed: 48.4min
[Parallel(n_jobs=-1)]: Done  27 tasks      | elapsed: 48.7min
[Parallel(n_jobs=-1)]: Done  28 tasks      | elapsed: 51.2min


R[write to console]: 



[Parallel(n_jobs=-1)]: Done  29 tasks      | elapsed: 51.3min
[Parallel(n_jobs=-1)]: Done  30 tasks      | elapsed: 58.0min


Unable to determine R library path: Command '('/Library/Frameworks/R.framework/Resources/bin/Rscript', '-e', 'cat(Sys.getenv("LD_LIBRARY_PATH"))')' returned non-zero exit status 2.


[Parallel(n_jobs=-1)]: Done  31 tasks      | elapsed: 63.7min
[Parallel(n_jobs=-1)]: Done  32 tasks      | elapsed: 69.8min
[Parallel(n_jobs=-1)]: Done  33 tasks      | elapsed: 72.8min
[Parallel(n_jobs=-1)]: Done  34 tasks      | elapsed: 73.0min
[Parallel(n_jobs=-1)]: Done  35 tasks      | elapsed: 75.6min
[Parallel(n_jobs=-1)]: Done  36 tasks      | elapsed: 85.5min
[Parallel(n_jobs=-1)]: Done  37 tasks      | elapsed: 104.9min


Unable to determine R library path: Command '('/Library/Frameworks/R.framework/Resources/bin/Rscript', '-e', 'cat(Sys.getenv("LD_LIBRARY_PATH"))')' returned non-zero exit status 2.


[Parallel(n_jobs=-1)]: Done  38 tasks      | elapsed: 106.2min
[Parallel(n_jobs=-1)]: Done  39 tasks      | elapsed: 106.6min
[Parallel(n_jobs=-1)]: Done  40 tasks      | elapsed: 108.0min
[Parallel(n_jobs=-1)]: Done  41 tasks      | elapsed: 108.3min
[Parallel(n_jobs=-1)]: Done  42 tasks      | elapsed: 109.4min
[Parallel(n_jobs=-1)]: Done  43 tasks      | elapsed: 114.5min


R[write to console]: 



[Parallel(n_jobs=-1)]: Done  44 tasks      | elapsed: 114.8min


Unable to determine R library path: Command '('/Library/Frameworks/R.framework/Resources/bin/Rscript', '-e', 'cat(Sys.getenv("LD_LIBRARY_PATH"))')' returned non-zero exit status 2.


[Parallel(n_jobs=-1)]: Done  45 tasks      | elapsed: 119.5min
[Parallel(n_jobs=-1)]: Done  46 tasks      | elapsed: 124.7min


R[write to console]: 



[Parallel(n_jobs=-1)]: Done  47 tasks      | elapsed: 126.1min
[Parallel(n_jobs=-1)]: Done  48 tasks      | elapsed: 126.6min
[Parallel(n_jobs=-1)]: Done  49 tasks      | elapsed: 127.7min
[Parallel(n_jobs=-1)]: Done  50 tasks      | elapsed: 133.7min
[Parallel(n_jobs=-1)]: Done  51 tasks      | elapsed: 137.7min
[Parallel(n_jobs=-1)]: Done  52 tasks      | elapsed: 140.5min
[Parallel(n_jobs=-1)]: Done  53 tasks      | elapsed: 150.8min
[Parallel(n_jobs=-1)]: Done  54 tasks      | elapsed: 152.8min
[Parallel(n_jobs=-1)]: Done  55 tasks      | elapsed: 155.3min
[Parallel(n_jobs=-1)]: Done  56 tasks      | elapsed: 155.4min
[Parallel(n_jobs=-1)]: Done  57 tasks      | elapsed: 157.1min
[Parallel(n_jobs=-1)]: Done  58 tasks      | elapsed: 165.1min
[Parallel(n_jobs=-1)]: Done  59 tasks      | elapsed: 174.0min
[Parallel(n_jobs=-1)]: Done  60 tasks      | elapsed: 174.3min
[Parallel(n_jobs=-1)]: Done  61 tasks      | elapsed: 175.7min
[Parallel(n_jobs=-1)]: Done  62 tasks      | elapsed: 1

Unable to determine R library path: Command '('/Library/Frameworks/R.framework/Resources/bin/Rscript', '-e', 'cat(Sys.getenv("LD_LIBRARY_PATH"))')' returned non-zero exit status 2.


[Parallel(n_jobs=-1)]: Done  66 tasks      | elapsed: 211.6min
[Parallel(n_jobs=-1)]: Done  67 tasks      | elapsed: 217.9min
[Parallel(n_jobs=-1)]: Done  68 tasks      | elapsed: 222.6min
[Parallel(n_jobs=-1)]: Done  69 tasks      | elapsed: 226.7min
[Parallel(n_jobs=-1)]: Done  70 tasks      | elapsed: 227.6min
[Parallel(n_jobs=-1)]: Done  71 tasks      | elapsed: 239.5min
[Parallel(n_jobs=-1)]: Done  72 tasks      | elapsed: 252.0min


Unable to determine R library path: Command '('/Library/Frameworks/R.framework/Resources/bin/Rscript', '-e', 'cat(Sys.getenv("LD_LIBRARY_PATH"))')' returned non-zero exit status 2.


[Parallel(n_jobs=-1)]: Done  73 tasks      | elapsed: 263.0min
[Parallel(n_jobs=-1)]: Done  74 tasks      | elapsed: 267.7min
[Parallel(n_jobs=-1)]: Done  75 tasks      | elapsed: 273.2min
[Parallel(n_jobs=-1)]: Done  76 tasks      | elapsed: 277.8min
[Parallel(n_jobs=-1)]: Done  77 tasks      | elapsed: 281.5min
[Parallel(n_jobs=-1)]: Done  78 tasks      | elapsed: 294.2min
[Parallel(n_jobs=-1)]: Done  79 tasks      | elapsed: 316.0min
[Parallel(n_jobs=-1)]: Done  80 tasks      | elapsed: 328.8min
[Parallel(n_jobs=-1)]: Done  81 tasks      | elapsed: 331.8min
[Parallel(n_jobs=-1)]: Done  82 tasks      | elapsed: 336.3min
[Parallel(n_jobs=-1)]: Done  83 tasks      | elapsed: 340.4min
[Parallel(n_jobs=-1)]: Done  84 tasks      | elapsed: 376.9min
[Parallel(n_jobs=-1)]: Done  85 tasks      | elapsed: 388.4min
[Parallel(n_jobs=-1)]: Done  86 tasks      | elapsed: 409.1min
[Parallel(n_jobs=-1)]: Done  87 tasks      | elapsed: 421.0min
[Parallel(n_jobs=-1)]: Done  88 tasks      | elapsed: 4

Unable to determine R library path: Command '('/Library/Frameworks/R.framework/Resources/bin/Rscript', '-e', 'cat(Sys.getenv("LD_LIBRARY_PATH"))')' returned non-zero exit status 2.


[Parallel(n_jobs=-1)]: Done 101 tasks      | elapsed: 772.7min
[Parallel(n_jobs=-1)]: Done 102 tasks      | elapsed: 777.0min
[Parallel(n_jobs=-1)]: Done 103 tasks      | elapsed: 780.3min
[Parallel(n_jobs=-1)]: Done 104 tasks      | elapsed: 781.8min
[Parallel(n_jobs=-1)]: Done 105 tasks      | elapsed: 783.2min


R[write to console]: 



[Parallel(n_jobs=-1)]: Done 106 tasks      | elapsed: 784.6min
[Parallel(n_jobs=-1)]: Done 107 tasks      | elapsed: 796.2min


Unable to determine R library path: Command '('/Library/Frameworks/R.framework/Resources/bin/Rscript', '-e', 'cat(Sys.getenv("LD_LIBRARY_PATH"))')' returned non-zero exit status 2.


[Parallel(n_jobs=-1)]: Done 108 tasks      | elapsed: 796.7min
[Parallel(n_jobs=-1)]: Done 109 tasks      | elapsed: 807.0min
[Parallel(n_jobs=-1)]: Done 111 out of 140 | elapsed: 811.7min remaining: 212.1min
[Parallel(n_jobs=-1)]: Done 113 out of 140 | elapsed: 813.6min remaining: 194.4min


Unable to determine R library path: Command '('/Library/Frameworks/R.framework/Resources/bin/Rscript', '-e', 'cat(Sys.getenv("LD_LIBRARY_PATH"))')' returned non-zero exit status 2.


[Parallel(n_jobs=-1)]: Done 115 out of 140 | elapsed: 835.8min remaining: 181.7min
[Parallel(n_jobs=-1)]: Done 117 out of 140 | elapsed: 844.4min remaining: 166.0min
[Parallel(n_jobs=-1)]: Done 119 out of 140 | elapsed: 856.7min remaining: 151.2min
[Parallel(n_jobs=-1)]: Done 121 out of 140 | elapsed: 867.2min remaining: 136.2min


Unable to determine R library path: Command '('/Library/Frameworks/R.framework/Resources/bin/Rscript', '-e', 'cat(Sys.getenv("LD_LIBRARY_PATH"))')' returned non-zero exit status 2.


[Parallel(n_jobs=-1)]: Done 123 out of 140 | elapsed: 877.2min remaining: 121.2min


R[write to console]: 



[Parallel(n_jobs=-1)]: Done 125 out of 140 | elapsed: 879.9min remaining: 105.6min
[Parallel(n_jobs=-1)]: Done 127 out of 140 | elapsed: 880.4min remaining: 90.1min
[Parallel(n_jobs=-1)]: Done 129 out of 140 | elapsed: 901.6min remaining: 76.9min
[Parallel(n_jobs=-1)]: Done 131 out of 140 | elapsed: 923.5min remaining: 63.4min


R[write to console]: 

R[write to console]: 

R[write to console]: 



[Parallel(n_jobs=-1)]: Done 133 out of 140 | elapsed: 959.2min remaining: 50.5min


R[write to console]: 



[Parallel(n_jobs=-1)]: Done 135 out of 140 | elapsed: 1029.5min remaining: 38.1min


R[write to console]: 

R[write to console]: 



[Parallel(n_jobs=-1)]: Done 137 out of 140 | elapsed: 1167.9min remaining: 25.6min


R[write to console]: 

R[write to console]: 



[Parallel(n_jobs=-1)]: Done 140 out of 140 | elapsed: 1204.4min finished


The following code loops over each saved independence test file and generates absolute power curves for each test and for each simulation modality.

In [24]:
def plot_power():
    fig, ax = plt.subplots(nrows=4, ncols=5, figsize=(25,20))
    
    plt.suptitle("Multivariate Three-Sample Testing Increasing Dimension", y=0.93, va='baseline')
    
    for i, row in enumerate(ax):
        for j, col in enumerate(row):
            count = 5*i + j
            sim = list(SIMULATIONS.keys())[count]
            
            for test in TESTS:
                test_name = test.__name__
                power = np.genfromtxt('../ksample/ksamp_vs_dimension/{}_{}.csv'.format(sim, test_name), delimiter=',')
#                 manova_power = np.genfromtxt('../ksample/ksamp_vs_dimension/{}_Manova.csv'.format(sim), delimiter=',')
                
                if test_name == "MGC":
                    col.plot(DIMENSIONS[1:], power[1:], color="#e41a1c", label=test_name, lw=4)
                elif test_name == "KMERF":
                    col.plot(DIMENSIONS[1:], power[1:], color="#e41a1c", label=test_name, lw=4, linestyle='dashed')
                elif test_name == "Manova":
                    col.plot(DIMENSIONS[1:], power[1:], color="#000000", label=test_name, lw=4)
                else:
                    col.plot(DIMENSIONS[1:], power[1:], label=test_name, lw=2)
                col.set_xticks([])
                if i == 3:
                    col.set_xticks([DIMENSIONS[1], DIMENSIONS[-1]])
                col.set_ylim(-0.05, 1.05)
                col.set_yticks([])
                if j == 0:
                    col.set_yticks([0, 1])
                col.set_title(SIMULATIONS[sim])
    
    fig.text(0.5, 0.08, 'Dimension', ha='center')
    fig.text(0.08, 0.5, 'Statistical Power Relative to Manova', va='center', rotation='vertical')
    leg = plt.legend(bbox_to_anchor=(0.5, 0.07), bbox_transform=plt.gcf().transFigure,
                     ncol=len(TESTS), loc='upper center')
    leg.get_frame().set_linewidth(0.0)
    for legobj in leg.legendHandles:
        legobj.set_linewidth(5.0)
    plt.subplots_adjust(hspace=.50)
    plt.savefig('../ksample/figs/ksamp_power_dimension.pdf', transparent=True, bbox_inches='tight')

In [25]:
plot_power()