In [1]:
import numpy as np
from scipy import optimize, stats
import tensorflow as tf

import sys
sys.path.append("../src/")
import logistic_regression as lr
import affine_operator_win_bounds as affine_ops

In [2]:
def where_bya_breaks(win, C, Sigma, beta_MLE):
    D = Sigma.shape[0]
    A, b, d = np.eye(D), np.zeros(D), np.zeros(D)
    def bya(alpha):
        bya_val = affine_ops.b_bound(A, b, C, d, Sigma, beta_MLE, alpha)[0]
        assert np.isclose(bya_val, np.real(bya_val))
        return np.real(bya_val)
    c_upper_bound =  1-(1e-10)
    if win>bya(1-c_upper_bound): return 1-c_upper_bound
    if win<bya(c_upper_bound): return c_upper_bound
    
    
    where_breaks = optimize.bisect(
        f=lambda alpha: bya(alpha)-win,
        a=c_upper_bound, b=1-c_upper_bound, xtol=1e-3)
    return where_breaks

def wins_and_c_value_breaks_pts(X, Y, beta):
    """wins_and_c_value_breaks_pts, computes three estimates and compares them.
    
    The three estimates are the MLE, MAP, and our approximation
    """
    # Compute the estimates
    beta_MLE, beta_MAP = lr.MLE(X, Y), lr.MAP(X, Y)
    beta_tilde, C, H = lr.approx_post_mean(X, Y, beta_MLE, return_tranform_and_cov=True)
    Sigma = np.linalg.inv(H)
    
    # Compute wins 
    W_MAP_to_MLE = np.linalg.norm(beta_MLE-beta)**2 - np.linalg.norm(beta_MAP-beta)**2
    W_tilde_to_MLE = np.linalg.norm(beta_MLE-beta)**2 - np.linalg.norm(beta_tilde-beta)**2
    
    # Compute alphas where bounds break
    bya_break_MAP = where_bya_breaks(W_MAP_to_MLE, C, Sigma, beta_MLE)
    bya_break_tilde = where_bya_breaks(W_tilde_to_MLE, C, Sigma, beta_MLE)
    
    # Compute c-value
    A, b, d = np.eye(D), np.zeros(D), np.zeros(D)
    c_value = affine_ops.c_value(beta_MLE, A, b, C, d, Sigma)
    
    return W_MAP_to_MLE, W_tilde_to_MLE, bya_break_MAP, bya_break_tilde, c_value

In [3]:
def compare_wins_and_bya_breaks(N=100, D=10, sigma_x=0.25, sigma_beta=1, n_reps=10):
    
    # Simulate beta
    np.random.seed(42)
    beta = np.random.normal(size=D)*sigma_beta
    
    Ws_MAP_to_MLE, Ws_tilde_to_MLE, byas_break_MAP, byas_break_tilde = [], [], [], []
    c_values = []
    
    print("beta:", beta)
    for rep in range(n_reps):
        if (rep+1)%(n_reps/25)==0: print("rep %02d/%02d"%(rep+1, n_reps))
        X, Y = lr.gen_data(N, beta, sigma_x)
        X, Y = tf.convert_to_tensor(X), tf.convert_to_tensor(Y)
        
        W_MAP_to_MLE, W_tilde_to_MLE, bya_break_MAP, bya_break_tilde, c_value = wins_and_c_value_breaks_pts(
            X, Y, beta)
        Ws_MAP_to_MLE.append(W_MAP_to_MLE)
        Ws_tilde_to_MLE.append(W_tilde_to_MLE)
        byas_break_MAP.append(bya_break_MAP)
        byas_break_tilde.append(bya_break_tilde)
        c_values.append(c_value)
    return Ws_MAP_to_MLE, Ws_tilde_to_MLE, byas_break_MAP, byas_break_tilde, c_values

In [4]:
def plot_bya_breaks(byas_break, title):
    by_breaks=np.array(byas_break)
    alpha_grid = np.linspace(0, 1, 50)
    perc_break_by_alpha = np.array([np.mean(by_breaks>=alpha_break) for alpha_break in alpha_grid])
    coverage_by_alpha = 1-perc_break_by_alpha
    plt.figure(figsize=[6.5/3,6.5/4])
    plt.plot(alpha_grid, coverage_by_alpha, label="Actual")
    plt.plot([0,1],[0, 1], 'k--', label="Nominal")
    plt.xlabel(r"Confidence Level ($\alpha$)")
    plt.ylabel(r"Coverage ($\mathbb{P} \left[ W \ge b(y, \alpha)\right]$)")
    plt.xlim([0,1])
    plt.ylim([0,1])
    plt.legend(loc='lower right')
    plt.title(title)
    plt.show()

In [5]:
N, D, = 1000, 25
sigma_x, sigma_beta = 1./D, np.sqrt(1/2)
n_reps =  2500

Ws_MAP_to_MLE, Ws_tilde_to_MLE, byas_break_MAP, byas_break_tilde, c_values = compare_wins_and_bya_breaks(
    N=N, D=D, sigma_x=sigma_x, sigma_beta=sigma_beta, n_reps=n_reps)

beta: [ 0.35122995 -0.09776762  0.45798496  1.07694474 -0.16557144 -0.16555983
  1.11667209  0.5426583  -0.33196852  0.38364789 -0.32768579 -0.32932067
  0.17109316 -1.35289344 -1.2197011  -0.39759732 -0.71617975  0.22220642
 -0.64206998 -0.99864952  1.03637018 -0.15964795  0.04774965 -1.0074491
 -0.38493672]


  r = _zeros._bisect(f, a, b, xtol, rtol, maxiter, args, full_output, disp)


rep 100/2500
roots: [-56.89293349+20.30999201j -56.89293349-20.30999201j]
returning zero as the upper bound
roots: [-44.64776632+31.16655479j -44.64776632-31.16655479j]
returning zero as the upper bound
roots: [-30.12022389+35.29074055j -30.12022389-35.29074055j]
returning zero as the upper bound
roots: [-14.37965328+32.64914529j -14.37965328-32.64914529j]
returning zero as the upper bound
roots: [2.1105746+18.61333954j 2.1105746-18.61333954j]
returning zero as the upper bound
roots: [-56.89293349+20.30999201j -56.89293349-20.30999201j]
returning zero as the upper bound
roots: [-44.64776632+31.16655479j -44.64776632-31.16655479j]
returning zero as the upper bound
roots: [-30.12022389+35.29074055j -30.12022389-35.29074055j]
returning zero as the upper bound
roots: [-14.37965328+32.64914529j -14.37965328-32.64914529j]
returning zero as the upper bound
roots: [2.1105746+18.61333954j 2.1105746-18.61333954j]
returning zero as the upper bound
roots: [-56.89293364+20.30999179j -56.89293364-20

roots: [-26.92183213+5.44651179j -26.92183213-5.44651179j]
returning zero as the upper bound
roots: [-26.92183213+5.44651179j -26.92183213-5.44651179j]
returning zero as the upper bound
roots: [-26.92183227+5.4465118j -26.92183227-5.4465118j]
returning zero as the upper bound
roots: [-28.90236232+9.04234395j -28.90236232-9.04234395j]
returning zero as the upper bound
roots: [-28.90236232+9.04234395j -28.90236232-9.04234395j]
returning zero as the upper bound
roots: [-28.90236249+9.04234391j -28.90236249-9.04234391j]
returning zero as the upper bound
rep 600/2500
roots: [-26.20661271+1.78358384j -26.20661271-1.78358384j]
returning zero as the upper bound
roots: [-26.20661271+1.78358384j -26.20661271-1.78358384j]
returning zero as the upper bound
roots: [-26.20661286+1.78358409j -26.20661286-1.78358409j]
returning zero as the upper bound
roots: [-26.44831108+1.18304957j -26.44831108-1.18304957j]
returning zero as the upper bound
roots: [-26.44831108+1.18304957j -26.44831108-1.18304957j]


roots: [-26.20774859+1.49715978j -26.20774859-1.49715978j]
returning zero as the upper bound
roots: [-26.20774859+1.49715978j -26.20774859-1.49715978j]
returning zero as the upper bound
roots: [-26.20774875+1.49716013j -26.20774875-1.49716013j]
returning zero as the upper bound
rep 1200/2500
roots: [-26.58249436+5.41805992j -26.58249436-5.41805992j]
returning zero as the upper bound
roots: [-26.58249436+5.41805992j -26.58249436-5.41805992j]
returning zero as the upper bound
roots: [-26.58249451+5.41805994j -26.58249451-5.41805994j]
returning zero as the upper bound
roots: [-35.15128234+13.55235639j -35.15128234-13.55235639j]
returning zero as the upper bound
roots: [-20.87016485+14.94800674j -20.87016485-14.94800674j]
returning zero as the upper bound
roots: [-35.15128234+13.55235639j -35.15128234-13.55235639j]
returning zero as the upper bound
roots: [-20.87016485+14.94800674j -20.87016485-14.94800674j]
returning zero as the upper bound
roots: [-35.15128251+13.55235628j -35.15128251-1

roots: [-32.78900572+10.28375779j -32.78900572-10.28375779j]
returning zero as the upper bound
roots: [-20.29111328+8.34204155j -20.29111328-8.34204155j]
returning zero as the upper bound
roots: [-32.78900572+10.28375779j -32.78900572-10.28375779j]
returning zero as the upper bound
roots: [-20.29111328+8.34204155j -20.29111328-8.34204155j]
returning zero as the upper bound
roots: [-32.78900588+10.28375772j -32.78900588-10.28375772j]
returning zero as the upper bound
roots: [-20.29111387+8.3420421j -20.29111387-8.3420421j]
returning zero as the upper bound
roots: [-34.36173925+12.9845547j -34.36173925-12.9845547j]
returning zero as the upper bound
roots: [-21.02564434+14.81170287j -21.02564434-14.81170287j]
returning zero as the upper bound
roots: [-34.36173925+12.9845547j -34.36173925-12.9845547j]
returning zero as the upper bound
roots: [-21.02564434+14.81170287j -21.02564434-14.81170287j]
returning zero as the upper bound
roots: [-34.36173941+12.98455459j -34.36173941-12.98455459j]
r

roots: [-42.61497679+17.05107616j -42.61497679-17.05107616j]
returning zero as the upper bound
roots: [-28.04953438+22.8559411j -28.04953438-22.8559411j]
returning zero as the upper bound
roots: [-10.76924088+15.71379474j -10.76924088-15.71379474j]
returning zero as the upper bound
roots: [-42.61497679+17.05107616j -42.61497679-17.05107616j]
returning zero as the upper bound
roots: [-28.04953438+22.8559411j -28.04953438-22.8559411j]
returning zero as the upper bound
roots: [-10.76924088+15.71379474j -10.76924088-15.71379474j]
returning zero as the upper bound
roots: [-42.61497697+17.05107601j -42.61497697-17.05107601j]
returning zero as the upper bound
roots: [-28.04953508+22.85594108j -28.04953508-22.85594108j]
returning zero as the upper bound
roots: [-10.7692427+15.71379666j -10.7692427-15.71379666j]
returning zero as the upper bound
roots: [-54.38937608+20.88020641j -54.38937608-20.88020641j]
returning zero as the upper bound
roots: [-40.88098268+31.46580202j -40.88098268-31.465802

In [6]:
def save_as_tsv(fn, cols, data): 
    with open(fn, 'w') as f:
        f.write("\t".join(cols) + "\n")
        for row in data:
            f.write("\t".join("%f"%v for v in row) +"\n")
    
cvals_and_by_breaks = zip(c_values, byas_break_MAP)
save_as_tsv("../results/logistic_regression/logistic_regression_cvals_and_calibration.tsv",
            ["c_value", "by_break"], cvals_and_by_breaks)