In [None]:
import arviz as az
import matplotlib.pylab as plt
import numpyro.distributions as dist

from makeCorner import *

az.rcParams["plot.max_subplots"] = 80

In [None]:
cdf_file = "./RUNS/CBC_O3_Power_Law.cdf"
data = az.from_netcdf(cdf_file)
samps = data.posterior.stack(draws=("chain", "draw"))

In [None]:
def massModel_variation_all_m1(m1, alpha_ref, delta_alpha, width_alpha, middle_alpha,
                               mu_m1, sig_m1, log_f_peak, log_high_f_peak, width_f_peak, middle_f_peak,
                               mMax, high_mMax, width_mMax, middle_mMax,
                               mMin, dmMax, high_dmMax, width_dm, middle_dm, dmMin, zs):
    
    alpha_new = sigmoid(alpha_ref, delta_alpha, width_alpha, middle_alpha, zs)
    p_m1_pl = (1.+alpha_new)*m1**(alpha_new)/(tmp_max**(1.+alpha_new) - tmp_min**(1.+alpha_new))

    p_m1_peak = jnp.exp(-(m1-mu_m1)**2/(2.*sig_m1**2))/jnp.sqrt(2.*np.pi*sig_m1**2)
    
    new_mMax = sigmoid_no_delta(mMax, high_mMax, width_mMax, middle_mMax, zs)
    new_dmMax = sigmoid_no_delta(dmMax, high_dmMax, width_dm, middle_dm, zs)

    # Compute low- and high-mass filters
    low_filter = jnp.exp(-(m1-mMin)**2/(2.*dmMin**2))
    low_filter = jnp.where(m1<mMin,low_filter,1.)
    high_filter = jnp.exp(-(m1-new_mMax)**2/(2.*new_dmMax**2))
    high_filter = jnp.where(m1>new_mMax,high_filter,1.)

    new_f_peak = sigmoid_no_delta(log_f_peak, log_high_f_peak, width_f_peak, middle_f_peak, zs)
    actual_f_peak = 10.**(new_f_peak)
    combined_p = jnp.array((actual_f_peak*p_m1_peak + (1. - actual_f_peak)*p_m1_pl)*low_filter*high_filter)
    return combined_p

In [None]:
def merger_rate_z(z, alpha, beta, zp):
    return (1+z)**alpha/(1+((1+z)/(1+zp))**(alpha+beta))

In [None]:
fig,ax = plt.subplots(figsize=(10,6))
z_grid = np.linspace(0,10,500)
for i in range(stacked_samples_no_GWB.alpha_ref.size):
    p_z = massModel_variation_all_m1(20, samps_no_gwb.alpha_ref.values[i], samps_no_gwb.delta_alpha.values[i],
                                     samps_no_gwb.width_alpha.values[i], samps_no_gwb.middle_z_alpha.values[i],
                                     samps_no_gwb.mu_m1.values[i], samps_no_gwb.sig_m1.values[i], samps_no_gwb.log_f_peak.values[i],
                                     samps_no_gwb.log_high_f_peak.values[i], samps_no_gwb.width_f_peak.values[i], samps_no_gwb.middle_z_f_peak.values[i],
                                     samps_no_gwb.mMax.values[i], samps_no_gwb.high_mMax.values[i], samps_no_gwb.width_mMax.values[i], samps_no_gwb.middle_z_mMax.values[i],
                                     samps_no_gwb.mMin.values[i], 10.**samps_no_gwb.log_dmMax.values[i], 10.**samps_no_gwb.log_high_dmMax.values[i],
                                     samps_no_gwb.width_dm.values[i], samps_no_gwb.middle_z_dm.values[i], 10.**samps_no_gwb.log_dmMin.values[i], z_grid)
    p_z *= merger_rate_z(z_grid, stacked_samples_no_GWB.alpha_z.values[i], stacked_samples_no_GWB.beta_z.values[i], stacked_samples_no_GWB.zp.values[i])

    p_z *= 1/(1+z_grid) # Factor to go from source frame to detector frame
    p_z *= 4.*np.pi*Planck15.differential_comoving_volume(z_grid).to(u.Gpc**3/u.sr).value # to go from per Volume to per z
    p_z /= np.trapz(p_z,z_grid) 
    ax.plot(z_grid,p_z,color='#1f78b4',lw=0.1,alpha=0.1, rasterized = True)

ax.set_yscale('log')
ax.set_xlim(0,10)
ax.set_ylim(1e-5,5)
ax.set_xlabel(r"$z$", fontsize=35)
ax.set_ylabel(r"$p(z)$", fontsize=35)
plt.xticks(fontsize=25)
plt.yticks(fontsize=25)
plt.savefig("p(z)_power_law.pdf", dpi=400, bbox_inches='tight')
plt.show()