In [1]:
import jax
import jax.numpy as jnp
import discovery as ds
import eryn
import glob

from discoverysamplers.nessai_interface import DiscoveryNessaiBridge

In [2]:
allpsrs = [ds.Pulsar.read_feather(psrfile) for psrfile in sorted(glob.glob('../data/NanoGrav_15yr/*-[JB]*.feather'))]

print(f"Loaded {len(allpsrs)} pulsars from feather files.")

psrs = allpsrs[:3]

print("Building likelihood for 3 pulsars...")

m = ds.ArrayLikelihood([ds.PulsarLikelihood([psr.residuals,
                                        ds.makenoise_measurement(psr, psr.noisedict),
                                        ds.makegp_ecorr(psr, psr.noisedict),
                                        ds.makegp_timing(psr, svd=True),
                                        ds.makegp_fourier(psr, ds.powerlaw, components=30, name='rednoise')])
                for psr in psrs])
print("...done.")

Loaded 67 pulsars from feather files.
Building likelihood for 3 pulsars...
...done.


In [3]:

m.logL.params

['B1855+09_rednoise_gamma',
 'B1855+09_rednoise_log10_A',
 'B1937+21_rednoise_gamma',
 'B1937+21_rednoise_log10_A',
 'B1953+29_rednoise_gamma',
 'B1953+29_rednoise_log10_A']

In [4]:
# Make Latex labels for parameters
latex_labels = {}
for param in m.logL.params:
    if 'log10_A' in param:
        pulsar_name = param.split('_rednoise_')[0]
        latex_labels[param] = r"$\log_{10} A_{" + pulsar_name + r"}$"
    elif 'gamma' in param:
        pulsar_name = param.split('_rednoise_')[0]
        latex_labels[param] = r"$\gamma_{" + pulsar_name + r"}$"
    else:
        latex_labels[param] = param
print(latex_labels)

{'B1855+09_rednoise_gamma': '$\\gamma_{B1855+09}$', 'B1855+09_rednoise_log10_A': '$\\log_{10} A_{B1855+09}$', 'B1937+21_rednoise_gamma': '$\\gamma_{B1937+21}$', 'B1937+21_rednoise_log10_A': '$\\log_{10} A_{B1937+21}$', 'B1953+29_rednoise_gamma': '$\\gamma_{B1953+29}$', 'B1953+29_rednoise_log10_A': '$\\log_{10} A_{B1953+29}$'}


In [5]:
p0 = ds.sample_uniform(m.logL.params, n=1)
print(p0)
m.logL(p0)

{'B1855+09_rednoise_gamma': 0.16528082861610516, 'B1855+09_rednoise_log10_A': -16.970918107546076, 'B1937+21_rednoise_gamma': 3.703796365577012, 'B1937+21_rednoise_log10_A': -16.096158910296648, 'B1953+29_rednoise_gamma': 1.364998024247709, 'B1953+29_rednoise_log10_A': -12.931125013285767}


Array(437726.75298959, dtype=float64)

In [6]:

test_priors = {
    'B1855+09_rednoise_gamma': {'dist': 'uniform', 'min': 0, 'max': 7},
    'B1855+09_rednoise_log10_A': {'dist': 'uniform', 'min': -20, 'max': -11},
    'B1937+21_rednoise_gamma': {'dist': 'uniform', 'min': 0, 'max': 7},
    'B1937+21_rednoise_log10_A': {'dist': 'fixed', 'value': -13.5},
    'B1953+29_rednoise_gamma': {'dist': 'fixed', 'value': 3},
    'B1953+29_rednoise_log10_A': {'dist': 'uniform', 'min': -20, 'max': -11},
}

In [7]:
bridge = DiscoveryNessaiBridge(m, priors=test_priors, latex_labels=latex_labels)

In [8]:
print(bridge.sampled_names)
print(bridge.fixed_params)


['B1855+09_rednoise_gamma', 'B1855+09_rednoise_log10_A', 'B1937+21_rednoise_gamma', 'B1953+29_rednoise_log10_A']
{'B1937+21_rednoise_log10_A': -13.5, 'B1953+29_rednoise_gamma': 3.0}


In [9]:
sampler = bridge.run_sampler()



In [10]:
# Get the chain samples from the dict
samples = bridge.return_sampled_samples()
print(samples["names"])
print(samples["chain"].shape) 

['B1855+09_rednoise_gamma', 'B1855+09_rednoise_log10_A', 'B1937+21_rednoise_gamma', 'B1953+29_rednoise_log10_A']
(3825, 4)


In [11]:
fig = bridge.plot_trace(plot_fixed=True)
fig.show()

  fig.show()
  fig.show()


In [12]:
fig2 = bridge.plot_corner()
fig2.show()

  fig2.show()
