In [None]:
import arviz as az
import matplotlib.pylab as plt
import numpyro.distributions as dist

from makeCorner import *

az.rcParams["plot.max_subplots"] = 80

In [None]:
cdf_file = "./RUNS/CBC_O3_Power_Law.cdf"
data = az.from_netcdf(cdf_file)
samps = data.posterior.stack(draws=("chain", "draw"))

In [None]:
def massModel_variation_all_m1(m1, alpha_ref, delta_alpha, width_alpha, middle_alpha,
                               mu_m1, sig_m1, log_f_peak, log_high_f_peak, width_f_peak, middle_f_peak,
                               mMax, high_mMax, width_mMax, middle_mMax,
                               mMin, dmMax, high_dmMax, width_dm, middle_dm, dmMin, zs):
    
    alpha_new = sigmoid(alpha_ref, delta_alpha, width_alpha, middle_alpha, zs)
    p_m1_pl = (1.+alpha_new)*m1**(alpha_new)/(tmp_max**(1.+alpha_new) - tmp_min**(1.+alpha_new))

    p_m1_peak = jnp.exp(-(m1-mu_m1)**2/(2.*sig_m1**2))/jnp.sqrt(2.*np.pi*sig_m1**2)
    
    new_mMax = sigmoid_no_delta(mMax, high_mMax, width_mMax, middle_mMax, zs)
    new_dmMax = sigmoid_no_delta(dmMax, high_dmMax, width_dm, middle_dm, zs)

    # Compute low- and high-mass filters
    low_filter = jnp.exp(-(m1-mMin)**2/(2.*dmMin**2))
    low_filter = jnp.where(m1<mMin,low_filter,1.)
    high_filter = jnp.exp(-(m1-new_mMax)**2/(2.*new_dmMax**2))
    high_filter = jnp.where(m1>new_mMax,high_filter,1.)

    new_f_peak = sigmoid_no_delta(log_f_peak, log_high_f_peak, width_f_peak, middle_f_peak, zs)
    actual_f_peak = 10.**(new_f_peak)
    combined_p = jnp.array((actual_f_peak*p_m1_peak + (1. - actual_f_peak)*p_m1_pl)*low_filter*high_filter)
    return combined_p

def merger_rate_z(z, alpha, beta, zp):
    return (1+z)**alpha/(1+((1+z)/(1+zp))**(alpha+beta))

In [None]:
z_grid = np.linspace(0,10,500)
def R(m, samps):
    fig,ax = plt.subplots(figsize=(10,6))
    for i in range(samps.alpha_ref.size):
        p_z = massModel_variation_all_m1(m, samps.alpha_ref.values[i], samps.delta_alpha.values[i],
                                         samps.width_alpha.values[i], samps.middle_z_alpha.values[i],
                                         samps.mu_m1.values[i], samps.sig_m1.values[i], samps.log_f_peak.values[i],
                                         samps.log_high_f_peak.values[i], samps.width_f_peak.values[i], samps.middle_z_f_peak.values[i],
                                         samps.mMax.values[i], samps.high_mMax.values[i], samps.width_mMax.values[i], samps.middle_z_mMax.values[i],
                                         samps.mMin.values[i], 10.**samps.log_dmMax.values[i], 10.**samps.log_high_dmMax.values[i],
                                         samps.width_dm.values[i], samps.middle_z_dm.values[i], 10.**samps.log_dmMin.values[i], z_grid)
        p_z *= merger_rate_z(z_grid, samps.alpha_z.values[i], samps.beta_z.values[i], samps.zp.values[i])
        p_z /= merger_rate_z(0.2, samps.alpha_z.values[i], samps.beta_z.values[i], samps.zp.values[i])
        p_20 = massModel_variation_all_m1(20, samps.alpha_ref.values[i], samps.delta_alpha.values[i],
                                         samps.width_alpha.values[i], samps.middle_z_alpha.values[i],
                                         samps.mu_m1.values[i], samps.sig_m1.values[i], samps.log_f_peak.values[i],
                                         samps.log_high_f_peak.values[i], samps.width_f_peak.values[i], samps.middle_z_f_peak.values[i],
                                         samps.mMax.values[i], samps.high_mMax.values[i], samps.width_mMax.values[i], samps.middle_z_mMax.values[i],
                                         samps.mMin.values[i], 10.**samps.log_dmMax.values[i], 10.**samps.log_high_dmMax.values[i],
                                         samps.width_dm.values[i], samps.middle_z_dm.values[i], 10.**samps.log_dmMin.values[i], 0.2)
        p_z *= samps.R20.values[i]/p_20 # overall factor of evolving R
        ax.plot(z_grid,p_z,color='#1f78b4',lw=0.1,alpha=0.1, rasterized = True)
    
    ax.set_yscale('log')
    ax.set_ylim(1e-5,2e3)
    ax.set_xlim(0,10)
    ax.set_xlabel(r"$z$", fontsize=35)
    ax.set_ylabel(r"$\mathcal{R}$($z$, $m_1$ =" + f" {m}" + r" $M_{\odot}$)", fontsize=35)
    plt.xticks(fontsize=25)
    plt.yticks(fontsize=25)
    plt.grid(visible=True, which='major', axis='both',linestyle='--')
    plt.grid(visible=True, which='minor', axis='both',linewidth=0.5) 
    plt.show()

In [None]:
R(20, samps)

In [None]:
R(50, samps)

In [None]:
R(80, samps)