## Our estimator operates by discretizing the parameter space and the observation space. A finer discretization is more accurate, but takes longer to run. In this notebook, we examine this trade-off in the two-spike Gaussian setting, to inform our choices in our other experiments.

In [None]:
from estimator import KS_test, binarySearch
from utils import construct_A_gaussian
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle
from multiprocessing import Pool, freeze_support
from time import time

In [None]:
# For several different values of the discretization, estimate zetaHat(0) and zetaHat(1/10 mu2)
# as well as the time it takes to run.
# Do this across n-values, in case that changes the answer.

# Timing: Running the full set of trials took me ~16 hours, parallelized across a 96-Xeon machine (but only using 20 of them)

numTrials = 20
NUM_CORES = 20
sigma = 1
mu2s = [1, 2]
alpha = 0.05
zeta = 0.1
tolerance = 0.0001
padding = 5

discretizations = [100, 300, 500, 700, 900, 1100]
nVals = [10**4, 10**6]

folderName = "discretizationTests"
if not os.path.exists(folderName):
    os.makedirs(folderName)
    
for mu2 in mu2s:
    for n in nVals:
        for threshold in [0, mu2/10.0]:
            #Store all the results, so we can plot one line per random draw of the data
            zetaHatLists = np.zeros((len(discretizations), numTrials))
            times = []

            # Draw the data ONCE, and then try each discretization on that same draw.
            # This gives us a better idea of how discretization affects zetaHat, where the
            # randomess is only between trials, not between discretizations as well
            # We draw numTrials copies of the data, so we can parallelize over the copies in the next loop
            observationList = []
            for i in range(numTrials):
                mu = np.random.choice([0, mu2], size=(n,1), replace=True, p=[1-zeta, zeta])
                noise = np.random.randn(n, 1)*sigma
                observations = mu + noise
                observationList.append(observations)

            for (i, disc) in enumerate(discretizations):
                print(n, threshold, disc)
                try:
                    # We may want to start and stop this task; this allows us to pick up where we left off
                    # Aaag, this doesn't really work, since we won't have saved the random draw of our observations...
                    # Really, we needed to save the observationList separately
                    zetaHats, elapsedTime = pickle.load(open(folderName+"/zetaHats_"+str(n)+str(threshold)+str(mu2)+str(disc)+".p", 'rb'))
                except:
                    # If the file didn't exist, re-collect it
                    t = time()
                    # Fit the observations to the new grid, get the A matrix
                    grid = np.linspace(0-padding, mu2+padding, disc)
                    A = construct_A_gaussian(grid, sigma**2)

                    jobs = [(observationList[i], threshold, tolerance, alpha, KS_test, grid, grid, sigma, A) for i in range(numTrials)]
                    with Pool(NUM_CORES) as p:
                        zetaHats = p.starmap(binarySearch, jobs)

                    elapsedTime = (time()-t)/60
                    print('    (Elapsed time: {0:8.1f} minutes)'.format(elapsedTime))
                    pickle.dump((zetaHats, elapsedTime), open(folderName+"/zetaHats_"+str(n)+str(threshold)+str(mu2)+str(disc)+".p", 'wb'))

                # Store results
                zetaHatLists[i,:] = zetaHats

                times.append(elapsedTime)
                  
            for i in range(numTrials):
                plt.plot(discretizations, zetaHatLists[:,i])
            plt.title("Estimate vs. discretization for threshold "+str(threshold)+", mu2 "+str(mu2)+" and n="+str(n))
            plt.xlabel("Discretization")
            plt.ylabel("Estimate zetaHat")
            plt.savefig(folderName+"/plotEstimate_"+str(n)+str(threshold)+str(mu2)+".png")
            plt.show()

            plt.plot(discretizations, times)
            plt.title("Total elapsed time across "+str(numTrials)+" trials and "+str(NUM_CORES)+" cores")
            plt.xlabel("Discretization")
            plt.ylabel("Elapsed time (min)")
            plt.savefig(folderName+"/plotTime_"+str(n)+str(threshold)+str(mu2)+".png")
            plt.show()

