In [None]:
!pip install sbi --quiet
!pip install deeplenstronomy --quiet
!pip install lenstronomy --quiet
!pip install deeplenstronomy --quiet
!pip install getdist --quiet

In [None]:
!pip uninstall matplotlib -y
!pip install matplotlib==3.1.3

In [None]:
!pip install pandas

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import torch
import torch.nn as nn 
import torch.nn.functional as F 
from sbi import utils, inference
# from sbi import inference
from sbi.inference import SNPE, simulate_for_sbi, prepare_for_sbi
import numpy as np
import scipy as sp

import deeplenstronomy.deeplenstronomy as dl
from deeplenstronomy.visualize import view_image
from sbi.analysis import pairplot

# import main simulation class of lenstronomy
from lenstronomy.Util import util
from lenstronomy.LightModel.light_model import LightModel
from lenstronomy.LensModel.lens_model import LensModel
from lenstronomy.Data.imaging_data import ImageData
from lenstronomy.ImSim.image_model import ImageModel
import lenstronomy.Util.image_util as image_util
from lenstronomy.Data.psf import PSF

import deeplenstronomy.deeplenstronomy as dl
from deeplenstronomy.visualize import view_image
import pickle
import pandas as pd

In [None]:
!pip install astroML --quiet

In [None]:
from matplotlib.font_manager import FontProperties
from matplotlib import gridspec
plt.rcParams.update({'xtick.major.pad': '7.0'})
plt.rcParams.update({'xtick.major.size': '7.5'})
plt.rcParams.update({'xtick.major.width': '1.5'})
plt.rcParams.update({'xtick.minor.pad': '7.0'})
plt.rcParams.update({'xtick.minor.size': '3.5'})
plt.rcParams.update({'xtick.minor.width': '1.0'})
plt.rcParams.update({'ytick.major.pad': '7.0'})
plt.rcParams.update({'ytick.major.size': '7.5'})
plt.rcParams.update({'ytick.major.width': '1.5'}) 
plt.rcParams.update({'ytick.minor.pad': '7.0'})
plt.rcParams.update({'ytick.minor.size': '3.5'})
plt.rcParams.update({'ytick.minor.width': '1.0'})
plt.rcParams.update({'xtick.color': 'k'})
plt.rcParams.update({'ytick.color': 'k'})
plt.rcParams.update({'font.size': 20})
plt.rcParams.update({'axes.linewidth':2})
plt.rcParams.update({'patch.linewidth':2})
#from astroML.plotting import setup_text_plots
#setup_text_plots(fontsize=20, usetex=True)

In [None]:
#fix seed
torch.manual_seed(100)
torch.cuda.manual_seed(100)
torch.cuda.manual_seed_all(100)
np.random.seed(100)
#random.seed(100)
torch.backends.cudnn.deterministic=True
torch.backends.cudnn.enabled=False
torch.backends.cudnn.benchmark = False
#os.environ['PYTHONHASHSEED'] = str(100)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
DATA_DIR = '/content/drive/MyDrive/data_sbi/ts5_10_s/'
#PLOT_DIR = '../plots/ts1_100k/'

num_sed = 10000

df = pd.read_csv(DATA_DIR+'table.txt',header=None)
mass = np.log10(df[1].values)
met = df[4].values
age = df[2].values
tau = np.log10(df[3].values)
dust = df[5].values

data_j = []
for e,element1 in enumerate(mass):
    data_j.append(np.array([element1,age[e],tau[e],met[e],dust[e]]))
    
data_j = np.array(data_j)

plt.hist(mass,bins=20)
plt.xlabel('logM$_{*}$')
plt.show()
plt.hist(met,bins=20)
plt.xlabel('log(Z/Z$_{\odot})$')
plt.show()
plt.hist(tau,bins=20)
plt.xlabel('$Tau$')
plt.show()

In [None]:
import time
start = time.time()
sed_spline = pd.read_csv(DATA_DIR+'sed_resamp_all.txt',header=None,delimiter=' ').to_numpy()
end = time.time()
print("Time taken to load table:", end - start)
    
#sed1 = np.log10(sed)
sed1 = np.log10(sed_spline)

In [None]:
unc_factor = np.log10(1.05)
len_sed = len(sed1[2])

def simulator2(parameter_set): #parameter_set = tensor --->

    distances = np.linalg.norm(data_j - np.array(parameter_set), axis=1)
    min_index = np.argmin(distances)
    testsim = torch.tensor(np.random.normal(sed1[min_index],unc_factor,len_sed)) 
    return testsim
   
test_sim = simulator2(torch.tensor([8.8,  1.9,  0.15, -0.5,  0.7]))
plt.plot(test_sim)
plt.show()

In [None]:
from sbi.inference import SNPE, prepare_for_sbi, simulate_for_sbi
import sbi.utils as utils
from sbi.inference.base import infer
from sbi.utils.get_nn_models import posterior_nn

prior_min = [8,0.01,-1,-2.0,0.1]
prior_max = [13,4,1,0.2,1.0]
prior2 = utils.torchutils.BoxUniform(low=torch.as_tensor(prior_min), 
                                    high=torch.as_tensor(prior_max))

newfac = 20
num_sim=int(1e5) * newfac

hfeat = 25
ntrans = 10

ndata = len(simulator2(torch.tensor([10.3,1.0,-0.1,0.0,0.3])))

posterior_already_run = False
#posterior_already_run = True

unc_term = '5pct'
pickle_filename = "ts2_10k_%s_h%s_n%s_nsim%s_data%s.pkl"%(unc_term,hfeat,ntrans,num_sim,ndata)

if posterior_already_run:
    with open(DATA_DIR + pickle_filename, 'rb') as handle:
        posterior = pickle.load(handle)
    
    simulator, prior = prepare_for_sbi(simulator2, prior2)

else:
    simulator, prior = prepare_for_sbi(simulator2, prior2)
    
    density_estimator_build_fun = posterior_nn(model='maf', hidden_features=hfeat, num_transforms=ntrans) #model = maf or nsf
    
    inference = SNPE(prior=prior, density_estimator=density_estimator_build_fun)

    theta, x = simulate_for_sbi(simulator, proposal=prior, num_simulations=num_sim)
    inference = inference.append_simulations(theta, x)

    start = time.time()
    density_estimator = inference.train() #Train the neural density estimator
    end = time.time()
    print()
    print("Time to train neural network:", end - start)

    posterior = inference.build_posterior(density_estimator) #Build posterior

    with open(DATA_DIR + pickle_filename, "wb") as handle:
        pickle.dump(posterior, handle)

# **Inference and Tests**

In [None]:
data_j[568]

In [None]:
#OBSERVATIONS
obs_par = torch.tensor([data_j[568]])


def get_models(parameter_set):
    distances = np.linalg.norm(data_j - np.array(parameter_set), axis=1)
    ii = np.argmin(distances)
    return torch.tensor(sed1[ii])

obs1 = simulator2(obs_par) #with stochasticity
observation = obs1
print(len(observation))
observation

In [None]:
from sbi import analysis as analysis
from sbi.inference.base import infer

samples1 = posterior.sample((10000,), x=observation)
log_probability = posterior.log_prob(samples1, x=observation)

_ = analysis.pairplot([samples1], figsize=(6,6),labels=['logM$_{*}$','age','tau','log(Z/Z$_{\odot})$','dust'],points=obs_par,
                     points_colors="red")

plt.savefig(DATA_DIR + 'posterior_5param.pdf',dpi=300)

print("median logM:",np.median(samples1[:,0]))
print("16-50 percentile:",np.percentile(samples1[:,0],16) - np.median(samples1[:,0]))
print("84-50 percentile:",np.percentile(samples1[:,0],84)- np.median(samples1[:,0]))
print("median metallicity:",np.median(samples1[:,3]))
print("16-50 percentile:",np.percentile(samples1[:,3],16) - np.median(samples1[:,3]))
print("84-50 percentile:",np.percentile(samples1[:,3],84)- np.median(samples1[:,3]))

In [None]:
par1_rand = []
par2_rand = []
par3_rand = []
par4_rand = []
par5_rand = []

ms1 = []
ms2 = []
ms3 = []
ms4 = []
ms5 = []

us1 = []
us2 = []
us3 = []
us4 = []
us5 = []

for t,elementt in enumerate(np.arange(500)):
#for t,elementt in enumerate(np.arange(1000)):

    mass_rand = np.random.uniform(8,13)
    age_rand = np.random.uniform(0.01,4)
    tau_rand = np.random.uniform(-1,1)
    met_rand = np.random.uniform(-2.0,0.2)
    dust_rand = np.random.uniform(0.1,1.0)
    
    samples = posterior.sample((100,), x=simulator2(torch.tensor([mass_rand,age_rand,tau_rand,met_rand,dust_rand])))
    
    median_mass = np.median(samples[:,0])
    unc1 =  np.median(samples[:,0]) - np.percentile(samples[:,0],16)
    unc2 = np.percentile(samples[:,0],84)- np.median(samples[:,0])
    us1.append(np.sqrt(unc1**2 + unc2**2))

    median_age = np.median(samples[:,1])
    unc1 =  np.median(samples[:,1]) - np.percentile(samples[:,1],16)
    unc2 = np.percentile(samples[:,1],84)- np.median(samples[:,1])
    us2.append(np.sqrt(unc1**2 + unc2**2))

    median_tau = np.median(samples[:,2])
    unc1 =  np.median(samples[:,2]) - np.percentile(samples[:,2],16)
    unc2 = np.percentile(samples[:,2],84)- np.median(samples[:,2])
    us3.append(np.sqrt(unc1**2 + unc2**2))

    median_met = np.median(samples[:,3])
    unc1 =  np.median(samples[:,3]) - np.percentile(samples[:,3],16)
    unc2 = np.percentile(samples[:,3],84)- np.median(samples[:,3])
    us4.append(np.sqrt(unc1**2 + unc2**2))

    median_dust = np.median(samples[:,4])
    unc1 =  np.median(samples[:,4]) - np.percentile(samples[:,4],16)
    unc2 = np.percentile(samples[:,4],84)- np.median(samples[:,4])
    us5.append(np.sqrt(unc1**2 + unc2**2))
    
    par1_rand.append(mass_rand)
    par2_rand.append(age_rand)
    par3_rand.append(tau_rand)
    par4_rand.append(met_rand)
    par5_rand.append(dust_rand)
    #print("median logM:",median_scale)
    ms1.append(median_mass)
    ms2.append(median_age)
    ms3.append(median_tau)
    ms4.append(median_met)
    ms5.append(median_dust)
    
ms1 = np.array(ms1)
ms2 = np.array(ms2)
ms3 = np.array(ms3)
ms4 = np.array(ms4)
ms5 = np.array(ms5)

par1_rand = np.array(par1_rand)
par2_rand = np.array(par2_rand)
par3_rand = np.array(par3_rand)
par4_rand = np.array(par4_rand)
par5_rand = np.array(par5_rand)

us1 = np.array(us1)
us2 = np.array(us2)
us3 = np.array(us3)
us4 = np.array(us4)
us5 = np.array(us5)

In [None]:
plt.figure(figsize=(10,4))
plt.errorbar(par1_rand,(ms1 - par1_rand)/us1,fmt='o',yerr = us1,markersize=2)
plt.show()

plt.hist((ms1 - par1_rand)/us1,bins=30)
plt.show()
######################################
plt.figure(figsize=(10,4))
plt.errorbar(par2_rand,(ms2 - par2_rand)/us2,fmt='o',yerr = us2,markersize=2)
plt.show()

plt.hist((ms2 - par2_rand)/us2,bins=30)
plt.show()

######################################
plt.figure(figsize=(10,4))
plt.errorbar(par4_rand,(ms4 - par4_rand)/us4,fmt='o',yerr = us4,markersize=2)
plt.show()

plt.hist((ms4 - par4_rand)/us4,bins=30)
plt.show()

In [None]:
plt.scatter(par1_rand,ms1,s=50)
plt.xlabel('True Value (logM)')
plt.ylabel('Recovered Value (logM)')
plt.plot(np.linspace(7,14,100),np.linspace(7,14,100),color='black',ls='--',alpha=0.75)
plt.show()

plt.scatter(par4_rand,ms4,s=50)
plt.xlabel('True Value (Metallicity)')
plt.ylabel('Recovered Value (Metallicity)')
plt.plot(np.linspace(-2,0.2,100),np.linspace(-2,0.2,100),color='black',ls='--')
plt.show()

plt.scatter(par2_rand,ms2,s=50)
plt.xlabel('True Value (Age)')
plt.ylabel('Recovered Value (Age)')
plt.plot(np.linspace(0.01,4,100),np.linspace(0.01,4,100),color='black',ls='--')
plt.show()

In [None]:
#Posterior Predictive Checks (PPC)
# A PPC is performed after we trained or neural posterior
x_o = observation
posterior.set_default_x(x_o)

# We draw theta samples from the posterior. This part is not in the scope of SBI
posterior_samples = posterior.sample((2000,))

# We use posterior theta samples to generate x data
x_pp = simulator(posterior_samples)

#Reduce dimensionality of >4000 wavelength-SED
x_pp_median = torch.median(x_pp,1,keepdim=True)[0]
x_o_median = torch.median(x_o)

# We verify if the observed data falls within the support of the generated data
_ = analysis.pairplot(samples=x_pp_median, points=x_o_median,
                      figsize=(6,6),
                      points_colors="red",
                     labels=['Median(Splined SED)'])

##################################################################################################################
#Simulation-Based Calibration
from torch import eye, ones, zeros
from torch.distributions import MultivariateNormal

from sbi.analysis import check_sbc,run_sbc, get_nltp, sbc_rank_plot
from sbi.inference import SNPE, SNPE_C, prepare_for_sbi, simulate_for_sbi
from sbi.simulators import linear_gaussian, diagonal_linear_gaussian

num_sbc_runs = 1000 # choose a number of sbc runs, should be ~100s or ideally 1000
# generate ground truth parameters and corresponding simulated observations for SBC.
thetas = prior.sample((num_sbc_runs,))
xs = simulator(thetas) 

# run SBC: for each inference we draw 1000 posterior samples.
#num_posterior_samples = 1000
num_posterior_samples = 100

ranks, dap_samples = run_sbc(thetas, xs, posterior, num_posterior_samples=num_posterior_samples)
check_stats = check_sbc(ranks, thetas, dap_samples, num_posterior_samples=num_posterior_samples)

print(f"kolmogorov-smirnov p-values \ncheck_stats['ks_pvals'] = {check_stats['ks_pvals'].numpy()}")
print(f"c2st accuracies \ncheck_stats['c2st_ranks'] = {check_stats['c2st_ranks'].numpy()}")
print(f"- c2st accuracies check_stats['c2st_dap'] = {check_stats['c2st_dap'].numpy()}")

f,ax=sbc_rank_plot(ranks=ranks,num_posterior_samples=num_posterior_samples,
              plot_type="hist",
              num_bins=None)

f, ax = sbc_rank_plot(ranks, 1000, plot_type = "cdf")

In [None]:
_ = analysis.pairplot(samples=x_pp[:,::18], 
                      points=x_o[::18],
                      labels=['wave1','wave2','wave3','wave4','wave5','wave6','wave7','wave8'], 
                      points_colors="red")

#plt.savefig(DATA_DIR + 'bestfit_ppc_2.pdf',dpi=300)