# Fig 1

In [None]:
import torch
import matplotlib.pyplot as plt
from scipy.stats import norm
import numpy as np
import matplotlib as mpl
import seaborn as sns

num_subnet_arr = [2, 4, 8, 16]
sigmas = torch.linspace(0,5,500)
ideal_coverages = norm.cdf(sigmas) - norm.cdf(-sigmas)

part_llr_pred = torch.load('YOUR_DIR_HERE/partition_llr_pred_arr.pt')
part_llr_true = torch.load('YOUR_DIR_HERE/partition_llr_true_arr.pt')
part_llr_unc = torch.load('YOUR_DIR_HERE/partition_llr_unc_arr.pt')

# trained up to 32 for bootstrap for bias checking purposes, so drop this training for fig 1
strap_llr_pred = torch.load('YOUR_DIR_HERE/bootstrap_llr_pred_arr.pt')[:-1]
strap_llr_true = torch.load('YOUR_DIR_HERE/bootstrap_llr_true_arr.pt')[:-1]
strap_llr_unc = torch.load('YOUR_DIR_HERE/bootstrap_llr_unc_arr.pt')[:-1]

num_trainings = part_llr_pred.shape[1]

part_z_scores = ((part_llr_pred - part_llr_true)/part_llr_unc)
part_coverages = (part_z_scores.abs().unsqueeze(-1) < sigmas).float().mean(dim=2)

strap_z_scores = ((strap_llr_pred - strap_llr_true)/strap_llr_unc)
strap_coverages = (strap_z_scores.abs().unsqueeze(-1) < sigmas).float().mean(dim=2)

color_pal = sns.color_palette('bright')
mpl.rc('font',family='Times New Roman')
mpl.rc('mathtext', fontset='cm')
num_subnet_arr = [2, 4, 8, 16]
coverages = torch.stack([part_coverages[:4], strap_coverages[:4]], dim=-1)
fig = plt.figure(figsize=(16, 12))
gs = fig.add_gridspec(5, 2, height_ratios=[2.5, 1, 0.5, 2.5, 1], hspace=0.04, wspace=0)

axs = np.array([
    [fig.add_subplot(gs[0, 0]), fig.add_subplot(gs[0, 1])],
    [fig.add_subplot(gs[1, 0]), fig.add_subplot(gs[1, 1])],
    [None, None],  # This row is an empty space
    [fig.add_subplot(gs[3, 0]), fig.add_subplot(gs[3, 1])],
    [fig.add_subplot(gs[4, 0]), fig.add_subplot(gs[4, 1])],
])
# fig, axs = plt.subplots(4,2,sharex=True,figsize = (16,12),gridspec_kw={'hspace': 0, 'wspace': 0, 'height_ratios': [2.5, 1, 2.5, 1]})
axs[0,0].tick_params(axis='both', which='major', direction='in', labelsize=16)
axs[0,1].tick_params(axis='both', which='major', direction='in', labelsize=0)
axs[1,0].tick_params(axis='both', which='major', direction='in', labelsize=16)
axs[1,0].tick_params(axis='x', which='major', direction='in', labelsize=0)
axs[1,1].tick_params(axis='both', which='major', direction='in', labelsize=0)
axs[3,0].tick_params(axis='both', which='major', direction='in', labelsize=16)
axs[3,1].tick_params(axis='both', which='major', direction='in', labelsize=0)
axs[4,0].tick_params(axis='both', which='major', direction='in', labelsize=16)
axs[4,1].tick_params(axis='y', which='major', direction='in', labelsize=0)
axs[4,1].tick_params(axis='x', which='major', direction='in', labelsize=16)

method_names = ['Partition', 'Bootstrap']
colors = [color_pal[0], color_pal[1]]
axs[0,0].set_title(f"$M=2$", fontsize=22)
axs[0,0].plot(sigmas, ideal_coverages, 'k--', label='Ideal')
axs[1,0].axhline(y=0, color='k', linestyle='--')
axs[1,0].set_ylabel(r"$c(z) - \bar{c}(z)$", fontsize=18)
axs[1,0].set_ylim(-0.2, 0.2)
axs[0,0].set_ylabel(r'$c(z)$',fontsize=18)

axs[0,1].set_title(f"$M=4$", fontsize=22)
axs[0,1].plot(sigmas, ideal_coverages, 'k--', label='Ideal')
axs[1,1].axhline(y=0, color='k', linestyle='--')
axs[1,1].set_ylim(-0.2, 0.2)

axs[3,0].set_title(f"$M=8$", fontsize=22)
axs[3,0].plot(sigmas, ideal_coverages, 'k--', label='Ideal')
axs[4,0].axhline(y=0, color='k', linestyle='--')
axs[4,0].set_xlabel(f'$z$', fontsize=18)
axs[4,0].set_ylabel(r"$c(z) - \bar{c}(z)$", fontsize=18)
axs[4,0].set_ylim(-0.1, 0.1)
axs[3,0].set_ylabel(r'$c(z)$',fontsize=18)

axs[3,1].set_title(f"$M=16$", fontsize=22)
axs[3,1].plot(sigmas, ideal_coverages, 'k--', label='Ideal')
axs[4,1].axhline(y=0, color='k', linestyle='--')
axs[4,1].set_xlabel(f'$z$', fontsize=18)
axs[4,1].set_ylim(-0.1, 0.1)

for method_idx_switch in range(len(method_names)):
    for subnet_idx, num_subnets in enumerate(num_subnet_arr):
        method_idx = int(1-method_idx_switch)
        if subnet_idx == 0:
            axs[0,0].plot(sigmas, coverages[subnet_idx,:,:,method_idx].mean(axis=0), label=f'{method_names[method_idx]}', color=colors[method_idx])
            axs[0,0].fill_between(sigmas, coverages[subnet_idx,:,:,method_idx].mean(axis=0) - coverages[subnet_idx,:,:,method_idx].std(axis=0)/num_trainings**0.5, coverages[subnet_idx,:,:,method_idx].mean(axis=0) + coverages[subnet_idx,:,:,method_idx].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[method_idx])
            axs[1,0].plot(sigmas, (coverages[subnet_idx,:,:,method_idx].mean(axis=0) - ideal_coverages), label=f'{method_names[method_idx]}', color=colors[method_idx])
            axs[1,0].fill_between(sigmas, (coverages[subnet_idx,:,:,method_idx].mean(axis=0) - ideal_coverages) - coverages[subnet_idx,:,:,method_idx].std(axis=0)/num_trainings**0.5, (coverages[subnet_idx,:,:,method_idx].mean(axis=0) - ideal_coverages) + coverages[subnet_idx,:,:,method_idx].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[method_idx])
        elif subnet_idx == 1:
            axs[0,1].plot(sigmas, coverages[subnet_idx,:,:,method_idx].mean(axis=0), label=f'{method_names[method_idx]}', color=colors[method_idx])
            axs[0,1].fill_between(sigmas, coverages[subnet_idx,:,:,method_idx].mean(axis=0) - coverages[subnet_idx,:,:,method_idx].std(axis=0)/num_trainings**0.5, coverages[subnet_idx,:,:,method_idx].mean(axis=0) + coverages[subnet_idx,:,:,method_idx].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[method_idx])
            axs[1,1].plot(sigmas, (coverages[subnet_idx,:,:,method_idx].mean(axis=0) - ideal_coverages), label=f'{method_names[method_idx]}', color=colors[method_idx])
            axs[1,1].fill_between(sigmas, (coverages[subnet_idx,:,:,method_idx].mean(axis=0) - ideal_coverages) - coverages[subnet_idx,:,:,method_idx].std(axis=0)/num_trainings**0.5, (coverages[subnet_idx,:,:,method_idx].mean(axis=0) - ideal_coverages) + coverages[subnet_idx,:,:,method_idx].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[method_idx])
        elif subnet_idx == 2:
            axs[3,0].plot(sigmas, coverages[subnet_idx,:,:,method_idx].mean(axis=0), label=f'{method_names[method_idx]}', color=colors[method_idx])
            axs[3,0].fill_between(sigmas, coverages[subnet_idx,:,:,method_idx].mean(axis=0) - coverages[subnet_idx,:,:,method_idx].std(axis=0)/num_trainings**0.5, coverages[subnet_idx,:,:,method_idx].mean(axis=0) + coverages[subnet_idx,:,:,method_idx].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[method_idx])
            axs[4,0].plot(sigmas, (coverages[subnet_idx,:,:,method_idx].mean(axis=0) - ideal_coverages), label=f'{method_names[method_idx]}', color=colors[method_idx])
            axs[4,0].fill_between(sigmas, (coverages[subnet_idx,:,:,method_idx].mean(axis=0) - ideal_coverages) - coverages[subnet_idx,:,:,method_idx].std(axis=0)/num_trainings**0.5, (coverages[subnet_idx,:,:,method_idx].mean(axis=0) - ideal_coverages) + coverages[subnet_idx,:,:,method_idx].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[method_idx])
        else:
            axs[3,1].plot(sigmas, coverages[subnet_idx,:,:,method_idx].mean(axis=0), label=f'{method_names[method_idx]}', color=colors[method_idx])
            axs[3,1].fill_between(sigmas, coverages[subnet_idx,:,:,method_idx].mean(axis=0) - coverages[subnet_idx,:,:,method_idx].std(axis=0)/num_trainings**0.5, coverages[subnet_idx,:,:,method_idx].mean(axis=0) + coverages[subnet_idx,:,:,method_idx].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[method_idx])
            axs[4,1].plot(sigmas, (coverages[subnet_idx,:,:,method_idx].mean(axis=0) - ideal_coverages), label=f'{method_names[method_idx]}', color=colors[method_idx])
            axs[4,1].fill_between(sigmas, (coverages[subnet_idx,:,:,method_idx].mean(axis=0) - ideal_coverages) - coverages[subnet_idx,:,:,method_idx].std(axis=0)/num_trainings**0.5, (coverages[subnet_idx,:,:,method_idx].mean(axis=0) - ideal_coverages) + coverages[subnet_idx,:,:,method_idx].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[method_idx])

axs[0,0].legend(fontsize=16, frameon=False)
axs[0,1].legend(fontsize=16, frameon=False)
axs[3,0].legend(fontsize=16, frameon=False)
axs[3,1].legend(fontsize=16, frameon=False)
plt.suptitle("Gaussian Case Study, Coverage on Log Likelihood Ratio", fontsize=24)
plt.savefig("fig1.pdf", bbox_inches='tight')

# Fig 2

Note: This one takes awhile (~20 minutes on my machine) since it trains the networks from scratch.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchmin import minimize
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
num_samples = 50000
train_frac = 0.5
num_mixture_samples = int(train_frac * num_samples)
num_train_samples = int(train_frac * num_samples)
num_val_samples = num_samples - num_train_samples

train_batch_size = 5000
val_batch_size = 1*train_batch_size
num_subnets_arr = [2]

num_runs = 1
num_trainings = 1
units = 32
latent_dim = 32

llr_true_arr = torch.zeros((len(num_subnets_arr), num_trainings, num_runs))
llr_pred_arr = torch.zeros((len(num_subnets_arr), num_trainings, num_runs))
llr_unc_arr = torch.zeros((len(num_subnets_arr), num_trainings, num_runs))

true_arr = torch.zeros((len(num_subnets_arr), num_trainings, num_runs, 3))
pred_arr = torch.zeros((len(num_subnets_arr), num_trainings, num_runs, 3))
unc_arr = torch.zeros((len(num_subnets_arr), num_trainings, num_runs, 3))
    
def model_with_weights_scaled(submodels, dataloader, w):
    dataset_length = len(dataloader.dataset)
    h_outs = torch.zeros(dataset_length, device=device)
    end_idx = 0
    for data in dataloader:
        start_idx = end_idx
        end_idx += len(data[0])
        h_outs[start_idx:end_idx] = submodels.submodel_all(data[0].to(device))@w
    return h_outs

def mlc_min(w, submodels, qloader, ploader):
    q_out = model_with_weights_scaled(submodels, qloader, w)
    p_out = model_with_weights_scaled(submodels, ploader, w)
    mlc1 = -(q_out.mean() - (torch.exp(p_out) - 1).mean())
    mlc2 = -(-p_out.mean() - (torch.exp(-q_out) - 1).mean())
    return mlc1 + mlc2

def MLC(y_true, model_outputs, net_idx):
    w00 = torch.zeros(model_outputs.shape[-1], device=device)
    w00[net_idx] = 1.
    y_pred = (model_outputs@w00).unsqueeze(1)
    cont1 = -(y_true.unsqueeze(1) * y_pred)
    cont2 = -(1-y_true.unsqueeze(1)) * (1 - torch.exp(y_pred))
    cont3 = -(1-y_true.unsqueeze(1))*(-y_pred)
    cont4 = -(y_true.unsqueeze(1)) * (1 - torch.exp(-y_pred))
    return cont1+cont2+cont3+cont4
    
class Model(nn.Module):
    def __init__(self, num_subnets):
        super(Model, self).__init__()
        self.num_subnets = num_subnets

        self.layer1 = nn.Linear(1, units)
        self.layer2 = nn.Linear(units, units)
        self.layer3 = nn.Linear(units, 1)

        self.layer1_list = nn.ModuleList([nn.Linear(1, units) for i in range(self.num_subnets)])
        self.layer2_list = nn.ModuleList([nn.Linear(units, units) for i in range(self.num_subnets)])
        self.layer3_list = nn.ModuleList([nn.Linear(units, 1) for i in range(self.num_subnets)])

    def forward(self, x):
        x = F.leaky_relu(self.layer1(x), negative_slope = 0.2)
        x = F.leaky_relu(self.layer2(x), negative_slope = 0.2)
        x = self.layer3(x)
        return x
    
    def submodel_all(self, x):
        x1 = [F.leaky_relu(self.layer1_list[i](x), negative_slope = 0.2) for i in range(self.num_subnets)]
        x1 = [F.leaky_relu(self.layer2_list[i](x1[i]), negative_slope = 0.2) for i in range(self.num_subnets)]
        x1 = [self.layer3_list[i](x1[i]) for i in range(self.num_subnets)]
        x1 = torch.cat(x1, axis=1)
        x1 = torch.cat([x1, torch.ones(x1.shape[0], device=device).unsqueeze(1)], axis=1)
        return x1

def train_func(model_to_train, epochs, num_subnets):
    train_losses = []
    val_losses = []
    opt = optim.Adam(model_to_train.parameters(), lr=1e-2)
    for start_idx in range(num_subnets):
        train_strap = np.random.randint(0, x_train.shape[0], x_train.shape[0])
        x_train_strap = x_train[train_strap]
        y_train_strap = y_train[train_strap]
        val_strap = np.random.randint(0, x_val.shape[0], x_val.shape[0])
        x_val_strap = x_val[val_strap]
        y_val_strap = y_val[val_strap]

        trainset_strap = torch.utils.data.TensorDataset(x_train_strap, y_train_strap)
        valset_strap = torch.utils.data.TensorDataset(x_val_strap, y_val_strap)
        trainloader_strap = torch.utils.data.DataLoader(trainset_strap, batch_size=train_batch_size, shuffle=True)
        valloader_strap = torch.utils.data.DataLoader(valset_strap, batch_size=val_batch_size, shuffle=False)
        for param in model_to_train.parameters():
            param.requires_grad = False
        for param in model_to_train.layer1_list[start_idx].parameters():
            param.requires_grad = True
        for param in model_to_train.layer2_list[start_idx].parameters():
            param.requires_grad = True
        for param in model_to_train.layer3_list[start_idx].parameters():
            param.requires_grad = True
        print(f"Training basis function {start_idx}", flush=True)
        min_val_loss = 1e10
        for epoch in range(epochs):
            running_loss = 0.0
            val_loss = 0.0
            batches = 0
            for i, data in enumerate(trainloader_strap):
                batches += 1
                opt.zero_grad()
                inputs = data[0].to(device)
                train_outputs = model_to_train.submodel_all(inputs)
                loss = MLC(data[1].to(device), train_outputs, start_idx).mean()
                loss.backward()
                opt.step()
                running_loss += loss.item()
            val_batches = 0
            with torch.no_grad():
                for i, data in enumerate(valloader_strap):
                    val_batches += 1
                    inputs = data[0].to(device)
                    val_outputs = model_to_train.submodel_all(inputs)
                    val_loss += MLC(data[1].to(device), val_outputs, start_idx).mean().item()
            train_losses.append(running_loss/batches)
            val_losses.append(val_loss/val_batches)
            if val_loss/val_batches < min_val_loss:
                min_val_loss = val_loss/val_batches
                best_model = model_to_train.state_dict()
            print(f"Epoch {epoch+1} train loss: {running_loss/batches}, val loss: {val_loss/val_batches}", flush=True)
        model_to_train.load_state_dict(best_model)
    return train_losses, val_losses

def neg_maximum_likelihood_f(f, model_outputs):
    #need to fix this implementation to allow for difference in sizes of data and prior
    return -torch.log(torch.exp(model_outputs)*f + (1-f)).sum()

def neg_maximum_likelihood_f_wrapper(*args):
    return neg_maximum_likelihood_f(*args).detach().cpu().numpy()

def calc_ai(f, model_outputs_all, w):
    model_outputs = model_outputs_all@w
    ai = (((torch.exp(model_outputs)/(torch.exp(model_outputs)*f + (1-f))**2)).unsqueeze(1)*model_outputs_all).sum(axis=0)
    return ai.detach()

def calc_second_deriv(f, model_outputs_all, w):
    model_outputs = model_outputs_all@w
    second_deriv = (-(torch.exp(model_outputs)-1)**2/(torch.exp(model_outputs)*f + (1-f))**2).sum()
    return second_deriv.detach()

def uncertainties(f, model_outputs_all, w, cov):
    ai = calc_ai(f, model_outputs_all, w)
    second_deriv = calc_second_deriv(f, model_outputs_all, w)
    r = torch.abs(1/second_deriv) / (1/second_deriv**2 * ai@cov@ai)
    return torch.abs(1/second_deriv) + 1/second_deriv**2 * (ai.double()@cov.double()@ai.double()).float(), r

for subnet_idx, num_subnets in enumerate(num_subnets_arr):
    print("Number of subnets is", num_subnets, flush=True)
    for training in range(num_trainings):
        if training%100 == 0:
            print("Starting training run", training, flush=True)
        qs = torch.randn(num_samples) + .1
        ps = torch.randn(num_samples) - .1
        qs_train = qs[0:num_train_samples]
        qs_val = qs[num_train_samples:]
        ps_train = ps[0:num_train_samples]
        ps_val = ps[num_train_samples:]

        data_train = torch.concatenate((qs_train,ps_train)).detach()
        data_val = torch.concatenate((qs_val,ps_val)).detach()
        train_perm_key = torch.randperm(2*num_train_samples).detach()
        val_perm_key = torch.randperm(2*num_val_samples).detach()
        x_train = data_train[train_perm_key].unsqueeze(1).detach()
        x_val = data_val[val_perm_key].unsqueeze(1).detach()
        y_train = torch.concatenate((torch.ones(num_train_samples), torch.zeros(num_train_samples)))[train_perm_key].detach()
        y_val = torch.concatenate((torch.ones(num_val_samples), torch.zeros(num_val_samples)))[val_perm_key].detach()

        qset = torch.utils.data.TensorDataset(x_train[y_train==1])
        qloader = torch.utils.data.DataLoader(qset, batch_size = train_batch_size, shuffle=False)
        qset_val = torch.utils.data.TensorDataset(x_val[y_val==1])
        qloader_val = torch.utils.data.DataLoader(qset_val, batch_size = val_batch_size, shuffle=False)
            
        pset = torch.utils.data.TensorDataset(x_train[y_train==0])
        ploader = torch.utils.data.DataLoader(pset, batch_size = train_batch_size, shuffle=False)
        pset_val = torch.utils.data.TensorDataset(x_val[y_val==0])
        ploader_val = torch.utils.data.DataLoader(pset_val, batch_size = val_batch_size, shuffle=False)

        model = Model(num_subnets)
        model.to(device)
        train_losses, val_losses = train_func(model, 150, num_subnets)
        for run in range(num_runs):
            if run%10 == 0:
                print("Starting run", run, flush=True)
            qs = torch.randn(num_samples) + .1
            ps = torch.randn(num_samples) - .1
            qs_train = qs[0:num_train_samples]
            qs_val = qs[num_train_samples:]
            ps_train = ps[0:num_train_samples]
            ps_val = ps[num_train_samples:]

            data_train = torch.concatenate((qs_train,ps_train)).detach()
            data_val = torch.concatenate((qs_val,ps_val)).detach()
            train_perm_key = torch.randperm(2*num_train_samples).detach()
            val_perm_key = torch.randperm(2*num_val_samples).detach()
            x_train = data_train[train_perm_key].unsqueeze(1).detach()
            x_val = data_val[val_perm_key].unsqueeze(1).detach()
            y_train = torch.concatenate((torch.ones(num_train_samples), torch.zeros(num_train_samples)))[train_perm_key].detach()
            y_val = torch.concatenate((torch.ones(num_val_samples), torch.zeros(num_val_samples)))[val_perm_key].detach()

            qset = torch.utils.data.TensorDataset(x_train[y_train==1])
            qloader = torch.utils.data.DataLoader(qset, batch_size = train_batch_size, shuffle=False)
            qset_val = torch.utils.data.TensorDataset(x_val[y_val==1])
            qloader_val = torch.utils.data.DataLoader(qset_val, batch_size = val_batch_size, shuffle=False)
                
            pset = torch.utils.data.TensorDataset(x_train[y_train==0])
            ploader = torch.utils.data.DataLoader(pset, batch_size = train_batch_size, shuffle=False)
            pset_val = torch.utils.data.TensorDataset(x_val[y_val==0])
            ploader_val = torch.utils.data.DataLoader(pset_val, batch_size = val_batch_size, shuffle=False)
            w00 = torch.zeros(num_subnets+1, device=device)
            w00[0] = 1.
            res_root = minimize(lambda w: mlc_min(w, model, qloader, ploader), x0 = w00, method='newton-exact')
            w0 = (res_root.x)
            val_mlc = mlc_min(w0, model, qloader_val, ploader_val).detach()
            model_q_points = torch.zeros((num_subnets+1, qloader.dataset.tensors[0].shape[0]))
            model_p_points = torch.zeros((num_subnets+1, ploader.dataset.tensors[0].shape[0]))
            q_wgt_points = model_with_weights_scaled(model, qloader, w0).detach().cpu()
            p_wgt_points = model_with_weights_scaled(model, ploader, w0).detach().cpu()

            end_idx = 0
            for data in ploader:
                model_prod_points_batch = model.submodel_all(data[0].to(device)).detach()
                end_idx += len(model_prod_points_batch)
                start_idx = end_idx - len(model_prod_points_batch)
                model_p_points[:, start_idx:end_idx] = model_prod_points_batch.cpu().T
                
            end_idx = 0
            for data in qloader:
                model_joint_points_batch = model.submodel_all(data[0].to(device)).detach()
                end_idx += len(model_joint_points_batch)
                start_idx = end_idx - len(model_joint_points_batch)
                model_q_points[:, start_idx:end_idx] = model_joint_points_batch.cpu().T

            dijq = torch.zeros((num_subnets+1, num_subnets+1))
            dijp = torch.zeros((num_subnets+1, num_subnets+1))
            cijq = torch.zeros((num_subnets+1, num_subnets+1))
            cijp = torch.zeros((num_subnets+1, num_subnets+1))

            for i in range(num_subnets+1):
                for j in range(num_subnets+1):
                    hi_q = model_q_points[i,:]*(1 + torch.exp(-q_wgt_points))
                    hj_q = model_q_points[j,:]*(1 + torch.exp(-q_wgt_points))
                    dijq[i,j] = (hi_q*hj_q).mean() - (hi_q).mean()*(hj_q).mean()

                    hi_p = model_p_points[i,:]*(1 + torch.exp(p_wgt_points))
                    hj_p = model_p_points[j,:]*(1 + torch.exp(p_wgt_points))
                    dijp[i,j] = (hi_p*hj_p).mean() - (hi_p).mean()*(hj_p).mean()
                    cijp[i,j] = -(model_p_points[i,:]*model_p_points[j,:]*torch.exp(p_wgt_points)).mean()
                    cijq[i,j] = -(model_q_points[i,:]*model_q_points[j,:]*torch.exp(-q_wgt_points)).mean()
            cij = cijq + cijp
            dij = dijq + dijp
            cov_mat = torch.linalg.solve(cij.double(),torch.linalg.solve(cij.double(), dij.double()).T).float()/end_idx
            cov_mat = cov_mat.to(device)
            print("Cov mat is", cov_mat, flush=True)
            print("Weights are", w0, flush=True)

mpl.rc('font',family='Times New Roman')
mpl.rc('mathtext', fontset='cm')
color_pal = sns.color_palette('bright')
test_pts = torch.linspace(-3,3,1000).unsqueeze(1).to(device)
fi = model.submodel_all(test_pts)
prediction = fi@w0
var = fi@cov_mat@fi.T

fig, axs = plt.subplots(2,1, figsize=(8,6), sharex=True, gridspec_kw={'hspace': 0.04, 'height_ratios': [3, 1]})
axs[0].tick_params(direction='in')
axs[1].tick_params(direction='in')
axs[0].tick_params(labelsize=16)
axs[1].tick_params(labelsize=16)
axs[1].set_xlabel(r"$x$", fontsize=16)
axs[1].set_ylabel(r"$\log \hat{r}(x) - \log r(x)$", fontsize=16)
axs[0].set_ylabel(r"$\log \hat{r}(x)$", fontsize=16)
axs[0].set_xlim(-3,3)
axs[0].set_ylim(-0.7, 0.7)
axs[1].set_ylim(-0.2, 0.2)
axs[0].plot(test_pts[:,0].detach(), prediction.detach(), color=color_pal[0], label=f"Predicted $\log r(x)$")
axs[0].fill_between(test_pts[:,0].detach(), (prediction - var.diag()**0.5).detach(), (prediction + var.diag()**0.5).detach(), alpha=0.5, color=color_pal[0])
axs[1].fill_between(test_pts[:,0].detach(), (prediction - 0.2*test_pts[:,0] - var.diag()**0.5).detach(), (prediction - 0.2*test_pts[:,0] + var.diag()**0.5).detach(), alpha=0.5, color=color_pal[0])
axs[0].plot(test_pts[:,0].detach(), 0.2*test_pts[:,0].detach(), 'k--', label=f"True $\log r(x)$")
axs[1].plot(test_pts[:,0].detach(), (prediction - 0.2*test_pts[:,0]).detach())
axs[1].axhline(0, color='k', linestyle = '--')
axs[0].legend(fontsize=16, frameon=False)
axs[0].set_title(f"Estimated Density Ratio, $M=16$, Bootstrap", fontsize=20)
plt.savefig("fig2.pdf", bbox_inches='tight')

# Fig 3

In [None]:
import matplotlib as mpl
import numpy as np
import seaborn as sns

# Using "raw_pred" because we do not include the bias correction in the main body
part_pred_16 = torch.load('YOUR_DIR_HERE/partition_raw_pred_arr.pt')[-1]
part_true_16 = torch.load('YOUR_DIR_HERE/partition_true_arr.pt')[-1]
part_unc_16 = torch.load('YOUR_DIR_HERE/partition_unc_arr.pt')[-1]

# Trained up to M=32 for bootstrap, so need the second to last element for 16
strap_pred_16 = torch.load('YOUR_DIR_HERE/bootstrap_raw_pred_arr.pt')[-2]
strap_true_16 = torch.load('YOUR_DIR_HERE/bootstrap_true_arr.pt')[-2]
strap_unc_16 = torch.load('YOUR_DIR_HERE/bootstrap_unc_arr.pt')[-2]

baseline_ensemble_pred_16 = torch.load('YOUR_DIR_HERE/baseline_ensemble_pred_arr.pt')[-1]
baseline_ensemble_true_16 = torch.load('YOUR_DIR_HERE/baseline_ensemble_true_arr.pt')[-1]
baseline_ensemble_unc_16 = torch.load('YOUR_DIR_HERE/baseline_ensemble_unc_arr.pt')[-1]

part_coverages = ((((part_pred_16 - part_true_16).abs()/part_unc_16)).unsqueeze(-1) < sigmas).float().mean(dim=2)
strap_coverages = ((((strap_pred_16 - strap_true_16).abs()/strap_unc_16)).unsqueeze(-1) < sigmas).float().mean(dim=2)
baseline_coverages = ((((baseline_ensemble_pred_16 - baseline_ensemble_true_16).abs()/baseline_ensemble_unc_16)).unsqueeze(-1) < sigmas).float().mean(dim=2)

num_trainings = strap_coverages.shape[1]
color_pal = sns.color_palette('bright')
mpl.rc('font',family='Times New Roman')
mpl.rc('mathtext', fontset='cm')
num_subnet_arr = [16]
f_arr = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5]
coverages = torch.stack([part_coverages[:4], strap_coverages[:4], baseline_coverages[:4]], dim=-1)[-1,:,:]
fig = plt.figure(figsize=(16, 12))
gs = fig.add_gridspec(5, 3, height_ratios=[2.5, 1, 0.5, 2.5, 1], hspace=0.04, wspace=0)

axs = np.array([
    [fig.add_subplot(gs[0, 0]), fig.add_subplot(gs[0, 1]), fig.add_subplot(gs[0, 2])],
    [fig.add_subplot(gs[1, 0]), fig.add_subplot(gs[1, 1]), fig.add_subplot(gs[1, 2])],
    [None, None, None],  # This row is an empty space
    [fig.add_subplot(gs[3, 0]), fig.add_subplot(gs[3, 1]), fig.add_subplot(gs[3, 2])],
    [fig.add_subplot(gs[4, 0]), fig.add_subplot(gs[4, 1]), fig.add_subplot(gs[4, 2])],
])
# fig, axs = plt.subplots(4,2,sharex=True,figsize = (16,12),gridspec_kw={'hspace': 0, 'wspace': 0, 'height_ratios': [2.5, 1, 2.5, 1]})
axs[0,0].tick_params(axis='both', which='major', direction='in', labelsize=16)
axs[0,1].tick_params(axis='both', which='major', direction='in', labelsize=0)
axs[0,2].tick_params(axis='both', which='major', direction='in', labelsize=0)
axs[1,0].tick_params(axis='both', which='major', direction='in', labelsize=16)
axs[1,0].tick_params(axis='x', which='major', direction='in', labelsize=0)
axs[1,1].tick_params(axis='both', which='major', direction='in', labelsize=0)
axs[1,2].tick_params(axis='both', which='major', direction='in', labelsize=0)
axs[3,0].tick_params(axis='both', which='major', direction='in', labelsize=16)
axs[3,1].tick_params(axis='both', which='major', direction='in', labelsize=0)
axs[3,2].tick_params(axis='both', which='major', direction='in', labelsize=0)
axs[4,0].tick_params(axis='both', which='major', direction='in', labelsize=16)
axs[4,1].tick_params(axis='y', which='major', direction='in', labelsize=0)
axs[4,1].tick_params(axis='x', which='major', direction='in', labelsize=16)
axs[4,2].tick_params(axis='y', which='major', direction='in', labelsize=0)
axs[4,2].tick_params(axis='x', which='major', direction='in', labelsize=16)

method_names = ['Partition', 'Bootstrap', 'Naive Ensemble']
# colors = ['C0', 'C1', '#555555']
colors = [color_pal[0], color_pal[1], '#555555']
linestyles = ['-', '-', '-.']

axs[0,0].set_title(r"$\kappa=0.01$", fontsize=22)
axs[0,0].plot(sigmas, ideal_coverages, 'k--', label='Ideal')
axs[1,0].axhline(y=0, color='k', linestyle='--')
axs[1,0].set_ylabel(r"$c(z) - \bar{c}(z)$", fontsize=18)
axs[1,0].set_ylim(-0.05, 0.05)
axs[0,0].set_ylabel(r'$c(z)$',fontsize=18)

axs[0,1].set_title(r"$\kappa=0.02$", fontsize=22)
axs[0,1].plot(sigmas, ideal_coverages, 'k--', label='Ideal')
axs[1,1].axhline(y=0, color='k', linestyle='--')
axs[1,1].set_ylim(-0.05, 0.05)

axs[0,2].set_title(r"$\kappa=0.05$", fontsize=22)
axs[0,2].plot(sigmas, ideal_coverages, 'k--', label='Ideal')
axs[1,2].axhline(y=0, color='k', linestyle='--')
axs[1,2].set_ylim(-0.05, 0.05)

axs[3,0].set_title(r"$\kappa=0.1$", fontsize=22)
axs[3,0].plot(sigmas, ideal_coverages, 'k--', label='Ideal')
axs[4,0].axhline(y=0, color='k', linestyle='--')
axs[4,0].set_xlabel(f'$z$', fontsize=18)
axs[4,0].set_ylabel(r"$c(z) - \bar{c}(z)$", fontsize=18)
axs[4,0].set_ylim(-0.05, 0.05)
axs[3,0].set_ylabel(r'$c(z)$',fontsize=18)

axs[3,1].set_title(r"$\kappa=0.2$", fontsize=22)
axs[3,1].plot(sigmas, ideal_coverages, 'k--', label='Ideal')
axs[4,1].axhline(y=0, color='k', linestyle='--')
axs[4,1].set_xlabel(f'$z$', fontsize=18)
axs[4,1].set_ylim(-0.05, 0.05)

axs[3,2].set_title(r"$\kappa=0.5$", fontsize=22)
axs[3,2].plot(sigmas, ideal_coverages, 'k--', label='Ideal')
axs[4,2].axhline(y=0, color='k', linestyle='--')
axs[4,2].set_xlabel(f'$z$', fontsize=18)
axs[4,2].set_ylim(-0.05, 0.05)

for method_idx_swap in range(3):
    if method_idx_swap == 0:
        method_idx = 1
    elif method_idx_swap == 1:
        method_idx = 0
    elif method_idx_swap == 2:
        method_idx = 2
    for f_idx, f in enumerate(f_arr):
        if f_idx == 0:
            axs[0,0].plot(sigmas, coverages[:,f_idx,:,method_idx].mean(axis=0), label=f'{method_names[method_idx]}', color=colors[method_idx], linestyle=linestyles[method_idx])
            if method_idx != 2:
                axs[0,0].fill_between(sigmas, coverages[:,f_idx,:,method_idx].mean(axis=0) - coverages[:,f_idx,:,method_idx].std(axis=0)/num_trainings**0.5, coverages[:,f_idx,:,method_idx].mean(axis=0) + coverages[:,f_idx,:,method_idx].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[method_idx], linestyle=linestyles[method_idx])
            axs[1,0].plot(sigmas, (coverages[:,f_idx,:,method_idx].mean(axis=0) - ideal_coverages), label=f'{method_names[method_idx]}', color=colors[method_idx], linestyle=linestyles[method_idx])
            if method_idx != 2:
                axs[1,0].fill_between(sigmas, (coverages[:,f_idx,:,method_idx].mean(axis=0) - ideal_coverages) - coverages[:,f_idx,:,method_idx].std(axis=0)/num_trainings**0.5, (coverages[:,f_idx,:,method_idx].mean(axis=0) - ideal_coverages) + coverages[:,f_idx,:,method_idx].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[method_idx], linestyle=linestyles[method_idx])
        elif f_idx == 1:
            axs[0,1].plot(sigmas, coverages[:,f_idx,:,method_idx].mean(axis=0), label=f'{method_names[method_idx]}', color=colors[method_idx], linestyle=linestyles[method_idx])
            if method_idx != 2:
                axs[0,1].fill_between(sigmas, coverages[:,f_idx,:,method_idx].mean(axis=0) - coverages[:,f_idx,:,method_idx].std(axis=0)/num_trainings**0.5, coverages[:,f_idx,:,method_idx].mean(axis=0) + coverages[:,f_idx,:,method_idx].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[method_idx], linestyle=linestyles[method_idx])
            axs[1,1].plot(sigmas, (coverages[:,f_idx,:,method_idx].mean(axis=0) - ideal_coverages), label=f'{method_names[method_idx]}', color=colors[method_idx], linestyle=linestyles[method_idx])
            if method_idx != 2:    
                axs[1,1].fill_between(sigmas, (coverages[:,f_idx,:,method_idx].mean(axis=0) - ideal_coverages) - coverages[:,f_idx,:,method_idx].std(axis=0)/num_trainings**0.5, (coverages[:,f_idx,:,method_idx].mean(axis=0) - ideal_coverages) + coverages[:,f_idx,:,method_idx].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[method_idx], linestyle=linestyles[method_idx])
        elif f_idx == 2:
            axs[0,2].plot(sigmas, coverages[:,f_idx,:,method_idx].mean(axis=0), label=f'{method_names[method_idx]}', color=colors[method_idx], linestyle=linestyles[method_idx])
            if method_idx != 2:
                axs[0,2].fill_between(sigmas, coverages[:,f_idx,:,method_idx].mean(axis=0) - coverages[:,f_idx,:,method_idx].std(axis=0)/num_trainings**0.5, coverages[:,f_idx,:,method_idx].mean(axis=0) + coverages[:,f_idx,:,method_idx].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[method_idx], linestyle=linestyles[method_idx])
            axs[1,2].plot(sigmas, (coverages[:,f_idx,:,method_idx].mean(axis=0) - ideal_coverages), label=f'{method_names[method_idx]}', color=colors[method_idx], linestyle=linestyles[method_idx])
            if method_idx != 2:
                axs[1,2].fill_between(sigmas, (coverages[:,f_idx,:,method_idx].mean(axis=0) - ideal_coverages) - coverages[:,f_idx,:,method_idx].std(axis=0)/num_trainings**0.5, (coverages[:,f_idx,:,method_idx].mean(axis=0) - ideal_coverages) + coverages[:,f_idx,:,method_idx].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[method_idx], linestyle=linestyles[method_idx])
        elif f_idx == 3:
            axs[3,0].plot(sigmas, coverages[:,f_idx,:,method_idx].mean(axis=0), label=f'{method_names[method_idx]}', color=colors[method_idx], linestyle=linestyles[method_idx])
            if method_idx != 2:
                axs[3,0].fill_between(sigmas, coverages[:,f_idx,:,method_idx].mean(axis=0) - coverages[:,f_idx,:,method_idx].std(axis=0)/num_trainings**0.5, coverages[:,f_idx,:,method_idx].mean(axis=0) + coverages[:,f_idx,:,method_idx].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[method_idx], linestyle=linestyles[method_idx])
            axs[4,0].plot(sigmas, (coverages[:,f_idx,:,method_idx].mean(axis=0) - ideal_coverages), label=f'{method_names[method_idx]}', color=colors[method_idx], linestyle=linestyles[method_idx])
            if method_idx != 2:
                axs[4,0].fill_between(sigmas, (coverages[:,f_idx,:,method_idx].mean(axis=0) - ideal_coverages) - coverages[:,f_idx,:,method_idx].std(axis=0)/num_trainings**0.5, (coverages[:,f_idx,:,method_idx].mean(axis=0) - ideal_coverages) + coverages[:,f_idx,:,method_idx].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[method_idx], linestyle=linestyles[method_idx])
        elif f_idx == 4:
            axs[3,1].plot(sigmas, coverages[:,f_idx,:,method_idx].mean(axis=0), label=f'{method_names[method_idx]}', color=colors[method_idx], linestyle=linestyles[method_idx])
            if method_idx != 2:
                axs[3,1].fill_between(sigmas, coverages[:,f_idx,:,method_idx].mean(axis=0) - coverages[:,f_idx,:,method_idx].std(axis=0)/num_trainings**0.5, coverages[:,f_idx,:,method_idx].mean(axis=0) + coverages[:,f_idx,:,method_idx].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[method_idx], linestyle=linestyles[method_idx])
            axs[4,1].plot(sigmas, (coverages[:,f_idx,:,method_idx].mean(axis=0) - ideal_coverages), label=f'{method_names[method_idx]}', color=colors[method_idx], linestyle=linestyles[method_idx])
            if method_idx != 2:
                axs[4,1].fill_between(sigmas, (coverages[:,f_idx,:,method_idx].mean(axis=0) - ideal_coverages) - coverages[:,f_idx,:,method_idx].std(axis=0)/num_trainings**0.5, (coverages[:,f_idx,:,method_idx].mean(axis=0) - ideal_coverages) + coverages[:,f_idx,:,method_idx].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[method_idx], linestyle=linestyles[method_idx])
        elif f_idx == 5:
            axs[3,2].plot(sigmas, coverages[:,f_idx,:,method_idx].mean(axis=0), label=f'{method_names[method_idx]}', color=colors[method_idx], linestyle=linestyles[method_idx])
            if method_idx != 2:
                axs[3,2].fill_between(sigmas, coverages[:,f_idx,:,method_idx].mean(axis=0) - coverages[:,f_idx,:,method_idx].std(axis=0)/num_trainings**0.5, coverages[:,f_idx,:,method_idx].mean(axis=0) + coverages[:,f_idx,:,method_idx].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[method_idx], linestyle=linestyles[method_idx])
            axs[4,2].plot(sigmas, (coverages[:,f_idx,:,method_idx].mean(axis=0) - ideal_coverages), label=f'{method_names[method_idx]}', color=colors[method_idx], linestyle=linestyles[method_idx])
            if method_idx != 2:
                axs[4,2].fill_between(sigmas, (coverages[:,f_idx,:,method_idx].mean(axis=0) - ideal_coverages) - coverages[:,f_idx,:,method_idx].std(axis=0)/num_trainings**0.5, (coverages[:,f_idx,:,method_idx].mean(axis=0) - ideal_coverages) + coverages[:,f_idx,:,method_idx].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[method_idx], linestyle=linestyles[method_idx])

axs[0,0].legend(fontsize=16, frameon=False)
axs[0,1].legend(fontsize=16, frameon=False)
axs[0,2].legend(fontsize=16, frameon=False)
axs[3,0].legend(fontsize=16, frameon=False)
axs[3,1].legend(fontsize=16, frameon=False)
axs[3,2].legend(fontsize=16, frameon=False)
plt.suptitle(r"Gaussian Case Study $M=16$, Coverage on Mixture Fraction $\kappa$", fontsize=24)
plt.savefig("fig3.pdf", bbox_inches='tight')

# Fig 7

In [None]:
import torch
import matplotlib.pyplot as plt
from scipy.stats import norm

# We take M=8, 16, and 32 for this study
bootstrap_pred = torch.load('YOUR_DIR_HERE/bootstrap_pred_arr.pt')[2:]
bootstrap_raw_pred = torch.load('YOUR_DIR_HERE/bootstrap_raw_pred_arr.pt')[2:]
bootstrap_true = torch.load('YOUR_DIR_HERE/bootstrap_true_arr.pt')[2:]
bootstrap_unc = torch.load('YOUR_DIR_HERE/bootstrap_unc_arr.pt')[2:]

bootstrap_pred_8 = bootstrap_pred[0]
bootstrap_true_8 = bootstrap_true[0]
bootstrap_unc_8 = bootstrap_unc[0]
bootstrap_raw_pred_8 = bootstrap_raw_pred[0]

bootstrap_pred_16 = bootstrap_pred[1]
bootstrap_true_16 = bootstrap_true[1]
bootstrap_unc_16 = bootstrap_unc[1]
bootstrap_raw_pred_16 = bootstrap_raw_pred[1]

bootstrap_pred_32 = bootstrap_pred[2]
bootstrap_true_32 = bootstrap_true[2]
bootstrap_unc_32 = bootstrap_unc[2]
bootstrap_raw_pred_32 = bootstrap_raw_pred[2]

num_subnet_arr = [2, 4, 8, 16, 32]
sigmas = torch.linspace(0,5,500)
ideal_coverages = norm.cdf(sigmas) - norm.cdf(-sigmas)
i=3
color_pal = sns.color_palette('bright')
colors = [color_pal[5], color_pal[4]]
mpl.rc('font',family='Times New Roman')
mpl.rc('mathtext', fontset='cm')

bootstrap_z_scores_8 = ((bootstrap_pred_8 - bootstrap_true_8)/bootstrap_unc_8)
bootstrap_coverages_8 = (bootstrap_z_scores_8.abs().unsqueeze(-1) < sigmas).float().mean(dim=2)[:,:,i,:]

bootstrap_z_scores_16 = ((bootstrap_pred_16 - bootstrap_true_16)/bootstrap_unc_16)
bootstrap_coverages_16 = (bootstrap_z_scores_16.abs().unsqueeze(-1) < sigmas).float().mean(dim=2)[:,:,i,:]

bootstrap_z_scores_32 = ((bootstrap_pred_32 - bootstrap_true_32)/bootstrap_unc_32)
bootstrap_coverages_32 = (bootstrap_z_scores_32.abs().unsqueeze(-1) < sigmas).float().mean(dim=2)[:,:,i,:]

bootstrap_raw_z_scores_8 = ((bootstrap_raw_pred_8 - bootstrap_true_8)/bootstrap_unc_8)
bootstrap_raw_coverages_8 = (bootstrap_raw_z_scores_8.abs().unsqueeze(-1) < sigmas).float().mean(dim=2)[:,:,i,:]

bootstrap_raw_z_scores_16 = ((bootstrap_raw_pred_16 - bootstrap_true_16)/bootstrap_unc_16)
bootstrap_raw_coverages_16 = (bootstrap_raw_z_scores_16.abs().unsqueeze(-1) < sigmas).float().mean(dim=2)[:,:,i,:]

bootstrap_raw_z_scores_32 = ((bootstrap_raw_pred_32 - bootstrap_true_32)/bootstrap_unc_32)
bootstrap_raw_coverages_32 = (bootstrap_raw_z_scores_32.abs().unsqueeze(-1) < sigmas).float().mean(dim=2)[:,:,i,:]

num_trainings = bootstrap_coverages_32.shape[1]

fig, axs = plt.subplots(1, 3, figsize=(15,5), sharey=True, gridspec_kw={'wspace': 0})
i=3
color_pal = sns.color_palette('bright')
mpl.rc('font',family='Times New Roman')
mpl.rc('mathtext', fontset='cm')
colors = [color_pal[5], color_pal[4]]
axs[0].set_ylabel("Count", fontsize=18)
axs[0].hist((bootstrap_pred_8 - bootstrap_true_8)[0,:,:,i].flatten().numpy(), bins='auto', alpha=0.5, label=r"$\hat{\kappa}_{\rm BC}$", color=colors[0])
axs[0].hist((bootstrap_raw_pred_8 - bootstrap_true_8)[0,:,:,i].flatten().numpy(), bins='auto', alpha=0.5, label=r"$\hat{\kappa}$", color=colors[1])
axs[0].set_xlabel("Residual", fontsize=18)
axs[0].axvline(0, color='k', linestyle='--')
axs[0].axvline((bootstrap_pred_8 - bootstrap_true_8)[0,:,:,i].flatten().numpy().mean(), color=colors[0])
axs[0].axvline((bootstrap_raw_pred_8 - bootstrap_true_8)[0,:,:,i].flatten().numpy().mean(), color=colors[1])
axs[0].set_title("M=8 residuals", fontsize=22)
axs[0].legend(fontsize = 16, frameon=False)
axs[0].tick_params(axis='both', which='major', direction='in', labelsize=16)
axs[0].set_xlim(-0.15, 0.15)

axs[1].hist((bootstrap_pred_16 - bootstrap_true_16)[0,:,:,i].flatten().numpy(), bins='auto', alpha=0.5, label=r"$\hat{\kappa}_{\rm BC}$", color=colors[0])
axs[1].hist((bootstrap_raw_pred_16 - bootstrap_true_16)[0,:,:,i].flatten().numpy(), bins='auto', alpha=0.5, label=r"$\hat{\kappa}$", color=colors[1])
axs[1].set_xlabel("Residual", fontsize=18)
axs[1].axvline(0, color='k', linestyle='--')
axs[1].axvline((bootstrap_pred_16 - bootstrap_true_16)[0,:,:,i].flatten().numpy().mean(), color=colors[0])
axs[1].axvline((bootstrap_raw_pred_16 - bootstrap_true_16)[0,:,:,i].flatten().numpy().mean(), color=colors[1])
axs[1].set_title("M=16 residuals", fontsize=22)
axs[1].legend(fontsize = 16, frameon=False)
axs[1].tick_params(axis='both', which='major', direction='in', labelsize=16)
axs[1].set_xlim(-0.15, 0.15)

axs[2].hist((bootstrap_pred_32 - bootstrap_true_32)[0,:,:,i].flatten().numpy(), bins='auto', alpha=0.5, label=r"$\hat{\kappa}_{\rm BC}$", color=colors[0])
axs[2].hist((bootstrap_raw_pred_32 - bootstrap_true_32)[0,:,:,i].flatten().numpy(), bins='auto', alpha=0.5, label=r"$\hat{\kappa}$", color=colors[1])
axs[2].set_xlabel("Residual", fontsize=18)
axs[2].axvline(0, color='k', linestyle='--')
axs[2].axvline((bootstrap_pred_32 - bootstrap_true_32)[0,:,:,i].flatten().numpy().mean(), color=colors[0])
axs[2].axvline((bootstrap_raw_pred_32 - bootstrap_true_32)[0,:,:,i].flatten().numpy().mean(), color=colors[1])
axs[2].set_title("M=32 residuals", fontsize=22)
axs[2].legend(fontsize = 16, frameon=False)
axs[2].tick_params(axis='both', which='major', direction='in', labelsize=16)
axs[2].set_xlim(-0.15, 0.15)
plt.suptitle(r"Gaussian Case Study, Uncorrected and Corrected Residuals on $\kappa=0.1$", fontsize=24, y=1.05)
plt.savefig("fig7.pdf", bbox_inches='tight')

# Fig 8

In [None]:
fig, axs = plt.subplots(2,3, sharex="col", figsize = (14,6), sharey="row", gridspec_kw={'hspace': 0.05, 'wspace' : 0.04, 'height_ratios': [2, 1]})
axs[0,2].tick_params(axis='both', which='major', direction='in', labelsize=16)
axs[1,2].tick_params(axis='both', which='major', direction='in', labelsize=16)
axs[0,2].set_xlim(0, 5)
axs[0,2].set_ylim(0, 1)
subnet_idx = 0
axs[0,2].plot(sigmas, ideal_coverages, 'k--', label='Ideal')

axs[0,2].plot(sigmas, bootstrap_coverages_32[subnet_idx,:,:].mean(axis=0), label=f'Bias correction', color=colors[0])
axs[0,2].plot(sigmas, bootstrap_raw_coverages_32[subnet_idx,:,:].mean(axis=0), label=f'No bias correction', color=colors[1])
axs[0,2].fill_between(sigmas, bootstrap_coverages_32[subnet_idx,:,:].mean(axis=0) - bootstrap_coverages_32[subnet_idx,:,:].std(axis=0)/num_trainings**0.5, bootstrap_coverages_32[subnet_idx,:,:].mean(axis=0) + bootstrap_coverages_32[subnet_idx,:,:].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[0])
axs[0,2].fill_between(sigmas, bootstrap_raw_coverages_32[subnet_idx,:,:].mean(axis=0) - bootstrap_raw_coverages_32[subnet_idx,:,:].std(axis=0)/num_trainings**0.5, bootstrap_raw_coverages_32[subnet_idx,:,:].mean(axis=0) + bootstrap_raw_coverages_32[subnet_idx,:,:].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[1])
axs[1,2].plot(sigmas, (bootstrap_coverages_32[subnet_idx,:,:].mean(axis=0) - ideal_coverages), label=f'Bias correction', color=colors[0])
axs[1,2].plot(sigmas, (bootstrap_raw_coverages_32[subnet_idx,:,:].mean(axis=0) - ideal_coverages), label=f'No bias correction', color=colors[1])
axs[1,2].fill_between(sigmas, (bootstrap_coverages_32[subnet_idx,:,:].mean(axis=0) - ideal_coverages) - bootstrap_coverages_32[subnet_idx,:,:].std(axis=0)/num_trainings**0.5, (bootstrap_coverages_32[subnet_idx,:,:].mean(axis=0) - ideal_coverages) + bootstrap_coverages_32[subnet_idx,:,:].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[0])
axs[1,2].fill_between(sigmas, (bootstrap_raw_coverages_32[subnet_idx,:,:].mean(axis=0) - ideal_coverages) - bootstrap_raw_coverages_32[subnet_idx,:,:].std(axis=0)/num_trainings**0.5, (bootstrap_raw_coverages_32[subnet_idx,:,:].mean(axis=0) - ideal_coverages) + bootstrap_raw_coverages_32[subnet_idx,:,:].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[1])
axs[1,2].axhline(y=0, color='k', linestyle='--')
axs[1,2].set_xlabel(f'$z$', fontsize=18)
axs[1,0].set_ylabel(r"$c(z) - \bar{c}(z)$", fontsize=18)
axs[0,2].set_ylim(0, 1)
axs[1,2].set_ylim(-0.3, 0.3)
axs[0,0].set_ylabel(r'$c(z)$',fontsize=18)
axs[0,2].set_title('M=32',fontsize=22)
axs[0,2].legend(fontsize=16, frameon=False)

axs[0,1].tick_params(axis='both', which='major', direction='in', labelsize=16)
axs[1,1].tick_params(axis='both', which='major', direction='in', labelsize=16)
axs[0,1].set_xlim(0, 5)
axs[0,1].set_ylim(0, 1)
subnet_idx = 0
axs[0,1].plot(sigmas, ideal_coverages, 'k--', label='Ideal')
axs[0,1].plot(sigmas, bootstrap_coverages_16[subnet_idx,:,:].mean(axis=0), label=f'Bias correction', color=colors[0])
axs[0,1].plot(sigmas, bootstrap_raw_coverages_16[subnet_idx,:,:].mean(axis=0), label=f'No bias correction', color=colors[1])
axs[0,1].fill_between(sigmas, bootstrap_coverages_16[subnet_idx,:,:].mean(axis=0) - bootstrap_coverages_16[subnet_idx,:,:].std(axis=0)/num_trainings**0.5, bootstrap_coverages_16[subnet_idx,:,:].mean(axis=0) + bootstrap_coverages_16[subnet_idx,:,:].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[0])
axs[0,1].fill_between(sigmas, bootstrap_raw_coverages_16[subnet_idx,:,:].mean(axis=0) - bootstrap_raw_coverages_16[subnet_idx,:,:].std(axis=0)/num_trainings**0.5, bootstrap_raw_coverages_16[subnet_idx,:,:].mean(axis=0) + bootstrap_raw_coverages_16[subnet_idx,:,:].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[1])
axs[1,1].plot(sigmas, (bootstrap_coverages_16[subnet_idx,:,:].mean(axis=0) - ideal_coverages), label=f'Bias correction', color=colors[0])
axs[1,1].plot(sigmas, (bootstrap_raw_coverages_16[subnet_idx,:,:].mean(axis=0) - ideal_coverages), label=f'No bias correction', color=colors[1])
axs[1,1].fill_between(sigmas, (bootstrap_coverages_16[subnet_idx,:,:].mean(axis=0) - ideal_coverages) - bootstrap_coverages_16[subnet_idx,:,:].std(axis=0)/num_trainings**0.5, (bootstrap_coverages_16[subnet_idx,:,:].mean(axis=0) - ideal_coverages) + bootstrap_coverages_16[subnet_idx,:,:].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[0])
axs[1,1].fill_between(sigmas, (bootstrap_raw_coverages_16[subnet_idx,:,:].mean(axis=0) - ideal_coverages) - bootstrap_raw_coverages_16[subnet_idx,:,:].std(axis=0)/num_trainings**0.5, (bootstrap_raw_coverages_16[subnet_idx,:,:].mean(axis=0) - ideal_coverages) + bootstrap_raw_coverages_16[subnet_idx,:,:].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[1])

axs[1,1].axhline(y=0, color='k', linestyle='--')
axs[1,1].set_xlabel(f'$z$', fontsize=18)
axs[0,1].set_ylim(0, 1)
axs[1,1].set_ylim(-0.3, 0.3)
axs[0,1].set_title('M=16',fontsize=22)
axs[0,1].legend(fontsize=16, frameon=False)

axs[0,0].tick_params(axis='both', which='major', direction='in', labelsize=16)
axs[1,0].tick_params(axis='both', which='major', direction='in', labelsize=16)
axs[0,0].set_xlim(0, 5)
axs[0,0].set_ylim(0, 1)
subnet_idx = 0
axs[0,0].plot(sigmas, ideal_coverages, 'k--', label='Ideal')
axs[0,0].plot(sigmas, bootstrap_coverages_8[subnet_idx,:,:].mean(axis=0), label=f'Bias correction', color=colors[0])
axs[0,0].plot(sigmas, bootstrap_raw_coverages_8[subnet_idx,:,:].mean(axis=0), label=f'No bias correction', color=colors[1])
axs[0,0].fill_between(sigmas, bootstrap_coverages_8[subnet_idx,:,:].mean(axis=0) - bootstrap_coverages_8[subnet_idx,:,:].std(axis=0)/num_trainings**0.5, bootstrap_coverages_8[subnet_idx,:,:].mean(axis=0) + bootstrap_coverages_8[subnet_idx,:,:].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[0])
axs[0,0].fill_between(sigmas, bootstrap_raw_coverages_8[subnet_idx,:,:].mean(axis=0) - bootstrap_raw_coverages_8[subnet_idx,:,:].std(axis=0)/num_trainings**0.5, bootstrap_raw_coverages_8[subnet_idx,:,:].mean(axis=0) + bootstrap_raw_coverages_8[subnet_idx,:,:].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[1])
axs[1,0].plot(sigmas, (bootstrap_coverages_8[subnet_idx,:,:].mean(axis=0) - ideal_coverages), label=f'Bias correction', color=colors[0])
axs[1,0].plot(sigmas, (bootstrap_raw_coverages_8[subnet_idx,:,:].mean(axis=0) - ideal_coverages), label=f'No bias correction', color=colors[1])
axs[1,0].fill_between(sigmas, (bootstrap_coverages_8[subnet_idx,:,:].mean(axis=0) - ideal_coverages) - bootstrap_coverages_8[subnet_idx,:,:].std(axis=0)/num_trainings**0.5, (bootstrap_coverages_8[subnet_idx,:,:].mean(axis=0) - ideal_coverages) + bootstrap_coverages_8[subnet_idx,:,:].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[0])
axs[1,0].fill_between(sigmas, (bootstrap_raw_coverages_8[subnet_idx,:,:].mean(axis=0) - ideal_coverages) - bootstrap_raw_coverages_8[subnet_idx,:,:].std(axis=0)/num_trainings**0.5, (bootstrap_raw_coverages_8[subnet_idx,:,:].mean(axis=0) - ideal_coverages) + bootstrap_raw_coverages_8[subnet_idx,:,:].std(axis=0)/num_trainings**0.5, alpha=0.5, color=colors[1])
axs[1,0].axhline(y=0, color='k', linestyle='--')
axs[1,0].set_xlabel(f'$z$', fontsize=18)
axs[0,0].set_ylim(0, 1)
axs[1,0].set_ylim(-0.3, 0.3)
axs[0,0].set_title('M=8',fontsize=22)
axs[0,0].legend(fontsize=16, frameon=False)
# plt.show()
plt.suptitle(r"Gaussian Case Study, Uncorrected and Corrected Coverage for $\kappa=0.1$", fontsize=24, y=1.05)
plt.savefig("fig8.pdf", bbox_inches='tight')