In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config IPython.matplotlib.backend = "retina"
from matplotlib import rcParams
rcParams["savefig.dpi"] = 300
rcParams["figure.dpi"] = 300

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import json, os, pickle, glob
import tqdm
import corner
import scipy.linalg as sl
import logging
from IPython.display import display, Math
logger = logging.getLogger(__name__)

In [None]:
import la_forge.diagnostics as dg
import la_forge.core as co
from la_forge.rednoise import gorilla_bf
from la_forge.utils import epoch_ave_resid
from h5pulsar import FilePulsar
from enterprise.signals.utils import ConditionalGP
#from la_forge.gp import Signal_Reconstruction
import enterprise.constants as const
from enterprise_extensions.empirical_distr import EmpiricalDistribution1D, EmpiricalDistribution2D
#from collections import OrderedDict
#DM_K = float(2.41e-4)

In [None]:
from dr3_noise import post_processing_utils as ppu
from dr3_noise import plot_utils as pu
from dr3_noise.models import model_singlepsr_noise

In [None]:
def plot_post_and_set_burn(cs, nchains, param='lnpost', fig_savedir=None,
                           min_burn=500, max_burn=10000, adapt_burnin=False,
                           single_core_burns=[]):
    total_samples = 0
    total_postburn_samples = 0
    fig, ax = plt.subplots(2,1,figsize=(12,6),sharex=False)
    i = 0
    for _ in range(nchains):
        if adapt_burnin:
            idx_bad = np.where(cs[i]('lnpost', to_burn=False) < np.max(cs[i]('lnpost'))-200)[0][-1]
            min_burn = np.max([min_burn,idx_bad+1000])
        burn = np.min([np.max([min_burn, len(cs[i].chain)//4]), max_burn])
        for single_core_burn in single_core_burns:
            if single_core_burn[0] == i:
                burn = single_core_burn[1]
        cs[i].set_burn(burn)
        n_samples = len(cs[i].get_param(param,to_burn=True))
        x = np.arange(np.ceil(total_postburn_samples),
                      np.ceil(total_postburn_samples)+np.ceil(n_samples))
        y = cs[i].get_param(param,to_burn=True)
        ax[0].plot(y, alpha=0.2, lw=0.5, c='k')
        if n_samples < 100:
            print(f'no samples for cs[{i}]')
            cs.pop(i)
            nchains -= 1
        else:
            total_postburn_samples += n_samples
            total_samples += len(cs[i].get_param(param,to_burn=False))
            ax[1].plot(x,y,lw=0.5)
            i += 1
    if len(cs) == 0:
        return 0
    ax[0].set_ylabel(param)
    ax[0].set_title(cs[0].label)
    ax[1].set_ylabel(param)
    ax[-1].set_xlabel('sample')
    fig.tight_layout()
    if fig_savedir:
        fig.savefig(f'{fig_savedir}/{param}_trace.png')

    print(f'burns: {[c.burn for c in cs]}')
    print(total_samples,'total samples')
    print(total_postburn_samples,'samples after burn in')
    print(np.round((total_samples-total_postburn_samples)*100/total_samples,3),'% of samples burned')
    return nchains

In [None]:
def plot_nmodel_traces(cs, fig_savedir):
    total_samples = 0
    total_postburn_samples = 0
    fig, ax = plt.subplots(2,1,figsize=(12,6),sharex=False)
    for i in range(len(cs)):
        n_samples = len(cs[i].get_param('nmodel',to_burn=True))
        x = np.arange(np.ceil(total_postburn_samples),
                      np.ceil(total_postburn_samples)+np.ceil(n_samples))
        y1 = cs[i].get_param('nmodel',to_burn=False)
        ax[0].plot(y1, alpha=0.2, lw=0.5, c='k')
        ax[0].axvline(cs[i].burn, color='k', lw=1, ls='dashed', alpha=0.3)
        total_postburn_samples += n_samples
        total_samples += len(cs[i].get_param('nmodel',to_burn=True))
        y2 = cs[i].get_param('nmodel',to_burn=True)
        ax[1].plot(x,y2,lw=0.5)
    for i in range(2):
        ax[i].set_yticks(np.arange(cs[0].nmodels))
        ax[i].set_yticklabels(['CRN','HD'])
        for j in range(cs[0].nmodels+1):
            ax[i].axhline(j-0.5, lw=0.5, linestyle='dashed', color='k')
        ax[i].set_ylim([-0.5, cs[0].nmodels-0.5])
        ax[i].set_ylabel('model')
    ax[0].set_title(cs[0].label)
    ax[-1].set_xlabel('sample')
    fig.tight_layout()
    fig.savefig(f'{fig_savedir}/nmodel_trace.png')

In [None]:
def plot_proposals(core,ylim=None,ax=None,return_fig=False):
    if ax == None:
        return_ax = True
        fig, ax = plt.subplots(figsize=[8,7])
    else:
        return_ax = False
    # L = len(c1.jumps.keys())
    # half = L//2

    for ii,ky in enumerate(core.jumps.keys()):
        if ii>=9:
            ls='--'
        else:
            ls='-'
        if (ky=='jumps'):# or (ky=='DEJump_jump'):
            pass
        else:
            if ky[0]=='c':
                lab = 'SCAM' if 'SCAM' in ky else 'AM'
            elif ky=='DEJump_jump':
                lab = 'DEJump'
            else:
                lab = ' '.join(np.array(ky.split('_'))[2:-1])
                if 'gwb' in lab:
                    lab = 'gwb log-uniform'
            if lab == 'DEJump':
                deL = core.jumps[ky].size
                jL = core.jumps['covarianceJumpProposalAM_jump'].size

                nums = np.linspace(jL-deL,jL-1,deL)
                ax.plot(nums,core.jumps[ky],label=lab,
                        ls=ls,lw=2,color='C'+str(ii))
            else:
                ax.plot(core.jumps[ky],label=lab,
                        ls=ls,lw=2,color='C'+str(ii))
    if return_ax:
        ax.grid()
        if ylim is not None:
            ax.set_ylim(**ylim)
        ax.legend(loc='upper left',ncol=2,fontsize=11)
        ax.set_ylabel('Acceptance Rate',fontsize=14)
        ax.set_xlabel('Write Out Iteration',fontsize=14)
        ax.set_title('Jump Proposal Acceptance Rates')
        if return_fig:
            return fig, ax
        else:
            return ax

# Load chains

In [None]:
dataset = 'lite_unfiltered_53'
Nfreqs = 13
noisetype = 'advnoise'
chain_name = f'HM_CRN{Nfreqs}_g4p3_{noisetype}'
model_name = f'HM_GWB{Nfreqs}_g4p3_{noisetype}'
# ------------
# specify pulsar, dataset, model, and number of cores
project_path = '/vast/palmer/home.grace/bbl29/IPTA_DR2_analysis'
# where chains are stored
if dataset == 'edr2':
    chaindir = f'/vast/palmer/scratch/mingarelli/bbl29/IPTA_DR2_analysis/chains/edr2/{chain_name}'
    coresave = f'/vast/palmer/home.grace/bbl29/project/IPTA_DR2_analysis/edr2/{model_name}'
    figsave = f'{project_path}/figs/edr2/{model_name}'
else:
    chaindir = f'/vast/palmer/scratch/mingarelli/bbl29/IPTA_DR2_analysis/chains/dr2{dataset}/{chain_name}'
    coresave = f'/vast/palmer/home.grace/bbl29/project/IPTA_DR2_analysis/dr2{dataset}/{model_name}'
    figsave = f'{project_path}/figs/dr2{dataset}/{model_name}'
ePSR_dir = f'{project_path}/data/{dataset}_ePSRs'
load_pt_chains = False

In [None]:
if not os.path.isdir(chaindir):
    raise NameError(f'chaindir does not exist!!')
if not os.path.isdir(coresave):
    os.mkdir(coresave)
if not os.path.isdir(figsave):
    os.mkdir(figsave)

chaindirs = [x[0]+'/' for x in os.walk(chaindir)][1:]
nchains = len(chaindirs)

cs = []
for cd in tqdm.tqdm(chaindirs):
    #if i == 0:
    #    continue
    #print(f'{i+1}/{nchains}')
    try:
        with open(f'{cd}/model_params.json', 'r') as f:
            hm_params = json.load(f)
        with open(f'{cd}/model_labels.json', 'r') as f:
            hm_labels = json.load(f)
        core = co.HyperModelCore(label=f'HM GWB13 g4p3 {noisetype}', param_dict=hm_params,
                                 chaindir=cd, pt_chains=load_pt_chains)
        # load weights
        if len(core.chain) > 200:
            cs.append(core)
    except Exception as e:
        print(f'could not load from {cd}')
        print(f'Exception: {e}')
del core
nchains = len(cs)
if nchains == 0:
    print(f'Not enough samples!!')

In [None]:
for c in cs:
    print(c.chain.shape)
    if not 'lnpost' in c.params:
        c.params += ['lnlike', 'lnpost', 'chain_accept', 'pt_chain_accept']

In [None]:
c_bad = cs.pop(2)
nchains -= 1

In [None]:
nchains = plot_post_and_set_burn(cs, nchains, fig_savedir=figsave,
                                 min_burn=500, max_burn=20000, adapt_burnin=True)#,
#                                 single_core_burns=[[0,20000],[3,25000]])
if nchains == 0:
    print(f'Not enough samples!!')
else:
    plot_nmodel_traces(cs, fig_savedir=figsave)

In [None]:
# plot proposals
fig, ax = plot_proposals(cs[0],return_fig=True)
for i in range(nchains):
    plot_proposals(cs[i], ax=ax)
fig.savefig(f'{figsave}/proposals.png')
pars = cs[0].params[:15] + [p for p in cs[0].params if 'dip' in p] + cs[0].params[-7:]
dg.plot_chains(cs, pars=pars, hist=True, ncols=5)

In [None]:
# make core with all cores together
chains = []
hot_chains = {}
for ci in cs:
    chains.append(ci.chain[ci.burn::])
chain_array = np.concatenate(chains)
c = co.HyperModelCore(chain=chain_array, label=cs[0].label, param_dict=hm_params,
                      params=cs[0].params, pt_chains=False, burn=0)
# add priors
prior_path = glob.glob(chaindir + '/*/priors.txt')[0]
c.priors = np.loadtxt(prior_path, dtype=str, delimiter='\t')
# add runtime info
info_path = glob.glob(chaindir + '/*/runtime_info.txt')[0]
c.runtime_info = np.loadtxt(info_path, dtype=str, delimiter='\t')
# add hot chains
if load_pt_chains:
    for T in cs[0].hot_chains:
        hot_chains[T] = []
        for ci in cs:
            hot_chains[T].append(ci.hot_chains[T][ci.burn::])
        hot_chains[T] = np.concatenate(hot_chains[T])
    c.hot_chains = hot_chains
c.save(f'{coresave}/core.h5')

In [None]:
c = co.HyperModelCore(corepath=f'{coresave}/core.h5', burn=0)

In [None]:
from enterprise_extensions.model_utils import odds_ratio

In [None]:
BF, BF_err = odds_ratio(c('nmodel'))
display(Math(fr'\mathcal{{B}}^{{\rm{{HD}}}}_{{\rm{{CRN}}}} = {BF:0.3f} \pm {BF_err:0.3f}'))

In [None]:
dg.plot_chains([c.model_core(i) for i in range(2)],legend_labels=['CARN','HD'], ncols=5,
               pars=c.params[:20] + ['lnlike'])

In [None]:
dg.plot_grubin(c)
dg.plot_neff(c)
#if ppu.check_convergence(c, psrname=model_name, plot=True, Nsample_threshold=50000):
#    print(f'Convergence passed!!')

In [None]:
fig, ax = plt.subplots(figsize=(3,2), dpi=500)
ax.hist(c.model_core(0)('crn_log10_A'), histtype='step', density=True, bins=30, label='Auto correlations only')
ax.hist(c.model_core(1)('gwb_log10_A'), histtype='step', density=True, bins=30, label='Auto+cross correlations')
ax.set_xlabel(r'$\log_{10}A_{\rm{CRN}}$')
ax.set_ylabel('PDF')
ax.legend(fontsize='xx-small')
fig.tight_layout()
fig.savefig(f'{figsave}/HM_gwb_log10_A.png', dpi=500, bbox_inches='tight')

### Pulsar plots

In [None]:
def make_labels(params, psrname, add_psrname=True):
    labels = []
    if add_psrname:
        add = psrname+'\n'
    else:
        add = ''
    for p in params:
        if 'dm_gp_gamma' in p:
            labels.append(add+r'$\gamma_{\rm{DM}}$')
        elif 'dm_gp_log10_A' in p:
            labels.append(add+r'$\log_{10}A_{\rm{DM}}$')
        elif 'red_noise_gamma' in p:
            labels.append(add+r'$\gamma_{\rm{RN}}$')
        elif 'red_noise_log10_A' in p:
            labels.append(add+r'$\log_{10}A_{\rm{RN}}$')
        elif p == 'crn_log10_A':
            labels.append(r'$\log_{10}A_{\rm{CRN}}$')
        elif 'exp1_log10_Amp' in p:
            labels.append(add+r'$\log_{10}A_{\rm{exp}}$')
        elif 'exp1_log10_tau' in p:
            labels.append(add+r'$\log_{10}\tau_{\rm{exp}}$')
        elif 'exp1_t0' in p:
            labels.append(add+r'$t_{\rm{exp}}$')
        else:
            labels.append(p)
    return labels

In [None]:
psrnames = np.unique([p.split('_')[0] for p in c.params if 'J' in p])
for psrname in psrnames:
    corepath = f'/vast/palmer/home.grace/bbl29/project/IPTA_DR2_analysis/dr2full/advnoise/{psrname}/core.h5'
    c_psr = co.Core(corepath=corepath, burn=0)
    # first make a corner plot with the full run on it, minus the Common red noise
    params = [p for p in c.params if psrname in p]# + ['crn_log10_A']
    p_idxs = [c.params.index(p) for p in params]
    labels = make_labels(params, psrname, add_psrname=False)
    fig = corner.corner(c.chain[:,p_idxs], levels=(0.68,0.95), labels=labels,
                        color='k', plot_density=False, plot_datapoints=False,
                        no_fill_contours=True, hist_kwargs={'density':True},
                        label_kwargs={'fontsize':20})
    # next add single pulsar run
    p_psr_idxs = [c_psr.params.index(p) for p in params]
    fig = corner.corner(c_psr.chain[:,p_psr_idxs], levels=(0.68,0.95), fig=fig,
                        color='C0', plot_density=False, plot_datapoints=False,
                        no_fill_contours=True, hist_kwargs={'density':True})
    # finally add CRN onto intrinsic red noise plot
    ndim = len(params)
    axes = np.reshape(fig.axes, (ndim,ndim))
    idx_crn = c.params.index('crn_log10_A')
    axes[-1,-1].hist(c.chain[:,idx_crn], color='C1',
                     density=True, bins=20, histtype='step')
    # add 13/3 line
    axes[-1,-2].axvline(13/3, color='C1', ls='dashed')
    axes[-2,-2].axvline(13/3, color='C1', ls='dashed')
    lines = [mlines.Line2D([],[],color='C0',label='Single run')]
    lines += [mlines.Line2D([],[],color='k',label='PTA run')]
    lines += [mlines.Line2D([],[],color='C1',label='CRN')]
    fig.legend(handles=lines, fontsize=20)
    fig.suptitle(f'PSR {psrname}', fontsize=30)
    fig.savefig(f'{figsave}/{psrname}.png', dpi=300, bbox_inches='tight')

In [None]:
# Finally, let's identify pulsars with significant RN whose RN is no longer significant
fig, ax = plt.subplots(figsize=(9,7))
i = 0
lines = []
for psrname in psrnames:
    corepath = f'/vast/palmer/home.grace/bbl29/project/IPTA_DR2_analysis/dr2full/stdnoise/{psrname}/core.h5'
    c_psr = co.Core(corepath=corepath, burn=0)
    x = c_psr(f'{psrname}_red_noise_gamma')
    y = c_psr(f'{psrname}_red_noise_log10_A')
    BF = gorilla_bf(y, min=-20, max=-11)
    BF_new = gorilla_bf(c(f'{psrname}_red_noise_log10_A'), min=-20, max=-11)
    if (np.isnan(BF) or BF > 10) and (BF_new < 10):
        print(f'adding {psrname}')
        corner.hist2d(x, y, fig=fig, levels=(0.68,), color=f'C{i}',
                      plot_density=False, no_fill_contours=True, plot_datapoints=False,
                      contour_kwargs={'lw':1,'alpha':1})
        lines.append(mlines.Line2D([],[],color=f'C{i}',label=psrname))
        i += 1
ymed = c.get_param_median('crn_log10_A')
ylo = c.get_param_credint('crn_log10_A', interval=95)[0]
yhi = c.get_param_credint('crn_log10_A', interval=95)[1]
ax.errorbar(x=[13/3], y=[ymed], yerr=[[ymed-ylo],[yhi-ymed]], fmt='sk', ms=5)
ax.set_xlabel(r'$\gamma_{\rm{RN}}$')
ax.set_ylabel(r'$\log_{10}A_{\rm{RN}}$')
ax.set_xlim([0,7])
ax.set_ylim([-17,-11])
ax.set_title('Red noise posteriors (pulsars contributing to CP)')
ax.grid(lw=0.3)
ax.legend(handles=lines, fontsize='small')
plt.gca().set_aspect('equal')

# Make empirical distributions

Need list of the following:
- 2D dist for RN for each pulsars
- 2D dist for DM for each pulsar
- 1D dists for dip params for J1713
- 2D dist for common red noise

In [None]:
c

In [None]:
def make_empirical_distributions(c, psrnames, burn=0, nbins=40,
                                 filename=None, return_distribution=False):
    
    distr = []
    paramlist = []
    
    i = 0
    for psrname in psrnames:
        
        # make param list
        paramlist.append([f'{psrname}_red_noise_gamma',f'{psrname}_red_noise_log10_A'])
        paramlist.append([f'{psrname}_dm_gp_gamma',f'{psrname}_dm_gp_log10_A'])
        if psrname == 'J1713+0747':
            paramlist.append([f'{psrname}_exp1_log10_Amp'])
            paramlist.append([f'{psrname}_exp1_log10_tau'])
            paramlist.append([f'{psrname}_exp1_t0'])

    paramlist.append(['crn_log10_A'])

    for pl in paramlist:

        if type(pl) is not list:

            pl = [pl]

        if len(pl) == 1:
            idx = c.params.index(pl[0])

            prior_min = float(c.priors[idx][c.priors[idx].index('pmin')+5:c.priors[idx].index(',')])
            prior_max = float(c.priors[idx][c.priors[idx].index('pmax')+5:c.priors[idx].index(')')])

            # get the bins for the histogram
            bins = np.linspace(prior_min, prior_max, nbins)

            new_distr = EmpiricalDistribution1D(pl[0], c.chain[c.burn:, idx], bins)

            distr.append(new_distr)

        elif len(pl) == 2:

            # get the parameter indices
            idx = [c.params.index(pl1) for pl1 in pl]

            # get the bins for the histogram
            bins = [np.linspace(float(c.priors[i][c.priors[i].index('pmin')+5:c.priors[i].index(',')]),
                                float(c.priors[i][c.priors[i].index('pmax')+5:c.priors[i].index(')')]),
                                nbins) for i in idx]

            new_distr = EmpiricalDistribution2D(pl, c.chain[c.burn:, idx].T, bins)

            distr.append(new_distr)

        else:
            msg = 'WARNING: only 1D and 2D empirical distributions are currently allowed.'
            logger.warning(msg)

    # save the list of empirical distributions as a pickle file
    if filename is not None:
        if len(distr) > 0:
            with open(filename, 'wb') as f:
                pickle.dump(distr, f)

            msg = 'The empirical distributions have been pickled to {0}.'.format(filename)
            logger.info(msg)
        else:
            msg = 'WARNING: No empirical distributions were made!'
            logger.warning(msg)

    if return_distribution:
        return distr

In [None]:
filename = f'{project_path}/empdists/dr2full_crn_stdnoise.pkl'
distr = make_empirical_distributions(c, psrnames, burn=0, nbins=40,
                                     filename=filename, return_distribution=True)

In [None]:
distr

In [None]:
# remake the whitened residuals plots so they are all in one place
for psrname in psrs_done:
    print(psrname)
    c = co.Core(corepath=f'{coresave_dir}/{psrname}/core.h5')
    fname = f'{ePSR_dir}/{psrname}.hdf5'
    psr = FilePulsar(fname)
    # set up noise flags + selections
    inc_ecorr = False
    signal_names = [f'{psrname}_dm_gp',f'{psrname}_red_noise']
    colors = ['C1','C3']
    # set up DM Nfreqs
    pta = model_singlepsr_noise(psr, Tspan=952746385.6296968,
                                # timing -  set svd false for GPs
                                tm_svd=False,
                                # white noise - set gp_ecorr True for GPs
                                tnequad=True, inc_ecorr=inc_ecorr, gp_ecorr=True,
                                #efeq_groups=efeq_groups_by_PTA, ecorr_groups=ecorr_groups_by_PTA,
                                log_equad_min=-10, log_equad_max=-4,
                                # DM
                                dm_var=True, dm_type='gp',
                                dmgp_kernel='diag', dm_psd='powerlaw',
                                dm_Nfreqs=30,
                                # solar wind
                                dm_sw_deter=False,
                                # dm dip
                                dm_expdip=dm_expdip, dm_expdip_basename='exp',
                                dm_expdip_tau_min=np.log10(5), dm_expdip_tau_max=np.log10(500), 
                                # red noise
                                log_A_min=-20, log_A_max=-11)
    for signal in pta.signals:
        if pta.signals[signal].signal_type == 'basis':
            pta.signals[signal].basis_combine=False
    n_realizations = 500
    pu.plot_gp_realizations(c, psr, pta,
                            signal_names, colors, method='enterprise',
                            n_realizations=n_realizations,
                            alpha=0.02, save=figsave)
    idxs = np.random.choice(np.arange(len(c('lnpost'))), n_realizations)
    gp = ConditionalGP(pta)
    mlv_params = {p:c(p)[c.map_idx] for p in c.params}
    mlv_GPs = gp.sample_processes(mlv_params, n=1)[0]
    mlv_correction = np.sum([mlv_GPs[sn] for sn in mlv_GPs], axis=0)
    pu.plot_resids(psr, correction=mlv_correction, save=f'{figsave_dir}/{psrname}',
                   correction_label='_enterprise_ConditionalGP_MLV')

In [None]:
from corner import hist2d
import matplotlib.lines as mlines

In [None]:
fig, ax = plt.subplots(figsize=(9,7))
i = 0
lines = []
for psrname in psrnames:
    c = co.Core(corepath=f'{coresave_dir}/{psrname}/core.h5', burn=0)
    x = c(f'{psrname}_red_noise_gamma')
    y = c(f'{psrname}_red_noise_log10_A')
    BF = gorilla_bf(y, min=-20, max=-11)
    if np.isnan(BF) or BF > 10:
        hist2d(x, y, fig=fig, levels=(0.68,), color=f'C{i}',
               plot_density=False, no_fill_contours=True, plot_datapoints=False,
               contour_kwargs={'lw':1,'alpha':1})
        lines.append(mlines.Line2D([],[],color=f'C{i}',label=psrname))
        i += 1
ax.set_xlabel(r'$\gamma_{\rm{RN}}$')
ax.set_ylabel(r'$\log_{10}A_{\rm{RN}}$')
ax.set_xlim([0,7])
ax.set_ylim([-17,-11])
ax.set_title('Red noise posteriors')
ax.grid(lw=0.3)
ax.legend(handles=lines, fontsize='small')
plt.gca().set_aspect('equal')

In [None]:
fig, ax = plt.subplots(figsize=(9,7))
i = 0
lines = []
for psrname in psrnames:
    c = co.Core(corepath=f'{coresave_dir}/{psrname}/core.h5', burn=0)
    x = c(f'{psrname}_dm_gp_gamma')
    y = c(f'{psrname}_dm_gp_log10_A')
    BF = gorilla_bf(y, min=-20, max=-11)
    if np.isnan(BF) or BF > 10:
        hist2d(x, y, fig=fig, levels=(0.68,), color=f'C{i}',
               plot_density=False, no_fill_contours=True, plot_datapoints=False,
               contour_kwargs={'lw':1,'alpha':1})
        lines.append(mlines.Line2D([],[],color=f'C{i}',label=psrname))
        i += 1
ax.set_xlabel(r'$\gamma_{\rm{DM}}$')
ax.set_ylabel(r'$\log_{10}A_{\rm{DM}}$')
ax.set_title('DM noise posteriors')
ax.set_xlim([0,7])
ax.set_ylim([-17,-11])
ax.grid(lw=0.3)
ax.legend(handles=lines, fontsize='small')
plt.gca().set_aspect('equal')

In [None]:
psrname = 'J1713+0747'
c = co.Core(corepath=f'{coresave_dir}/{psrname}/core.h5')
pu.make_correlated_noise_corners(c, plot_ml_values=False, plot_med_values=True,
                                 noise_types=['red_noise','dm_gp','exp1'],
                                 save=f'{figsave_dir}/{psrname}')

In [None]:
figsave_dir

In [None]:
print('hi')

In [None]:
plt.close('all')

# Make noise dictionary

Here we will use median values to match DR2

In [None]:
noise_dict = {}
for psrname in psrnames:
    c = co.Core(corepath=f'{coresave_dir}/{psrname}/core.h5', burn=0)
    for p in c.params[:-4]:
        noise_dict[p] = c.get_param_median(p)

In [None]:
filename = f'{project_path}/noisedicts/dr2full_stdnoise.json'
with open(filename, 'w') as f:
    json.dump(noise_dict, f, indent=4, sort_keys=True)

# Make empirical distributions

Need list of the following:
- 2D dist for RN for each pulsars
- 2D dist for DM for each pulsar
- 1D dists for dip params for J1713

In [None]:
psrname = 'J1713+0747'
c = co.Core(corepath=f'{coresave_dir}/{psrname}/core.h5')

In [None]:
# here we need to make a chain with all the relevant parameters
# note red noise proposals may not be accepted since IRN is different from total RN
# would be better to use a fact like run
chain = np.zeros((len(psrnames)*4+3,200000))
param_list = []
params = []
i = 0
for psrname in psrnames:
    singlepsr_params = []
    singlepsr_params.append([f'{psrname}_red_noise_gamma',f'{psrname}_red_noise_log10_A'])
    singlepsr_params.append([f'{psrname}_dm_gp_gamma',f'{psrname}_dm_gp_log10_A'])
    if psrname == 'J1713+0747':
        singlepsr_params.append([f'{psrname}_exp1_log10_Amp'])
        singlepsr_params.append([f'{psrname}_exp1_log10_tau'])
        singlepsr_params.append([f'{psrname}_exp1_t0'])
    param_list.extend(singlepsr_params)
    singlepsr_params_flat = [p1 for p2 in singlepsr_params for p1 in p2]
    params.extend(singlepsr_params_flat)
    c = co.Core(corepath=f'{coresave_dir}/{psrname}/core.h5',burn=0)
    for p in singlepsr_params_flat:
        chain[i] = c(p)[:200000]
        i += 1

In [None]:
def make_empirical_distributions(psrnames, burn=0, nbins=40,
                                 filename=None, return_distribution=False):
    
    distr = []
    
    i = 0
    for psrname in psrnames:
        
        # load core
        c = co.Core(corepath=f'{coresave_dir}/{psrname}/core.h5', burn=burn)
        
        # make param list
        paramlist = []
        paramlist.append([f'{psrname}_red_noise_gamma',f'{psrname}_red_noise_log10_A'])
        paramlist.append([f'{psrname}_dm_gp_gamma',f'{psrname}_dm_gp_log10_A'])
        if psrname == 'J1713+0747':
            paramlist.append([f'{psrname}_exp1_log10_Amp'])
            paramlist.append([f'{psrname}_exp1_log10_tau'])
            paramlist.append([f'{psrname}_exp1_t0'])

        for pl in paramlist:

            if type(pl) is not list:

                pl = [pl]

            if len(pl) == 1:
                idx = c.params.index(pl[0])

                prior_min = float(c.priors[idx][c.priors[idx].index('pmin')+5:c.priors[idx].index(',')])
                prior_max = float(c.priors[idx][c.priors[idx].index('pmax')+5:c.priors[idx].index(')')])

                # get the bins for the histogram
                bins = np.linspace(prior_min, prior_max, nbins)

                new_distr = EmpiricalDistribution1D(pl[0], c.chain[c.burn:, idx], bins)

                distr.append(new_distr)

            elif len(pl) == 2:

                # get the parameter indices
                idx = [c.params.index(pl1) for pl1 in pl]

                # get the bins for the histogram
                bins = [np.linspace(float(c.priors[i][c.priors[i].index('pmin')+5:c.priors[i].index(',')]),
                                    float(c.priors[i][c.priors[i].index('pmax')+5:c.priors[i].index(')')]),
                                    nbins) for i in idx]

                new_distr = EmpiricalDistribution2D(pl, c.chain[c.burn:, idx].T, bins)

                distr.append(new_distr)

            else:
                msg = 'WARNING: only 1D and 2D empirical distributions are currently allowed.'
                logger.warning(msg)

    # save the list of empirical distributions as a pickle file
    if filename is not None:
        if len(distr) > 0:
            with open(filename, 'wb') as f:
                pickle.dump(distr, f)

            msg = 'The empirical distributions have been pickled to {0}.'.format(filename)
            logger.info(msg)
        else:
            msg = 'WARNING: No empirical distributions were made!'
            logger.warning(msg)

    if return_distribution:
        return distr

In [None]:
filename = f'{project_path}/empdists/dr2full_stdnoise.pkl'
distr = make_empirical_distributions(psrnames, burn=0, nbins=40,
                                     filename=filename, return_distribution=True)

## Check ConditionalGP ECORR

In [None]:
psrname = 'J1939+2134'
c = co.Core(corepath=f'{coresave_dir}/{psrname}/core.h5', burn=0)
fname = f'{ePSR_dir}/{psrname}.hdf5'
psr = FilePulsar(fname)
# set up noise flags + selections
inc_ecorr = False
signal_names = [f'{psrname}_dm_gp',f'{psrname}_red_noise']
colors = ['C1','C3']
# set up DM Nfreqs
pta = model_singlepsr_noise(psr, Tspan=952746385.6296968,
                            # timing -  set svd false for GPs
                            tm_svd=False,
                            # white noise - set gp_ecorr True for GPs
                            tnequad=True, inc_ecorr=True, gp_ecorr=True,
                            #efeq_groups=efeq_groups_by_PTA, ecorr_groups=ecorr_groups_by_PTA,
                            log_equad_min=-10, log_equad_max=-4,
                            # DM
                            dm_var=True, dm_type='gp',
                            dmgp_kernel='diag', dm_psd='powerlaw',
                            dm_Nfreqs=30,
                            # solar wind
                            dm_sw_deter=False,
                            # dm dip
                            dm_expdip=False, dm_expdip_basename='exp',
                            dm_expdip_tau_min=np.log10(5), dm_expdip_tau_max=np.log10(500), 
                            # red noise
                            log_A_min=-20, log_A_max=-11)
for signal in pta.signals:
    if pta.signals[signal].signal_type == 'basis':
        pta.signals[signal].basis_combine=False
gp = ConditionalGP(pta)
mlv_params = {p:c(p)[c.map_idx] for p in c.params}
mlv_GPs = gp.sample_processes(mlv_params, n=1)[0]

In [None]:
test_params = {'B1937+21_L-wide_ASP_log10_ecorr': -6.985907827570632,
 'B1937+21_L-wide_PUPPI_log10_ecorr': -6.933190952449128,
 'B1937+21_Rcvr1_2_GASP_log10_ecorr': -6.951805082863183,
 'B1937+21_Rcvr1_2_GUPPI_log10_ecorr': -6.9742651134780225,
 'B1937+21_Rcvr_800_GASP_log10_ecorr': -8.377559128584869,
 'B1937+21_Rcvr_800_GUPPI_log10_ecorr': -6.347443803757376,
 'B1937+21_S-wide_ASP_log10_ecorr': -6.560243521577494,
 'B1937+21_S-wide_PUPPI_log10_ecorr': -6.998893530255561,
 'B1937+21_basis_ecorr_L-wide_ASP_log10_ecorr': -6.985907827570632,
 'B1937+21_basis_ecorr_L-wide_PUPPI_log10_ecorr': -6.933190952449128,
 'B1937+21_basis_ecorr_Rcvr1_2_GASP_log10_ecorr': -6.951805082863183,
 'B1937+21_basis_ecorr_Rcvr1_2_GUPPI_log10_ecorr': -6.9742651134780225,
 'B1937+21_basis_ecorr_Rcvr_800_GASP_log10_ecorr': -8.377559128584869,
 'B1937+21_basis_ecorr_Rcvr_800_GUPPI_log10_ecorr': -6.347443803757376,
 'B1937+21_basis_ecorr_S-wide_ASP_log10_ecorr': -6.560243521577494,
 'B1937+21_basis_ecorr_S-wide_PUPPI_log10_ecorr': -6.998893530255561}

In [None]:
test_GPs = gp.sample_processes(mlv_params, n=1)[0]

In [None]:
test_GPs

In [None]:
for p in c.params:
    if 'ecorr' in p:
        print(f'{p}: {c.get_map_param(p)}')

In [None]:
correction = np.sum([mlv_GPs[sn] for sn in mlv_GPs if not 'ecorr' in sn], axis=0)

In [None]:
fig, ax = plt.subplots(figsize=(8,4))
nu_min = np.min(psr.freqs)
nu_max = np.max(psr.freqs)
#mean_GPs = gp.get_mean_processes(mlv_params)
if isinstance(correction, np.ndarray):
    resids = psr.residuals - correction
else:
    resids = psr.residuals
for marker, pta in zip(['s'], ['NANOGrav']):
    if pta in psr.flags['pta']:
        mask = (psr.flags['pta'] == pta)*(psr.flags['group'] != 'kaspi23')*(psr.flags['group'] != 'kaspi14')
        ax.errorbar(psr.toas[mask]/const.day, resids[mask]*1e6, yerr=psr.toaerrs[mask]*1e6,
                    fmt=f'{marker}k', ms=5, marker=None, mew=0, alpha=0.1, lw=1, zorder=0,
                    label='RN+DM corrected Residuals')
        ax.plot(psr.toas[mask]/const.day, 1e6*mlv_GPs[f'{psrname}_basis_ecorr'][mask],
                'C1', label='ECORR realization from MLV posterior sample')
        #ax.plot(psr.toas[mask]/const.day, 1e6*mean_GPs[f'{psrname}_basis_ecorr'][mask],
        #        'C1', label='Sample process (MLV posterior sample)')
        #sc = ax.scatter(psr.toas[mask]/const.day, resids[mask]*1e6, s=5, marker=marker,
        #                c=psr.freqs[mask], cmap='Spectral', vmin=nu_min, vmax=nu_max)
#cbar = plt.colorbar(sc)
#cbar.set_label(r"$\nu$ (MHz)")
ax.set_xlabel("MJD")
ax.set_ylabel(r"Residual ($\mu s$)")
ax.grid(linewidth=0.3)
if isinstance(correction, np.ndarray):
    ax.set_title(psr.name+' Corrected, NG only')
else:
    ax.set_title(psr.name)
ax.legend()
fig.tight_layout()

In [None]:
fig, ax = plt.subplots(figsize=(8,4))
nu_min = np.min(psr.freqs)
nu_max = np.max(psr.freqs)
mean_GPs = gp.get_mean_processes(mlv_params)
if isinstance(correction, np.ndarray):
    resids = psr.residuals - correction
else:
    resids = psr.residuals
for marker, pta in zip(['s'], ['NANOGrav']):
    if pta in psr.flags['pta']:
        mask = (psr.flags['pta'] == pta)*(psr.flags['group'] != 'kaspi23')*(psr.flags['group'] != 'kaspi14')
        ax.errorbar(psr.toas[mask]/const.day, resids[mask]*1e6, yerr=psr.toaerrs[mask]*1e6,
                    fmt=f'{marker}k', ms=5, marker=None, mew=0, alpha=0.1, lw=1, zorder=0,
                    label='RN+DM corrected Residuals')
        ax.plot(psr.toas[mask]/const.day, 1e7*mean_GPs[f'{psrname}_basis_ecorr'][mask],
                'C1', label=r'$10 \times$ ECORR realization from MLV posterior sample')
        #ax.plot(psr.toas[mask]/const.day, 1e7*mean_GPs[f'{psrname}_basis_ecorr'][mask],
        #        'C1', label='Sample process (MLV posterior sample)')
        #sc = ax.scatter(psr.toas[mask]/const.day, resids[mask]*1e6, s=5, marker=marker,
        #                c=psr.freqs[mask], cmap='Spectral', vmin=nu_min, vmax=nu_max)
#cbar = plt.colorbar(sc)
#cbar.set_label(r"$\nu$ (MHz)")
ax.set_xlabel("MJD")
ax.set_ylabel(r"Residual ($\mu s$)")
ax.grid(linewidth=0.3)
if isinstance(correction, np.ndarray):
    ax.set_title(psr.name+' Corrected, NG only')
else:
    ax.set_title(psr.name)
ax.legend()
fig.tight_layout()

# Compare different versions of GP reconstruction

In [None]:
gp = ConditionalGP(pta, tm_params=['DM','DM1','DM2'], psr=psr)

In [None]:
# define processes
n_realizations = 500
signal_names = [f'{psrname}_basis_ecorr',f'{psrname}_dm_gp',
                f'{psrname}_red_noise',f'{psrname}_linear_timing_model']
colors = ['C0','C1','C3','C4']
idxs = np.random.choice(np.arange(len(c('lnpost'))), n_realizations)
idxs[-1] = c.map_idx

In [None]:
signals = dict.fromkeys(signal_names)

for i in tqdm.tqdm(range(n_realizations)):
    idx = idxs[i]
    params = {p:c(p)[idx] for p in c.params}
    GPs = gp.sample_processes(params, n=1)[0]

    for j, signal_name in enumerate(signal_names):
        if i == 0:
            signals[signal_name] = np.zeros((n_realizations,len(psr.toas)))
        if signal_names[j] == f'{psrname}_dm_gp':
            scaling = 1e15*const.DM_K*psr.freqs**2 # units of 1e-3 pc/cm^3
            add = 0#GPs[f'{psrname}_DM'] + GPs[f'{psrname}_DM1'] + GPs[f'{psrname}_DM2']
        else:
            scaling = 1e6 # microseconds
            add = 0
        signals[signal_name][i] = scaling*(GPs[signal_name] + add)
        #ax[i].plot(psr.toas/const.day, scaling*(GPs[signal_names[i]] + add), alpha=0.02, color=colors[i])
        #ax[i].errorbar(x=psr.toas/const.day, y=np.mean(scaling*GPs[signal_name]*1e6),
        #               yerr=2*np.std(scaling*GPs[signal_name]*1e6), color='k')

# plot processes
ngb = ["ASP", "GASP", "GUPPI", "PUPPI", "YUPPI"]
fig, ax = plt.subplots(len(signal_names),1,figsize=(8,2*len(signal_names)),sharex=True)
for i, signal_name in enumerate(signal_names):
    units = r'($\mu$s)'
    mask = np.ones(len(psr.toas),dtype=bool)
    if signal_name == f'{psrname}_dm_gp':
        units = r'($10^{-3}$ pc/cm$^3$)'
    elif signal_name == f'{psrname}_basis_ecorr':
        mask = [i for i, flag in enumerate(psr.flags['group']) if
                any(ngb_entry in flag for ngb_entry in ngb)]
        #mask = [i for i, flag in enumerate(psr.flags['group']) if flag in ngb]
        #mask = np.sum([psr.flags['group'] == flag for flag in ngb])
    alpha=0.02
    color=colors[i]
    for j in range(n_realizations):
        if j == n_realizations-1:
            alpha=1
            color='k'
        ax[i].plot(psr.toas[mask]/const.day,
                   signals[signal_name][j][mask]-np.mean(signals[signal_name][j][mask]),
                   alpha=alpha, color=color)
    ax[i].set_ylabel(f'{signal_name.replace(f"{psrname}_","")} {units}')
    ax[i].grid(lw=0.3)
ax[-1].set_xlabel('MJD')
xlim = ax[0].get_xlim()
fig.suptitle(f'{psrname} processes using ConditionalGP')
fig.tight_layout()
fig.subplots_adjust(hspace=0)
#fig.savefig(f'{figsave}/GPs.png')

# plot ecorrs
ptas = np.unique(psr.flags['pta'])
if 'NANOGrav' in psr.flags['pta']:
    fig, ax = plt.subplots(figsize=(8,3))
    ecorr_med = np.median(signals[f'{psrname}_basis_ecorr'], axis=0)
    flags_select = np.unique(psr.flags['group'][psr.flags['pta'] == 'NANOGrav'])
    for j, flag in enumerate([flag for flag in flags_select if
                              any(ngb_entry in flag for ngb_entry in ngb)]):
        mask = psr.flags['group'] == flag
        ax.plot(psr.toas[mask]/const.day, ecorr_med[mask],'.', ms=2, label=flag)
    ax.grid(linewidth=0.3)
    ax.set_ylabel(f'NANOGrav: '+r'$\Delta t_{\rm{ECORR}}$ ($\mu s$)')
    ax.legend(fontsize='small')
    ax.set_xlabel('MJD')
    ax.set_title(f'ECORR | PSR {psrname} | ConditionalGP', fontsize='small')
    ax.set_xlim(xlim)
    fig.tight_layout()
    fig.subplots_adjust(hspace=0)
    #fig.savefig(f'{figsave}/ECORRs.png')

# mlv correction
correction = np.sum([GPs[sn] for sn in signal_names], axis=0)
pu.plot_resids(psr, correction=correction, correction_label='_enterprise_ConditionalGP')

In [None]:
gp = Signal_Reconstruction([psr], pta, core=c)

In [None]:
lf_ecorr = np.zeros((n_realizations,len(psr.toas)))
lf_RN = np.zeros((n_realizations,len(psr.toas)))
lf_DM = np.zeros((n_realizations,len(psr.toas)))
lf_TM = np.zeros((n_realizations,len(psr.toas)))
for i in tqdm.tqdm(range(n_realizations)):
    lf_ecorr[i] = gp.reconstruct_signal(gp_type='basis_ecorr',idx=idxs[0])[psrname]
    lf_RN[i] = gp.reconstruct_signal(gp_type='red_noise',idx=idxs[0])[psrname]
    lf_DM[i] = gp.reconstruct_signal(gp_type='dm_gp',idx=idxs[0])[psrname]
    lf_TM[i] = gp.reconstruct_signal(gp_type='timing',idx=idxs[0])[psrname]
signals_lf = {f'{psrname}_basis_ecorr':1e6*lf_ecorr,
              f'{psrname}_dm_gp':1e15*const.DM_K*psr.freqs**2*lf_DM,
              f'{psrname}_red_noise':1e6*lf_RN,
              f'{psrname}_linear_timing_model':1e6*lf_TM}

In [None]:
# plot processes
ngb = ["ASP", "GASP", "GUPPI", "PUPPI", "YUPPI"]
fig, ax = plt.subplots(len(signal_names),1,figsize=(8,2*len(signal_names)),sharex=True)
for i, signal_name in enumerate(signal_names):
    units = r'($\mu$s)'
    mask = np.ones(len(psr.toas),dtype=bool)
    if signal_name == f'{psrname}_dm_gp':
        units = r'($10^{-3}$ pc/cm$^3$)'
    elif signal_name == f'{psrname}_basis_ecorr':
        mask = [i for i, flag in enumerate(psr.flags['group']) if
                any(ngb_entry in flag for ngb_entry in ngb)]
        #mask = [i for i, flag in enumerate(psr.flags['group']) if flag in ngb]
        #mask = np.sum([psr.flags['group'] == flag for flag in ngb])
    alpha = 0.02
    color = colors[i]
    for j in range(n_realizations):
        if j == n_realizations-1:
            alpha=1
            color='k'
        ax[i].plot(psr.toas[mask]/const.day,
                   signals_lf[signal_name][j][mask]-np.mean(signals_lf[signal_name][j][mask]),
                   alpha=alpha, color=color)
    ax[i].set_ylabel(f'{signal_name.replace(f"{psrname}_","")} {units}')
    ax[i].grid(lw=0.3)
ax[-1].set_xlabel('MJD')
xlim = ax[0].get_xlim()
fig.suptitle(f'{psrname} processes using SignalReconstruction')
fig.tight_layout()
fig.subplots_adjust(hspace=0)
#fig.savefig(f'{figsave}/GPs.png')

# plot ecorrs
if 'NANOGrav' in psr.flags['pta']:
    fig, ax = plt.subplots(figsize=(8,3))
    ecorr_med = np.median(signals_lf[f'{psrname}_basis_ecorr'], axis=0)
    flags_select = np.unique(psr.flags['group'][psr.flags['pta'] == 'NANOGrav'])
    for j, flag in enumerate([flag for flag in flags_select if
                              any(ngb_entry in flag for ngb_entry in ngb)]):
        mask = psr.flags['group'] == flag
        ax.plot(psr.toas[mask]/const.day, ecorr_med[mask],'.', ms=2, label=flag)
    ax.grid(linewidth=0.3)
    ax.set_ylabel(f'NANOGrav: '+r'$\Delta t_{\rm{ECORR}}$ ($\mu s$)')
    ax.legend(fontsize='small')
    ax.set_xlabel('MJD')
    ax.set_title(f'ECORR | PSR {psrname} | SignalReconstruction', fontsize='small')
    ax.set_xlim(xlim)
    fig.tight_layout()
    fig.subplots_adjust(hspace=0)

# mlv correction
correction = gp.reconstruct_signal(gp_type='all', mlv=True)[psrname]
pu.plot_resids(psr, correction=correction, correction_label='_la_forge')

In [None]:
# define my own version of Signal_Reconstruction for debugging...
class Signal_Reconstruction():
    '''
    Class for building Gaussian process realizations from enterprise models.
    '''

    def __init__(self, psrs, pta, chain=None, burn=None,
                 p_list='all', core=None):
        '''
        Parameters
        ----------

        psrs : list
            A list of enterprise.pulsar.Pulsar objects.

        pta : enterprise.signal_base.PTA
            The PTA object from enterprise that contains the signals for making
            realizations.

        chain : array
            Array which contains chain samples from Bayesian analysis.

        burn : int
            Length of burn.

        p_list : list of str, optional
            A list of pulsar names that dictates which pulsar signals to
            reproduce. Useful when looking at a full PTA.

        core : la_forge.core.Core, optional
            A core which contains the same information as the chain of samples.

        '''
        if not isinstance(psrs, list):
            psrs = [psrs]

        self.psrs = psrs
        self.pta = pta
        self.p_names = [psrs[ii].name for ii in range(len(psrs))]

        if chain is None and core is None:
            raise ValueError('Must provide a chain or a la_forge.Core object.')
        if chain is None and core is not None:
            chain = core.chain
            burn = core.burn

        self.chain = chain
        if burn is None:
            self.burn = int(0.25*chain.shape[0])
        else:
            self.burn = burn

        self.DM_K = DM_K
        self.mlv_idx = np.argmax(chain[:, -4])
        self.mlv_params = self.sample_posterior(self.mlv_idx)

        if p_list=='all':
            p_list = self.p_names
            p_idx = np.arange(len(self.p_names))

        else:
            if isinstance(p_list, six.string_types):
                p_idx = [self.p_names.index(p_list)]
                p_list = [p_list]

            elif isinstance(p_list[0], six.string_types):
                p_idx = [self.p_names.index(p) for p in p_list]

            elif isinstance(p_list[0], int):
                p_idx = p_list
                p_list = self.p_names

        # find basis indices
        self.gp_idx = OrderedDict()
        self.common_gp_idx = OrderedDict()
        self.gp_freqs = OrderedDict()
        self.shared_sigs = OrderedDict()
        self.gp_types = []
        Ntot = 0
        for idx, pname in enumerate(self.p_names):
            sc = self.pta._signalcollections[idx]
            if sc.psrname==pname:
                pass
            else:
                raise KeyError('Pulsar name from signal collection does '
                               'not match name from provided list.')

            phi_dim = sc.get_phi(params=self.mlv_params).shape[0]
            if pname not in p_list:
                pass
            else:
                self.gp_idx[pname] = OrderedDict()
                self.common_gp_idx[pname] = OrderedDict()
                self.gp_freqs[pname] = OrderedDict()
                ntot = 0
                # all_freqs = []
                all_bases = []
                basis_signals = [sig for sig in sc._signals
                                 if sig.signal_type
                                 in ['basis', 'common basis']]

                phi_sum = np.sum([sig.get_phi(self.mlv_params).shape[0]
                                  for sig in basis_signals])
                if phi_dim == phi_sum:
                    shared_bases=False
                else:
                    shared_bases=True

                self.shared_sigs[pname] = OrderedDict()

                for sig in basis_signals:
                    if sig.signal_type in ['basis', 'common basis']:
                        basis = sig.get_basis(params=self.mlv_params)
                        nb = basis.shape[1]
                        sig._construct_basis()
                        if isinstance(sig._labels, dict):
                            try:
                                freqs = list(sig._labels[''])[::2]
                            except TypeError:
                                freqs = sig._labels['']
                            except:
                                freqs = None
                        elif isinstance(sig._labels, (np.ndarray, list)):
                            try:
                                freqs = list(sig._labels)[::2]
                            except TypeError:
                                freqs = sig._labels

                        # This was because svd timing bases weren't named originally.
                        # Maybe no longer needed.
                        if sig.signal_id=='':
                            ky = 'timing_model'
                        else:
                            ky = sig.signal_id

                        if ky not in self.gp_types:
                            self.gp_types.append(ky)

                        self.gp_freqs[pname][ky] = freqs

                        if shared_bases:
                            # basis = basis.tolist()
                            check = [np.array_equal(basis, M) for M in all_bases]
                            if any(check):
                                b_idx = check.index(True)
                                # b_idx = all_bases.index(basis)
                                b_key = list(self.gp_idx[pname].keys())[b_idx]
                                self.shared_sigs[pname][ky] = b_key
                                self.gp_idx[pname][ky] = self.gp_idx[pname][b_key]
                                # TODO Fix the common signal idx collector!!!
                                if sig.signal_type == 'common basis':
                                    self.common_gp_idx[pname][ky] = np.arange(Ntot+ntot, nb+Ntot+ntot)

                            else:
                                self.gp_idx[pname][ky] = np.arange(ntot, nb+ntot)
                                if sig.signal_type == 'common basis':
                                    self.common_gp_idx[pname][ky] = np.arange(Ntot+ntot, nb+Ntot+ntot)

                                all_bases.append(basis)
                                ntot += nb
                        else:
                            self.gp_idx[pname][ky] = np.arange(ntot, nb+ntot)
                            if sig.signal_type == 'common basis':
                                self.common_gp_idx[pname][ky] = np.arange(Ntot+ntot, nb+Ntot+ntot)

                            ntot += nb

            Ntot += phi_dim
        self.p_list = p_list
        self.p_idx = p_idx

    def reconstruct_signal(self, gp_type='achrom_rn', det_signal=False,
                           mlv=False, idx=None, condition=False, eps=1e-16):
        """
        Parameters
        ----------
        gp_type : str, {'achrom_rn','gw','DM','none','all',timing parameters}
            Type of gaussian process signal to be reconstructed. In addition
            any GP in `psr.fitpars` or `Signal_Reconstruction.gp_types` may be
            called.

            ['achrom_rn','red_noise'] : Return the achromatic red noise.

            ['DM'] : Return the timing-model parts of dispersion model.

            [timing parameters] : Any of the timing parameters from the linear timing model. A list is available as `psr.fitpars`.

            ['timing'] : Return the entire timing model.

            ['gw'] : Gravitational wave signal. Works with common process in full PTAs.

            ['none'] : Returns no Gaussian processes. Meant to be used for returning only a deterministic signal.

            ['all'] : Returns all Gaussian processes.

        det_signal : bool
            Whether to include the deterministic signals in the reconstruction.

        mlv : bool
            Whether to use the maximum likelihood value for the reconstruction.

        idx : int, optional
            Index of the chain array to use.

        Returns
        -------
        wave : array
            A reconstruction of a single gaussian process signal realization.
        """

        if idx is None:
            idx = np.random.randint(self.burn, self.chain.shape[0])
        elif mlv:
            idx = self.mlv_idx

        # get parameter dictionary
        params = self.sample_posterior(idx)
        self.idx = idx
        wave = {}

        TNrs, TNTs, phiinvs, Ts = self._get_matrices(params=params)

        for (p_ct, psrname, d, TNT, phiinv, T) in zip(self.p_idx, self.p_list,
                                                      TNrs, TNTs, phiinvs, Ts):
            wave[psrname] = 0

            # Add in deterministic signal if desired.
            if det_signal:
                wave[psrname] += self.pta.get_delay(params=params)[p_ct]

            b = self._get_b(d, TNT, phiinv)

            if gp_type in self.common_gp_idx[psrname].keys():
                B = self._get_b_common(gp_type, TNrs, TNTs, params,
                                       condition=condition, eps=eps)

            # Red noise pieces
            psr = self.psrs[p_ct]
            if gp_type == 'none' and det_signal:
                pass
            elif gp_type == 'none' and not det_signal:
                raise ValueError('Must return a GP or deterministic signal.')
            elif gp_type == 'DM':
                tm_key = [ky for ky in self.gp_idx[psrname].keys()
                          if 'timing' in ky][0]
                dmind = np.array([ct for ct, p in enumerate(psr.fitpars)
                                  if 'DM' in p])
                idx = self.gp_idx[psrname][tm_key][dmind]
                wave[psrname] += np.dot(T[:, dmind], b[dmind])

            elif gp_type in ['achrom_rn', 'red_noise']:
                if 'red_noise' not in self.shared_sigs[psrname]:
                    if 'red_noise' in self.common_gp_idx[psrname].keys():
                        idx = self.gp_idx[psrname]['red_noise']
                        cidx = self.common_gp_idx[psrname]['red_noise']
                        wave[psrname] += np.dot(T[:, idx], B[cidx])
                    else:
                        idx = self.gp_idx[psrname]['red_noise']
                        wave[psrname] += np.dot(T[:, idx], b[idx])
                else:
                    rn_sig = self.pta.get_signal('{0}_red_noise'.format(psrname))
                    sc = self.pta._signalcollections[p_ct]
                    phi_rn = self._shared_basis_get_phi(sc, params, rn_sig)
                    phiinv_rn = phi_rn.inv()
                    idx = self.gp_idx[psrname]['red_noise']
                    b = self._get_b(d, TNT, phiinv_rn)
                    wave[psrname] += np.dot(T[:, idx], b[idx])
            elif gp_type == 'timing':
                tm_key = [ky for ky in self.gp_idx[psrname].keys()
                          if 'timing' in ky][0]
                idx = self.gp_idx[psrname][tm_key]
                wave[psrname] += np.dot(T[:, idx], b[idx])
            elif gp_type in psr.fitpars:
                if any([ky for ky in self.gp_idx[psrname].keys()
                        if 'svd' in ky]):
                    raise ValueError('The SVD decomposition does not allow '
                                     'reconstruction of the timing model '
                                     'gaussian process realizations '
                                     'individually.')

                tm_key = [ky for ky in self.gp_idx[psrname].keys()
                          if 'timing' in ky][0]
                dmind = np.array([ct for ct, p in enumerate(psr.fitpars)
                                  if gp_type in p])
                idx = self.gp_idx[psrname][tm_key][dmind]
                wave[psrname] += np.dot(T[:, idx], b[idx])
            elif gp_type == 'all':
                wave[psrname] += np.dot(T, b)
            elif gp_type == 'gw':
                if 'red_noise_gw' not in self.shared_sigs[psrname]:
                    # Parse whether it is a common signal.
                    if 'red_noise_gw' in self.common_gp_idx[psrname].keys():
                        idx = self.gp_idx[psrname]['gw']
                        cidx = self.common_gp_idx[psrname]['gw']
                        wave[psrname] += np.dot(T[:, idx], B[cidx])
                    else:  # If not common use pulsar Phi
                        idx = self.gp_idx[psrname]['gw']
                        wave[psrname] += np.dot(T[:, idx], b[idx])
                # Need to make our own phi when shared...
                else:
                    gw_sig = self.pta.get_signal('{0}_gw'.format(psrname))
                    # [sig for sig
                    #           in self.pta._signalcollections[p_ct]._signals
                    #           if sig.signal_id=='red_noise_gw'][0]
                    # phi_gw = gw_sig.get_phi(params=params)
                    sc = self.pta._signalcollections[p_ct]
                    phi_gw = self._shared_basis_get_phi(sc, params, gw_sig)
                    # phiinv_gw = gw_sig.get_phiinv(params=params)
                    phiinv_gw = phi_gw.inv()
                    idx = self.gp_idx[psrname]['red_noise_gw']
                    # b = self._get_b(d[idx], TNT[idx,idx], phiinv_gw)
                    # wave[psrname] += np.dot(T[:,idx], b)
                    b = self._get_b(d, TNT, phiinv_gw)
                    wave[psrname] += np.dot(T[:, idx], b[idx])
            elif gp_type in self.gp_types:
                try:
                    if gp_type in self.common_gp_idx[psrname].keys():
                        idx = self.gp_idx[psrname][gp_type]
                        cidx = self.common_gp_idx[psrname][gp_type]
                        wave[psrname] += np.dot(T[:, idx], B[cidx])
                    else:
                        idx = self.gp_idx[psrname][gp_type]
                        wave[psrname] += np.dot(T[:, idx], b[idx])
                except IndexError:
                    raise IndexError('Index is out of range. '
                                     'Maybe the basis for this is shared.')
            else:
                err_msg = '{0} is not an available gp_type. '.format(gp_type)
                err_msg += 'Available gp_types '
                err_msg += 'include {0}'.format(self.gp_types)
                raise ValueError(err_msg)

        return wave

    def _get_matrices(self, params):
        TNrs = self.pta.get_TNr(params)
        TNTs = self.pta.get_TNT(params)
        phiinvs = self.pta.get_phiinv(params, logdet=False)  # ,method='partition')
        Ts = self.pta.get_basis(params)

        # The following takes care of common, correlated signals.
        if TNTs[0].shape[0]<phiinvs[0].shape[0]:
            phiinvs = self._div_common_phiinv(TNTs, params)

        # Excise pulsars if p_list not 'all'.
        if len(self.p_list)<len(self.p_names):
            TNrs = self._subset_psrs(TNrs, self.p_idx)
            TNTs = self._subset_psrs(TNTs, self.p_idx)
            phiinvs = self._subset_psrs(phiinvs, self.p_idx)
            Ts = self._subset_psrs(Ts, self.p_idx)

        return TNrs, TNTs, phiinvs, Ts

    def _get_b(self, d, TNT, phiinv):
        Sigma = TNT + (np.diag(phiinv) if phiinv.ndim == 1 else phiinv)
        try:
            u, s, _ = sl.svd(Sigma)
            mn = np.dot(u, np.dot(u.T, d)/s)
            Li = u * np.sqrt(1/s)
        except np.linalg.LinAlgError:
            Q, R = sl.qr(Sigma)
            Sigi = sl.solve(R, Q.T)
            mn = np.dot(Sigi, d)
            u, s, _ = sl.svd(Sigi)
            Li = u * np.sqrt(1/s)

        return mn + np.dot(Li, np.random.randn(Li.shape[0]))

    def _get_b_common(self, gp_type, TNrs, TNTs, params,
                      condition=False, eps=1e-16):
        if condition:
            # conditioner = [eps*np.ones_like(TNT) for TNT in TNTs]
            # Sigma += sps.block_diag(conditioner,'csc')
            # Sigma += eps * sps.eye(phiinv.shape[0])
            phi = self.pta.get_phi(params)  # .astype(np.float128)
            # phisparse = sps.csc_matrix(phi)
            # conditioner = [eps*np.ones_like(TNT) for TNT in TNTs]
            # phisparse += sps.block_diag(conditioner,'csc')
            # phisparse += eps * sps.identity(phisparse.shape[0])
            # cf = cholesky(phisparse)
            # phiinv = cf.inv()

            # u,s,vT = np.linalg.svd(phi)
            # s_inv=np.diagflat(1/s)
            # phiinv = np.dot(np.dot(vT.T,s_inv),u.T)
            # print('NP Inv')
            # q,r = np.linalg.qr(phi,mode='complete')
            # phiinv = np.dot(np.linalg.inv(r),q.T)
            phiinv = np.linalg.inv(phi)
            phiinv = sps.csc_matrix(phiinv)
        else:
            phiinv = self.pta.get_phiinv(params, logdet=False)
            # phiinv = sps.csc_matrix(self.pta.get_phiinv(params, logdet=False))#,
            #   method='partition'))

        sps_Sigma = sps.block_diag(TNTs, 'csc') + sps.csc_matrix(phiinv)
        Sigma = sl.block_diag(*TNTs) + phiinv  # .astype(np.float128)
        TNr = np.concatenate(TNrs)

        ch = cholesky(sps_Sigma)
        # mn = ch(TNr)
        Li = sps.linalg.inv(ch.L()).todense()
        mn = np.linalg.solve(Sigma, TNr)
        # r = 1e30
        # regul = np.dot(Sigma.T,Sigma) + r*np.eye(Sigma.shape[0])
        # regul_inv = sl.inv(regul)
        # mn = np.dot(regul_inv,np.dot(Sigma.T,TNr))

        self.gp = np.random.randn(mn.shape[0])
        L = self.common_gp_idx[self.p_list[0]][gp_type].shape[0]
        common_gp = np.random.randn(L)

        for psrname in self.p_list:
            idxs = self.common_gp_idx[psrname][gp_type]
            self.gp[idxs] = common_gp

        B = mn + np.dot(Li, self.gp)

        try:
            B = np.array(B.tolist()[0])
        except:
            pass

        return B

    def sample_params(self, index):
        return {par: self.chain[index, ct] for ct, par
                in enumerate(self.pta.param_names)}

    def sample_posterior(self, samp_idx, array_params=['alphas', 'rho', 'nE']):
        param_names = self.pta.param_names
        if any([any([array_str in par for par in param_names])
                for array_str in array_params]):
            # Check for any array params and make samples appropriate shape.
            mask = np.ones(len(param_names), dtype=bool)

            array_par_dict = {}
            for array_str in array_params:
                # Go through each type of possible array sample.
                mask &= [array_str not in par for par in param_names]
                if any([array_str+'_0' in par for par in param_names]):
                    array_par_name = [par.replace('_0', '')
                                      for par in param_names
                                      if array_str+'_0'in par][0]
                    array_idxs = np.where([array_str in par
                                           for par in param_names])[0]
                    par_array = self.chain[samp_idx, array_idxs]
                    array_par_dict.update({array_par_name: par_array})

            par_idx = np.where(mask)[0]
            par_sample = {param_names[p_idx]: self.chain[samp_idx, p_idx]
                          for p_idx in par_idx}
            par_sample.update(array_par_dict)

            return par_sample

        else:
            return {par: self.chain[samp_idx, ct]
                    for ct, par in enumerate(self.pta.param_names)}

    def _subset_psrs(self, likelihood_list, p_idx):
        return list(np.array(likelihood_list)[p_idx])

    def _div_common_phiinv(self, TNTs, params):
        phivecs = [signalcollection.get_phi(params) for
                   signalcollection in self.pta._signalcollections]
        return [None if phivec is None else phivec.inv(logdet=False)
                for phivec in phivecs]

    def _make_sigma(self, TNTs, phiinv):
        return sl.block_diag(*TNTs) + phiinv

    def _shared_basis_get_phi(self, sc, params, primary_signal):
        """Rewrite of get_phi where overlapping bases are ignored."""
        phi = KernelMatrix(sc._Fmat.shape[1])

        idx_dict, _ = sc._combine_basis_columns(sc._signals)
        primary_idxs = idx_dict[primary_signal]
        # sig_types = []

        # Make new list of signals with no overlapping bases
        new_signals = []
        for sig in idx_dict.keys():
            if sig.signal_id==primary_signal.signal_id:
                new_signals.append(sig)
            elif not np.array_equal(primary_idxs, idx_dict[sig]):
                new_signals.append(sig)
            else:
                pass

        for signal in new_signals:
            if signal in sc._idx:
                phi = phi.add(signal.get_phi(params), sc._idx[signal])

        return phi

    def _shared_basis_get_phiinv(self, sc, params, primary_signal):
        """Rewrite of get_phiinv where overlapping bases are ignored."""
        return _shared_basis_get_phi.get_phi(sc, params, primary_signal).inv()  # noqa: F821

In [None]:
def plot_resids(psr):
    fig, ax = plt.subplots(figsize=(8,4))
    nu_min = np.min(psr.freqs)
    nu_max = np.max(psr.freqs)
    for marker, pta in zip(['s','o','*'], ['NANOGrav','EPTA','PPTA']):
        if pta in psr.flags['pta']:
            mask = psr.flags['pta'] == pta
            ax.errorbar(psr.toas[mask]/const.day, psr.residuals[mask]*1e6, yerr=psr.toaerrs[mask]*1e6,
                        fmt=f'{marker}k', ms=5, marker=None, mew=0, alpha=0.5, lw=1, zorder=0, label=pta)
            sc = ax.scatter(psr.toas[mask]/const.day, psr.residuals[mask]*1e6, s=5, marker=marker,
                            c=psr.freqs[mask], cmap='Spectral', vmin=nu_min, vmax=nu_max)
    cbar = plt.colorbar(sc)
    cbar.set_label(r"$\nu$ (MHz)")
    ax.set_xlabel("MJD")
    ax.set_ylabel(r"Residual ($\mu s$)")
    ax.grid(linewidth=0.3)
    ax.set_title(psr.name)
    ax.legend()
    fig.tight_layout()
    
def plot_NG_resids(psr):
    fig, ax = plt.subplots(figsize=(8,4))
    nu_min = np.min(psr.freqs)
    nu_max = np.max(psr.freqs)
    for marker, pta in zip(['s'], ['NANOGrav']):
        if pta in psr.flags['pta']:
            mask = (psr.flags['pta'] == pta)*~(psr.flags['group'] == 'kaspi')
            ax.errorbar(psr.toas[mask]/const.day, psr.residuals[mask]*1e6, yerr=psr.toaerrs[mask]*1e6,
                        fmt=f'{marker}k', ms=5, marker=None, mew=0, alpha=0.5, lw=1, zorder=0, label=pta)
            sc = ax.scatter(psr.toas[mask]/const.day, psr.residuals[mask]*1e6, s=5, marker=marker,
                            c=psr.freqs[mask], cmap='Spectral', vmin=nu_min, vmax=nu_max)
    cbar = plt.colorbar(sc)
    cbar.set_label(r"$\nu$ (MHz)")
    ax.set_xlabel("MJD")
    ax.set_ylabel(r"Residual ($\mu s$)")
    ax.grid(linewidth=0.3)
    ax.set_title(psr.name)
    ax.legend()
    fig.tight_layout()
    
def plot_avg_resids(psr, resids):
    fig, ax = plt.subplots(figsize=(8,4))
    nu_min = np.min(psr.freqs)
    nu_max = np.max(psr.freqs)
    for band in resids[0]:
        #mask = resids[1][band]
        toas = resids[0][band][:,0]
        res = resids[0][band][:,1]
        toaerrs = resids[0][band][:,2]
        ax.errorbar(toas/const.day, res*1e6, yerr=toaerrs*1e6,
                    fmt=f'sk', ms=5, marker=None, mew=0, alpha=0.5, lw=1, zorder=0)
        sc = ax.scatter(toas/const.day, res*1e6, s=5, marker='s', label=band)
    ax.set_xlabel("MJD")
    ax.set_ylabel(r"Residual ($\mu s$)")
    ax.grid(linewidth=0.3)
    ax.set_title(psr.name)
    ax.legend()
    fig.tight_layout()

In [None]:
np.count_nonzero(resids[1]['430'])

In [None]:
resids = epoch_ave_resid(psr, dt=10)

In [None]:
resids

In [None]:
len(resids[0]['430'])

In [None]:
resids = epoch_ave_resid(psr, dt=10)
resids_corr = epoch_ave_resid(psr,
                              correction=np.median(lf_TM,axis=0)+
                              np.median(lf_DM,axis=0)+np.median(lf_ecorr,axis=0), dt=10)
plot_NG_resids(psr)
plot_avg_resids(psr, resids)
plot_avg_resids(psr, resids_corr)

In [None]:
from la_forge.utils import epoch_ave_resid

In [None]:
resids = epoch_ave_resid(psr, dt=10)

In [None]:
plot_avg_resids(psr, resids)

In [None]:
resids[0]