In [None]:
import numpy as np
from scipy.stats import norm
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
sns.set_style("white")
sns.set(context='notebook',
        style='ticks',
        font_scale=1,
        rc={'axes.grid':True,
            'grid.color':'.9',
            'grid.linewidth':0.75})
from joblib import Parallel, delayed
import scipy.optimize as optimize
from scipy.optimize import minimize_scalar, brute
from tqdm import tqdm
from matplotlib.ticker import FormatStrFormatter

import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning) 

In [1]:
# create datasets
def generate_in_distribution_data(n, mu, sigma, pi_in):
    n_1 = round(n*pi_in)
    n_0 = round(n*(1-pi_in))
    mu_1 = mu
    mu_0 = -mu_1
    X_0 = np.random.normal(mu_0, sigma, n_0)
    X_1 = np.random.normal(mu_1, sigma, n_1)
    X = np.concatenate((X_0, X_1))
    Y = np.concatenate((np.zeros(n_0), np.ones(n_1)))
    return X, Y

def generate_out_distribution_data(n, mu, sigma, pi_out, delta):
    n_1 = round(n*pi_out)
    n_0 = round(n*(1-pi_out))
    mu_1 = mu + delta
    mu_0 = -mu + delta
    X_0 = np.random.normal(mu_0, sigma, n_0)
    X_1 = np.random.normal(mu_1, sigma, n_1)
    X = np.concatenate((X_0, X_1))
    Y = np.concatenate((np.zeros(n_0), np.ones(n_1)))
    return X, Y

def compute_singlehead_decision_rule(X_in, Y_in, X_out, Y_out):

    X = np.concatenate((X_in, X_out))
    Y = np.concatenate((Y_in, Y_out))

    X_0 = X[Y == 0]
    X_0_bar = np.nan_to_num(np.mean(X_0))

    X_1 = X[Y == 1]
    X_1_bar = np.nan_to_num(np.mean(X_1))

    # estimate threshold
    c = (X_1_bar + X_0_bar)/2
    
    return c

def compute_singlehead_weighted_decision_rule(X_in, Y_in, X_out, Y_out, alpha):

    sum_X_in_0 = np.nan_to_num(np.sum(X_in[Y_in==0]))
    sum_X_out_0 = np.nan_to_num(np.sum(X_out[Y_out==0]))
    sum_X_in_1 = np.nan_to_num(np.sum(X_in[Y_in==1]))
    sum_X_out_1 = np.nan_to_num(np.sum(X_out[Y_out==1]))
    n_0 = len(Y_in[Y_in==0])
    m_0 = len(Y_out[Y_out==0])
    n_1 = len(Y_in[Y_in==1])
    m_1 = len(Y_out[Y_out==0])

    X_0_bar = (alpha * sum_X_in_0 + (1-alpha) * sum_X_out_0)/(alpha * n_0 + (1-alpha) * m_0)
    X_1_bar = (alpha * sum_X_in_1 + (1-alpha) * sum_X_out_1)/(alpha * n_1 + (1-alpha) * m_1)

    # estimate threshold
    c = (X_1_bar + X_0_bar)/2
    
    return c

def compute_empirical_risk(X, Y, c):
    Y_pred = (X > c).astype('int')
    risk = 1 - np.mean(Y_pred == Y)
    return risk

def expected_risk(n, m, delta):
    def loss(alpha):
        if alpha > 1.0 or alpha < 0.0:
            return 1
        mu_h = ((1-alpha)*m/(alpha*n + (1-alpha)*m))*delta
        sigma_h = np.sqrt((alpha**2*n + (1-alpha)**2*m)/(alpha*n + (1-alpha)*m)**2)
        risk = 0.5*(1 + norm.cdf((mu_h - 1)/(np.sqrt(1 + sigma_h**2))) - norm.cdf((mu_h + 1)/(np.sqrt(1 + sigma_h**2))))
        return risk
    
    res = brute(loss, ranges=[(0, 1)], Ns=100, full_output=True, finish=optimize.fmin)
    return res[0][0], res[1]

def compute_risk_per_delta(
    delta,
    n_test,
    m_sizes,
    n = 4, 
    mu = 1, 
    sigma = 1, 
    pi_in = 0.5, 
    pi_out = 0.5,
    find_optimal_alpha=False
):
    X_test, y_test = generate_in_distribution_data(n_test, mu, sigma, pi_in)
    X_in, Y_in = generate_in_distribution_data(n, mu, sigma, pi_in)
    risk_per_m = []
    for m in m_sizes:
        X_out, Y_out = generate_out_distribution_data(m, mu, sigma, pi_out, delta)
        if find_optimal_alpha:
            if m == 0:
                opt_alpha = 1
            else:
                opt_alpha, _ = expected_risk(n, m, delta)
            c = compute_singlehead_weighted_decision_rule(X_in, Y_in, X_out, Y_out, opt_alpha)
            risk_per_m.append(compute_empirical_risk(X_test, y_test, c))
        else:
            c = compute_singlehead_decision_rule(X_in, Y_in, X_out, Y_out)
            risk_per_m.append(compute_empirical_risk(X_test, y_test, c))
    return risk_per_m