In [None]:
import numpy as np
import scipy
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import corner
import os, pickle

In [None]:
import la_forge.core as co
from la_forge.rednoise import gorilla_bf

In [None]:
from nautilus import Sampler, Prior
from scipy.stats import uniform
import json, cloudpickle, glob
import h5py
import tqdm
from h5pulsar import FilePulsar
from enterprise_extensions.blocks import common_red_noise_block
from enterprise.signals import signal_base
from dr3_noise.models import model_singlepsr_noise
from IPTA_DR2_analysis.model_blocks import adv_noise_block, full_Tspan, lite_Tspan

# first check the chains/posteriors

In [None]:
sampler_dict = {}
unfinished_sampler_dict = {}
pta_dict = {}

In [None]:
datastr = 'lite_unfiltered'
Np = 53
#litec = False
edr2 = False
gwb_Nfreqs = 13
project_path = '/home/bbl29/IPTA_DR2_analysis'
if datastr == 'full':
    outdir = '/vast/palmer/home.grace/bbl29/project/IPTA_DR2_analysis/dr2full/factlike'
else:
    outdir = f'/vast/palmer/home.grace/bbl29/project/IPTA_DR2_analysis/dr2{datastr}/CRN{gwb_Nfreqs}_FL'
figsave_dir = f'{project_path}/figs/dr2{datastr}/CRN{gwb_Nfreqs}_FL'
noisepath=f'{project_path}/noisedicts/dr2{datastr}_advnoise.json'
with open(noisepath,'r') as f:
    noise_params = json.load(f)
if litec and datastr == 'full':
    figsave_dir = f'{project_path}/figs/dr2litec/CRN{gwb_Nfreqs}_FL'
if edr2 and datastr == 'full':
    figsave_dir = f'{project_path}/figs/edr2/CRN{gwb_Nfreqs}_FL'
if datastr == 'lite_unfiltered' and Np == 43:
    figsave_dir = f'{project_path}/figs/dr2{datastr}/CRN{gwb_Nfreqs}_FL_43'
elif datastr == 'lite_unfiltered' and Np == 21:
    figsave_dir = f'{project_path}/figs/dr2{datastr}/CRN{gwb_Nfreqs}_FL_21'
elif datastr == 'lite_unfiltered' and Np == 53:
    figsave_dir = f'{project_path}/figs/dr2{datastr}/CRN{gwb_Nfreqs}_FL_21'

In [None]:
if datastr == 'full' or datastr == 'lite_unfiltered':
    Tspan = full_Tspan
elif datastr == 'lite':
    Tspan = lite_Tspan

In [None]:
if edr2:
    psrdir = f'/vast/palmer/home.grace/bbl29/IPTA_DR2_analysis/data/edr2_ePSRs'
elif Np is not None:
    psrdir = f'/vast/palmer/home.grace/bbl29/IPTA_DR2_analysis/data/{datastr}_{Np}_ePSRs'
else:
    psrdir = f'/vast/palmer/home.grace/bbl29/IPTA_DR2_analysis/data/{datastr}_ePSRs'
psrpaths = glob.glob(f'{psrdir}/*')
all_psrnames = np.sort(np.unique([p.split('/')[-1].split('_')[0][:-4] for p in psrpaths]))
psrnames = np.sort([p.split('/')[-1].split('_')[0] for p in glob.glob(f'{outdir}/*_converged.txt')])
missing_psrs = [p for p in all_psrnames if not p in psrnames]
print(f'Missing {len(missing_psrs)}/{len(all_psrnames)} PSRs: {missing_psrs}')

In [None]:
def get_prior_distr(param):
    pline = str(param)
    pmin = float(pline[pline.index('pmin')+5:pline.index(', pmax')])
    pmax = float(pline[pline.index('pmax')+5:-1])
    prior_dist = uniform(loc=pmin, scale=pmax-pmin)
    return prior_dist

In [None]:
! ls /vast/palmer/home.grace/bbl29/IPTA_DR2_analysis/data/edr2_ePSRs

In [None]:
for psrname in all_psrnames:
    if psrname in sampler_dict:
        continue
    filename = f'{outdir}/{psrname}_sampler.hdf5'
    try:
        psrpath = f'{psrdir}/{psrname}.hdf5'
        with open(psrpath, 'rb') as f:
            psr = FilePulsar(f)
    except:
        psrpath = f'{psrdir}/{psrname}.pkl'
        with open(psrpath, 'rb') as f:
            psr = pickle.load(f)

    # load pickle
    crn = common_red_noise_block(psd='powerlaw', prior='log-uniform', Tspan=Tspan,
                                 components=13, gamma_val=13/3,
                                 logmin=-18, logmax=-12, orf=None, name='crn')
    noise = adv_noise_block(psr, full_pta_analysis=True, dataset=datastr, psr_model=True,
                            tm_marg=True, tm_svd=True)
    signals = crn + noise
    pta = signal_base.PTA([signals(psr)])
    pta.set_default_params(noise_params)
    pta_dict[psrname] = pta

    # make prior
    prior = Prior()
    for i in range(len(pta.params)):
        prior.add_parameter(pta.param_names[i], dist=get_prior_distr(pta.params[i]))
    print(f'{psrname}: dim = {prior.dimensionality()}')

    # get sampler
    sampler = Sampler(prior, pta.get_lnlikelihood, filepath=filename, pass_dict=False)
    if sampler.explored:
        sampler_dict[psrname] = sampler
    else:
        print(f'{psrname} still unfinished')
        unfinished_sampler_dict[psrname] = sampler

In [None]:
def corner_labels(params):
    labels = []
    for p in params:
        if 'dm_gp' in p:
            noise = 'DM'
        elif 'sw_gp' in p:
            noise = 'SW'
        elif 'chrom_gp' in p:
            noise = 'Chr'
        elif 'red_noise' in p:
            noise = 'RN'
        elif 'gw' in p:
            noise = 'CRN'
        elif 'crn' in p:
            noise = 'CRN'
        elif 'exp' in p:
            noise = 'd'
        else:
            noise = '_'.join(p.split('_')[1:-1])
        if 'gamma' in p:
            labels.append(fr'$\gamma_{{\rm{{{noise}}}}}$')
        elif 'log10_A' in p:
            labels.append(fr'$\log_{{10}}A_{{\rm{{{noise}}}}}$')
        elif 'log10_sigma_ne' in p:
            labels.append(fr'$\log_{{10}}\sigma_{{n_e}}$')
        elif 'log10_tau' in p:
            labels.append(fr'$\log_{{10}}\tau_{{\rm{{{noise}}}}}$')
        elif 't0' in p:
            labels.append(fr'${{t_{{0,}}}}_{{\rm{{{noise}}}}}$')
        else:
            labels.append(p)
    return labels

In [None]:
# plot in progress run
for psrname in unfinished_sampler_dict:
    corepath = f'/vast/palmer/home.grace/bbl29/project/IPTA_DR2_analysis/dr2{datastr}/advnoise/{psrname}/core.h5'
    c_psr = co.Core(corepath=corepath, burn=0)
    sampler = unfinished_sampler_dict[psrname]
    points, log_w, log_l = sampler.posterior()
    ndim = points.shape[1]
    print(sampler.posterior()[0].shape)
    print('log Z: {:.2f}'.format(sampler.log_z))
    fig, axes = plt.subplots(ndim, ndim, figsize=(3*ndim, 3*ndim))
    fig = corner.corner(points, bins=20, #weights=np.exp(log_w),
                        labels=corner_labels(sampler.prior.keys),
                        plot_datapoints=True, plot_density=True, label_kwargs={'fontsize':20},
                        fill_contours=False, no_fill_contours=True, levels=(0.68, 0.95),
                        hist_kwargs={'density':True},
                        fig=fig)#, range=np.ones(ndim) * 0.999)
    axes = np.array(fig.axes).reshape(ndim,ndim)
    for i in range(ndim):
        pline = str(pta_dict[psrname].params[i])
        pmin = float(pline[pline.index('pmin')+5:pline.index(', pmax')])
        pmax = float(pline[pline.index('pmax')+5:-1])
        x = np.linspace(pmin,pmax,300)
        axes[i,i].plot(x, pta_dict[psrname].params[i].get_pdf(x), color='C2')
        # add PTMCMC result
        if i < ndim-1:
            axes[i,i].hist(c_psr(pta_dict[psrname].param_names[i]), color='C1',
                           density=True, histtype='step', bins=20)
    fig.suptitle(psrname, fontsize=50)
    break

In [None]:
# bayes fac with weights
def bayes_fac(array, weights=None, max=-4, min=-10, nbins=None):
    prior = 1/(max-min)
    if nbins is None:
        nbins=int(max-min)
    bins = np.linspace(min, max, nbins+1)
    hist, _ = np.histogram(array, bins=bins, density=True, weights=weights)

    if hist[0] == 0:
        return np.nan
    else:
        return prior/hist[0]

In [None]:
def Flike(sampler_dict, psrnames, bin_num=100):
    bins = np.linspace(-18,-12,bin_num)
    hist = np.zeros((len(psrnames),len(bins)-1))
    for jj,p in enumerate(psrnames):
        # histogram getting chain of psr[p][:, gw_log10_A] param
        idx = sampler_dict[p].prior.keys.index('crn_log10_A')
        points, log_w, _ = sampler_dict[p].posterior()
        hist[jj,:] = np.histogram(points[:,idx], bins=bins, density=True, weights=np.exp(log_w))[0] + 1e-20
    finalhist = np.prod(hist,axis=0)
    finalhist_dist = scipy.stats.rv_histogram([finalhist,bins])
    return finalhist_dist#, hist

In [None]:
# get the CRN from PTMCMCSampler
corepath = '/vast/palmer/home.grace/bbl29/project/IPTA_DR2_analysis/'
if litec and datastr == 'full':
    corepath += f'dr2litec/CRN{gwb_Nfreqs}_g4p3_advnoise/core.h5'
elif Np is not None:
    corepath += f'dr2{datastr}_{Np}/CRN{gwb_Nfreqs}_g4p3_advnoise/core.h5'
else:
    corepath += f'dr2{datastr}/CRN{gwb_Nfreqs}_g4p3_advnoise/core.h5'
if os.path.isfile(corepath):
    c = co.Core(corepath=corepath, burn=0)
else:
    print('No file for full-PTA search')

In [None]:
fig, ax = plt.subplots(figsize=(6,3),dpi=300)
# plot hists
lines = []
ct = 0
i = 2
FL_psrnames = all_psrnames#np.unique([p.split('_')[0] for p in c.params if 'J' in p])
for psrname in FL_psrnames:
    s = sampler_dict[psrname]
    points, log_w, _ = s.posterior()
    idx = s.prior.keys.index('crn_log10_A')
    BF = bayes_fac(points[:,idx], weights=np.exp(log_w), max=-12, min=-18)
    if BF > 3:
        ax.hist(points[:,idx], weights=np.exp(log_w), alpha=1, bins=60, color=f'C{i}',
                histtype='step', density=True, range=(-18,-12), label=psrname)
        i += 1
    else:
        ax.hist(points[:,idx], weights=np.exp(log_w), alpha=0.1, bins=60,
                histtype='step', color='k', density=True, range=(-18,-12))
# compute fact like stuff
log10A_dist = Flike(sampler_dict, FL_psrnames, bin_num=61)
ax = fig.axes[0]
a = np.linspace(-18,-12,10000)
ax.plot(a,log10A_dist.pdf(a),'-C0',lw=1,label='Factorized likelihood analysis')
# Compute Savage-Dickey Bayes factor
BF = (1/6)/np.mean(log10A_dist.pdf(a)[a < -16])
# overplot CRN from PTMCMCSampler
if os.path.isfile(corepath):
    ax.hist(c('crn_log10_A'), bins=60, density=True, histtype='stepfilled',
            color='C1', range=(-18,-12), alpha=0.5, label='Full PTA analysis')
# labels, etc
ax.set_xlabel(r'$\log_{10}A_{\rm{CRN}}$')
ax.set_xlim([-18,-12])
ylim = ax.get_ylim()
ax.set_ylabel('PDF')
ax.semilogy()
ax.set_ylim([1e-11,ylim[1]+5])
ax.legend(fontsize='small')
if litec and datastr == 'full':
    title = f'IPTA DR2litec Fact Like analysis: '
else:
    title = f'IPTA DR2{datastr} Fact Like analysis: '
title += fr'$\log_{{10}}\mathcal{{B}}^{{CRN+IRN}}_{{IRN}}$ = {np.log10(BF):1.2f}'
ax.set_title(title)
fig.tight_layout()
fig.savefig(f'{figsave_dir}/FL.png', dpi=300, bbox_inches='tight')

## Dropout factors

Taylor et al 2022 defines an FL-derived dropout factor, which compares the probability of measuring the recovered CP with pulsar $p$ included vs the probability of measuring it without pulsar $p$, as follows,
\begin{align*}
    \text{Dropout Factor} = \frac{\mathcal{Z}_{p,0}}{\mathcal{Z}_{p,1}}\left\langle\frac{p(A_{\rm{CP}}|d,\mathcal{H}_1)}{p(A_{\rm{CP}})}\right\rangle_p.
\end{align*}
Here, $\mathcal{H}_0$ is the model with CP in all pulsars, $\mathcal{H}_1$ is the model with CP in all pulsars except pulsar $p$. $\mathcal{Z}_{p,0}$ is the Bayesian evidence of the noise + CRN in pulsar $p$, while $\mathcal{Z}_{p,1}$ is the evidence of just the noise in pulsar $p$. $p(A_{\rm{CP}}|d,\mathcal{H}_1)$ is the posterior on $A_{\rm{CP}}$ with pulsar $p$ not included, $p(A_{\rm{CP}})$ is the prior, and $\langle\rangle_p$ indicates the enclosed ratio is computed and averaged over all posterior samples from pulsar $p$. As such, the relevant quantities to compute are:
- $\mathcal{Z}_{p,0}/\mathcal{Z}_{p,1}$ is just the SD ratio for the FL CRN in pulsar $p$.
- $p(A_{\rm{CP}}|d,\mathcal{H}_1)$ can be computed as a histogram using the FL with all pulsars except the dropped pulsar.
- $\langle\rangle_p$ weighted average using distribution of weighted samples from pulsar $p$.

In [None]:
def compute_DF(psr_drop, sampler_dict, psrnames):
    # first compute SD BF
    s = sampler_dict[psr_drop]
    points, log_w, _ = s.posterior()
    idx = s.prior.keys.index('crn_log10_A')
    BF = bayes_fac(points[:,idx], weights=np.exp(log_w), max=-12, min=-18)
    # next, compute the FL with N-1 pulsars
    psrnames_minus_dropped = np.unique([p for p in psrnames if not p == psr_drop])
    log10A_dist_H1 = Flike(sampler_dict, psrnames_minus_dropped, bin_num=61)
    # compute the prior
    log10A_prior = 1/6
    # evaluate at all points
    brac_p = np.average(log10A_dist_H1.pdf(points[:,idx])/log10A_prior, weights=np.exp(log_w))
    return BF*brac_p

In [None]:
FL_psrnames

In [None]:
DFs = {}
#FL_psrnames = np.unique([p.split('_')[0] for p in c.params if 'J' in p])
for i, psr_drop in enumerate(FL_psrnames):
    DFs[psr_drop] = compute_DF(psr_drop, sampler_dict, FL_psrnames)
DFs = {k: v for k, v in sorted(DFs.items(), key=lambda item: item[1], reverse=True)}

In [None]:
DFs

In [None]:
if datastr == 'full' and not litec:
    fig, ax = plt.subplots(figsize=(20,3),dpi=300)
else:
    fig, ax = plt.subplots(figsize=(10,3),dpi=300)
ax.axhline(1, alpha=0.5, lw=1, color='k')
for i, psr_drop in enumerate(DFs):
    ax.plot([i], [DFs[psr_drop]], 'oC0')
ax.set_xticks(np.arange(len(FL_psrnames)))
ax.set_xticklabels(list(DFs.keys()), rotation=90, fontsize='large')
ax.set_ylabel('CRN Dropout Factor', fontsize='x-large')
if datastr == 'full' and litec:
    ax.set_title(f'IPTA DR2litec CRN FL Dropout analysis', fontsize='x-large')
else:
    ax.set_title(f'IPTA DR2{datastr} CRN FL Dropout analysis', fontsize='x-large')
ax.semilogy()
ax.grid(which='both',lw=0.3)
fig.savefig(f'{figsave_dir}/dropout.png')

# Sanity check

Let's compare the individual noise params to those from the full PTA analysis

In [None]:
for psrname in sampler_dict:
    print(psrname)
    sampler = sampler_dict[psrname]
    if psrname == 'J1713+0747':
        ranges = [(0,7),(-20,-11),(0,7),(-20,-11),(-6.3,-5),(np.log10(5),np.log10(500)),(54742,54768),
                  (0,7),(-20,-11),(-18,-12)]
    elif (f'{psrname}_chrom_gp_gamma' in sampler.prior.keys and
          f'{psrname}_sw_gp_log10_sigma_ne' in sampler.prior.keys):
        ranges = [(0,7),(-20,-11),(0,7),(-20,-11),(0,7),(-20,-11),(-4,2),(-18,-12)]
    elif f'{psrname}_chrom_gp_gamma' in sampler.prior.keys or psrname == 'J1012+5307':
        ranges = [(0,7),(-20,-11),(0,7),(-20,-11),(0,7),(-20,-11),(-18,-12)]
    elif f'{psrname}_sw_gp_log10_sigma_ne' in sampler.prior.keys:
        ranges = [(0,7),(-20,-11),(0,7),(-20,-11),(-4,2),(-18,-12)]
    else:
        ranges = [(0,7),(-20,-11),(0,7),(-20,-11),(-18,-12)]
    points, log_w, log_l = sampler.posterior()
    mask = np.exp(log_w) > 1e-6
    ndim = points.shape[1]
    fig, axes = plt.subplots(ndim, ndim, figsize=(3*ndim, 3*ndim))
    lines = []
    fig = corner.corner(points[mask], bins=20, weights=np.exp(log_w)[mask], range=ranges,
                        labels=corner_labels(sampler.prior.keys),
                        plot_datapoints=True, plot_density=True, label_kwargs={'fontsize':25},
                        fill_contours=False, no_fill_contours=True, levels=(0.68, 0.95),
                        hist_kwargs={'density':True}, fig=fig)#, range=np.ones(ndim) * 0.999)
    lines.append(mlines.Line2D([],[],color='k',label='Fact like (NS)'))
    axes = np.array(fig.axes).reshape(ndim,ndim)
    for i in range(ndim):
        pline = str(pta_dict[psrname].params[i])
        pmin = float(pline[pline.index('pmin')+5:pline.index(', pmax')])
        pmax = float(pline[pline.index('pmax')+5:-1])
        x = np.linspace(pmin,pmax,300)
        axes[i,i].plot(x, pta_dict[psrname].params[i].get_pdf(x), color='C2')
    # add PTMCMC result
    for i in range(ndim):
        bins=20
        if i == ndim-1:
            bins=60
        axes[i,i].hist(c(pta_dict[psrname].param_names[i].replace('gw','crn')),
                       color='C1', density=True, histtype='step', bins=bins, range=ranges[i])
        for j in range(ndim):
            if i > j:
                ax = axes[i,j]
                px = pta_dict[psrname].param_names[j].replace('gw','crn')
                py = pta_dict[psrname].param_names[i].replace('gw','crn')
                x = c(px)
                y = c(py)
                corner.hist2d(x, y, ax=ax, bins=20, levels=(0.68, 0.95), color='C1',
                              plot_datapoints=False, fill_contours=False,
                              no_fill_contours=True, plot_density=False,
                              range=[ranges[j],ranges[i]])
                if j > 0:
                    ax.set_yticks([])
                if i < ndim - 1:
                    ax.set_xticks([])
    lines.append(mlines.Line2D([],[],color='C1',label='Full PTA (MCMC)'))
    fig.legend(handles=lines, fontsize=30)
    fig.suptitle(psrname, fontsize=50)
    fig.savefig(f'{figsave_dir}/{psrname}.png', dpi=300, bbox_inches='tight')
    plt.show()

In [None]:
points.shape

In [None]:
corner_labels(sampler.prior.keys)

In [None]:
ranges

In [None]:
plt.close('all')