In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%matplotlib inline
%config IPython.matplotlib.backend = "retina"
from matplotlib import rcParams
rcParams["savefig.dpi"] = 300
rcParams["figure.dpi"] = 300

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import json, os, pickle, glob, logging, copy
logger = logging.getLogger(__name__)
import corner
#import h5pulsar
import astropy.units as u
import scipy.stats as sps

In [None]:
import la_forge.diagnostics as dg
import la_forge.core as co
from la_forge.utils import bayes_fac
from la_forge.rednoise import gorilla_bf

In [None]:
#from enterprise.signals import selections
#from enterprise_extensions.sampler import group_from_params
#from targeted_cws_ng15.models import cw_model_2
from targeted_cws_ng15.empirical_distr_new import (make_empirical_distributions,
                                                   EmpiricalDistribution1D,
                                                   EmpiricalDistribution2D)
from ipta_gwb_analysis import diagnostics
from enterprise.signals import parameter
from QuickCW.PulsarDistPriors import DMDistParameter#, PXDistParameter
from QuickCW.PulsarDistPriors import PXDistPrior, DMDistPrior
from targeted_cws_ng15.Dists_Parameters import PXDistParameter
#from DR3_noise_modeling import diagnostics
#from DR3_noise_modeling.empirical_distr import make_empirical_distributions

In [None]:
def get_prior_distr(core, param):
    idx = core.params.index(param)
    pline = core.priors[idx]
    prior_type = pline[pline.index(':')+1:pline.index('(')]
    # setup x-axis for the plot
    if prior_type == 'Uniform':
        pmin = float(pline[pline.index('pmin')+5:pline.index(', pmax')])
        pmax = float(pline[pline.index('pmax')+5:-1])
        x = np.linspace(pmin, pmax, 300)
        y = parameter.UniformPrior(x, pmin, pmax)
        #prior_dist = sps.uniform(loc=pmin, scale=pmax-pmin)
        #y = prior_dist.pdf(x)
    elif prior_type == 'Normal':
        mu = float(pline[pline.index('mu')+3:pline.index(', sigma')])
        sigma = float(pline[pline.index('sigma')+6:-1])
        #prior_dist = sps.norm(loc=mu, scale=sigma)
        pmin = np.min([mu-3*sigma, np.min(core.chain[core.burn:,idx])])
        pmax = np.max([mu+3*sigma, np.max(core.chain[core.burn:,idx])])
        x = np.linspace(pmin, pmax, 300)
        y = parameter.NormalPrior(x, mu, sigma)
        #y = prior_dist.pdf(x)
    elif prior_type == 'LinearExp':
        pmin = float(pline[pline.index('pmin')+5:pline.index(', pmax')])
        pmax = float(pline[pline.index('pmax')+5:-1])
        x = np.linspace(pmin, pmax, 300)
        y = parameter.LinearExpPrior(x, pmin, pmax)
    elif prior_type == 'PXDist':
        dist = float(pline[pline.index('dist=')+5:pline.index(', err')])
        err = float(pline[pline.index('err')+4:-1])
        pmin = np.min([dist-3*err, np.min(core.chain[core.burn:,idx])])
        pmax = np.max([dist+3*err, np.max(core.chain[core.burn:,idx])])
        x = np.linspace(pmin, pmax, 300)
        y = PXDistPrior(x, dist, err)
    elif prior_type == 'DMDist':
        dist = float(pline[pline.index('dist=')+5:pline.index(', err')])
        err = float(pline[pline.index('err')+4:-1])
        pmin = np.min([dist-2*err, np.min(core.chain[core.burn:,idx])])
        pmax = np.max([dist+2*err, np.max(core.chain[core.burn:,idx])])
        x = np.linspace(pmin, pmax, 300)
        y = DMDistPrior(x, dist, err)
    return x, y

In [None]:
def get_bayes_fac(c, amp_param='log10_mc', mask=None):
    idx = c.params.index(amp_param)
    pmin = float(c.priors[idx][c.priors[idx].index('pmin')+5:c.priors[idx].index(', pmax')])
    pmax = float(c.priors[idx][c.priors[idx].index('pmax')+5:c.priors[idx].index(')')])
    if isinstance(mask, np.ndarray):
        BF, _ = bayes_fac(c.chain[c.burn:, idx][mask], logAmin=pmin, logAmax=pmax)
    else:
        BF, _ = bayes_fac(c.chain[c.burn:, idx], logAmin=pmin, logAmax=pmax)
    if np.isnan(BF):
        if isinstance(mask, np.ndarray):
            BF = gorilla_bf(c.chain[c.burn:, idx][mask], max=pmax, min=pmin, nbins=20)
        else:
            BF = gorilla_bf(c.chain[c.burn:, idx], max=pmax, min=pmin, nbins=20)
    return BF

# DO THINGS SLOW

# Load chains

Running models.

In [None]:
plt.close('all')
detection = True
upper_limit = False
vary_crn = False
vary_fgw = True
dataset = 'ng15_v1p1'
source_name = '3C66B'
#source_name = 'HS_0926+3608'
#source_name = 'HS_1630+2355'
#source_name = 'HS_1630+2355_altskyloc'
#source_name = 'NGC_3115'
#source_name = 'OJ287_UL'
#source_name = 'PKS_2131-021_UL'
#source_name = 'SDSS_J092911.35+2037'
#source_name = 'SDSS_J114857.33+1600'
#source_name = 'SDSS_J131706.19+2714'
#source_name = 'SDSS_J133516.17+1833'
#source_name = 'SDSS_J134855.27-0321'
#source_name = 'SDSS_J140704.43+2735'
#source_name = 'SDSS_J160730.33+1449'
#source_name = 'SDSS_J164452.71+4307_UL'
#source_name = 'SNU_J13120+0641'
#source_name = 'PKS_J0805-0111'
project_path = '/vast/palmer/home.grace/bbl29/targeted_cws_ng15'
outdir_path = '/vast/palmer/home.grace/bbl29/project/targeted_cws_ng15'

In [None]:
# set up figure save directories
save_loc = f'{project_path}/reports/figures/{dataset}/{source_name}'
if detection:
    save_loc += '_det'
elif upper_limit:
    save_loc += '_UL'
if vary_fgw:
    save_loc += '_varyfgw'
if vary_crn:
    save_loc += '_varycrn'
save_loc += '/'
if not os.path.isdir(save_loc):
    os.mkdir(save_loc)

In [None]:
chaindir = f'{outdir_path}/data/chains/{dataset}/{source_name}'#_old_prior'
if detection:
    chaindir += '_det'
elif upper_limit:
    chaindir += '_UL'
if vary_fgw:
    chaindir += '_varyfgw'
if vary_crn:
    chaindir += '_varycrn'
chaindirs = [x[0]+'/' for x in os.walk(chaindir)][1:]

nchains = len(chaindirs)
nchains

In [None]:
# load model params
try:
    with open(chaindir+'model_params.json' , 'r') as fin:
        model_params = json.load(fin)
except:
    try:
        with open(chaindirs[0]+'model_params.json' , 'r') as fin:
            model_params = json.load(fin)
    except:
        with open(chaindirs[3]+'model_params.json' , 'r') as fin:
            model_params = json.load(fin)
stats = ['lnpost', 'lnlike', 'chain_accept', 'pt_chain_accept']

In [None]:
cs = []
for i, cd in enumerate(chaindirs):
    print(f'{i+1}/{nchains}')
    try:
        core = co.Core(label=source_name, chaindir=cd,
                       params=model_params+stats, pt_chains=True)
        if len(core.chain) > 200:
            cs.append(core)
        else:
            print(f'not enough lines in {cd}')
    except:
        print(f'could not load from {cd}')

In [None]:
use_converged = False
chain_info_path = f'{chaindir}/converged.txt'
if os.path.isfile(chain_info_path) and use_converged:
    print('using converged.txt')
    with open(chain_info_path, 'r') as fi:
        chain_info = fi.readlines()
        nchains = int(chain_info[0].replace('Num chains = ','').replace('\n',''))
        burn = int(chain_info[1].replace('burned from each chain = ','').replace('\n',''))
        thin_by = int(chain_info[2].replace('thinned by = ','').replace('\n',''))
else:
    nchains = len(cs)
    print(nchains, 'chains')
    # set by hand
    burn = 8000
    thin_by = 2

In [None]:
total_samples = 0
total_postburn_samples = 0
fig, ax = plt.subplots(2,1,figsize=(12,6),sharex=False)
for i in range(nchains):
    cs[i].set_burn(burn)
    n_samples = len(cs[i].get_param('lnpost',to_burn=True))
    x = np.arange(np.ceil(total_postburn_samples/thin_by),
                  np.ceil(total_postburn_samples/thin_by)+np.ceil(n_samples/thin_by))
    y = cs[i].get_param('lnpost',to_burn=True)[::thin_by]
    ax[0].plot(y, alpha=0.2, lw=0.5, c='k')
    if n_samples == 0:
        print(f'no samples for c[{i}]')
    total_postburn_samples += n_samples
    total_samples += len(cs[i].get_param('lnpost',to_burn=False))
    ax[1].plot(x,y,lw=0.5)
ax[0].set_ylabel(r'$\log$post')
ax[0].set_ylabel(r'$\log$post')
ax[0].set_title(source_name)
ax[1].set_ylabel(r'$\log$post')
ax[-1].set_xlabel('sample')
fig.tight_layout()

print(total_samples,'total samples')
print(total_postburn_samples,'samples after burn in')
print(total_postburn_samples//thin_by,'thinned samples')
print(np.round((total_samples-total_postburn_samples)*100/total_samples,3),'% of samples burned')

In [None]:
for i in np.flip(np.arange(nchains)):
    if len(cs[i].get_param(cs[i].params[0],to_burn=True)) == 0:
        print(f'no samples for c[{i}]')
        cs.pop(i)
        nchains -= 1

In [None]:
cs[i].params = model_params+stats

In [None]:
fig, ax = diagnostics.plot_proposals(cs[0],return_fig=True)
for i in range(nchains):
    diagnostics.plot_proposals(cs[i], ax=ax)

Params of interest

In [None]:
pars = ['cos_inc', 'log10_mc', 'phase0', 'psi',
        #'J1713+0747_cw_p_dist', 'J1713+0747_cw_p_phase',
        'J1713+0747_red_noise_gamma', 'J1713+0747_red_noise_log10_A',
        #'J1909-3744_cw_p_dist', 'J1909-3744_cw_p_phase',
        #'J1909-3744_red_noise_gamma', 'J1909-3744_red_noise_log10_A', 
        'J1024-0719_cw_p_dist', 'J1024-0719_cw_p_phase',
        #'J1640+2224_cw_p_dist', 'J1640+2224_cw_p_phase',
        #'J1640+2224_red_noise_gamma', 'J1640+2224_red_noise_log10_A',
        'lnpost', 'lnlike', 'chain_accept', 'pt_chain_accept']
if vary_fgw:
    pars = ['log10_fgw'] + pars
if vary_crn:
    pars = ['crn_gamma', 'crn_log10_A'] + pars
save_str = save_loc+"/posteriors_ind_cores_select"
dg.plot_chains(cs, pars=pars, hist=True, ncols=4, title_y=1.01, save=save_str)

In [None]:
param = 'log10_fgw'
total_samples = 0
total_postburn_samples = 0
fig, ax = plt.subplots(2,1,figsize=(12,6),sharex=False)
for i in range(nchains):
    cs[i].set_burn(burn)
    n_samples = len(cs[i].get_param(param,to_burn=True))
    x = np.arange(np.ceil(total_postburn_samples/thin_by),
                  np.ceil(total_postburn_samples/thin_by)+np.ceil(n_samples/thin_by))
    y = cs[i].get_param(param,to_burn=True)[::thin_by]
    ax[0].plot(y, alpha=0.2, lw=0.5, c='k')
    if n_samples == 0:
        print(f'no samples for c[{i}]')
    total_postburn_samples += n_samples
    total_samples += len(cs[i].get_param(param,to_burn=False))
    ax[1].plot(x,y,lw=0.5)
ax[0].set_ylabel(param)
ax[0].set_title(source_name)
ax[1].set_ylabel(param)
ax[-1].set_xlabel('sample')
fig.tight_layout()

Look at rhat statistics from individual chains to see if any bad individual chains/parameters

In [None]:
diagnostics.plot_grubin(cs, threshold=1+0.01*nchains)

### Save concatenated chain

These samples will be more useful in one chain file. Note there may occasionally be a case where one or more cores get stuck in a lower-likelihood region. In these cases, we'll want to split the chains up. If we do this, we'll define the full set of chains in `chain_mix.txt`. There's an additional section to set up split chain arrays later.

In [None]:
chains = []
for c in cs:
    chains.append(c.chain[burn::thin_by])
chain_array = np.concatenate(chains)
c = co.Core(chain=chain_array, label=source_name,
            params=model_params+stats, pt_chains=False, burn=0)

Reconstruct strain using:
\begin{align*}
    h_0 &= \frac{2\mathcal{M}^{5/3}(\pi f_{\text{GW}}^{2/3})}{d_L}
\end{align*}
or 
\begin{align*}
    \log_{10}h_0 &= \frac{5}{3}\log_{10}\mathcal{M} + \frac{2}{3}\log_{10}f_{\text{GW}} - \log_{10}d_L + \log_{10}(2\pi)
\end{align*}
These equations are defined using $c = G = 1$, however the parameters defined in enterprise are not necessarily defined using those units. Let's convert them here:

| Param | c = G = 1 | enterprise |
| --- | --- | --- |
| $\mathcal{M}$ | $\text{s}$ | $M_\odot$ |
| $f_{\text{GW}}$ | $1/\text{s}$ | $1/\text{s}$ |
| $d_L$ | $\text{s}$ | $\text{Mpc}$ |

This means $M_\odot$ convert to $\text{s}$ as $(\text{kg}/M_\odot)\cdot G/c^2$ and $\text{Mpc}$ convert to $\text{s}$ as $(\text{m}/\text{Mpc})c$

In [None]:
c_light = 299792458 # m/s
G = 6.67430e-11 # Nm^2/kg

# add priors, runtime info
prior_path = glob.glob(chaindir + '/*/priors.txt')[0]
c.priors = np.loadtxt(prior_path, dtype=str, delimiter='\t')
info_path = glob.glob(chaindir + '/*/runtime_info.txt')[0]
c.runtime_info = np.loadtxt(info_path, dtype=str, delimiter='\t')

# append param name
c.params = c.params[:-4] + ['log10_h0'] + c.params[-4:]
if upper_limit:
    # append linear chirp mass
    c.params = c.params[:-4] + ['mc'] + c.params[-4:]

# get log10_fgw, log10_dl
if vary_fgw:
    log10_fgw = c('log10_fgw')
    pline = [p for p in c.priors if 'log10_fgw' in p][0]
    pmin = float(pline[pline.index('pmin')+5:pline.index(', pmax')])
    pmax = float(pline[pline.index('pmax')+5:-1])
    log10_fgw_prior = np.array([pmin, pmax])
    # scale for log10_h0 prior calculation
    log10_fgw_prior_scaled = 2/3*log10_fgw_prior
else:
    line = [c.runtime_info[i] for i in range(len(c.runtime_info))
            if 'log10_fgw' in c.runtime_info[i]][0]
    log10_fgw = float(line.replace('log10_fgw:Constant=',''))
line = [c.runtime_info[i] for i in range(len(c.runtime_info))
        if 'log10_dL' in c.runtime_info[i]][0]
log10_dL = float(line.replace('log10_dL:Constant=',''))
log10_dL_scaled = log10_dL + np.log10(u.Mpc.to(u.m)/c_light)

# append h0 chain and prior
log10_mc = c('log10_mc',to_burn=False)
log10_mc_scaled = log10_mc + np.log10(u.Msun.to(u.kg)*G/c_light**3)
# calculate strain
log10_h0 = (5*log10_mc_scaled/3 + 2*log10_fgw/3 - log10_dL_scaled +
            np.log10(2*np.pi**(2/3)))
c.chain = np.vstack([c.chain[:,:-4].T,log10_h0,c.chain[:,-4:].T]).T
#c.chain = np.vstack([c.chain.T,log10_h0]).T
if upper_limit:
    # append linear chirp mass
    c.chain = np.vstack([c.chain[:,:-4].T,10**log10_mc,c.chain[:,-4:].T]).T

# append h0 prior
pline = [p for p in c.priors if 'log10_mc' in p][0]
if upper_limit or detection:
    pmin = float(pline[pline.index('pmin')+5:pline.index(', pmax')])
    pmax = float(pline[pline.index('pmax')+5:-1])
    log10_mc_prior = np.array([pmin, pmax])
    log10_mc_prior_scaled = (log10_mc_prior +
                             np.log10(u.Msun.to(u.kg)*G/c_light**3))
    if vary_fgw:
        # Now include dL and pi pieces in the scaled log10_mc prior for fgw calulation
        log10_mc_prior_scaled = 5/3*log10_mc_prior_scaled - log10_dL_scaled + np.log10(2*np.pi**(2/3))
        # strain prior will be a convolution
        if upper_limit:
            log10_h0_prior_str = (f'log10_h0:Convolve(LinearExp(pmin={log10_mc_prior_scaled[0]}, '
                                  f'pmax={log10_mc_prior_scaled[1]}), '
                                  f'Uniform(pmin={log10_fgw_prior_scaled[0]}, '
                                  f'pmax={log10_fgw_prior_scaled[1]}))')
        else:
            log10_h0_prior_str = (f'log10_h0:Convolve(Uniform(pmin={log10_mc_prior_scaled[0]}, '
                                  f'pmax={log10_mc_prior_scaled[1]}), '
                                  f'Uniform(pmin={log10_fgw_prior_scaled[0]}, '
                                  f'pmax={log10_fgw_prior_scaled[1]}))')
    else:
        # calculate strain prior directly
        log10_h0_prior = (5*log10_mc_prior_scaled/3 + 2*log10_fgw/3 -
                          log10_dL_scaled + np.log10(2*np.pi**(2/3)))
        if upper_limit:
            log10_h0_prior_str=(f'log10_h0:LinearExp(pmin={log10_h0_prior[0]}'+
                                f', pmax={log10_h0_prior[1]})')
        else:
            log10_h0_prior_str = (f'log10_h0:Uniform(pmin={log10_h0_prior[0]}'+
                                  f', pmax={log10_h0_prior[1]})')
else:
    log10_mc_mu = float(pline[pline.index('mu')+3:pline.index(', sigma')])
    log10_mc_sigma = float(pline[pline.index('sigma')+6:-1])
    log10_h0_mu = (5*(log10_mc_mu +
                      np.log10(u.Msun.to(u.kg)*G/c_light**3))/3 +
                   2*log10_fgw/3 - log10_dL_scaled + np.log10(2*np.pi**(2/3)))
    # hacky way to get h0 sigma
    # just rescale mc sigma by comparing widths of the samples
    r = ((np.max(c.chain[:,c.params.index('log10_h0')]) -
          np.min(c.chain[:,c.params.index('log10_h0')]))/
         (np.max(c.chain[:,c.params.index('log10_mc')]) -
          np.min(c.chain[:,c.params.index('log10_mc')])))
    log10_h0_sigma = r*log10_mc_sigma
    log10_h0_prior_str = (f'log10_h0:Normal(mu={log10_h0_mu}, '+
                          f'sigma={log10_h0_sigma})')
c.priors = np.concatenate([c.priors, [log10_h0_prior_str]])
if upper_limit:
    # append linear chirp mass
    log10_mc_prior_str = (f'mc:Uniform(pmin={10**log10_mc_prior[0]}, '+
                          f'pmax={10**log10_mc_prior[1]})')
    c.priors = np.concatenate([c.priors, [log10_mc_prior_str]])

In [None]:
c.save(f'{chaindir}/core.h5')

### Save higher temp cores

In [None]:
if isinstance(cs[0].hot_chains, dict):
    cT = {}
    for T in cs[0].hot_chains.keys():
        print(T)
        chains = [c_i.hot_chains[T][burn::thin_by] for c_i in cs]
        chain_array = np.concatenate(chains)
        cT[T] = co.Core(chain=chain_array, label=source_name,
                        params=c.params, pt_chains=False, burn=0)
        cT[T].runtime_info = c.runtime_info
        cT[T].priors = c.priors
        # get log10_fgw, log10_dl
        if vary_fgw:
            log10_fgw = cT[T]('log10_fgw')
            pline = [p for p in cT[T].priors if 'log10_fgw' in p][0]
            pmin = float(pline[pline.index('pmin')+5:pline.index(', pmax')])
            pmax = float(pline[pline.index('pmax')+5:-1])
            log10_fgw_prior = np.array([pmin, pmax])
            # scale for log10_h0 prior calculation
            log10_fgw_prior_scaled = 2/3*log10_fgw_prior

        # append h0 chain and prior
        log10_mc = cT[T]('log10_mc',to_burn=False)
        log10_mc_scaled = log10_mc + np.log10(u.Msun.to(u.kg)*G/c_light**3)
        # calculate strain
        log10_h0 = (5*log10_mc_scaled/3 + 2*log10_fgw/3 - log10_dL_scaled +
                    np.log10(2*np.pi**(2/3)))
        cT[T].chain = np.vstack([cT[T].chain[:,:-4].T,log10_h0,cT[T].chain[:,-4:].T]).T
        #c.chain = np.vstack([c.chain.T,log10_h0]).T
        if upper_limit:
            # append linear chirp mass
            cT[T].chain = np.vstack([cT[T].chain[:,:-4].T,10**log10_mc,cT[T].chain[:,-4:].T]).T
        cT[T].save(f'{chaindir}/core_{T}.h5')

In [None]:
pars = ['cos_inc', 'log10_mc', 'phase0', 
        'lnpost', 'lnlike', 'chain_accept', 'pt_chain_accept']
if vary_fgw:
    pars = ['log10_fgw'] + pars
if vary_crn:
    pars = ['crn_gamma', 'crn_log10_A'] + pars
dg.plot_chains([c] + list(cT.values()), pars=pars, hist=True,
               legend_labels=['T=1']+[f'T={T}' for T in list(cT)], ncols=4, title_y=1.10)

### Load core (optional)

In [None]:
c = co.Core(corepath=f'{chaindir}/core.h5', label=source_name,
            params=model_params+stats, pt_chains=False, burn=0)

In [None]:
#np.savetxt(f'{chaindir}/{source_name}_log10_mc.txt', c('log10_mc'))
#f'{chaindir}/{source_name}_log10_mc.txt'

### Plot full chain

In [None]:
fig, ax = plt.subplots(figsize=(12,3))
ax.plot(c.get_param('lnpost',to_burn=True), c='k', alpha=1, lw=0.5)
ax.set_ylabel(r'$\log$post')
ax.set_xlabel('sample')
fig.tight_layout()

In [None]:
fig = diagnostics.plot_neff(c, return_fig=True)
fig.savefig(f'{save_loc}neff.png')

In [None]:
fig = diagnostics.plot_grubin(c, M=2, return_fig=True)
fig.savefig(f'{save_loc}GR.png')

Plot all params, with major thinning!!! Have 10k samples shown at most

In [None]:
thin_extra = len(c.chain)//5000
idxs = [idx for idx in diagnostics.grubin(c, M=2)[1] if idx < len(c.params)-3]
idxs += [c.params.index('cos_inc'), c.params.index('log10_mc'),
         c.params.index('phase0'), c.params.index('psi'), c.params.index('lnpost')]
fig, ax = plt.subplots(len(idxs),1,figsize=(10,np.min([len(idxs),30])),sharex=True)
for i, idx in enumerate(idxs):
    ax[i].plot(c.get_param(c.params[idx],to_burn=True)[::thin_extra], c='k', alpha=1, lw=0.2)
    ax[i].set_ylabel(c.params[idx],fontsize='xx-small')
ax[-1].set_xlabel('sample')
fig.tight_layout()
fig.subplots_adjust(hspace=0)

## Additional analyses

In [None]:
if detection:
    log10_mc_prior = 'uniform'
elif upper_limit:
    log10_mc_prior = 'linearexp'
else:
    log10_mc_prior = 'normal'

for line in c.runtime_info:
    if not vary_crn:
        if 'crn_log10_A' in line:
            crn_log10_A = np.round(float(line[line.index('=')+1:]), decimals=2)
        if 'crn_gamma' in line:
            crn_gamma = np.round(float(line[line.index('=')+1:]), decimals=2)
    if not vary_fgw:
        if 'log10_fgw' in line:
            log10_fgw = np.round(float(line[line.index('=')+1:]), decimals=2)
    if 'log10_dL' in line:
        log10_dL = np.round(float(line[line.index('=')+1:]), decimals=2)
    if 'log10_dL' in line:
        log10_dL = np.round(float(line[line.index('=')+1:]), decimals=2)
    if 'cos_gwtheta' in line:
        cos_gwtheta = np.round(float(line[line.index('=')+1:]), decimals=2)
    if 'gwphi' in line:
        gwphi = np.round(float(line[line.index('=')+1:]), decimals=2)

In [None]:
# define all the stuff
#idx_fgw = c.params.index('log10_fgw')
#mask = (c.chain[c.burn:, idx_fgw] < -7.3)*(c.chain[c.burn:, idx_fgw] > -7.32)

pars_select = ['cos_inc','phase0','psi', 'log10_h0', 'log10_mc']
#units = ['rad', 'rad', 'rad', 'calculated', r'$M_\odot$']
titles = [r'$\cos\iota$', '$\Phi_0$', r'$\psi$', r'$\log_{10}h_0$', r'$\log_{10}\mathcal{M}_c$']
labels = [r'$\cos\iota$', '$\Phi_0$', r'$\psi$', r'$\log_{10}h_0$', r'$\log_{10}\mathcal{M}_c$ [$M_\odot$]']
if log10_mc_prior == 'linearexp':
    pars_select = pars_select[:-1] + ['mc']
    titles = titles[:-1] + [r'$\mathcal{M}_c$']
    labels = labels[:-1] + [r'$\mathcal{M}_c$ [$M_\odot$]']
if vary_fgw:
    pars_select = ['cos_inc','phase0','psi', 'log10_fgw', 'log10_h0', 'log10_mc']
    titles = [r'$\cos\iota$', '$\Phi_0$', r'$\psi$',
              r'$\log_{10}f_{\rm{GW}}$', r'$\log_{10}h_0$', r'$\log_{10}\mathcal{M}_c$']
    labels = [r'$\cos\iota$', '$\Phi_0$', r'$\psi$',
              r'$\log_{10}f_{\rm{GW}}$', r'$\log_{10}h_0$', r'$\log_{10}\mathcal{M}_c$ [$M_\odot$]']
if vary_crn:
    pars_select = ['crn_gamma', 'crn_log10_A'] + pars_select
    titles = [r'$\gamma_{\rm{CRN}}$', r'$\log_{10}A_{\rm{CRN}}$'] + titles
    labels = [r'$\gamma_{\rm{CRN}}$', r'$\log_{10}A_{\rm{CRN}}$'] + labels

# plot
idxs = [c.params.index(p) for p in pars_select]
quantiles = [0.16,0.5,0.84]
chain = c.chain[c.burn:, idxs]
#chain = np.array([c.chain[c.burn:,idx][mask] for idx in idxs]).T
fig = corner.corner(chain, labels=labels, quantiles=quantiles,
                    title_quantiles=quantiles, hist_kwargs={'density':True},
                    titles=titles, show_titles=True, levels=(0.68, 0.95),
                    label_kwargs={'fontsize': 20}, title_kwargs={'fontsize': 14})
if log10_mc_prior == 'normal':
    if vary_crn:
        fig.suptitle(f'{source_name} w/ enterprise \n'
                     'informative mass priors, varied CRN \n'
                     r'$\cos\theta =$' + f'{cos_gwtheta} \n'
                     r'$\phi =$' + f'{gwphi} \n'
                     r'$\log_{10}f_{\rm{GW}} =$' + f'{log10_fgw} [Hz] \n'
                     r'$\log_{10}d_{\rm{L}} =$' + f'{np.round(log10_dL, decimals=3)} [Mpc] \n',
                     fontsize=25, x=1, horizontalalignment='right')
    else:
        fig.suptitle(f'{source_name} w/ enterprise \n'
                     'informative mass priors, fied CRN \n'
                     r'$\cos\theta =$' + f'{cos_gwtheta} \n'
                     r'$\phi =$' + f'{gwphi} \n'
                     r'$\log_{10}f_{\rm{GW}} =$' + f'{log10_fgw} [Hz] \n'
                     r'$\log_{10}d_{\rm{L}} =$' + f'{np.round(log10_dL, decimals=3)} [Mpc] \n'
                     r'$\log_{10}A_{\rm{CRN}} =$' + f'{crn_log10_A} \n'
                     r'$\gamma_{\rm{CRN}} =$' + f'{crn_gamma} \n',
                     fontsize=25, x=1, horizontalalignment='right')
elif log10_mc_prior == 'linearexp':
    UL = c.get_param_credint('mc', onesided=True, interval=95)
    if vary_crn:
        fig.suptitle(f'{source_name} w/ enterprise \n'
                     'uniform mass priors, varied CRN \n'
                     r'$\cos\theta =$' + f'{cos_gwtheta} \n'
                     r'$\phi =$' + f'{gwphi} \n'
                     r'$\log_{10}f_{\rm{GW}} =$' + f'{log10_fgw} [Hz] \n'
                     r'$\log_{10}d_{\rm{L}} =$' + f'{np.round(log10_dL, decimals=3)} [Mpc] \n'
                     r'$\mathcal{M}_c$ UL ='+f'{np.round(UL/1e9, decimals=3)}'+r'$\cdot 10^9$ M$_\odot$',
                     fontsize=25, x=1, horizontalalignment='right')
    else:
        fig.suptitle(f'{source_name} w/ enterprise \n'
                     'uniform mass priors \n'
                     r'$\cos\theta =$' + f'{cos_gwtheta} \n'
                     r'$\phi =$' + f'{gwphi} \n'
                     r'$\log_{10}f_{\rm{GW}} =$' + f'{log10_fgw} [Hz] \n'
                     r'$\log_{10}d_{\rm{L}} =$' + f'{np.round(log10_dL, decimals=3)} [Mpc] \n'
                     r'$\log_{10}A_{\rm{CRN}} =$' + f'{crn_log10_A} \n'
                     r'$\gamma_{\rm{CRN}} =$' + f'{crn_gamma} \n'
                     r'$\mathcal{M}_c$ UL ='+f'{np.round(UL/1e9, decimals=3)}'+r'$\cdot 10^9$ M$_\odot$',
                     fontsize=25, x=1, horizontalalignment='right')
else:
    BF = get_bayes_fac(c)
    if vary_crn:
        fig.suptitle(f'{source_name} w/ enterprise \n'
                     f'log-uniform mass priors, varied CRN \n'
                     r'$\cos\theta =$' + f'{cos_gwtheta} \n'
                     r'$\phi =$' + f'{gwphi} \n'
                     r'$\log_{10}f_{\rm{GW}} =$' + f'{log10_fgw} [Hz] \n'
                     r'$\log_{10}d_{\rm{L}} =$' + f'{np.round(log10_dL, decimals=3)} [Mpc] \n'
                     '($\mathcal{B} = $'+f'{np.round(BF, decimals=3)}',
                     fontsize=25, x=1, horizontalalignment='right')
    else:
        if vary_fgw:
            fig.suptitle(f'{source_name} w/ enterprise \n'
                         f'log-uniform mass priors \n'
                         r'$\cos\theta =$' + f'{cos_gwtheta} \n'
                         r'$\phi =$' + f'{gwphi} \n'
                         r'$\log_{10}d_{\rm{L}} =$' + f'{np.round(log10_dL, decimals=3)} [Mpc] \n'
                         r'$\log_{10}A_{\rm{CRN}} =$' + f'{crn_log10_A} \n'
                         r'$\gamma_{\rm{CRN}} =$' + f'{crn_gamma} \n'
                         '(SD BF: $\mathcal{B} = $'+f'{np.round(BF, decimals=2)})',
                         fontsize=25, x=1, horizontalalignment='right')
        else:
            fig.suptitle(f'{source_name} w/ enterprise \n'
                         f'log-uniform mass priors \n'
                         r'$\cos\theta =$' + f'{cos_gwtheta} \n'
                         r'$\phi =$' + f'{gwphi} \n'
                         r'$\log_{10}f_{\rm{GW}} =$' + f'{log10_fgw} [Hz] \n'
                         r'$\log_{10}d_{\rm{L}} =$' + f'{np.round(log10_dL, decimals=3)} [Mpc] \n'
                         r'$\log_{10}A_{\rm{CRN}} =$' + f'{crn_log10_A} \n'
                         r'$\gamma_{\rm{CRN}} =$' + f'{crn_gamma} \n'
                         '(SD BF: $\mathcal{B} = $'+f'{np.round(BF, decimals=2)})',
                         fontsize=25, x=1, horizontalalignment='right')

# Extract the axes
axes = np.array(fig.axes).reshape((len(idxs), len(idxs)))

# Loop over the diagonal
for i, p in enumerate(pars_select):
    ax = axes[i, i]
    x, y = get_prior_distr(c, p)
    ax.plot(x, y, 'C2')
    if not p == 'mc':
        ax.set_xlim([x.min(),x.max()])
    if p == 'log10_mc' and log10_mc_prior == 'linearexp':
        ax.axvline(UL, color='C0')
# Loop over the histograms
for yi, p1 in enumerate(pars_select): # rows
    for xi, p2 in enumerate(pars_select[:yi]): # cols
        ax = axes[yi, xi]
        y, _ = get_prior_distr(c, p1)
        x, _ = get_prior_distr(c, p2)
        ax.set_xlim([x.min(),x.max()])
        if not p1 == 'mc':
            ax.set_ylim([y.min(),y.max()])
fig.savefig(f'{save_loc}/corner.png')

option to do masked plot

In [None]:
# define all the stuff
idx_fgw = c.params.index('log10_fgw')
mask = (c.chain[c.burn:, idx_fgw] < -7.3)*(c.chain[c.burn:, idx_fgw] > -7.32)

pars_select = ['cos_inc','phase0','psi', 'log10_h0', 'log10_mc']
#units = ['rad', 'rad', 'rad', 'calculated', r'$M_\odot$']
titles = [r'$\cos\iota$', '$\Phi_0$', r'$\psi$', r'$\log_{10}h_0$', r'$\log_{10}\mathcal{M}_c$']
labels = [r'$\cos\iota$', '$\Phi_0$', r'$\psi$', r'$\log_{10}h_0$', r'$\log_{10}\mathcal{M}_c$ [$M_\odot$]']
if log10_mc_prior == 'linearexp':
    pars_select = pars_select[:-1] + ['mc']
    titles = titles[:-1] + [r'$\mathcal{M}_c$']
    labels = labels[:-1] + [r'$\mathcal{M}_c$ [$M_\odot$]']
if vary_fgw:
    pars_select = ['cos_inc','phase0','psi', 'log10_fgw', 'log10_h0', 'log10_mc']
    titles = [r'$\cos\iota$', '$\Phi_0$', r'$\psi$',
              r'$\log_{10}f_{\rm{GW}}$', r'$\log_{10}h_0$', r'$\log_{10}\mathcal{M}_c$']
    labels = [r'$\cos\iota$', '$\Phi_0$', r'$\psi$',
              r'$\log_{10}f_{\rm{GW}}$', r'$\log_{10}h_0$', r'$\log_{10}\mathcal{M}_c$ [$M_\odot$]']
if vary_crn:
    pars_select = ['crn_gamma', 'crn_log10_A'] + pars_select
    titles = [r'$\gamma_{\rm{CRN}}$', r'$\log_{10}A_{\rm{CRN}}$'] + titles
    labels = [r'$\gamma_{\rm{CRN}}$', r'$\log_{10}A_{\rm{CRN}}$'] + labels

# plot
idxs = [c.params.index(p) for p in pars_select]
quantiles = [0.16,0.5,0.84]
#chain = c.chain[c.burn:, idxs]
chain = np.array([c.chain[c.burn:,idx][mask] for idx in idxs]).T
fig = corner.corner(chain, labels=labels, quantiles=quantiles,
                    title_quantiles=quantiles, hist_kwargs={'density':True},
                    titles=titles, show_titles=True, levels=(0.68, 0.95),
                    label_kwargs={'fontsize': 20}, title_kwargs={'fontsize': 14})
BF = get_bayes_fac(c, mask=mask)
fig.suptitle(f'{source_name} w/ enterprise \n'
             f'log-uniform mass priors \n'
             r'$\cos\theta =$' + f'{cos_gwtheta} \n'
             r'$\phi =$' + f'{gwphi} \n'
             r'$\log_{10}f_{\rm{GW}} =$' + f'-7.31 [Hz] \n'
             r'$\log_{10}d_{\rm{L}} =$' + f'{np.round(log10_dL, decimals=3)} [Mpc] \n'
             r'$\log_{10}A_{\rm{CRN}} =$' + f'{crn_log10_A} \n'
             r'$\gamma_{\rm{CRN}} =$' + f'{crn_gamma} \n'
             '(SD BF: $\mathcal{B} = $'+f'{np.round(BF, decimals=2)})',
             fontsize=25, x=1, horizontalalignment='right')

# Extract the axes
axes = np.array(fig.axes).reshape((len(idxs), len(idxs)))

# Loop over the diagonal
for i, p in enumerate(pars_select):
    ax = axes[i, i]
    x, y = get_prior_distr(c, p)
    ax.plot(x, y, 'C2')
    if not p == 'mc':
        ax.set_xlim([x.min(),x.max()])
    if p == 'log10_mc' and log10_mc_prior == 'linearexp':
        ax.axvline(UL, color='C0')
# Loop over the histograms
for yi, p1 in enumerate(pars_select): # rows
    for xi, p2 in enumerate(pars_select[:yi]): # cols
        ax = axes[yi, xi]
        y, _ = get_prior_distr(c, p1)
        x, _ = get_prior_distr(c, p2)
        ax.set_xlim([x.min(),x.max()])
        if not p1 == 'mc':
            ax.set_ylim([y.min(),y.max()])
fig.savefig(f'{save_loc}/corner_masked.png')

What about pulsar specific params? Plot for a few pulsars

In [None]:
# define all the stuff
psrnames = ['J2017+0603']
for psrname in psrnames:
    pars_select = [p for p in c.params if psrname in p] + ['cos_inc','phase0','psi', 'log10_h0']
    #units = ['kpc', 'rad', None, None, 'rad', 'rad', 'rad', None]
    titles = [r'$L_{\rm{psr}}$', r'$\Phi_{\rm{psr}}$', r'$\gamma_{\rm{psr}}$', r'$\log_{10}A_{\rm{psr}}$',
              r'$\cos\iota$', '$\Phi_0$', r'$\psi$', r'$\log_{10}h_0$']
    labels = [r'$L_{\rm{psr}}$ [kpc]', r'$\Phi_{\rm{psr}}$', r'$\gamma_{\rm{psr}}$', r'$\log_{10}A_{\rm{psr}}$',
              r'$\cos\iota$', '$\Phi_0$', r'$\psi$', r'$\log_{10}h_0$']

    # plot
    idxs = [c.params.index(p) for p in pars_select]
    fig = corner.corner(c.chain[c.burn:, idxs], labels=labels,
                        hist_kwargs={'density':True},
                        titles=titles, show_titles=True, levels=(0.68, 0.95),
                        label_kwargs={'fontsize': 20}, title_kwargs={'fontsize': 14})
    fig.suptitle(f'{source_name} global params + {psrname} params \n', fontsize=25)

    # Extract the axes
    axes = np.array(fig.axes).reshape((len(idxs), len(idxs)))

    # Loop over the diagonal
    for i, p in enumerate(pars_select):
        ax = axes[i, i]
        x, y = get_prior_distr(c, p)
        ax.plot(x, y, 'C2')
        ax.set_xlim([x.min(),x.max()])
    # Loop over the histograms
    for yi, p1 in enumerate(pars_select): # rows
        for xi, p2 in enumerate(pars_select[:yi]): # cols
            ax = axes[yi, xi]
            y, _ = get_prior_distr(c, p1)
            x, _ = get_prior_distr(c, p2)
            ax.set_xlim([x.min(),x.max()])
            ax.set_ylim([y.min(),y.max()])
    #fig.savefig(f'{save_loc}/{psrname}_corner.png')

In [None]:
dg.plot_chains(c, pars=[p for p in c.params if 'cw_p_phase' in p],
               ncols=5, save=save_loc+"/cw_p_phase_params")

In [None]:
plt.close('all')

Also draw pulsar distances, with priors. See custom la forge function at bottom of this notebook for a version which draws priors.

In [None]:
plot_chains(c, pars=[p for p in c.params if 'cw_p_' in p],
            ncols=6, save=save_loc+"/cw_p_dist_params", draw_prior=True)

In [None]:
dg.plot_chains(c, hist=False, pars=['J1944+0907_cw_p_dist'])

Hot chain analysis. Need to ahve these included in the new core

## Make full noise empirical distributions

Might need these for a subsequent run. Can only make emp dists for params with uniform priors using current code.

In [None]:
full_params = [c.params[i] for i in range(len(c.params[:-4])) if 'pmin' in c.priors[i]]
full_params

In [None]:
float(c.priors[1][c.priors[1].index('pmax')+5:c.priors[1].index(')')])

In [None]:
def make_empirical_distributions_from_core(core, paramlist,
                                           burn=0, nbins=81, filename='distr.pkl',
                                           return_distribution=True, save_dists=True):
    """
        Utility function to construct empirical distributions.

        :param pta: the pta object used to generate the posteriors
        :param paramlist: a list of parameter names,
                          either single parameters or pairs of parameters
        :param chain: MCMC chain from a previous run
        :param burn: desired number of initial samples to discard
        :param nbins: number of bins to use for the empirical distributions

        :return distr: list of empirical distributions

        """

    distr = []

    if not save_dists and not return_distribution:
        msg = "no distribution returned or saved, are you sure??"
        logger.info(msg)

    for pl in paramlist:

        if type(pl) is not list:

            pl = [pl]

        if len(pl) == 1:
            idx = c.params.index(pl[0])

            prior_min = float(c.priors[idx][c.priors[idx].index('pmin')+5:c.priors[idx].index(', pmax')])
            prior_max = float(c.priors[idx][c.priors[idx].index('pmax')+5:c.priors[idx].index(')')])

            # get the bins for the histogram
            bins = np.linspace(prior_min, prior_max, nbins)

            new_distr = EmpiricalDistribution1D(pl[0], c.chain[c.burn:, idx], bins)

            distr.append(new_distr)

        elif len(pl) == 2:

            # get the parameter indices
            idx = [c.params.index(pl1) for pl1 in pl]

            # get the bins for the histogram
            bins = []
            for i in idx:
                prior_min = float(c.priors[idx][c.priors[idx].index('pmin')+5:c.priors[idx].index(', pmax')])
                prior_max = float(c.priors[idx][c.priors[idx].index('pmax')+5:c.priors[idx].index(')')])
                bins.append(np.linspace(prior_min, prior_max, nbins))

            new_distr = EmpiricalDistribution2D(pl, c.chain[burn:, idx].T, bins)

            distr.append(new_distr)

        else:
            msg = 'WARNING: only 1D and 2D empirical distributions are currently allowed.'
            logger.warning(msg)

    # save the list of empirical distributions as a pickle file
    if save_dists:
        if len(distr) > 0:
            with open(filename, 'wb') as f:
                pickle.dump(distr, f)

            msg = 'The empirical distributions have been pickled to {0}.'.format(filename)
            logger.info(msg)
        else:
            msg = 'WARNING: No empirical distributions were made!'
            logger.warning(msg)

    if return_distribution:
        return distr

In [None]:
overwrite = False
emp_dist_file_name = f'{project_path}/empirical_dists/{source_name}_emp_dist.pkl'
if not os.path.isfile(emp_dist_file_name) or overwrite:
    if overwrite and os.path.isfile(emp_dist_file_name):
        print('overwriting empirical distribution')
    make_empirical_distributions_from_core(c, full_params, filename=emp_dist_file_name)
else:
    print('empirical distribution already exists!')

## Mark convergence

If we are happy with the result, check off the following cell which will create an empty text file to let us know our run has converged and we can perform some more in-depth analyses.

In [None]:
try:
    with open(chaindir+"/converged.txt", 'w') as f:
        f.write(f'Num chains = {nchains}\n')
        f.write(f'burned from each chain = {burn}\n')
        f.write(f'thinned by = {thin_by}\n')
        f.write(f'Final num samples = {len(c.get_param("lnpost"))}\n')
except:
    with open(chaindir+"/converged.txt", 'w') as f:
        f.write(f'Final num samples = {len(c.get_param("lnpost"))}\n')

# DO things fast

Loop over whatever I need to do

### Empirical distributions

In [None]:
# Load PSR objects, depending on dataset
psrs = []
datastr = f'{project_path}/data/ePSRs/ng15_hdf5/*.hdf5'
for hdf5_file in glob.glob(datastr):
    psrs.append(h5pulsar.FilePulsar(hdf5_file))
print('Loaded {0} pulsars from hdf5 files'.format(len(psrs)))

In [None]:
project_path = '/vast/palmer/home.grace/bbl29/targeted_cws_ng15'
outdir_path = '/vast/palmer/home.grace/bbl29/project/targeted_cws_ng15'
sources = ['SDSS_J164452.71+4307']
stats = ['lnpost', 'lnlike', 'chain_accept', 'pt_chain_accept']
noisedict_path = f'{project_path}/noise_dicts/15yr_wn_dict.json'
psr_distance_path=f'{project_path}/psr_distances/pulsar_distances_15yr.pkl'
overwrite=True
# CRN params
log10_A = np.log10(6.4e-15)
gamma = 3.2
for source_name in sources:
    print(source_name)
    
    chaindir = f'{outdir_path}/data/chains/{source_name}/'
    save_loc = f'{project_path}/reports/figures/{source_name}/'
    try:
        with open(chaindir+'model_params.json' , 'r') as fin:
            model_params = json.load(fin)
    except:
        try:
            with open(chaindir+'0/model_params.json' , 'r') as fin:
                model_params = json.load(fin)
        except:
            with open(chaindir+'5/model_params.json' , 'r') as fin:
                model_params = json.load(fin)
    
    c = co.Core(corepath=f'{chaindir}/core.h5', label=source_name,
                params=model_params+stats, pt_chains=False, burn=0)
    
    # get priors
    fname = f'{project_path}/priors/{source_name}_priors.json'
    with open(fname, 'r') as fin:
        priors = json.load(fin)
    
    # set up PTA object
    pta = cw_model_2(psrs, priors, noisedict_path=noisedict_path,
                     psr_distance_path=psr_distance_path,
                     log10_A_val=log10_A, gamma_val=gamma)
    
    full_params = [c.params[i] for i in range(len(c.params[:-4]))
                   if 'pmin' in list(pta.params[i].prior._defaults.keys())]
    
    emp_dist_file_name = f'{project_path}/empirical_dists/{source_name}_emp_dist.pkl'
    if not os.path.isfile(emp_dist_file_name) or overwrite:
        if overwrite and os.path.isfile(emp_dist_file_name):
            print('overwriting empirical distribution')
        try:
            make_empirical_distributions(pta, full_params, chain=c.chain, 
                                         filename=emp_dist_file_name)
        except:
            print('using param names from pta object')
            make_empirical_distributions(pta, pta.param_names, chain=c.chain, 
                                         filename=emp_dist_file_name)
    else:
        print('empirical distribution already exists!')

In [None]:
def plot_chains(core, hist=True, pars=None, exclude=None,
                ncols=3, bins=40, suptitle=None, color='k',
                publication_params=False, titles=None,
                linestyle=None, plot_map=False, truths=None,
                save=False, show=True, linewidth=1,
                log=False, title_y=1.01, hist_kwargs={},
                plot_kwargs={}, legend_labels=None, real_tm_pars=True,
                legend_loc=None, draw_prior=False, prior_fmt='C2', **kwargs):
    """Function to plot histograms or traces of chains from cores.

    Parameters
    ----------
    core : {`la_forge.core.Core`,
            `la_forge.core.HyperModelCore`,
            `la_forge.core.TimingCore`,
            `la_forge.slices.SlicedCore`}

    hist : bool, optional
        Whether to plot histograms. If False then traces of the chains will be
        plotted.

    pars : list of str, optional
        List of the parameters to be plotted.

    exclude : list of str, optional
        List of the parameters to be excluded from plot.

    ncols : int, optional
        Number of columns of subplots to use.

    bins : int, optional
        Number of bins to use in histograms.

    suptitle : str, optional
        Title to use for the plots.

    color : str or list of str, optional
        Color to use for histograms.

    publication_params=False,

    titles=None,

    linestyle : str,

    plot_map=False,

    save=False,
    show=True,
    linewidth=1,
    log=False,
    title_y=1.01,
    hist_kwargs={},
    plot_kwargs={},
    legend_labels=None,
    legend_loc=None,

    """
    if pars is not None:
        params = pars
    elif exclude is not None and pars is not None:
        raise ValueError('Please remove excluded parameters from `pars`.')
    elif exclude is not None:
        if isinstance(core, list):
            params = set()
            for c in core:
                params.intersection_update(c.params)
        else:
            params = core.params
        params = list(params)
        for p in exclude:
            params.remove(p)
    elif pars is None and exclude is None:
        if isinstance(core, list):
            params = core[0].params
            for c in core[1:]:
                params = [p for p in params if p in c.params]
        else:
            params = core.params

    if isinstance(core, list):
        fancy_par_names=core[0].fancy_par_names
        if linestyle is None:
            linestyle = ['-' for ii in range(len(core))]

        if isinstance(plot_map, list):
            pass
        else:
            plot_map = [plot_map for ii in range(len(core))]
    else:
        fancy_par_names=core.fancy_par_names

    L = len(params)

    if suptitle is None:
        psr_name = copy.deepcopy(params[0])
        if psr_name[0] == 'B':
            psr_name = psr_name[:8]
        elif psr_name[0] == 'J':
            psr_name = psr_name[:10]
    else:
        psr_name = None

    nrows = int(L // ncols)
    if L %ncols > 0:
        nrows +=1

    if publication_params:
        fig = plt.figure()
    else:
        fig = plt.figure(figsize=[15, 4*nrows])

    for ii, p in enumerate(params):
        cell = ii+1
        axis = fig.add_subplot(nrows, ncols, cell)
        if hist:
            if truths is not None:
                ans = truths[p]
                plt.axvline(ans, linewidth=2,
                            color='k', linestyle='-.')

            if isinstance(core, list):
                for jj, c in enumerate(core):
                    gpar_kwargs= dg._get_gpar_kwargs(c, real_tm_pars)
                    phist=plt.hist(c.get_param(p, **gpar_kwargs),
                                   bins=bins, density=True, log=log,
                                   linewidth=linewidth,
                                   linestyle=linestyle[jj],
                                   histtype='step', **hist_kwargs)

                    if plot_map[jj]:
                        pcol=phist[-1][-1].get_edgecolor()
                        plt.axvline(c.get_map_param(p), linewidth=1,
                                    color=pcol, linestyle='--')

            else:
                gpar_kwargs= dg._get_gpar_kwargs(core, real_tm_pars)
                phist=plt.hist(core.get_param(p, **gpar_kwargs),
                               bins=bins, density=True, log=log,
                               linewidth=linewidth,
                               histtype='step', **hist_kwargs)
                if plot_map:
                    pcol=phist[-1][-1].get_edgecolor()
                    plt.axvline(c.get_map_param(p), linewidth=1,
                                color=pcol, linestyle='--')

                if truths is not None:
                    if p not in truths:
                        print(p + ' was not found in truths dict.')
                        continue
                    plt.axvline(truths[p], linewidth=2,
                                color='k', linestyle='-.')

        else:
            gpar_kwargs= dg._get_gpar_kwargs(core, real_tm_pars)
            plt.plot(core.get_param(p, to_burn=True, **gpar_kwargs),
                     lw=linewidth, **plot_kwargs)
        
        if draw_prior:
            x, y = get_prior_distr(core, p)
            axis.plot(x, y, prior_fmt)
        
        if (titles is None) and (fancy_par_names is None):
            if psr_name is not None:
                par_name = p.replace(psr_name+'_', '')
            else:
                par_name = p
            axis.set_title(par_name)
        elif titles is not None:
            axis.set_title(titles[ii])
        elif fancy_par_names is not None:
            axis.set_title(fancy_par_names[ii])

        axis.set_yticks([])
        xticks = kwargs.get('xticks')
        if xticks is not None:
            axis.set_xticks(xticks)

    if suptitle is None:
        guess_times = np.array([psr_name in p for p in params], dtype=int)
        yes = np.sum(guess_times)
        if yes/guess_times.size > 0.5:
            suptitle = 'PSR {0} Noise Parameters'.format(psr_name)
        else:
            suptitle = 'Parameter Posteriors    '

    if legend_labels is not None:
        patches = []
        colors = ['C{0}'.format(ii) for ii in range(len(legend_labels))]
        for ii, lab in enumerate(legend_labels):
            patches.append(mpatches.Patch(color=colors[ii], label=lab))

        fig.legend(handles=patches, loc=legend_loc)

    fig.tight_layout(pad=0.4)
    fig.suptitle(suptitle, y=title_y, fontsize=18)
    # fig.subplots_adjust(top=0.96)
    xlabel = kwargs.get('xlabel')
    if xlabel is not None:
        fig.text(0.5, -0.02, xlabel, ha='center', usetex=False)

    if save:
        plt.savefig(save, dpi=150, bbox_inches='tight')
    if show:
        plt.show()

    plt.close()