In [1]:
from scipy import stats
import numpy as np
from scipy import optimize

In [2]:
def theta_norm_upper_bound(Y, alpha=0.05):
    Y_norm = sum(Y**2)
    df = Y.shape[0]
    if stats.chi2.cdf(Y_norm, df) < alpha: return 1e-5
    return optimize.newton(
        func=lambda lmbda: stats.ncx2(df, lmbda).cdf(Y_norm)-alpha,
        x0=max([1e-4,Y_norm-df]),
        tol=1e-4)

def b_bayes(Y, alpha, tau):
    upper = theta_norm_upper_bound(Y, (1-alpha)/2)
    Y_norm = sum(Y**2)
    g = 1/(1+tau**2)
    N = len(Y)
    offset = - ((g/2)*upper + (g**2)*Y_norm  )
    ncChi2_quantile = stats.ncx2(
        N,
        (1./4.)*upper).ppf((1-alpha)/2)
    ncChi2_term = 2*g*ncChi2_quantile
    bya = offset + ncChi2_term
    return bya, upper

def c_value(Y, tau):
    bya = lambda alpha: b_bayes(Y, alpha, tau)[0]
    c_lower_bound =  1e-10
    if bya(c_lower_bound) < 0: return c_lower_bound
    if bya(1.-c_lower_bound) > 0: return 1.

    c_val = optimize.bisect(
        f=lambda alpha: bya(alpha),
        a=c_lower_bound, b=1.-c_lower_bound,xtol=1e-3)
    return c_val

def where_bya_breaks(win, Y, tau):
    bya = lambda alpha: b_bayes(Y, alpha, tau)[0]
    c_lower_bound =  1e-10
    if win>bya(c_lower_bound): return c_lower_bound
    if win<bya(1.-c_lower_bound): return 1.-c_lower_bound
    
    where_breaks = optimize.bisect(
        f=lambda alpha: bya(alpha)-win,
        a=c_lower_bound, b=1.-c_lower_bound, xtol=1e-3)
    return where_breaks
    
def W_lower_bound_with_norm(Y, theta_norm, tau, alpha=0.05):
    Y_norm = sum(Y**2)
    g = 1/(1+tau**2)
    N = len(Y)
    
    offset = - ((g/2)*theta_norm + (g**2)*Y_norm  )
    ncChi2_quantile = stats.ncx2(
        N-1,
        (1./4.)*theta_norm).ppf((1-alpha)/2)
    ncChi2_term = 2*g*ncChi2_quantile
    bya = offset + ncChi2_term
    return bya

# Test for different $\theta$ and samples

In [3]:
empirical_bayes=True
np.random.seed(42)
alpha=0.9
N=50
n_bad_conf = 0
n_bad_by = 0
n_reps = 250
wins, bys, c_values, by_breaks = [], [], [], []
mle_errs, bayes_errs = [], []
theta_norms, Y_norms = [], []
print("starting")
base_fn = "../results/empirical_bayes//"
for tau_true in [0., 2.]:
    print("tau_true = ", tau_true)
    theta = tau_true*np.ones([N])
    for rep in range(n_reps):
        if rep%int(n_reps/10)==0:print("rep %04d/%04d"%(rep,n_reps))

        Y=theta +np.random.normal(size=theta.shape)

        if empirical_bayes:
            tau_hat = max([sum(Y**2)/(N-2) - 1., 0.])
            tau = np.sqrt(tau_hat)
        else:
            tau = tau_true

        bayes = Y - (1/(1+tau**2))*Y

        theta_norm, Y_norm = np.sqrt(np.sum(theta**2)), np.sqrt(sum(Y**2))
        mle_err, bayes_err = sum((Y-theta)**2), sum((theta-bayes)**2)
        win = mle_err - bayes_err

        by, upper = b_bayes(Y, alpha, tau)
        c_val = c_value(Y, tau)
        by_break = where_bya_breaks(win, Y, tau)

        theta_norms.append(theta_norm)
        Y_norms.append(Y_norm)
        mle_errs.append(float(mle_err))
        bayes_errs.append(float(bayes_err))
        wins.append(float(win))
        bys.append(float(by))
        c_values.append(c_val)
        by_breaks.append(by_break)


        if  win < by: n_bad_by += 1
        if  theta_norm**2 > upper:
            n_bad_conf += 1

    print("\n\nalpha", alpha)
    print("n_bad_conf/n_reps", n_bad_conf/n_reps)
    print("n_bad_cy/n_reps", n_bad_by/n_reps)
    
    # Save out results
    fn_out = base_fn + "JS_EBayes_theta_norm=%02.02f_N=%03d_alpha=%0.02f.tsv"%(np.linalg.norm(theta)**2, N, alpha)

    # For each replicate save
    # $\| P_1^\perp \theta\|,  \| P_1^\perp y\|, \|\hat \theta - \theta\|^2, \|\theta^* - \theta\|^2, W, b(y, \alpha), c(y)$ and $b_{break}$
    with open(fn_out, 'w') as f:
        f.write("\t".join([
            "Rep", "theta_norm", "Y_norm", "MLE_Err", "Bayes_Err", "Win", "bya", "c_value", "bya_breakpoint"
        ]) + "\n")
        for rep, rep_vals in enumerate(zip(
    theta_norms, Y_norms, mle_errs, bayes_errs, wins, bys, c_values, by_breaks)):
            f.write("%04d\t"%rep + "\t".join(["%0.05f"%val for val in rep_vals])+"\n")
    print("\n\n\n")

starting
tau_true =  0.0
rep 0000/0250
rep 0025/0250
rep 0050/0250
rep 0075/0250
rep 0100/0250
rep 0125/0250
rep 0150/0250
rep 0175/0250
rep 0200/0250
rep 0225/0250


alpha 0.9
n_bad_conf/n_reps 0.0
n_bad_cy/n_reps 0.04




tau_true =  2.0
rep 0000/0250
rep 0025/0250
rep 0050/0250
rep 0075/0250
rep 0100/0250
rep 0125/0250
rep 0150/0250
rep 0175/0250
rep 0200/0250
rep 0225/0250


alpha 0.9
n_bad_conf/n_reps 0.028
n_bad_cy/n_reps 0.072




