In [1]:
from dustbi_simulator import *
from Functions import *

In [2]:
from astropy.cosmology import Planck18
import astropy.units as u

In [3]:
import numpy as np
import pandas as pd
df = pd.read_csv("INPUT_DES5YR_D2D.FITRES", comment="#", sep='\s+')

df['SIM_EBV'] = df.SIM_AV/df.SIM_RV

df['MU'] = Planck18.distmod(df.zHD.values).value



dfdata = pd.read_csv("SIMS_FOR_TESTING/FITOPT000.FITRES.gz", 
                     comment="#", sep=r'\s+')

dfdata['MU'] = Planck18.distmod(dfdata.zHD.values).value

#dfdata = pd.read_csv("../INVERSE_H0/D5YR_DATA/FITOPT000_MUOPT000.FITRES.gz", comment="#", sep=r'\s+')

try:
    dfdata['SIM_EBV'] = dfdata.SIM_AV/dfdata.SIM_RV
except:
    print("eh.")

dfdata = dfdata.loc[dfdata.IDSURVEY == 10]
dfdata = dfdata.loc[dfdata.PROB_SNNV19 >= 0.5]

  dfdata['MU'] = Planck18.distmod(dfdata.zHD.values).value
  dfdata['SIM_EBV'] = dfdata.SIM_AV/dfdata.SIM_RV


In [4]:
bounds_dict = {
    "SIM_c"   : (-0.5, 0.5),
    "SIM_RV"  : (1.5, 5),
    "SIM_EBV" : (0,1),
    "SIM_beta": (0.5,4),
}

function_dict = {
    "SIM_c"   : DistGaussian,
    "SIM_RV"  : DistGaussian,
    "SIM_EBV" : DistExponential,
    "SIM_beta": DistGaussian,
}

split_dict = {
#    "SIM_RV":["HOST_LOGMASS", 10],
    "SIM_EBV":['HOST_LOGMASS', 10],
#    'SIM_c':['HOST_LOGMASS', 10]
}


#Prior dict is a weird one; it should be a tuple for each parameter and then a boolean statement.

split_dict = {}


priors_dict = {
    
    "SIM_c"   : [(-0.2, 0), (0.0, 0.1), ],
    "SIM_RV"  : [(1.5,4), (0,2), ],
    "SIM_EBV" : [(0.05, 0.3)],
    "SIM_beta": [(0,3), (0,1), ],
    
}

latex_dict = {
    
    'DistExponential':[r'$\tau$'],
    'DistGaussian':[r'$\mu$', r'$\sigma$'],
    'SIM_c':r"$c_{\rm int}$",
    'SIM_beta':r"$\beta_{\rm int}$",
    'SIM_RV':r"$R_V$",
    'SIM_EBV':r"$EBV$",
    
}


dicts = [bounds_dict, function_dict, split_dict, priors_dict]

In [5]:
param_names = ['SIM_c', 'SIM_RV', 'SIM_beta', 'SIM_EBV']
#param_names = ['SIM_c']


params_to_fit = parameter_generation(param_names, dicts)
priors = prior_generator(param_names, dicts)

Total priors added: 7
[0] <class 'sbi.utils.torchutils.BoxUniform'>
[1] <class 'sbi.utils.torchutils.BoxUniform'>
[2] <class 'sbi.utils.torchutils.BoxUniform'>
[3] <class 'sbi.utils.torchutils.BoxUniform'>
[4] <class 'sbi.utils.torchutils.BoxUniform'>
[5] <class 'sbi.utils.torchutils.BoxUniform'>
[6] <class 'sbi.utils.torchutils.BoxUniform'>


In [6]:
layout = build_layout(params_to_fit, dicts)

In [7]:
parameters_to_condition_on = ['c', 'mB', 'x1', 'zHD', 'cERR', 'mBERR', 'x1ERR', 'MU']

In [8]:
simulatinator = make_simulator(layout, df, param_names, parameters_to_condition_on, dicts, dfdata, is_split=True)


In [9]:
ndim = len(parameters_to_condition_on)#+1

if any(p in split_dict for p in param_names): #check early to see if we need to split anything. 
    ndim *= 2
    
print(ndim)

8


In [None]:
from joblib import Parallel, delayed

In [10]:
def batched_simulator(theta_batch):
    return torch.stack([simulatinator(theta) for theta in theta_batch])

In [None]:
def batched_simulator(theta_batch):
    results = Parallel(n_jobs=-1)(
        delayed(simulatinator)(theta)
        for theta in theta_batch
    )
    return torch.stack(results)

In [11]:
from sbi import analysis as analysis

# sbi
from sbi import utils as utils
from sbi.inference import NPE, simulate_for_sbi
from sbi.utils.user_input_checks import (
    check_sbi_inputs,
    process_prior,
    process_simulator,
)

In [12]:
# Check prior, simulator, consistency
prior, num_parameters, prior_returns_numpy = process_prior(priors)
simulation_wrapper = process_simulator(simulatinator, prior, prior_returns_numpy)
check_sbi_inputs(simulation_wrapper, prior)

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F




In [14]:
from sbi.inference import SNPE
from sbi.utils import MultipleIndependent

from sbi.neural_nets import posterior_nn




# Potentially Upgraded Version

In [15]:
from sbi import analysis as analysis
from sbi.inference import SNPE
from sbi.neural_nets import posterior_nn

class PopulationEmbeddingFull(nn.Module):
    def __init__(self, input_dim=ndim, hidden_dim=64, output_dim=32):
        super().__init__()
        self.phi = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        self.attention = nn.Linear(hidden_dim, 1)
        self.rho = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        h = self.phi(x)                              # (batch, N, hidden)
        w = torch.softmax(self.attention(h), dim=1)   # May need to train this to only run on errors? 
                                                      # Right now runs on everything ... 
        h = (h * w).sum(dim=1)                        # (batch, hidden)
        return self.rho(h)
    
#might need to standardise errors and signal 

In [16]:
from sbi.inference import SNPE
from sbi.utils import MultipleIndependent

from sbi.neural_nets import posterior_nn

density_estimator = posterior_nn(
    model="nsf", #switch to nsf if interested 
    embedding_net=PopulationEmbeddingFull(input_dim=ndim)
)

inference = SNPE(
    prior=priors,
    density_estimator=density_estimator, 
)





# Permutation Invariant Embedding 

In [None]:
from sbi.neural_nets import posterior_nn
from sbi.neural_nets.embedding_nets import FCEmbedding, PermutationInvariantEmbedding


In [None]:

single_trial_net = FCEmbedding(
    input_dim=ndim,
    num_hiddens=40,
    num_layers=2,
    output_dim=ndim,
)
embedding_net = PermutationInvariantEmbedding(
    single_trial_net,
    trial_net_output_dim=ndim,
    num_layers=1,
    num_hiddens=10,
    output_dim=ndim,
)
density_estimator = posterior_nn("nsf", embedding_net=embedding_net)



inference = SNPE(
    prior=priors,
    density_estimator=density_estimator, 
)



In [None]:
import torch
import os
import time


start_time = time.perf_counter()

batch_size = 50
num_simulations = 4000
save_path = "simulations_v2.pt"

# If the file already exists, start fresh
if os.path.exists(save_path):
    os.remove(save_path)

for start in range(0, num_simulations, batch_size):
    current_bs = min(batch_size, num_simulations - start)

    # Sample and simulate
    theta_batch = priors.sample((current_bs,))
    x_batch = batched_simulator(theta_batch)

    # Append to SBI inference
    inference.append_simulations(theta_batch, x_batch)

    # Save incrementally
    if start == 0:
        # First batch, create the file
        torch.save({'theta': theta_batch, 'x': x_batch}, save_path)
    else:
        # Load existing data
        data = torch.load(save_path)
        data['theta'] = torch.cat([data['theta'], theta_batch], dim=0)
        data['x'] = torch.cat([data['x'], x_batch], dim=0)
        torch.save(data, save_path)

    print(f"Appended {start + current_bs}/{num_simulations} simulations and saved incrementally.")

print(f"All simulations saved incrementally to '{save_path}'")

end = time.perf_counter()

elapsed = end - start_time
#print(f'Time taken: {elapsed:.6f} seconds')

In [None]:
print(f'Time taken: {elapsed/60:.6f} minutes')

In [None]:
inference.append_simulations(theta_batch, x_batch)

density_estimator = inference.train(validation_fraction=0.1,
                                   show_train_summary=True)
#force_first_round_loss 
#when true, only compute standard NPE loss;
#think about moving to the S of SNPE eventually... 

print("\n inferred successfully")

posterior = inference.build_posterior(density_estimator)

torch.save(posterior, "posterior.pt")




 Training neural network. Epochs trained: 19

In [17]:
data = torch.load("simulations_v2.pt")
theta_batch = data["theta"]
x_batch = data["x"]


  data = torch.load("simulations_v2.pt")


In [None]:
def preprocess_data(param_names, parameters_to_condition_on, split_dict, dfdata, ):
    
    #salt_mcmc = start_distance()
       
    output_distribution = preprocess_input_distribution(
        dfdata, parameters_to_condition_on+['x0', 'x0ERR', 'MU']
    )
    
    #salt_mcmc.run(
    #    output_distribution['x0'],
    #    output_distribution['x0ERR'],
    #    output_distribution['x1'],
    #    output_distribution['x1ERR'],
    #    output_distribution['c'],
    #    output_distribution['cERR'],
    #    output_distribution['MU']
    #    )

    #MURES = add_distance(salt_mcmc, output_distribution)
    #output_distribution['MURES'] = MURES
    
    is_split = False
    if any(p in split_dict for p in param_names):
        is_split = True#check early to see if we need to split anything. 
    
    if is_split:
    
        matching = [p for p in param_names if p in split_dict]
        name = matching[0]

        split_param = split_dict[name][0]
        split_val   = split_dict[name][1]

        split_tensor = torch.tensor(
            dfdata[split_param].to_numpy(),
            dtype=torch.float32,
            )

        x = split_outputs(
            output_distribution,
            split_tensor,
            split_val,
            parameters_to_condition_on#+['MURES']
            )

    else:
        x = torch.stack(
            [output_distribution[p] for p in parameters_to_condition_on],#+['MURES']],
            dim=-1
        )
        
    return x 

In [None]:
x = preprocess_data(param_names, parameters_to_condition_on, split_dict, dfdata)



In [None]:
true_vals = priors.sample()

new_x = simulatinator(true_vals)

In [None]:
true_vals

In [None]:
labels = unspool_labels(param_names, dicts, latex_dict, function_dict)

In [None]:
posterior_samples = posterior.sample((50000,), x=x)


In [None]:
fig, axes = analysis.pairplot(
    posterior_samples,
    labels=labels

);

In [None]:
true_vals

In [None]:
theta_hat = posterior_samples.mean(0)


In [None]:
theta_hat


In [None]:
posterior_samples.std(0)

In [None]:
true_params = true_vals#torch.tensor([-0.07, 0.053, 2, 0.95, 2.07, 0.22, 0.15, 0.12,])

In [None]:
from IPython.display import display, Math


In [None]:
for n in range(len(theta_hat)):
    

    delta = theta_hat[n] - true_params[n]
    sigma = delta/posterior_samples.std(0)[n]
    
    string = rf"{labels[n]} = {theta_hat[n]:.3f} +/- {posterior_samples.std(0)[n]:.3f} which is {sigma:.3f} $\sigma$"
    
    display(Math(string))


In [None]:
simulatinator = make_simulator(layout, df, param_names, parameters_to_condition_on, dicts, 
                               dfdata, is_split=True, debug=True)


In [None]:
dft = simulatinator(theta_hat)

#dft = simulatinator(torch.tensor([[-0.1006,  0.0507,  2.7590,  1.0042,  1.4923,  0.5086,  0.3, 0.06]]))

In [None]:
import matplotlib.pyplot as plt

In [None]:
bins = np.linspace(-0.4, 0.4, 20)

#plt.hist(dft.loc[dft.HOST_LOGMASS < 10].c.values, histtype='step', bins=bins, label="sim output, low mass", density=True)
#plt.hist(dft.loc[dft.HOST_LOGMASS > 10].c.values, histtype='step', bins=bins, label="sim output, high mass", density=True)
plt.hist(dft.c.values, histtype='step', bins=bins, label="sim output, all mass", density=True)

plt.hist(dfdata.c.values, histtype='step', bins=bins, label="data", density=True)

plt.legend()
plt.xlabel("c")

In [None]:
bins = np.linspace(-4, 10, 20)

#plt.hist(dft.loc[dft.HOST_LOGMASS < 10].c.values, histtype='step', bins=bins, label="sim output, low mass", density=True)
#plt.hist(dft.loc[dft.HOST_LOGMASS > 10].c.values, histtype='step', bins=bins, label="sim output, high mass", density=True)
plt.hist(dft.MURES.values, histtype='step', bins=bins, label="sim output, all mass", density=True)

plt.hist(dfdata.MURES.values, histtype='step', bins=bins, label="data", density=True)

plt.legend()
plt.xlabel("MURES")

In [None]:
bins = np.linspace(18, 26, 20)

plt.hist(dft.mB.values, histtype='step', bins=bins, label="sim output", density=True)
plt.hist(dfdata.mB.values, histtype='step', bins=bins, label="data", density=True)

plt.legend()
plt.xlabel("mB")

In [None]:
bins = np.linspace(0, 0.6, 20)

plt.hist(dft.loc[dft.HOST_LOGMASS < 10].SIM_EBV.values, histtype='step', bins=bins, label="low mass output", density=True)
plt.hist(dft.loc[dft.HOST_LOGMASS > 10].SIM_EBV.values, histtype='step', bins=bins, label="high mass output", density=True)


plt.hist(dfdata.SIM_EBV.values, histtype='step', bins=bins, label="data", density=True)


plt.legend()
plt.xlabel("E(B-V)")

# Calibrate some posteriors

In [None]:
import matplotlib.pyplot as plt

In [None]:
from sbi.diagnostics import run_sbc
from sbi.analysis.plot import sbc_rank_plot


In [None]:
# Obtain your `posterior_estimator` with NPE, NLE, NRE.
posterior = inference.build_posterior()

num_sbc_samples = 200  # choose a number of sbc runs, should be ~100s
prior_samples = prior.sample((num_sbc_samples,))
prior_predictives = batched_simulator(prior_samples)

num_posterior_samples = 4000

In [None]:
ranks, dap_samples = run_sbc(
    prior_samples,
    prior_predictives,
    posterior,
    num_posterior_samples=num_posterior_samples,
    use_batched_sampling=True, # `True` can give speed-ups, but can cause memory issues.
    num_workers=1
)




In [None]:
import joblib
import sbi

print("joblib version:", joblib.__version__)
print("sbi version:", sbi.__version__)

In [None]:

fig, ax = sbc_rank_plot(
    ranks,
    num_posterior_samples,
    plot_type="cdf",
    num_bins=20,
    figsize=(5, 3),
)

In [None]:
f, ax = sbc_rank_plot(
    ranks=ranks,
    num_posterior_samples=num_posterior_samples,
    plot_type="hist",
    num_bins=None,  # by passing None we use a heuristic for the number of bins.
)

In [None]:
labels

In [None]:
#Flat histogram → well-calibrated.

#U-shaped → posteriors too narrow.

#Bell-shaped → posteriors too wide.

In [None]:
num_tarp_samples = 200  # choose a number of sbc runs, should be ~100s
# generate ground truth parameters and corresponding simulated observations for SBC.
thetas = prior.sample((num_tarp_samples,))
xs = batched_simulator(thetas)

In [None]:
from sbi.diagnostics import check_sbc, check_tarp, run_sbc, run_tarp
# the tarp method returns the ECP values for a given set of alpha coverage levels.
ecp, alpha = run_tarp(
    thetas,
    xs,
    posterior,
    references=None,  # will be calculated automatically.
    num_posterior_samples=3000,
)

In [None]:
atc, ks_pval = check_tarp(ecp, alpha)
print(atc, "Should be close to 0")
print(ks_pval, "Should be larger than 0.05")

In [None]:
from sbi.analysis.plot import plot_tarp

plot_tarp(ecp, alpha)

In [None]:
def plot_tarp(
    ecp, alpha, title,):
    """
    Plots the expected coverage probability (ECP) against the credibility
    level,alpha, for a given alpha grid.

    Args:
        ecp : numpy.ndarray
            Array of expected coverage probabilities.
        alpha : numpy.ndarray
            Array of credibility levels.
        title : str, optional
            Title for the plot. The default is "".

     Returns
        fig : matplotlib.figure.Figure
            The figure object.
        ax : matplotlib.axes.Axes
            The axes object.

    """

    fig = plt.figure(figsize=(6, 6))
    ax: Axes = plt.gca()

    ax.plot(alpha, ecp, color="blue", label="TARP")
    ax.plot(alpha, alpha, color="black", linestyle="--", label="ideal")
    ax.set_xlabel(r"Credibility Level $\alpha$")
    ax.set_ylabel(r"Expected Coverage Probability")
    ax.set_xlim(0.0, 1.0)
    ax.set_ylim(0.0, 1.0)
    ax.set_title(title or "")
    ax.legend()
    return fig, ax  # type: ignore


In [None]:
losses = inference.summary["training_loss"]
val_losses = inference.summary["validation_loss"]

# Plot
plt.plot(losses, label='Training Loss')
plt.plot(val_losses, label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
#plt.ylim([-10,10])
plt.title("Loss vs Epochs")
plt.legend()
plt.show()

In [None]:
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS
from pyro.infer.autoguide.initialization import init_to_median

def distancinator(x0_obs, x0_err, x1_obs, x1_err, c_obs, c_err, dist_mod):

    n = dist_mod.shape[0]

    alpha = pyro.sample("alpha", dist.Normal(0.1, 1.0))
    beta = pyro.sample("beta", dist.Normal(2.0, 3.0))
    M = pyro.sample("M", dist.Uniform(-21.5, -17.0))
    sigma_int = pyro.sample("sigma_int", dist.HalfNormal(0.3))

    with pyro.plate("sne", n):
        #log10_x0 = pyro.sample("log10_x0", dist.Uniform(-6.0, 0.0))
        #x0_true = 10.0 ** log10_x0
        log10_x0 = pyro.sample("log10_x0", dist.Normal(-3.0, 2.0))
        x0_true = torch.pow(10.0, log10_x0)

        pyro.sample("x0_obs",
                    dist.Normal(x0_true, x0_err),
                    obs=x0_obs)

        correction = alpha * x1_obs - beta * c_obs - M

        mag_err = (2.5 / torch.log(torch.tensor(10.0))) * (x0_err / x0_true)
        total_err = torch.sqrt(mag_err**2 + sigma_int**2 + x1_err**2 + c_err**2)

        mean_mag = -2.5 * torch.log10(x0_true) + 10.635 + correction
        
        pyro.sample("cosmo",
                    dist.Normal(mean_mag, total_err),
                    obs=dist_mod)

In [None]:
def start_distance(NUM_WARMUP = 50, NUM_SAMPLES = 150, NUM_CHAINS = 1):
    nuts_kernel = NUTS(
        distancinator,
        jit_compile=True,
        init_strategy=init_to_median(),
        max_tree_depth=10
    )
    
    salt_mcmc = MCMC(
        nuts_kernel,
        warmup_steps=NUM_WARMUP,
        num_samples=NUM_SAMPLES,
        num_chains=NUM_CHAINS
    )
    
    return salt_mcmc
    

In [None]:
salt_mcmc = start_distance()

salt_mcmc.run(
    torch.tensor(dfdata.x0.values),
    torch.tensor(dfdata.x0ERR.values),
    torch.tensor(dfdata.x1.values),
    torch.tensor(dfdata.x1ERR.values),
    torch.tensor(dfdata.c.values),
    torch.tensor(dfdata.cERR.values),
    torch.tensor(dfdata.MU.values)
)

In [None]:
blep = salt_mcmc.get_samples()

In [None]:
def add_distance(mcmc, df_tensor):
    
    x1_obs = df_tensor['x1'] ; c_obs = df_tensor['c'] ; mB_obs = df_tensor['mB']
    
    nuisance = mcmc.get_samples()
    beta = nuisance['beta'].mean() ; alpha = nuisance['alpha'].mean() ; M0 = nuisance['M'].mean()
    
    correction = alpha * x1_obs - beta * c_obs - M0 + mB
        
    MURES = df_tensor['MU'] - correction
    
    return  MURES


In [None]:
tripp = dfdata.mB + float(blep['alpha'].mean())*dfdata.x1.values - float(blep['beta'].mean())*dfdata.c.values + float(blep['M'].mean())

In [None]:
#plt.hist(tripp)
#plt.hist(tripp_t)
dfdata['MURES'] = tripp

In [None]:
salt_mcmc = start_distance()

salt_mcmc.run(
    torch.tensor(dft.x0.values),
    torch.tensor(dft.x0ERR.values),
    torch.tensor(dft.x1.values),
    torch.tensor(dft.x1ERR.values),
    torch.tensor(dft.c.values),
    torch.tensor(dft.cERR.values),
    torch.tensor(dft.MU.values)
)

In [None]:
blep = salt_mcmc.get_samples()

In [None]:
tripp_t = dft.mB + float(blep['alpha'].mean())*dft.x1.values - float(blep['beta'].mean())*dft.c.values + float(blep['M'].mean())

In [None]:
dft['MURES'] = tripp_t