# Compare our estimator to baselines from the literature, for several parameterizations of the two-spike Gaussian model.

In [None]:
import numpy as np
from sampling_utils import getSamples_gaussianTwoSpike
from syntheticExperimentWrappers import drawAndEstimate_GaussianTwoSpike
import matplotlib.pyplot as plt
from time import time
from multiprocessing import Pool, freeze_support
import pickle
import csv
import os

In [None]:
# Use different amounts of parallelization for different baselines
# Our estimator is light on memory, heavy on compute
# KR and DKW are very memory-intensive (because they need to keep statistics for EACH X_i, whereas we only keep them over a grid)
# The MLE is compute-heavy, but the solver does some of its own parallelization internally (hence the fewer threads)
NUM_CORES = {"KS": 30, "KR": 1, "MLE": 10, "DKW": 1}

In [None]:
# For various estimators, a few values of n, and different
# sizes of the alternate effect in the two-spike model,
# get the value of zetaHat for both gamma=0 and gamma=1/2gamma*

discretization = 900
sigma = 1
zeta = 0.1
tolerance = 0.0001
numSamples = 60
alpha = 0.05
meanList = [2, 1.6, 1.2, 0.8, 0.4]
estimators = ["KS", "KR", "MLE", "DKW"]
thresholdTypes = ["zero", "half"]
nVals = [10**4, 10**5, 10**6]

In [None]:
folderName = "facetGraphs"
if not os.path.exists(folderName):
    os.makedirs(folderName)
for estimator in estimators:
    if not os.path.exists(os.path.join(folderName, estimator)):
        os.makedirs(os.path.join(folderName, estimator))

In [None]:
# Test each estimator
for estimator in estimators:
    print(estimator)
    # For every estimator, test each value of the alternate mean
    for mu2 in meanList:
        print("   testing mean", mu2)
        # For every value of the alternate mean, estimate zetaHat(0) and zetaHat(1/2 mu2)
        for thresholdType in thresholdTypes:
            if thresholdType == "zero":
                threshold = 0
            elif thresholdType == "half":
                threshold = mu2/2.0
            else:
                print("undefined thresholdType")
            print("      testing threshold", thresholdType, threshold)
            for n in nVals:
                print("         testing n", n)
                try:
                    # We may want to start and stop this task; this allows us to pick up where we left off
                    zetaHats = pickle.load(open("./facetGraphs/"+estimator+"/zetaHats_"+str(discretization)+"_"+str(threshold)+"_"+str(mu2)+"_"+str(n)+".p", 'rb'))
                except:
                    # If the file didn't exist, collect it
                    t = time()
                    # Redraw the values every time
                    jobs = [(threshold, n, zeta, mu2, discretization, sigma, tolerance, estimator, alpha) for _ in range(numSamples)]
                    with Pool(NUM_CORES[estimator]) as p:
                        zetaHats = p.starmap(drawAndEstimate_GaussianTwoSpike, jobs)

                    print('    (Elapsed time: {0:8.1f} minutes)'.format((time()-t)/60))
                    pickle.dump(zetaHats, open("./facetGraphs/"+estimator+"/zetaHats_"+str(discretization)+"_"+str(threshold)+"_"+str(mu2)+"_"+str(n)+".p", 'wb'))
                    print('zeta hats',zetaHats)

## Plot the comparisons to baselines. Note, you will also have to run the "ashrBaseline.R" script, to get comparisons to the Empirical Bayes method Ashr (Stephens, 2016)

In [None]:
def getAshrLCB(n, threshold, mu2):
    pastHeader = False
    lcb = []
    if int(threshold) == threshold:
        threshold = int(threshold)
    with open("./facetGraphs/ashr/zetaHats_ashr_"+str(threshold)+"_"+str(mu2)+"_"+str(n)+".csv", newline='') as csvfile:
        spamreader = csv.reader(csvfile, delimiter=',', quotechar='|')
        for row in spamreader:
            if pastHeader:
                lcb.append(float(row[1]))
            else:
                # Skip the first row; it is the header
                pastHeader = True
    return(lcb)

def getAshrPlugin(n, threshold, mu2):
    pastHeader = False
    plugin = []
    if int(threshold) == threshold:
        threshold = int(threshold)
    with open("./facetGraphs/ashr/zetaHats_ashr_"+str(threshold)+"_"+str(mu2)+"_"+str(n)+".csv", newline='') as csvfile:
        spamreader = csv.reader(csvfile, delimiter=',', quotechar='|')
        for row in spamreader:
            if pastHeader:
                plugin.append(float(row[0]))
            else:
                # Skip the first row; it is the header
                pastHeader = True
    return(plugin)

In [None]:
colors = {"KS":'blue',
         "KR":'orange',
         "MLE":'red',
         "ashr-lcb":'green',
         "ashr-plugin":'purple',
         "DKW":'magenta'}
names = {"KS":'KS (ours)',
        "KR": 'KR\n[Katsevich & Ramdas 2018]',
        "MLE": 'MLE (plugin)',
        "ashr-lcb": "Ashr (lcb)\n[Stephens 2017]",
        "ashr-plugin": "Ashr (plugin)\n[Stephens 2017]",
        "DKW": "MR-DKW\n[Meinshausen & Rice 2006]"}
markers = {"KS":'o',
        "KR": 'v',
        "MLE": 's',
        "ashr-lcb": "^",
        "ashr-plugin": ">",
        "DKW": "<"}

plt.rcParams["font.size"] = 20
plt.rcParams["figure.figsize"] = [6,3.5]

# All estimators
estimators=["KS", "KR", "DKW", "ashr-lcb", "MLE", "ashr-plugin"]
# Only conservative estimators
#estimators=["KS", "KR", "ashr-lcb", "DKW"]

errorbar = "line"

# For attractive plotting; change these if you'd like
ylimUpper = {"zero":1, "half":0.18}

figsize = (14, 7)
fig, ax = plt.subplots(2, 3, figsize=figsize)
plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0.1, hspace=0.1)

for j, n in enumerate(nVals):
    for i, thresholdType in enumerate(["zero", "half"]):
        ax[i,j].plot(meanList, 
                     [zeta]*len(meanList), 
                     linestyle='--', 
                     color='k', 
                     label="Truth")
        ax[i,j].set_ylim([0, ylimUpper[thresholdType]])
        for estimator in estimators:
            medianZetaHats = []
            lbZetaHats = []
            ubZetaHats = []

            for mu2 in meanList:
                # Get the 100 trials of zetaHat for this mean in this plot
                if thresholdType == "zero":
                    threshold = 0
                elif thresholdType == "half":
                    threshold = mu2/2.0
                else:
                    print("Unrecognized threshold")
                if estimator == "ashr-lcb":
                    zetaHats = getAshrLCB(n, threshold, mu2)
                elif estimator == "ashr-plugin":
                    zetaHats = getAshrPlugin(n, threshold, mu2)
                else:
                    zetaHats = pickle.load(open("./facetGraphs/"+estimator+"/zetaHats_"+str(discretization)+"_"+str(threshold)+"_"+str(mu2)+"_"+str(n)+".p", 'rb'))
                    
                # Get mean and 90% CI's
                zetaHats = np.sort(zetaHats)
                medianZetaHats.append(np.median(zetaHats))
                lbZetaHats.append(zetaHats[int(len(zetaHats)*0.05)]) # 5th percentile
                ubZetaHats.append(zetaHats[int(len(zetaHats)*0.95)]) # 95th percentile
                
            ax[i,j].errorbar(meanList, medianZetaHats, 
                                 yerr=np.vstack([np.array(medianZetaHats) - np.array(lbZetaHats),
                                                 np.array(ubZetaHats) - np.array(medianZetaHats)]), 
                                 label=names[estimator],
                                color=colors[estimator],
                            marker=markers[estimator],
                            linewidth=3,
                            ms=10,
                            alpha=0.7)
        if i == 0:
            # Top row, don't show x-axis values
            ax[i,j].set_xticklabels([])
            # But do label n-values
            ax[i,j].set_title(r'$n=10^'+str(int(np.log10(n)))+'$')
        if j != 0:
            # Only show y-values on the left axis
            ax[i,j].set_yticklabels([])
        if j == (len(nVals) - 1):
            # Right label threshold type
            ax[i,j].yaxis.set_label_position("right")
            if thresholdType == "zero":
                ax[i,j].set_ylabel(r'Threshold $\gamma = 0$', rotation=270, va='bottom')
            else:
                ax[i,j].set_ylabel(r'Threshold $\gamma = \frac{1}{2}\gamma_*$', rotation=270, va='bottom')
        if j == 0 and i == 1:
            # Put in a fake axis label, to hold whitespace?
            ax[i,j].set_ylabel(r'Estimated mass $\widehat\zeta(\gamma)$', 
                               ha='left')

plt.legend(
            bbox_to_anchor=(0.5,0.01), 
            loc='lower center',
            bbox_transform=fig.transFigure, 
            ncol=7, 
            borderaxespad=-7.5,
            prop={'size': 15}
        )


# add a big axis, hide frame
fig.add_subplot(111, frameon=False)
# hide tick and tick label of the big axis
plt.tick_params(labelcolor='none', top=False, bottom=False, left=False, right=False)
plt.xlabel(r'Alternate mean $\gamma_*$')
#plt.ylabel(r'Estimated mass above the threshold, $\widehat\zeta(\gamma)$',
#          x=0, y=0.5)

plt.savefig("facetPlot_final.pdf", bbox_inches='tight')
plt.savefig("facetPlot_final.png", bbox_inches='tight')
plt.show()