## Building basic likelihoods

In [1]:
import sys
import os
import glob
import importlib

import tqdm

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as pp

%matplotlib inline

In [3]:
import jax
jax.config.update('jax_enable_x64', True)

import jax.random
import jax.numpy as jnp

In [4]:
import discovery as ds
import discovery.models.nanograv as ds_nanograv
import discovery.samplers.numpyro as ds_numpyro

In [5]:
import discovery.metamatrix as mm

Read nanograv pulsars

In [6]:
allpsrs = [ds.Pulsar.read_feather(psrfile) for psrfile in sorted(glob.glob('../data/*-[JB]*.feather'))]

In [7]:
psr = allpsrs[0]

## Graph tests

Simple graph, all constants, will be reduced to constant.

In [742]:
@mm.graph
def sumtwo(graph, a, b):
    a2 = 2 * a
    b3 = 3 * b
    result = a2 + b3

In [743]:
mm.print_graph(sumtwo(1, 2))

a: const = 1
b: const = 2
n0: node(a) = 2 * a
n1: node(b) = 3 * b
n2: node(n0, n1) = n0 + n1


In [744]:
mm.print_graph(mm.fold_constants(sumtwo(1, 2)))

a: const = 1
b: const = 2
n0: const = 2
n1: const = 6
n2: const = 8


In [745]:
mm.print_graph(mm.prune_graph(mm.fold_constants(sumtwo(1, 2))))

n2: const = 8


In [746]:
mm.func(sumtwo(1, 2))()

Array(8, dtype=int64, weak_type=True)

Now with one argument, will be reduced partially, and create a function of one argument.

In [747]:
mm.print_graph(sumtwo(None, 2))

a: arg
b: const = 2
n0: node(a) = 2 * a
n1: node(b) = 3 * b
n2: node(n0, n1) = n0 + n1


In [748]:
mm.print_graph(mm.fold_constants(sumtwo(None, 2)))

a: arg
b: const = 2
n0: node(a) = 2 * a
n1: const = 6
n2: node(n0, n1) = n0 + n1


In [749]:
mm.print_graph(mm.prune_graph(mm.fold_constants(sumtwo(None, 2))))

a: arg
n0: node(a) = 2 * a
n1: const = 6
n2: node(n0, n1) = n0 + n1


In [750]:
mm.func(sumtwo(None, 2))(5)

Array(16, dtype=int64, weak_type=True)

One of the inputs is a function. Everything folded.

In [751]:
def double(a, params={}):
    return 2.5 * a
double.args = ['a']

In [752]:
@mm.graph
def sumtwo(graph, a, b, f):
    result = f(a) + 3 * b

In [753]:
mm.print_graph(sumtwo(2, 3, double))

a: const = 2
b: const = 3
f: func = <function double at 0x40e20b740>
n0: node(f, a) = f(a)
n1: node(b) = 3 * b
n2: node(n0, n1) = n0 + n1


In [754]:
mm.print_graph(mm.fold_constants(sumtwo(2, 3, double)))

a: const = 2
b: const = 3
f: func = <function double at 0x40e20b740>
n0: const = 5.0
n1: const = 9
n2: const = 14.0


In [755]:
mm.print_graph(mm.prune_graph(mm.fold_constants(sumtwo(2, 3, double))))

n2: const = 14.0


One of the inputs is a function, cannot be evaluated immediately.

In [756]:
mm.print_graph(mm.prune_graph(mm.fold_constants(sumtwo(None, 3, double))))

a: arg
f: func = <function double at 0x40e20b740>
n0: node(f, a) = f(a)
n1: const = 9
n2: node(n0, n1) = n0 + n1


In [757]:
mm.func(sumtwo(None, 3, double))(4)

Array(19., dtype=float64, weak_type=True)

One of the inputs is a function that depends on params, won't be evaluated immediately.

In [758]:
def double(a, params):
    return params['m'] * a
double.args = ['a']
double.params = ['m']

In [759]:
mm.print_graph(sumtwo(2, 3, double))

a: const = 2
b: const = 3
f: func = <function double at 0x40e20bce0>
n0: node(f, a) = f(a)
n1: node(b) = 3 * b
n2: node(n0, n1) = n0 + n1


In [760]:
mm.print_graph(mm.fold_constants(sumtwo(2, 3, double)))

a: const = 2
b: const = 3
f: func = <function double at 0x40e20bce0>
n0: node(f, a) = f(a)
n1: const = 9
n2: node(n0, n1) = n0 + n1


In [761]:
mm.print_graph(mm.prune_graph(mm.fold_constants(sumtwo(2, 3, double))))

a: const = 2
f: func = <function double at 0x40e20bce0>
n0: node(f, a) = f(a)
n1: const = 9
n2: node(n0, n1) = n0 + n1


In [762]:
mm.func(sumtwo(2, 3, double)).params

['m']

In [763]:
mm.func(sumtwo(2, 3, double))(params={'m': 3})

Array(15, dtype=int64, weak_type=True)

## Discovery components

Simple noise solver

In [765]:
@mm.graph
def noisesolve(graph, y, N):
    result = N.solve(y)

All constants

In [766]:
yvec = psr.residuals
Nvec = ds.makenoise_measurement(psr, noisedict=psr.noisedict).N

In [767]:
Nsolve = noisesolve(yvec, Nvec)

In [768]:
mm.print_graph(Nsolve)

y: const = array(shape=(7758,), dtype=float64)
N: const = array(shape=(7758,), dtype=float64)
n0: node(N, y) = solve(N, y)


In [769]:
mm.print_graph(mm.prune_graph(mm.fold_constants(Nsolve)))

n0: const = array(shape=(7758,), dtype=float64), array(shape=(), dtype=float64)


First empty argument

In [770]:
Nsolve = noisesolve(None, Nvec)

In [771]:
mm.print_graph(mm.prune_graph(mm.fold_constants(Nsolve)))

y: arg
N: const = array(shape=(7758,), dtype=float64)
n0: node(N, y) = solve(N, y)


In [772]:
mm.func(Nsolve).params

[]

In [773]:
mm.func(Nsolve)(yvec)

(Array([-924684.05643357,  -15209.00619106, -256184.91074607, ...,
         457218.89618558, 3758690.70368335, 1397786.82265916],      dtype=float64),
 Array(-209930.62020633, dtype=float64))

Works also for a matrix

In [774]:
Fmat = ds.makegp_ecorr(psr, noisedict=psr.noisedict).F

In [775]:
mm.func(Nsolve)(Fmat)

(Array([[0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        ...,
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 2.05907192e+11],
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 1.94299584e+12],
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 5.57914432e+11]], dtype=float64),
 Array(-209930.62020633, dtype=float64))

What if N is a function?

In [776]:
Nfunc = ds.makenoise_measurement(psr).getN

In [777]:
Nfunc.params

['B1855+09_430_ASP_efac',
 'B1855+09_430_PUPPI_efac',
 'B1855+09_L-wide_ASP_efac',
 'B1855+09_L-wide_PUPPI_efac',
 'B1855+09_430_ASP_log10_t2equad',
 'B1855+09_430_PUPPI_log10_t2equad',
 'B1855+09_L-wide_ASP_log10_t2equad',
 'B1855+09_L-wide_PUPPI_log10_t2equad']

In [778]:
Nsolve = noisesolve(yvec, Nfunc)

In [779]:
mm.print_graph(mm.prune_graph(mm.fold_constants(Nsolve)))

y: const = array(shape=(7758,), dtype=float64)
N: func = <function makenoise_measurement.<locals>.getnoise at 0x3d82328e0>
n0: node(N, y) = solve(N, y)


In [780]:
mm.func(Nsolve).params

['B1855+09_430_ASP_efac',
 'B1855+09_430_ASP_log10_t2equad',
 'B1855+09_430_PUPPI_efac',
 'B1855+09_430_PUPPI_log10_t2equad',
 'B1855+09_L-wide_ASP_efac',
 'B1855+09_L-wide_ASP_log10_t2equad',
 'B1855+09_L-wide_PUPPI_efac',
 'B1855+09_L-wide_PUPPI_log10_t2equad']

In [781]:
mm.func(Nsolve)(params = ds.sample_uniform(mm.func(Nsolve).params))

(Array([ -76508.64853174,  -13787.42560781,  -38014.5492617 , ...,
         470538.49473593, 3713291.61543291, 1426449.03186325],      dtype=float64),
 Array(-207095.72133187, dtype=float64))

In [782]:
Nsolve = noisesolve(None, Nfunc)

In [783]:
mm.print_graph(mm.prune_graph(mm.fold_constants(Nsolve)))

y: arg
N: func = <function makenoise_measurement.<locals>.getnoise at 0x3d82328e0>
n0: node(N, y) = solve(N, y)


In [784]:
mm.func(Nsolve).params

['B1855+09_430_ASP_efac',
 'B1855+09_430_ASP_log10_t2equad',
 'B1855+09_430_PUPPI_efac',
 'B1855+09_430_PUPPI_log10_t2equad',
 'B1855+09_L-wide_ASP_efac',
 'B1855+09_L-wide_ASP_log10_t2equad',
 'B1855+09_L-wide_PUPPI_efac',
 'B1855+09_L-wide_PUPPI_log10_t2equad']

In [785]:
mm.func(Nsolve)(yvec, params = ds.sample_uniform(mm.func(Nsolve).params))

(Array([-1067985.5206509 ,   -16545.80953545,  -287263.80033425, ...,
          569309.84582354,  4678167.55704747,  1740315.15266447],      dtype=float64),
 Array(-210895.26040732, dtype=float64))

Prior inverse

In [786]:
@mm.graph
def noiseinv(graph, P):
    result = P.inv()

In [787]:
Pmat = ds.makegp_ecorr(psr, noisedict=psr.noisedict).Phi.N

In [788]:
Pinv = noiseinv(Pmat)

In [789]:
mm.print_graph(mm.fold_constants(Pinv))

P: const = array(shape=(360,), dtype=float64)
n0: const = array(shape=(360, 360), dtype=float64), array(shape=(), dtype=float64)


In [790]:
mm.func(Pinv).params

[]

In [791]:
mm.func(Pinv)()

(Array([[3.94679176e+13, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [0.00000000e+00, 3.94679176e+13, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [0.00000000e+00, 0.00000000e+00, 3.94679176e+13, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        ...,
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         1.92015299e+13, 0.00000000e+00, 0.00000000e+00],
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 1.92015299e+13, 0.00000000e+00],
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 1.92015299e+13]], dtype=float64),
 Array(-10482.98152208, dtype=float64))

In [792]:
Pfunc = ds.makegp_ecorr(psr).Phi.getN

In [793]:
Pfunc.params

['B1855+09_430_ASP_log10_ecorr',
 'B1855+09_430_PUPPI_log10_ecorr',
 'B1855+09_L-wide_ASP_log10_ecorr',
 'B1855+09_L-wide_PUPPI_log10_ecorr']

In [794]:
Pinv = noiseinv(Pfunc)

In [795]:
mm.print_graph(mm.prune_graph(mm.fold_constants(Pinv)))

P: func = <function makegp_ecorr.<locals>.getphi at 0x3e87351c0>
n0: node(P) = inv(P)


In [796]:
mm.func(Pinv).params

['B1855+09_430_ASP_log10_ecorr',
 'B1855+09_430_PUPPI_log10_ecorr',
 'B1855+09_L-wide_ASP_log10_ecorr',
 'B1855+09_L-wide_PUPPI_log10_ecorr']

In [797]:
mm.func(Pinv)(params=ds.sample_uniform(mm.func(Pinv).params))

(Array([[1.07556672e+15, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [0.00000000e+00, 1.07556672e+15, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [0.00000000e+00, 0.00000000e+00, 1.07556672e+15, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        ...,
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         4.09448888e+15, 0.00000000e+00, 0.00000000e+00],
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 4.09448888e+15, 0.00000000e+00],
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 4.09448888e+15]], dtype=float64),
 Array(-12113.61018479, dtype=float64))

Woodbury

In [824]:
@mm.graph
def noisesolve(graph, y, N):
    result = N.solve(y)

In [825]:
@mm.graph
def noiseinv(graph, P):
    result = P.inv()

In [826]:
@mm.graph
def woodbury(graph, y, Nsolve, F, Pinv):
    Nmy, lN = Nsolve(y)
    FtNmy = F.T @ Nmy

    NmF, _ = Nsolve(F)
    FtNmF = F.T @ NmF

    Pm, lP = Pinv
    cf, lS = (Pm + FtNmF).factor()

    logp = -0.5 * (y.T @ Nmy - FtNmy.T @ cf(FtNmy)) - 0.5 * (lN + lP + lS)

In [834]:
Nsolve = noisesolve(None, Nvec)
mm.print_graph(Nsolve)

y: arg
N: const = array(shape=(7758,), dtype=float64)
n0: node(N, y) = solve(N, y)


In [835]:
Pinv = noiseinv(Pfunc)
mm.print_graph(Pinv)

P: func = <function makegp_ecorr.<locals>.getphi at 0x3e87351c0>
n0: node(P) = inv(P)


In [836]:
onewood = woodbury(yvec, mm.func(Nsolve), Fmat, mm.func(Pinv))

In [837]:
mm.print_graph(mm.prune_graph(mm.fold_constants(onewood)))

Pinv: func = <PjitFunction of <function build_callable_from_graph.<locals>.f at 0x3c199e7a0>>
n2: const = array(shape=(), dtype=float64)
n4: const = array(shape=(360,), dtype=float64)
n9: const = array(shape=(360, 360), dtype=float64)
n10: node(Pinv) = Pinv[0]
n11: node(Pinv) = Pinv[1]
n12: node(n10, n9) = n10 + n9
n13: node(n12) = factor(n12)
n14: node(n13) = n13[0]
n15: node(n13) = n13[1]
n17: const = array(shape=(), dtype=float64)
n18: const = array(shape=(360,), dtype=float64)
n19: node(n14, n4) = n14(n4)
n20: node(n18, n19) = n18 @ n19
n21: node(n17, n20) = n17 - n20
n22: node(n21) = -0.5 * n21
n23: node(n2, n11) = n2 + n11
n24: node(n23, n15) = n23 + n15
n25: node(n24) = 0.5 * n24
n26: node(n22, n25) = n22 - n25


In [838]:
mm.func(onewood)(params = ds.sample_uniform(mm.func(onewood).params))

Array(98559.39376412, dtype=float64)

## Discovery likelihood

#### Measurement noise only

In [1426]:
# m = ds.PulsarLikelihood([psr.residuals,
#                          ds.makenoise_measurement_simple(psr)])

In [839]:
@mm.graph
def noisesolve(graph, y, N):
    result = N.solve(y)

In [840]:
@mm.graph
def noiseinv(graph, P):
    result = P.inv()

In [841]:
@mm.graph
def normal(g, y, Nsolve):
    Nmy, lN = Nsolve(y).split()
    logp = -0.5 * (y.T @ Nmy) - 0.5 * lN

In [842]:
class NoiseMatrix:
    def __init__(self, N):
        self.N = N

    @property
    def make_solve(self):
        return mm.func(noisesolve(None, self.N))    

    @property
    def make_inv(self):
        return mm.func(noiseinv(self.N))

In [843]:
ds.signals.matrix.NoiseMatrix1D_novar = NoiseMatrix
ds.signals.matrix.NoiseMatrix1D_var = NoiseMatrix

In [844]:
N = ds.makenoise_measurement(psr, noisedict=psr.noisedict)
graph = normal(psr.residuals, N.make_solve)

In [845]:
mm.print_graph(mm.prune_graph(mm.fold_constants(graph)))

n7: const = array(shape=(), dtype=float64)


In [846]:
mm.func(graph)()

Array(96686.81011374, dtype=float64)

In [847]:
N = ds.makenoise_measurement(psr)
graph = normal(psr.residuals, N.make_solve)

In [848]:
N.make_solve.params

['B1855+09_430_ASP_efac',
 'B1855+09_430_ASP_log10_t2equad',
 'B1855+09_430_PUPPI_efac',
 'B1855+09_430_PUPPI_log10_t2equad',
 'B1855+09_L-wide_ASP_efac',
 'B1855+09_L-wide_ASP_log10_t2equad',
 'B1855+09_L-wide_PUPPI_efac',
 'B1855+09_L-wide_PUPPI_log10_t2equad']

In [849]:
mm.print_graph(mm.prune_graph(mm.fold_constants(graph)))

y: const = array(shape=(7758,), dtype=float64)
Nsolve: func = <PjitFunction of <function build_callable_from_graph.<locals>.f at 0x43712e980>>
n0: node(Nsolve, y) = Nsolve(y)
n1: node(n0) = split(n0)[0]
n2: node(n0) = split(n0)[1]
n3: const = array(shape=(7758,), dtype=float64)
n4: node(n3, n1) = n3 @ n1
n5: node(n4) = -0.5 * n4
n6: node(n2) = 0.5 * n2
n7: node(n5, n6) = n5 - n6


In [850]:
mm.func(graph).params

['B1855+09_430_ASP_efac',
 'B1855+09_430_ASP_log10_t2equad',
 'B1855+09_430_PUPPI_efac',
 'B1855+09_430_PUPPI_log10_t2equad',
 'B1855+09_L-wide_ASP_efac',
 'B1855+09_L-wide_ASP_log10_t2equad',
 'B1855+09_L-wide_PUPPI_efac',
 'B1855+09_L-wide_PUPPI_log10_t2equad']

In [851]:
mm.func(graph)(params=ds.sample_uniform(mm.func(graph).params))

Array(90546.10232252, dtype=float64)

#### Add ECORR noise

In [852]:
# m = ds.PulsarLikelihood([psr.residuals,
#                          ds.makenoise_measurement(psr),
#                          ds.makegp_ecorr(psr)])

In [853]:
@mm.graph
def woodbury(g, y, Nsolve, F, Pinv):
    Nmy, lN = Nsolve(y)
    FtNmy = F.T @ Nmy

    NmF, _ = Nsolve(F)    
    FtNmF = F.T @ NmF

    Pm, lP = Pinv
    cf, lS = g.factor(Pm + FtNmF)

    logp = -0.5 * (y.T @ Nmy - FtNmy.T @ cf(FtNmy)) - 0.5 * (lN + lP + lS)

In [854]:
N = ds.makenoise_measurement(psr, noisedict=psr.noisedict)
graph = normal(psr.residuals, N.make_solve)

In [855]:
N = ds.makenoise_measurement(psr, noisedict=psr.noisedict)
egp = ds.makegp_ecorr(psr)
graph = woodbury(psr.residuals, N.make_solve, egp.F, egp.Phi.make_inv)

In [856]:
mm.print_graph(mm.prune_graph(mm.fold_constants(graph)))

Pinv: func = <PjitFunction of <function build_callable_from_graph.<locals>.f at 0x3ea6814e0>>
n2: const = array(shape=(), dtype=float64)
n4: const = array(shape=(360,), dtype=float64)
n9: const = array(shape=(360, 360), dtype=float64)
n10: node(Pinv) = Pinv[0]
n11: node(Pinv) = Pinv[1]
n12: node(n10, n9) = n10 + n9
n13: node(n12) = <function cholesky_factor at 0x393043ce0>
n14: node(n13) = n13[0]
n15: node(n13) = n13[1]
n17: const = array(shape=(), dtype=float64)
n18: const = array(shape=(360,), dtype=float64)
n19: node(n14, n4) = n14(n4)
n20: node(n18, n19) = n18 @ n19
n21: node(n17, n20) = n17 - n20
n22: node(n21) = -0.5 * n21
n23: node(n2, n11) = n2 + n11
n24: node(n23, n15) = n23 + n15
n25: node(n24) = 0.5 * n24
n26: node(n22, n25) = n22 - n25


In [857]:
mm.func(graph).params

['B1855+09_430_ASP_log10_ecorr',
 'B1855+09_430_PUPPI_log10_ecorr',
 'B1855+09_L-wide_ASP_log10_ecorr',
 'B1855+09_L-wide_PUPPI_log10_ecorr']

In [858]:
mm.func(graph)(params=ds.sample_uniform(mm.func(graph).params))

Array(96796.92856542, dtype=float64)

#### Constant GP + variable GP

In [859]:
# m = ds.PulsarLikelihood([psr.residuals,
#                          ds.makenoise_measurement(psr, psr.noisedict),
#                          ds.makegp_timing(psr, svd=True),
#                          ds.makegp_fourier(psr, ds.powerlaw, components=30, name='rednoise')])

In [860]:
importlib.reload(mm)

<module 'discovery.metamatrix' from '/Users/vallis/Documents/discovery/src/discovery/metamatrix.py'>

In [861]:
@mm.graph
def woodburysolve(g, y, Nsolve, F, Pinv):
    Nmy, lN = Nsolve(y)
    FtNmy = F.T @ Nmy

    NmF, _ = Nsolve(F)
    FtNmF = F.T @ NmF

    Pm, lP = Pinv
    cf, lS = g.factor(Pm + FtNmF)

    solve = Nmy - NmF @ cf(FtNmy)
    ld = lN + lP + lS

    result = g.pair(solve, ld)

In [862]:
N = ds.makenoise_measurement(psr, noisedict=psr.noisedict)
tgp = ds.makegp_timing(psr, svd=True)
rgp = ds.makegp_fourier(psr, ds.powerlaw, components=30, name='rednoise')

In [863]:
graph1 = woodburysolve(None, N.make_solve, tgp.F, tgp.Phi.make_inv)

In [864]:
mm.print_graph(mm.prune_graph(mm.fold_constants(graph1)))

y: arg
Nsolve: func = <PjitFunction of <function build_callable_from_graph.<locals>.f at 0x4370cd620>>
n0: node(Nsolve, y) = Nsolve(y)
n1: node(n0) = n0[0]
n2: node(n0) = n0[1]
n3: const = array(shape=(166, 7758), dtype=float64)
n4: node(n3, n1) = n3 @ n1
n6: const = array(shape=(7758, 166), dtype=float64)
n11: const = array(shape=(), dtype=float64)
n14: const = <function cholesky_factor.<locals>.solver at 0x447864d60>
n15: const = array(shape=(), dtype=float64)
n16: node(n14, n4) = n14(n4)
n17: node(n6, n16) = n6 @ n16
n18: node(n1, n17) = n1 - n17
n19: node(n2, n11) = n2 + n11
n20: node(n19, n15) = n19 + n15
n21: node(n18, n20) = <function GraphBuilder.pair.<locals>.<lambda> at 0x43342c900>


In [865]:
graph2 = woodbury(yvec, mm.func(graph1), rgp.F, rgp.Phi.make_inv)

In [866]:
mm.print_graph(mm.prune_graph(mm.fold_constants(graph2)))

Pinv: func = <PjitFunction of <function build_callable_from_graph.<locals>.f at 0x395d5d080>>
n2: const = array(shape=(), dtype=float64)
n4: const = array(shape=(60,), dtype=float64)
n9: const = array(shape=(60, 60), dtype=float64)
n10: node(Pinv) = Pinv[0]
n11: node(Pinv) = Pinv[1]
n12: node(n10, n9) = n10 + n9
n13: node(n12) = <function cholesky_factor at 0x38e0940e0>
n14: node(n13) = n13[0]
n15: node(n13) = n13[1]
n17: const = array(shape=(), dtype=float64)
n18: const = array(shape=(60,), dtype=float64)
n19: node(n14, n4) = n14(n4)
n20: node(n18, n19) = n18 @ n19
n21: node(n17, n20) = n17 - n20
n22: node(n21) = -0.5 * n21
n23: node(n2, n11) = n2 + n11
n24: node(n23, n15) = n23 + n15
n25: node(n24) = 0.5 * n24
n26: node(n22, n25) = n22 - n25


In [867]:
mm.func(graph2).params

['B1855+09_rednoise_gamma', 'B1855+09_rednoise_log10_A']

In [868]:
mm.func(graph2)(params=ds.sample_uniform(mm.func(graph2).params))

Array(89903.63134199, dtype=float64)

### Full single likelihood

In [869]:
# m = ds.PulsarLikelihood([psr.residuals,
#                          ds.makenoise_measurement(psr, psr.noisedict),
#                          ds.makegp_ecorr(psr, psr.noisedict),
#                          ds.makegp_timing(psr, svd=True),
#                          ds.makegp_fourier(psr, ds.powerlaw, components=30, name='rednoise')])

In [871]:
@mm.graph
def concat(g, a, b):
    result = g.node(lambda x, y: jnp.hstack([x, y]), [a, b])

In [872]:
import functools

class CompoundGP:
    def __init__(self, gplist):
        self.gplist = gplist

    @property
    def F(self):
        return functools.reduce(lambda x, y: mm.func(concat(x, y)), [gp.F for gp in self.gplist])
    
    @property
    def Phi(self):
        # won't work for 2D priors
        N = functools.reduce(lambda x, y: mm.func(concat(x, y)), [gp.Phi.N for gp in self.gplist]) 
        return NoiseMatrix(N)

In [873]:
@mm.graph
def delay(g, y, d):
    result = y - d

In [874]:
def CompoundDelay(residuals, delays):
    return functools.reduce(lambda x, y: mm.func(delay(x, y)), [residuals, *delays])

#### Clean example

In [875]:
N = ds.makenoise_measurement(psr, noisedict=psr.noisedict)

egp = ds.makegp_ecorr(psr, noisedict=psr.noisedict)
tgp = ds.makegp_timing(psr, svd=True)
cgp = CompoundGP([egp, tgp])

graph1 = woodburysolve(None, N.make_solve, cgp.F, cgp.Phi.make_inv)

cwcommon = ['cw_sindec', 'cw_cosinc', 'cw_log10_f0', 'cw_log10_h0', 'cw_phi_earth', 'cw_psi', 'cw_ra']
yd = CompoundDelay(yvec, [ds.makedelay(psr, ds.makedelay_binary(pulsarterm=True), common=cwcommon, name='cw')])

rgp = ds.makegp_fourier(psr, ds.powerlaw, components=30, name='rednoise')

graph2 = woodbury(yd, mm.func(graph1), rgp.F, rgp.Phi.make_inv)

logp = mm.func(graph2)

In [876]:
mm.print_graph(graph2, simplify=True)

y: func = <PjitFunction of <function build_callable_from_graph.<locals>.f at 0x4370cc400>>
Nsolve: func = <PjitFunction of <function build_callable_from_graph.<locals>.f at 0x447ee3ba0>>
Pinv: func = <PjitFunction of <function build_callable_from_graph.<locals>.f at 0x447ee3920>>
n0: node(Nsolve, y) = Nsolve(y)
n1: node(n0) = n0[0]
n2: node(n0) = n0[1]
n3: const = array(shape=(60, 7758), dtype=float64)
n4: node(n3, n1) = n3 @ n1
n9: const = array(shape=(60, 60), dtype=float64)
n10: node(Pinv) = Pinv[0]
n11: node(Pinv) = Pinv[1]
n12: node(n10, n9) = n10 + n9
n13: node(n12) = <function cholesky_factor at 0x383a177e0>
n14: node(n13) = n13[0]
n15: node(n13) = n13[1]
n16: node(y) = <function Sym.T.<locals>.<lambda> at 0x447ef8e00>
n17: node(n16, n1) = n16 @ n1
n18: node(n4) = <function Sym.T.<locals>.<lambda> at 0x447ef8f40>
n19: node(n14, n4) = n14(n4)
n20: node(n18, n19) = n18 @ n19
n21: node(n17, n20) = n17 - n20
n22: node(n21) = -0.5 * n21
n23: node(n2, n11) = n2 + n11
n24: node(n23, n1

In [877]:
p0 = ds.sample_uniform(logp.params)

In [878]:
p0

{'B1855+09_cw_phi_psr': 1.411928033421281,
 'B1855+09_rednoise_gamma': 0.3375543574402363,
 'B1855+09_rednoise_log10_A': -15.03119920193601,
 'cw_cosinc': 0.5033299616673554,
 'cw_log10_f0': -8.499182894688783,
 'cw_log10_h0': -14.928982381721593,
 'cw_phi_earth': 6.023302048435835,
 'cw_psi': 0.14147652042076786,
 'cw_ra': 0.9348743543830029,
 'cw_sindec': -0.4938450071370477}

In [879]:
logp(params=ds.sample_uniform(p0))

Array(76806.82934148, dtype=float64)

## Standard tests

In [1057]:
import discovery.metamatrix as mm
import discovery.metamath as mh

(For debugging)

In [1058]:
importlib.reload(mm)
importlib.reload(mh)
importlib.reload(ds.likelihood)
importlib.reload(ds)

<module 'discovery' from '/Users/vallis/Documents/discovery/src/discovery/__init__.py'>

Monkey patching

In [1127]:
ds.matrix.NoiseMatrix1D_novar = mh.NoiseMatrix
ds.matrix.NoiseMatrix1D_var = mh.NoiseMatrix
ds.matrix.NoiseMatrix2D_var = mh.NoiseMatrix

ds.matrix.WoodburyKernel = mh.WoodburyKernel
ds.matrix.CompoundGP = mh.CompoundGP
ds.matrix.CompoundDelay = CompoundDelay

ds.matrix.VectorNoiseMatrix1D_var = mh.NoiseMatrix
ds.matrix.VectorWoodburyKernel_varP = mh.VectorWoodburyKernel
ds.matrix.VectorCompoundGP = mh.CompoundGP

### Single pulsar likelihood

#### Measurement noise only, no backends

In [1142]:
m = ds.PulsarLikelihood([psr.residuals,
                         ds.makenoise_measurement_simple(psr)])

What are the active parameters?

In [1143]:
m.logL.params

['B1855+09_efac', 'B1855+09_log10_t2equad']

Sample random values from their priors

In [1144]:
p0 = ds.sample_uniform(m.logL.params); p0

{'B1855+09_efac': 0.9377216319852704,
 'B1855+09_log10_t2equad': -6.307874756634455}

Evaluate the likelihood

In [1147]:
m.logL(p0)

Array(88836.8723052, dtype=float64)

Try compiled version, grad

In [1064]:
jax.jit(m.logL)(p0)

Array(91977.44767544, dtype=float64)

In [1065]:
jax.grad(m.logL)(p0)

{'B1855+09_efac': Array(17719.96996108, dtype=float64, weak_type=True),
 'B1855+09_log10_t2equad': Array(73.97882498, dtype=float64, weak_type=True)}

#### Measurement noise only, nanograv backends, free parameters

In [1148]:
m = ds.PulsarLikelihood([psr.residuals,
                         ds.makenoise_measurement(psr)])

In [1149]:
m.logL.params

['B1855+09_430_ASP_efac',
 'B1855+09_430_ASP_log10_t2equad',
 'B1855+09_430_PUPPI_efac',
 'B1855+09_430_PUPPI_log10_t2equad',
 'B1855+09_L-wide_ASP_efac',
 'B1855+09_L-wide_ASP_log10_t2equad',
 'B1855+09_L-wide_PUPPI_efac',
 'B1855+09_L-wide_PUPPI_log10_t2equad']

In [None]:
p0 = ds.sample_uniform(m.logL.params)

In [1151]:
m.logL(p0)

Array(98829.22637937, dtype=float64)

#### Measurement noise only, nanograv backends, parameters from noisedict

In [1069]:
psr.noisedict

{'B1855+09_430_ASP_efac': 1.115935306813982,
 'B1855+09_430_ASP_log10_t2equad': -7.564164330699591,
 'B1855+09_430_PUPPI_efac': 1.000049037085653,
 'B1855+09_430_PUPPI_log10_t2equad': -6.572540211467256,
 'B1855+09_L-wide_ASP_efac': 1.043114017270374,
 'B1855+09_L-wide_ASP_log10_t2equad': -6.517929916655293,
 'B1855+09_L-wide_PUPPI_efac': 1.1118432332882,
 'B1855+09_L-wide_PUPPI_log10_t2equad': -7.755603780476984,
 'B1855+09_430_ASP_log10_ecorr': -6.798122106550257,
 'B1855+09_430_PUPPI_log10_ecorr': -5.6989064141929715,
 'B1855+09_L-wide_ASP_log10_ecorr': -6.120457109433745,
 'B1855+09_L-wide_PUPPI_log10_ecorr': -6.641667916624413,
 'B1855+09_red_noise_log10_A': -13.940953818371378,
 'B1855+09_red_noise_gamma': -3.68432133461766}

In [1070]:
m = ds.PulsarLikelihood([psr.residuals,
                         ds.makenoise_measurement(psr, psr.noisedict)])

In [1071]:
m.logL.params

[]

In [1074]:
m.logL({})

Array(96686.81011374, dtype=float64)

In [1075]:
jax.jit(m.logL)({}), jax.jit(jax.grad(m.logL))({})

(Array(96686.81011374, dtype=float64), {})

#### Add ECORR noise (GP), free params

In [1076]:
m = ds.PulsarLikelihood([psr.residuals,
                         ds.makenoise_measurement(psr),
                         ds.makegp_ecorr(psr)])

In [1077]:
m.logL.params

['B1855+09_430_ASP_efac',
 'B1855+09_430_ASP_log10_ecorr',
 'B1855+09_430_ASP_log10_t2equad',
 'B1855+09_430_PUPPI_efac',
 'B1855+09_430_PUPPI_log10_ecorr',
 'B1855+09_430_PUPPI_log10_t2equad',
 'B1855+09_L-wide_ASP_efac',
 'B1855+09_L-wide_ASP_log10_ecorr',
 'B1855+09_L-wide_ASP_log10_t2equad',
 'B1855+09_L-wide_PUPPI_efac',
 'B1855+09_L-wide_PUPPI_log10_ecorr',
 'B1855+09_L-wide_PUPPI_log10_t2equad']

In [1078]:
p0 = ds.sample_uniform(m.logL.params)

In [1079]:
m.logL(p0)

Array(97773.03750861, dtype=float64)

In [1080]:
jax.jit(m.logL)(p0), jax.jit(jax.grad(m.logL))(p0)

(Array(97773.03750861, dtype=float64),
 {'B1855+09_430_ASP_efac': Array(-109.80551039, dtype=float64, weak_type=True),
  'B1855+09_430_ASP_log10_ecorr': Array(-123.38423455, dtype=float64, weak_type=True),
  'B1855+09_430_ASP_log10_t2equad': Array(-169.77146571, dtype=float64, weak_type=True),
  'B1855+09_430_PUPPI_efac': Array(482.14891213, dtype=float64, weak_type=True),
  'B1855+09_430_PUPPI_log10_ecorr': Array(133.30281563, dtype=float64, weak_type=True),
  'B1855+09_430_PUPPI_log10_t2equad': Array(0.13309954, dtype=float64, weak_type=True),
  'B1855+09_L-wide_ASP_efac': Array(895.55098936, dtype=float64, weak_type=True),
  'B1855+09_L-wide_ASP_log10_ecorr': Array(38.23896488, dtype=float64, weak_type=True),
  'B1855+09_L-wide_ASP_log10_t2equad': Array(1252.13799406, dtype=float64, weak_type=True),
  'B1855+09_L-wide_PUPPI_efac': Array(4571.57930731, dtype=float64, weak_type=True),
  'B1855+09_L-wide_PUPPI_log10_ecorr': Array(157.61343895, dtype=float64, weak_type=True),
  'B1855+0

#### Add ECORR noise (GP), noisedict params

In [1081]:
m = ds.PulsarLikelihood([psr.residuals,
                         ds.makenoise_measurement(psr, psr.noisedict),
                         ds.makegp_ecorr(psr, psr.noisedict)])

In [1082]:
m.logL.params

[]

In [1084]:
m.logL({})

Array(100267.11920383, dtype=float64)

In [1085]:
jax.jit(m.logL)({}), jax.jit(jax.grad(m.logL))({})

(Array(100267.11920383, dtype=float64), {})

#### Add timing model

In [1086]:
m = ds.PulsarLikelihood([psr.residuals,
                         ds.makenoise_measurement(psr, psr.noisedict),
                         ds.makegp_ecorr(psr, psr.noisedict),
                         ds.makegp_timing(psr, svd=True)])

In [1087]:
m.logL.params

[]

In [1088]:
m.logL({})

Array(90998.49655392, dtype=float64)

In [1089]:
jax.jit(m.logL)({}), jax.jit(jax.grad(m.logL))({})

(Array(90998.49655392, dtype=float64), {})

#### Add red noise (powerlaw)

In [1158]:
m = ds.PulsarLikelihood([psr.residuals,
                         ds.makenoise_measurement(psr, psr.noisedict),
                         ds.makegp_ecorr(psr, psr.noisedict),
                         ds.makegp_timing(psr, svd=True),
                         ds.makegp_fourier(psr, ds.powerlaw, components=30, name='rednoise')])

In [1159]:
m.logL.params

['B1855+09_rednoise_gamma', 'B1855+09_rednoise_log10_A']

In [1154]:
p0 = ds.sample_uniform(m.logL.params)

In [1161]:
m.logL(p0)

Array(91001.08284231, dtype=float64)

In [1094]:
jax.jit(m.logL)(p0), jax.jit(jax.grad(m.logL))(p0)

(Array(90998.7269468, dtype=float64),
 {'B1855+09_rednoise_gamma': Array(0.50503406, dtype=float64, weak_type=True),
  'B1855+09_rednoise_log10_A': Array(1.05634722, dtype=float64, weak_type=True)})

#### Add red noise (powerlaw, fixed gamma)

In [1095]:
m = ds.PulsarLikelihood([psr.residuals,
                         ds.makenoise_measurement(psr, psr.noisedict),
                         ds.makegp_ecorr(psr, psr.noisedict),
                         ds.makegp_timing(psr, svd=True),
                         ds.makegp_fourier(psr, ds.partial(ds.powerlaw, gamma=4.33), components=30, name='rednoise')])

In [1096]:
m.logL.params

['B1855+09_rednoise_log10_A']

In [1097]:
p0 = ds.sample_uniform(m.logL.params)

In [1098]:
m.logL(p0)

Array(90998.49655564, dtype=float64)

In [1099]:
jax.jit(m.logL)(p0), jax.jit(jax.grad(m.logL))(p0)

(Array(90998.49655564, dtype=float64),
 {'B1855+09_rednoise_log10_A': Array(7.8946549e-06, dtype=float64, weak_type=True)})

#### Add red noise (free spectrum)

In [1162]:
m = ds.PulsarLikelihood([psr.residuals,
                         ds.makenoise_measurement(psr, psr.noisedict),
                         ds.makegp_ecorr(psr, psr.noisedict),
                         ds.makegp_timing(psr, svd=True),
                         ds.makegp_fourier(psr, ds.freespectrum, components=30, name='rednoise')])

In [1163]:
m.logL.params

['B1855+09_rednoise_log10_rho(30)']

In [1102]:
p0 = {'B1855+09_rednoise_log10_rho(30)': 1e-6 * np.random.randn(30)}

In [1166]:
m.logL(p0)

Array(90205.3381496, dtype=float64)

TO DO - Can this be correct?

In [1104]:
jax.jit(m.logL)(p0), jax.jit(jax.grad(m.logL))(p0)

(Array(90205.41583763, dtype=float64),
 {'B1855+09_rednoise_log10_rho(30)': Array([-4.60516845, -4.60517007, -4.60517016, -4.60517018, -4.60517018,
         -4.60517018, -4.60517018, -4.60517018, -4.60517018, -4.60517018,
         -4.60517018, -4.60517018, -4.60517019, -4.60517018, -4.60517014,
         -4.60517013, -4.60517019, -4.60517019, -4.60517019, -4.60517019,
         -4.60517019, -4.60517019, -4.60517019, -4.60517019, -4.60517019,
         -4.60517019, -4.60517019, -4.60517019, -4.60517019, -4.60517019],      dtype=float64)})

### Multiple pulsars

In [1167]:
psrs = allpsrs[:3]

#### Combined likelihood

In [1168]:
m = ds.ArrayLikelihood([ds.PulsarLikelihood([psr.residuals,
                                             ds.makenoise_measurement(psr, psr.noisedict),
                                             ds.makegp_ecorr(psr, psr.noisedict),
                                             ds.makegp_timing(psr, svd=True),
                                             ds.makegp_fourier(psr, ds.powerlaw, components=30, name='rednoise')])
                        for psr in psrs])

In [1169]:
m.logL.params

['B1855+09_rednoise_gamma',
 'B1855+09_rednoise_log10_A',
 'B1937+21_rednoise_gamma',
 'B1937+21_rednoise_log10_A',
 'B1953+29_rednoise_gamma',
 'B1953+29_rednoise_log10_A']

In [1170]:
p0 = ds.sample_uniform(m.logL.params)

In [1174]:
m.logL(p0)

Array(428610.75808808, dtype=float64)

In [1111]:
jax.jit(m.logL)(p0), jax.jit(jax.grad(m.logL))(p0)

(Array(473556.98158261, dtype=float64),
 {'B1855+09_rednoise_gamma': Array(5.71375813e-06, dtype=float64, weak_type=True),
  'B1855+09_rednoise_log10_A': Array(1.01289668e-05, dtype=float64, weak_type=True),
  'B1937+21_rednoise_gamma': Array(-8.31327206, dtype=float64, weak_type=True),
  'B1937+21_rednoise_log10_A': Array(-127.07457963, dtype=float64, weak_type=True),
  'B1953+29_rednoise_gamma': Array(1.71711591e-08, dtype=float64, weak_type=True),
  'B1953+29_rednoise_log10_A': Array(5.22661473e-08, dtype=float64, weak_type=True)})

#### Add common noise

Indicating parameters under common shares them among pulsars

In [1175]:
T = ds.getspan(psrs)

In [1176]:
m = ds.ArrayLikelihood([ds.PulsarLikelihood([psr.residuals,
                                             ds.makenoise_measurement(psr, psr.noisedict),
                                             ds.makegp_ecorr(psr, psr.noisedict),
                                             ds.makegp_timing(psr, svd=True),
                                             ds.makegp_fourier(psr, ds.powerlaw, components=30, T=T, name='rednoise'),
                                             ds.makegp_fourier(psr, ds.powerlaw, components=14, T=T, name='crn',
                                                               common=['crn_log10_A', 'crn_gamma'])])
                        for psr in psrs])

In [1177]:
m.logL.params

['B1855+09_rednoise_gamma',
 'B1855+09_rednoise_log10_A',
 'B1937+21_rednoise_gamma',
 'B1937+21_rednoise_log10_A',
 'B1953+29_rednoise_gamma',
 'B1953+29_rednoise_log10_A',
 'crn_gamma',
 'crn_log10_A']

In [1180]:
p0 = ds.sample_uniform(m.logL.params); p0

{'B1855+09_rednoise_gamma': 0.22692837299032842,
 'B1855+09_rednoise_log10_A': -14.855929802663269,
 'B1937+21_rednoise_gamma': 1.0599042474713591,
 'B1937+21_rednoise_log10_A': -15.116911973999551,
 'B1953+29_rednoise_gamma': 3.5757301091084486,
 'B1953+29_rednoise_log10_A': -14.37084627399558,
 'crn_gamma': 4.432798514507736,
 'crn_log10_A': -13.879467862230983}

In [1181]:
m.logL(p0)

Array(473878.04941868, dtype=float64)

In [1117]:
jax.jit(m.logL)(p0), jax.jit(jax.grad(m.logL))(p0)

(Array(435310.1023092, dtype=float64),
 {'B1855+09_rednoise_gamma': Array(27.69874745, dtype=float64, weak_type=True),
  'B1855+09_rednoise_log10_A': Array(59.75583282, dtype=float64, weak_type=True),
  'B1937+21_rednoise_gamma': Array(13726.66873139, dtype=float64, weak_type=True),
  'B1937+21_rednoise_log10_A': Array(25107.24685034, dtype=float64, weak_type=True),
  'B1953+29_rednoise_gamma': Array(5.46358814e-13, dtype=float64, weak_type=True),
  'B1953+29_rednoise_log10_A': Array(3.11287027e-12, dtype=float64, weak_type=True),
  'crn_gamma': Array(54.44095989, dtype=float64, weak_type=True),
  'crn_log10_A': Array(126.71157003, dtype=float64, weak_type=True)})

#### Parallelize red components

Coordinated timespan is required

In [1186]:
m = ds.ArrayLikelihood([ds.PulsarLikelihood([psr.residuals,
                                             ds.makenoise_measurement(psr, psr.noisedict),
                                             ds.makegp_ecorr(psr, psr.noisedict),
                                             ds.makegp_timing(psr, svd=True)]) for psr in psrs],
                       commongp = [ds.makecommongp_fourier(psrs, ds.powerlaw, components=30, T=T, name='rednoise'),
                                   ds.makecommongp_fourier(psrs, ds.powerlaw, components=14, T=T, name='crn',
                                                           common=['crn_log10_A', 'crn_gamma'])])

In [1187]:
m.logL.params

['B1855+09_rednoise_gamma',
 'B1855+09_rednoise_log10_A',
 'B1937+21_rednoise_gamma',
 'B1937+21_rednoise_log10_A',
 'B1953+29_rednoise_gamma',
 'B1953+29_rednoise_log10_A',
 'crn_gamma',
 'crn_log10_A']

In [1120]:
p0 = ds.sample_uniform(m.logL.params)

In [1189]:
m.logL(p0)

Array(473520.64257213, dtype=float64)

In [1122]:
jax.jit(m.logL)(p0), jax.jit(jax.grad(m.logL))(p0)

(Array(471632.66710763, dtype=float64),
 {'B1855+09_rednoise_gamma': Array(-9.94695056, dtype=float64, weak_type=True),
  'B1855+09_rednoise_log10_A': Array(-57.71674371, dtype=float64, weak_type=True),
  'B1937+21_rednoise_gamma': Array(0.00114001, dtype=float64, weak_type=True),
  'B1937+21_rednoise_log10_A': Array(0.00197886, dtype=float64, weak_type=True),
  'B1953+29_rednoise_gamma': Array(0.00083211, dtype=float64, weak_type=True),
  'B1953+29_rednoise_log10_A': Array(0.00173446, dtype=float64, weak_type=True),
  'crn_gamma': Array(4299.35184227, dtype=float64, weak_type=True),
  'crn_log10_A': Array(8259.82410556, dtype=float64, weak_type=True)})

#### Reuse Fourier vectors

`ds.makepowerlaw_crn` yields the sum of two powerlaws, with possibly different number of components.

In [1190]:
m = ds.ArrayLikelihood([ds.PulsarLikelihood([psr.residuals,
                                             ds.makenoise_measurement(psr, psr.noisedict),
                                             ds.makegp_ecorr(psr, psr.noisedict),
                                             ds.makegp_timing(psr, svd=True)]) for psr in psrs],
                       commongp = ds.makecommongp_fourier(psrs, ds.makepowerlaw_crn(components=14), components=30, T=T, name='rednoise',
                                                          common=['crn_log10_A', 'crn_gamma']))

In [1129]:
p0 = ds.sample_uniform(m.logL.params)

In [1191]:
m.logL(p0)

Array(473520.64257212, dtype=float64)

In [1131]:
jax.jit(m.logL)(p0), jax.jit(jax.grad(m.logL))(p0)

(Array(473786.13322098, dtype=float64),
 {'B1855+09_rednoise_gamma': Array(-8.55084333, dtype=float64, weak_type=True),
  'B1855+09_rednoise_log10_A': Array(-27.20138098, dtype=float64, weak_type=True),
  'B1937+21_rednoise_gamma': Array(8.32201633e-08, dtype=float64, weak_type=True),
  'B1937+21_rednoise_log10_A': Array(1.39321327e-07, dtype=float64, weak_type=True),
  'B1953+29_rednoise_gamma': Array(-0.03332742, dtype=float64, weak_type=True),
  'B1953+29_rednoise_log10_A': Array(-0.07797149, dtype=float64, weak_type=True),
  'crn_gamma': Array(68.46449385, dtype=float64, weak_type=True),
  'crn_log10_A': Array(44.07745783, dtype=float64, weak_type=True)})

#### Add global spatially correlated process

Note `ds.makeglobalgp_fourier` requires the ORF, but not the `common` specification, which is automatic.

In [1198]:
m = ds.ArrayLikelihood([ds.PulsarLikelihood([psr.residuals,
                                             ds.makenoise_measurement(psr, psr.noisedict),
                                             ds.makegp_ecorr(psr, psr.noisedict),
                                             ds.makegp_timing(psr, svd=True)]) for psr in psrs],
                       commongp = ds.makecommongp_fourier(psrs, ds.powerlaw, components=30, T=T, name='rednoise'),
                       globalgp = ds.makeglobalgp_fourier(psrs, ds.powerlaw, ds.hd_orf, components=14, T=T, name='gw'))

In [1199]:
p0 = ds.sample_uniform(m.logL.params)

In [1194]:
m.logL(p0)

Array(473547.1319047, dtype=float64)

In [1135]:
jax.jit(m.logL)(p0), jax.jit(jax.grad(m.logL))(p0)

(Array(473775.23128436, dtype=float64),
 {'B1855+09_rednoise_gamma': Array(2.34714251, dtype=float64, weak_type=True),
  'B1855+09_rednoise_log10_A': Array(3.8833717, dtype=float64, weak_type=True),
  'B1937+21_rednoise_gamma': Array(1.72711408, dtype=float64, weak_type=True),
  'B1937+21_rednoise_log10_A': Array(3.28187406, dtype=float64, weak_type=True),
  'B1953+29_rednoise_gamma': Array(-1.28797504, dtype=float64, weak_type=True),
  'B1953+29_rednoise_log10_A': Array(-2.50672347, dtype=float64, weak_type=True),
  'gw_gamma': Array(36.28119423, dtype=float64, weak_type=True),
  'gw_log10_A': Array(-33.7434324, dtype=float64, weak_type=True)})

#### Another way of doing this (useful if variable GPs differ among pulsars)

In [1195]:
m = ds.GlobalLikelihood([ds.PulsarLikelihood([psr.residuals,
                                             ds.makenoise_measurement(psr, psr.noisedict),
                                             ds.makegp_ecorr(psr, psr.noisedict),
                                             ds.makegp_timing(psr, svd=True),
                                             ds.makegp_fourier(psr, ds.powerlaw, components=30, name='rednoise')]) for psr in psrs],
                        globalgp = ds.makeglobalgp_fourier(psrs, ds.powerlaw, ds.hd_orf, components=14, T=T, name='gw'))

In [1196]:
m.logL.params

['B1855+09_rednoise_gamma',
 'B1855+09_rednoise_log10_A',
 'B1937+21_rednoise_gamma',
 'B1937+21_rednoise_log10_A',
 'B1953+29_rednoise_gamma',
 'B1953+29_rednoise_log10_A',
 'gw_gamma',
 'gw_log10_A']

In [1139]:
p0 = ds.sample_uniform(m.logL.params)

In [1197]:
m.logL(p0)

Array(473547.35835491, dtype=float64)

In [1141]:
jax.jit(m.logL)(p0), jax.jit(jax.grad(m.logL))(p0)

(Array(473586.95745345, dtype=float64),
 {'B1855+09_rednoise_gamma': Array(-12.48407864, dtype=float64, weak_type=True),
  'B1855+09_rednoise_log10_A': Array(-59.6615166, dtype=float64, weak_type=True),
  'B1937+21_rednoise_gamma': Array(-8.29678319, dtype=float64, weak_type=True),
  'B1937+21_rednoise_log10_A': Array(-127.74213009, dtype=float64, weak_type=True),
  'B1953+29_rednoise_gamma': Array(0.0042981, dtype=float64, weak_type=True),
  'B1953+29_rednoise_log10_A': Array(0.0238647, dtype=float64, weak_type=True),
  'gw_gamma': Array(2.91228473, dtype=float64, weak_type=True),
  'gw_log10_A': Array(6.95787697, dtype=float64, weak_type=True)})