In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import timeit
import sys, os
sys.path.append(os.path.realpath('..'))
from scipy.interpolate import interp1d

from hyppo.independence import *
from hyppo.ksample import KSample
from hyppo.sims import linear

In [None]:
sns.set(color_codes=True, style='white', context='talk', font_scale=1.5)
PALETTE = sns.color_palette("Set1")
sns.set_palette(PALETTE[1:5] + PALETTE[6:], n_colors=9)

In [None]:
N = [
    50,
    100,
    200,
    500,
    1000,
    2000,
    5000,
    10000
]

TESTS = {
    "indep" : [Dcorr, MGC, HHG],
    "ksample" : [Hsic],
    "fast" : [Dcorr]
}

In [None]:
# Function runs wall time estimates using timeit (for python) and 
def estimate_wall_times(tests, **kwargs):
    for test in tests:
        times = []
        for n in N:
            x, y = linear(n, 1, noise=True)
            if test_type == "ksample":
                hyp_test = KSample(test.__name__)
            else:
                hyp_test = test()
            time = %timeit -n 1 -r 3 -q -o hyp_test.test(x, y, workers=-1, **kwargs)
            times.append(np.min(time.timings))
        np.savetxt('../benchmarks/perf/{}_{}.csv'.format(test_type, test.__name__), times, delimiter=',')
    return times

In [None]:
kwargs = {}
for test_type in TESTS.keys():
    if test_type == "fast":
        kwargs["auto"] = True
    estimate_wall_times(TESTS[test_type], **kwargs)

In [None]:
FONTSIZE = 30

TEST_METADATA = {
    "MGC" : {
        "test_name" : "MGC (hyppo)",
        "color" : "#e41a1c"
    },
    "HHG" : {
        "test_name" : "HHG (hyppo)",
        "color" : "#4daf4a"
    },
    "Dcorr" : {
        "test_name" : "Dcorr (hyppo)",
        "color" : "#377eb8"
    },
    "ksample_Hsic" : {
        "test_name" : "MMD (hyppo)",
        "color" : "#ff7f00"
    },
    "fast_Dcorr" : {
        "test_name" : "Fast Dcorr (hyppo)",
        "color" : "#984ea3"
    },
    "HHG_hhg" : {
        "test_name" : "HHG (HHG)",
        "color" : "#4daf4a"
    },
    "Dcorr_energy" : {
        "test_name" : "Dcorr (energy)",
        "color" : "#377eb8"
    },
    "Dcorr_kernlab" : {
        "test_name" : "MMD (kernlab)",
        "color" : "#ff7f00"
    },
}


def plot_wall_times():
    fig = plt.figure(figsize=(10,7))
    ax = plt.subplot(111)
    
    i = 0
    kwargs = {}
    for file_name, metadata in TEST_METADATA.items():
        test_times = np.genfromtxt('../hyppo/perf/{}.csv'.format(file_name), delimiter=',')
        
        if file_name in ["HHG_hhg", "Dcorr_energy", "Dcorr_kernlab"]:
            kwargs = {"linestyle" : "dashed"}
        ax.plot(N, test_times, color=metadata["color"], label=metadata["test_name"], lw=5, **kwargs)
        i += 1
    
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xlabel('Number of Samples')
    ax.set_ylabel('Execution Time\n(Seconds)')
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xticks([1e2, 1e3, 1e4])
    ax.set_yticks([1e-4, 1e-2, 1e0, 1e2, 1e4])
    
    leg = plt.legend(bbox_to_anchor=(0.5, -0.05), bbox_transform=plt.gcf().transFigure,
                     ncol=2, loc='upper center')
    leg.get_frame().set_linewidth(0.0)
    for legobj in leg.legendHandles:
        legobj.set_linewidth(5.0)
    plt.savefig('../hyppo/figs/wall_times.pdf', transparent=True, bbox_inches='tight')

In [None]:
plot_wall_times()