In [25]:
%reset

In [26]:
import sys
from pathlib import Path

# Current working dir  …/DGRB Scripts/Test - sbi for 1 SFG and 1 mAGN…
# parent[0] → …/DGRB Scripts
# parent[1] → …/home/users/ids29           ← where “DGRB/” lives
package_path = Path.cwd().parents[1] / "DGRB"   # /home/users/ids29/DGRB

sys.path.insert(0, str(package_path))           # make it import-able


In [27]:
import aegis
import numpy as np
import healpy as hp
import torch
import pickle as pk
from astropy import units as u
from astropy import constants as c
import matplotlib.pyplot as plt
from os import listdir
import os
from sbi.inference import SNLE, SNPE#, prepare_for_sbi, simulate_for_sbi
from sbi import utils as utils
from sbi import analysis as analysis
# from sbi.inference.base import infer
from getdist import plots, MCSamples
import pickle
from scipy.stats import norm
from scipy.integrate import quad, simpson
from joblib import Parallel, delayed
from sbi.neural_nets import posterior_nn

%matplotlib inline

In [28]:
grains=1000
num_simulations = 1000
num_workers = -1

In [29]:
parameter_range = [[], []]
abundance_luminosity_and_spectrum_list = []
source_class_list = []
parameter_names = []
energy_range = [1000, 100000] #MeV
energy_range_gen = [energy_range[0]*0.5, energy_range[1]*18]
max_radius = 8.5 + 20*2 #kpc
exposure = 2000*10*0.2 #cm^2 yr
flux_cut = 1e-9 #photons/cm^2/s
angular_cut = np.pi #10*u.deg.to('rad') #degrees
angular_cut_gen = np.pi #angular_cut*1.5
lat_cut = 0 #2*u.deg.to('rad') #degrees
lat_cut_gen = lat_cut*0.5

In [30]:
my_cosmology = 'Planck18'
z_range = [0, 14]
luminosity_range = 10.0**np.array([37, 50]) # Minimum value set by considering Andromeda distance using Fermi as benchmark and receiving 0.1 photon at detector side
my_AEGIS = aegis.aegis(abundance_luminosity_and_spectrum_list, source_class_list, parameter_range, energy_range, luminosity_range, max_radius, exposure, angular_cut, lat_cut, flux_cut, energy_range_gen=energy_range_gen, cosmology = my_cosmology, z_range = z_range, verbose = False)
my_AEGIS.angular_cut_gen, my_AEGIS.lat_cut_gen = angular_cut_gen, lat_cut_gen

In [31]:
Gamma_SFG = 2.2
gamma_energy_bounds = energy_range_gen  # in MeV
E_photon_MeV_SFG = ((-Gamma_SFG + 1) / (-Gamma_SFG + 2) *
                (gamma_energy_bounds[1]**(-Gamma_SFG + 2) - gamma_energy_bounds[0]**(-Gamma_SFG + 2)) /
                (gamma_energy_bounds[1]**(-Gamma_SFG + 1) - gamma_energy_bounds[0]**(-Gamma_SFG + 1))) # in MeV
E_photon_SFG = E_photon_MeV_SFG * 1.60218e-6  # erg

Gamma_mAGN = 2.2 # enforced by user to match SFG spectrum (actually should 2.25)
gamma_energy_bounds = energy_range_gen  # in MeV
E_photon_MeV_mAGN = ((-Gamma_mAGN + 1) / (-Gamma_mAGN + 2) *
                (gamma_energy_bounds[1]**(-Gamma_mAGN + 2) - gamma_energy_bounds[0]**(-Gamma_mAGN + 2)) /
                (gamma_energy_bounds[1]**(-Gamma_mAGN + 1) - gamma_energy_bounds[0]**(-Gamma_mAGN + 1))) # MeV
E_photon_mAGN = E_photon_MeV_mAGN * 1.60218e-6  # erg

res = int(1e4)
log_LIRs = np.linspace(-5, 25, res)
log_L5Gs = np.linspace(20, 55, res)

In [32]:
def ZL_SFG1(z, l, params):


    log_PhiStar = params[0]
    Phi_star = 10**log_PhiStar

    l_erg = l * E_photon_SFG # erg/s
    LFs = np.zeros_like(l)

    def Phi_IR(log_LIR): #log_LIR = log_10(L_IR / solar_luminosity) # unitless

        # from Table 8 in Gruppioni et al.
        # Phi_star = 10**(-2.08) # Mpc^{-3} dex^{-1}
        Lstar = 10**(9.46) # Solar luminosity
        alpha = 1.00
        sigma = 0.50

        LIR = 10**log_LIR # solar luminosity

        Phi_IR = Phi_star * (LIR / Lstar)**(1 - alpha) * np.exp(-1 / (2 * sigma**2) * (np.log10(1 + LIR / Lstar))**2) # from Gruppioni paper eqn (3)  	

        return Phi_IR

    def PDF_log_Lgamma_given_log_LIR(log_LIR, log_Lgamma): #log_LIR = log_10(L_IR / solar_luminosity) # unitless
        LIR_solar_luminosity = 10**log_LIR # Solar luminosity
        L_IR_erg_second = LIR_solar_luminosity * 3.826e33 # erg/s

        a = 1.09
        g = 40.8
        sigma_SF = 0.202 

        mean = g + a * np.log10(L_IR_erg_second / 1e45)
        std = sigma_SF

        return norm.pdf(log_Lgamma, loc=mean, scale=std)

    def integrand(PhiIR_of_logLIRs, log_LIRs, log_Lgamma):
        return PhiIR_of_logLIRs * PDF_log_Lgamma_given_log_LIR(log_LIRs, log_Lgamma)

    PhiIR_of_logLIRs = Phi_IR(log_LIRs)

    for i in range(LFs.shape[0]):
        for j in range(LFs.shape[1]):
            LFs[i,j] = simpson(integrand(PhiIR_of_logLIRs, log_LIRs, np.log10(l_erg[i,j])), x=log_LIRs)
    return 1e-9 / np.log(10) / l * LFs # LF has spatial units of Mpc^{-3}. We need to convert this to kpc^{-3}. Hence the factor of 1e-9


def spec_SFG1(energy, params):
    Gamma = 2.2
    return energy**(-Gamma)

In [33]:
def ZL_mAGN(z, l, params):

    log_phi1 = params[1]
    phi1 = 10**log_phi1

    l_erg = l * E_photon_mAGN # erg/s
    LFs = np.zeros_like(l)

    def Phi_5G(log_L5G, z): #log_L5G = log_10(L_5GHz / (erg/s)) # unitless
        #Output is in Mpc^{-3}

        L_5G = 10**log_L5G # erg/s
        radio_bandwidth = 4.87e9 # measured in Hz # width of radio band centered around blueshifted frequency of 5GHz 
        diff_L5G = L_5G / radio_bandwidth * 1e-7 # measured in W/Hz # Converted erg to Joule # luminosity per unit frequency

        # Values taken from Table 4 of Yuan 2018 paper. Second row.
        p1 = 2.085
        p2 = -4.602
        z_c = 0.893
        k1 = 1.744
        e1 = ( (1+z_c)**p1 + (1+z_c)**p2 ) / ( ((1+z_c)/(1+z))**p1 + ((1+z_c)/(1+z))**p2 )
        e2 = (1+z)**k1
        # phi1 = 10**(-3.749) # Mpc^{-3}
        L_star = 10**21.592 # W/Hz
        beta = 0.139
        gamma = 0.878

        # From Yuan 2018 paper equation 21
        # Note that this is dN/dV dlog(diff_5G). But this is also equal to dN/dV dlog(L_5G) because the radio bandwidth is fixed.
        Phi_5G = e1 * phi1 * ( (diff_L5G / (e2 * L_star))**beta + (diff_L5G / (e2 * L_star))**gamma )**-1

        return Phi_5G
    

    def PDF_log_Lgamma_given_log_L5G(log_L5G, log_Lgamma): #log_L5G = log_10(L_5GHz / (erg/s)) # unitless
        L_5GHz = 10**log_L5G # erg/s

        b = 0.78
        d = 40.78
        sigma_mAGN = 0.880

        mean = d + b * np.log10(L_5GHz / 1e40)
        std = sigma_mAGN

        return norm.pdf(log_Lgamma, loc=mean, scale=std)
    

    def integrand(log_L5G, z, log_Lgamma):
        return Phi_5G(log_L5G, z) * PDF_log_Lgamma_given_log_L5G(log_L5G, log_Lgamma)
    


    for i in range(LFs.shape[0]):
        for j in range(LFs.shape[1]):
            LFs[i,j] = simpson(integrand(log_L5Gs, z[i,j], np.log10(l_erg[i,j])), x=log_L5Gs)


    return 1e-9 / np.log(10) / l * LFs # LF has spatial units of Mpc^{-3}. We need to convert this to kpc^{-3}. Hence the factor of 1e-9



def spec_mAGN(energy, params):
    Gamma = 2.2 #modified sepctrum to match the SFG spectrum
    return energy**(-Gamma)

In [34]:
als_SFG1 = [ZL_SFG1, spec_SFG1]
als_mAGN = [ZL_mAGN, spec_mAGN]
my_AEGIS.abun_lum_spec = [als_SFG1, als_mAGN]
my_AEGIS.source_class_list = ['extragalactic_isotropic_faint_single_spectrum', 'extragalactic_isotropic_faint_single_spectrum']

In [42]:
# a simple simulator with the total number of photons as the summary statistic
def simulator(params):

    input_params = params.numpy()

    source_info = my_AEGIS.create_sources(input_params, grains=grains, epsilon=1e-2)
    photon_info = my_AEGIS.generate_photons_from_sources(input_params, source_info, grains=grains) 
    obs_info = {'psf_fits_path': '../../DGRB/FERMI_files/psf_P8R3_ULTRACLEANVETO_V2_PSF.fits', 'edisp_fits_path': '../../DGRB/FERMI_files/edisp_P8R3_ULTRACLEANVETO_V2_PSF.fits', 'event_type': 'PSF3', 'exposure_map': None}
    obs_photon_info = my_AEGIS.mock_observe(photon_info, obs_info)
    
    return obs_photon_info

In [36]:
# def manual_simulate_for_sbi(proposal, num_simulations=1000, num_workers=32):
#     """
#     Simulates the model in parallel using joblib.
#     Each simulation call samples a parameter from the proposal and passes the index to the simulator.
#     """
#     def run_simulation(i):
#         if i % 10 == 0:
#             print(f"i= {i}")
#         # Sample a parameter from the proposal (sbi.utils.BoxUniform has a .sample() method)
#         theta_i = proposal.sample()
#         photon_info = simulator(theta_i)
#         return theta_i, photon_info

#     # Run simulations in parallel using joblib.
#     results = Parallel(n_jobs=num_workers, timeout=None)(delayed(run_simulation)(i) for i in range(num_simulations))
#     theta_list, photon_info_list = zip(*results)

#     theta_tensor = torch.stack(theta_list, dim=0).to(torch.float32)
    
    
#     return theta_tensor, photon_info_list

In [44]:
class SNPE_C_Custom(SNPE):
    def train(self, *a, optimizer_class=None, optimizer_kwargs=None, **kw):
        if optimizer_class is None:
            return super().train(*a, **kw)
        orig = self._build_neural_net
        def builder(*aa, **kk):
            model = orig(*aa, **kk)
            model.optimizer = lambda ps: optimizer_class(ps, **optimizer_kwargs)
            return model
        self._build_neural_net = builder
        try:  return super().train(*a, **kw)
        finally: self._build_neural_net = orig

# --- 1) load round-1 checkpoint ---------------------------------------
ckpt = torch.load("snpe_round1_state.pt",
                  map_location="cpu")

keep_mask = ckpt["keep_mask"]          # kept summary dims
x_mean    = ckpt["x_mean"]              # for later use if needed
x_std     = ckpt["x_std"]
x_test    = ckpt["x_test"]              # already masked & z-scored


parameter_range = torch.load('parameter_range_1SFG_1mAGN.pt')

prior = utils.BoxUniform(low=parameter_range[0], high=parameter_range[1])

net_builder = posterior_nn(
    model="nsf", hidden_features=128, num_transforms=8,
    dropout_probability=0.2, use_combined_loss=True,
    z_score_x="none", z_score_theta="none",
)


# fresh SNPE object
inf = SNPE_C_Custom(prior, net_builder)

# ---------- 3) build network once then load weights -------------------
theta_dummy = ckpt["theta"][:2]
x_dummy     = ckpt["x"][:2]
inf.append_simulations(theta_dummy, x_dummy)
_ = inf.train(
        training_batch_size = 1,     # ≥ 1
        validation_fraction = 0.5,   # anything > 0
        max_num_epochs      = 0,     # ← still zero optimisation steps
        show_train_summary  = False,
)
inf._neural_net.load_state_dict(ckpt["de_state"])  # warm-start weights

# ---------- 4) posterior conditional on x₀ and draw θ_new -------------
posterior = inf.build_posterior(inf._neural_net, sample_with="mcmc")
posterior.set_default_x(x_test)                    # x₀

num_new   = num_simulations
theta_new = posterior.sample((num_new,))           # [num_simulations, 2]

# ---------- 5) simulate photons in parallel ---------------------------
# def simulate_one(theta_vec):
#     return theta_vec, simulator(theta_vec)         # your simulator fn

# res = Parallel(n_jobs=num_workers)(
#     delayed(simulate_one)(th) for th in theta_new)

def simulate_one(i, theta_vec):
    if (i + 1) % 10 == 0:
        print(f"Processing iteration {i + 1}")
    return theta_vec, simulator(theta_vec)  # your simulator fn

res = Parallel(n_jobs=num_workers)(
    delayed(simulate_one)(i, th) for i, th in enumerate(theta_new)
)

theta_new_tensor, photon_info_new = zip(*res)
theta_new_tensor = torch.stack(theta_new_tensor)



# theta_new_tensor, photon_info_new  = manual_simulate_for_sbi(posterior,
#                                    num_simulations=num_simulations,
#                                    num_workers=num_workers)

  ckpt = torch.load("snpe_round1_state.pt",
  parameter_range = torch.load('parameter_range_1SFG_1mAGN.pt')


 Training neural network. Epochs trained: 1

Generating 20 MCMC inits via resample strategy:   0%|          | 0/20 [00:00<?, ?it/s]

Running vectorized MCMC with 20 chains:   0%|          | 0/6000 [00:00<?, ?it/s]

Processing iteration 100
Processing iteration 60
Processing iteration 40
Processing iteration 110
Processing iteration 90
Processing iteration 80
Processing iteration 170
Processing iteration 20
Processing iteration 140
Processing iteration 160
Processing iteration 70
Processing iteration 130
Processing iteration 120
Processing iteration 10
Processing iteration 50
Processing iteration 150
Processing iteration 190
Processing iteration 180
Processing iteration 30
Processing iteration 200
Processing iteration 210
Processing iteration 220
Processing iteration 230
Processing iteration 240
Processing iteration 250
Processing iteration 260
Processing iteration 270
Processing iteration 280
Processing iteration 290
Processing iteration 300
Processing iteration 310
Processing iteration 320
Processing iteration 330
Processing iteration 340
Processing iteration 350
Processing iteration 360
Processing iteration 370
Processing iteration 380
Processing iteration 390
Processing iteration 400
Processin



Processing iteration 830
Processing iteration 840
Processing iteration 850
Processing iteration 860
Processing iteration 870
Processing iteration 880
Processing iteration 890
Processing iteration 900
Processing iteration 910
Processing iteration 920
Processing iteration 930
Processing iteration 940
Processing iteration 950
Processing iteration 960
Processing iteration 970
Processing iteration 980
Processing iteration 990
Processing iteration 1000


In [45]:
# ---------- 6) save round-2 raw data ----------------------------------
torch.save(theta_new_tensor,
           "case1_theta_round2.pt")
with open("case1_photon_info_round2.pkl", "wb") as f:
    pickle.dump(photon_info_new, f)

print("Round-2 simulations complete and saved.")

Round-2 simulations complete and saved.
