Alexander W. Criswell 8/5/25

The idea here to implement the most bare-bones implementation of the formalism we developed at the Sprint, ignoring all realistic aspects, just so we can get the paper out without spinning our wheels on some of the finer details of Global Fit implementation. Here I am taking the bare-bones model and doing some basic optimization/acceleration before testing the model in Eryn.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
# import jax.numpy as jnp
import jax.numpy as xp
# import jax.scipy.stats as xst
# import cupy.random as cst
import scipy.stats as xst
import jax; jax.config.update("jax_enable_x64", True)
from corner import corner
import legwork as lw
import astropy.units as u
from tqdm import tqdm
from math import factorial
plt.style.use('default')

In [2]:
jax.devices()

[CudaDevice(id=0)]

In [3]:
## this is a very silly class to deal with how terrible the structure of jax.random is
## we are essentially creating an object which spoofs the original scipy structure
## and forces splitting the jax key so we don't just get the same draw repeatedly
class jax_stat_dist():

    def __init__(self,jax_module,init_key,**stat_kwargs):
        '''
        Class to port scipy's object-oriented functionality for statistical distributions for use with JAX's random modules.
    
        Arguments
        ---------------
        jax_module (jax.random.xxx)   : A jax.random module corresponding to one of the typical statistical distributions.
        init_key (int or jaxprng key) : Initial random key for the distribution
        stat_kwargs                   : Keyward arguments for the statistical distribution (loc, scale, a, b, etc.)
        
        '''

        ## set base dist
        self.dist = jax_module

        ## set key
        if type(init_key) is int:
            self.key = jax.random.key(init_key)
        else:
            self.key = init_key

        ## set kwargs
        self.stat_kwargs = stat_kwargs

        return

    def draw_keys(self,N=2):
        keys = jax.random.split(self.key,N)
        self.key = keys[0]
        return keys[1:]
    
    def rvs(self,size):
        '''
        Wrap the rvs function.
        '''
        return self.dist.rvs(self.draw_keys(N=size),size)

In [4]:
testwrapper = jax_stat_dist(jax.random.normal,42,loc=2,scale=4)

In [5]:
# @sync_numerical_libs
def truncnorm(loc=0.0, scale=1.0, size=None, a=None, b=None):
    """Provide a vectorized truncnorm implementation that is compatible with cupy.

    Adapted from https://github.com/mattkinsey/bucky/blob/master/bucky/util/distributions.py

    The output is calculated by using the numpy/cupy random.normal() and
    truncted via rejection sampling. The interface is intended to mirror
    the scipy implementation of truncnorm.

    Parameters
    ----------
    loc:
    scale:
    size:
    a:
    b:

    Returns
    -------
    ndarray:
    """

    ret = jax.random.normal(loc, scale, size)
    ret = xp.atleast_1d(ret)
    if a is None:
        a = xp.array(-xp.inf)
    if b is None:
        b = xp.array(xp.inf)

    while True:
        valid = (ret > a) & (ret < b)
        if xp.atleast_1d(valid).all():
            return ret
        ret[~valid] = xp.atleast_1d(jax.random.normal(loc, scale, size))[~valid]



Some GW helper functions.

In [6]:
def get_mc(m_1,m_2):
    return (m_1*m_2)**(3/5) / (m_1+m_2)**(1/5)
def get_amp_freq(theta):
    m_1 = theta[0]*(1*u.Msun).to(u.kg).value ## to kg
    m_2 = theta[1]*(1*u.Msun).to(u.kg).value ## to kg
    d_L = theta[2]*(1*u.kpc).to(u.m).value ## to m
    a = theta[3]*(1*u.AU).to(u.m).value ## to m
    G = 6.6743e-11 ## m^3 kg^-1 s^-2
    c = 2.99792458e8 ## m/s
    amp = (8/xp.sqrt(5)) * (G**2/c**4) * (m_1*m_2)/(d_L*a)
    fgw = 1/xp.pi * xp.sqrt(G*(m_1+m_2)/a**3)
    return amp, fgw

Base classes for the conditional population priors. Straightforward, will only need to be swapped for cupy distributions from here: https://docs.cupy.dev/en/stable/reference/random.html

In [7]:
class HierarchicalPrior:
    
    '''
    Generic class to handle the population-informed priors.
    
    Arguments
    -------------
    prior_dict (dict) : Dictionary of priors given as {'parameter_name':prior_function,...}
    conditional_map (func) : Function which returns the population-dependent priors given in prior_dict
                             conditioned on the current values of the population parameters given as pop_theta
    kwargs : Any additional values needed by conditional map. These will be added as attributes of the 
             HierarchicalPrior object, such that passing keyward_1=kwarg_1 will set self.keyword_1 = kwarg_1.
    
    '''
    
    def __init__(self,prior_dict,conditional_map,rng,**kwargs):
        ## prior dict of the form {parameter_name:prior_func}
        self.prior_dict = prior_dict
        ## conditional map is a function to condition the above priors on the current values of the population priors
        self.conditional_map = conditional_map
        ## set rng
        self.rng = rng
        ## set any additional kwargs needed by conditional_map function as object attributes
        for kw in kwargs:
            setattr(self,kw,kwargs[kw])
        
        return
    
    def condition(self,pop_theta):
        
        self.conditional_dict = self.conditional_map(pop_theta,self.prior_dict)
        
        return

    def sample_conditional(self,N=1):

        theta = xp.empty((len(self.conditional_dict.keys()),N))
        for i, key in enumerate(self.conditional_dict.keys()):
            theta = theta.at[i,:].set(self.conditional_dict[key].rvs(N,random_state=self.rng))
        return theta
        

class GalacticBinaryPrior(HierarchicalPrior):
    '''
    Population-informed GB prior. Assumes:
    - Gaussian-distributed masses
    - Power-law distributed orbital separations
    - Uniformly distributed inclinations (uniform in cos(i); not population-dependent)
    - (for now) broad Gaussian-distributed distances (TODO: update to an analytic Galaxy model)
    - (TODO: add sky localization parameters)
    - (TODO: add fdot)
    '''
    
    def __init__(self,rng):
        
        self.prior_dict = {'m_1':xst.truncnorm, ## in Msun
                           'm_2':xst.truncnorm, ## in Msun
                           # 'd_L':st.truncnorm, ## in kpc
                           'd_L':xst.gamma, ## in kpc
                           'a':xst.powerlaw ## in AU
        }
        
        ## set minimum allowed distance in kpc
        self.d_min = 1e-3 ## no GBs closer than the closest known star
        self.a_min = 1e-4 ## no binaries with a semimajor axis comparable to their radius
        self.a_max = 1e-2 ## no binaries outside of LISA's frequency range
        self.m_min = 0.17 ## lowest-mass observed white dwarf
        self.m_max = 1.44 ## no WDs with mass above the Chandrasekar limit

        ## store rng
        self.rng = rng
        
        return
    
    def condition(self,pop_theta):
        '''
        Condition the resolved GB parameters on the population parameters.
        
        Arguments:
        ---------------
        pop_theta (dict) : The population parameter chains as produced by Eryn. Keys are population parameter names.
        '''
        
        self.conditional_dict = {}
        ## condition mass prior on current pop values for the mean and standard deviation
        #scipy's truncnorm definition truncates by the number of sigmas, not at a value
        ## we have now changed this to just the actual truncation values since we had to redefine truncnorm
        m_trunc_low = (self.m_min - pop_theta['m_mu'][-1]) #/pop_theta['m_sigma'][-1]
        m_trunc_high = (self.m_max - pop_theta['m_mu'][-1]) #/pop_theta['m_sigma'][-1]
        self.conditional_dict['m_1'] = self.prior_dict['m_1'](a=m_trunc_low,
                                                              b=m_trunc_high,
                                                              loc=pop_theta['m_mu'][-1],
                                                              scale=pop_theta['m_sigma'][-1])
        ## m1 and m2 should come from the same distribution; we can label-switch later if we need to assert m1>m2.
        self.conditional_dict['m_2'] = self.prior_dict['m_2'](a=m_trunc_low,
                                                              b=m_trunc_high,
                                                              loc=pop_theta['m_mu'][-1],
                                                              scale=pop_theta['m_sigma'][-1])
        ## ensure minimum distance is preserved; 
        ## scipy's truncnorm definition truncates by the number of sigmas, not at a value
        # d_trunc = (self.d_min - pop_theta['d_mu'][-1])/pop_theta['d_sigma'][-1] 
        # self.conditional_dict['d_L'] = self.prior_dict['d_L'](a=d_trunc,
        #                                                       b=xp.inf,
        #                                                       loc=pop_theta['d_mu'][-1],
        #                                                       scale=pop_theta['d_sigma'][-1]
        #                                                       )
        self.conditional_dict['d_L'] = self.prior_dict['d_L'](a=pop_theta['d_gamma_a'],
                                                              scale = pop_theta['d_gamma_b']
                                                              )
        ## condition semimajor axis prior
        ## NOTE: I am defining this as p(a) ~ a^{alpha}
        ## adding 1 because scipy defines the power law as p(a) ~ a^{alpha - 1} for some reason
        self.conditional_dict['a'] = self.prior_dict['a'](pop_theta['a_alpha']+1,
                                                          loc=self.a_min, ## minimum
                                                          scale=self.a_max ## maximum
                                                         )
        return

Basic likelihood building blocks. This needs to be cleaned up and slimmed down.

In [8]:
## make some basic faux likelihoods for the GBs
class Likelihood():
    '''
    Base class for the analytic likelihood methods.
    '''

    def const_covar_gaussian_logpdf(self, theta, mu_vec, cov):
        """
        Compute log N(x_i; mu_i, sigma_i) for each x_i, mu_i, sigma_i.
        From Daniel W. on StackOverflow (https://stackoverflow.com/questions/48686934/numpy-vectorization-of-multivariate-normal)
        Args:
            X : shape (n, d)
                Data points
            means : shape (n, d)
                Mean vectors
            covariances : shape (n, d)
                Diagonal covariance matrices
        Returns:
            logpdfs : shape (n,)
                Log probabilities
        """
        _, d = theta.shape
        constant = d * xp.log(2 * xp.pi)
        log_determinants = xp.log(xp.prod(xp.diag(cov)))
        deviations = theta - mu_vec
        inverses = 1/xp.diag(cov)
        return -0.5 * (constant + log_determinants + xp.sum(deviations * inverses * deviations, axis=1))

    def array_gaussian_logpdf(self, theta_vec, mu_vec, sigma):
        """
        Compute log N(x_i; mu_i, sigma_i) for each x_i, mu_i, sigma_i.
        From Daniel W. on StackOverflow (https://stackoverflow.com/questions/48686934/numpy-vectorization-of-multivariate-normal)
        Args:
            X : shape (n, d)
                Data points
            means : shape (n, d)
                Mean vectors
            covariances : shape (n, d)
                Diagonal covariance matrices
        Returns:
            logpdfs : shape (n,)
                Log probabilities
        """
        # d = theta.shape
        # constant = xp.log(2 * xp.pi)

        return - xp.sum((theta_vec - mu_vec)**2)/(2*sigma)


        
        # log_determinants = xp.log(xp.prod(xp.diag(cov)))
        # deviations = theta - mu_vec
        # inverses = 1/xp.diag(cov)
        # return -0.5 * (constant + log_determinants + xp.sum(deviations * inverses * deviations, axis=1))
    
    def vectorized_gaussian_logpdf(self, theta, mu_vec, cov_vec):
        """
        Compute log N(x_i; mu_i, sigma_i) for each x_i, mu_i, sigma_i.
        From Daniel W. on StackOverflow (https://stackoverflow.com/questions/48686934/numpy-vectorization-of-multivariate-normal)
        Args:
            X : shape (n, d)
                Data points
            means : shape (n, d)
                Mean vectors
            covariances : shape (n, d)
                Diagonal covariance matrices
        Returns:
            logpdfs : shape (n,)
                Log probabilities
        """
        _, d = theta.shape
        constant = d * xp.log(2 * xp.pi)
        log_determinants = xp.log(xp.prod(cov_vec, axis=1))
        deviations = theta - mu_vec
        inverses = 1 / cov_vec
        return -0.5 * (constant + log_determinants + xp.sum(deviations * inverses * deviations, axis=1))

class GB_Likelihood(Likelihood):
    '''
    GB analytic likelihood class
    '''

    def __init__(self,theta_true,cov,sigma_of_f=False):
        '''
        theta_true are the true simulated parameter values, of shape N_res x N_theta
        sigma is the N_theta x N_theta (N_theta x N_theta x N_f) or covariance matrix
        sigma_of_f (bool) : Whether the provided covariance is a function of frequency
        '''
        
        if not sigma_of_f:
            ## calculate the observed means with scatter from true vals
            self.mu_vec = xp.array([st.multivariate_normal.rvs(mean=theta_true[ii,:],
                                                               cov=cov,size=1) for ii in range(theta_true.shape[0])])
            self.cov = cov
            self.ln_prob = self.ln_prob_const_sigma
        else:
            self.mu_vec = st.multivariate_normal.rvs(mean=theta_true,cov=cov,size=1)
            self.cov_vec = cov
            self.ln_prob = self.ln_prob_sigma_of_f
            raise(NotImplementedError)
    
    # def ln_prob(self,theta):
    #     return -0.5*(theta - self.mu_vec).T @ xp.inv(self.cov) @ (theta - self.mu_vec)
    def ln_prob_const_sigma(self,theta):
        return self.const_covar_gaussian_logpdf(theta,self.mu_vec,self.cov)
    def ln_prob_sigma_of_f(self,theta):
        return self.vectorized_gaussian_logpdf(theta,self.mu_vec,self.cov_vec)

class Nres_Likelihood(Likelihood):
    '''
    N_res Poisson likelihood
    '''

    def __init__(self,N_res_obs):
        '''
        N_res_obs (Number of resolved binaries)
        '''

        self.N_res_obs = N_res_obs
        self.base_dist = xst.poisson(self.N_res_obs)
        self.ln_prob = self.ln_conditional_Poisson

    # def ln_conditional_Poisson(self,N_res_theta):

    #     return -self.N_res_obs + N_res_theta*xp.log(N_res_obs) - xp.log(factorial(N_res_theta))

    ## okay for now, just use the scipy stats one and take the log
    def ln_conditional_Poisson(self,N_res_theta):
        return xp.log(self.base_dist.pmf(N_res_theta))

class FG_Likelihood(Likelihood):
    '''
    Foreground analytic likelihood class
    '''

    def __init__(self,spec_data,cov,sigma_of_f=False):
        '''
        spec_data (foreground PSD)
        cov (!! needs to be diagonal and in units of log amplitude)
        '''
        
        if not sigma_of_f:
            ## calculate the observed means with scatter from true vals
            self.mu_vec = spec_data, #st.multivariate_normal.rvs(mean=spec_data,
                                    # cov=cov,size=1)
            self.cov = cov
            self.ln_prob = self.ln_prob_const_sigma
        else:
            self.mu_vec = theta_true ## st.multivariate_normal.rvs(mean=theta_true,cov=cov,size=1)
            self.cov_vec = cov
            self.ln_prob = self.ln_prob_sigma_of_f
            raise(NotImplementedError)
    
    # def ln_prob(self,theta):
    #     return -0.5*(theta - self.mu_vec).T @ xp.inv(self.cov) @ (theta - self.mu_vec)
    def ln_prob_const_sigma(self,theta_spec):
        return self.array_gaussian_logpdf(xp.log10(theta_spec),xp.log10(self.mu_vec),self.cov)
    def ln_prob_sigma_of_f(self,theta_spec):
        return self.vectorized_gaussian_logpdf(theta_spec,self.mu_vec,self.cov_vec)

And, finally, the population model itself. Needs parallelization, everything moved to cupy, and to be wrapped such that it can be passed to Eryn.

### IMPORTANT NOTE -- The version here does not return the specific indices of the resolved binaries for the sake of simple optimization. 

Depending on how we want to handle interaction with the resolved binaries, this may needto be rethought. That being said, I am at present increasingly convinced that the specific tail draws that represent resolved binaries are not actually useful beyond allowing us to marginalize over specific realizations to get statistics on $N_{\rm res}(f_i)$.

In [9]:
class PopulationHyperPrior():
    '''
    Class for the actual hyperparameters.
    '''

    def __init__(self,hyperprior_dict=None):

        '''.
        For now, set defaults but we can adjust later.
        '''

        if hyperprior_dict is None:

            hyperprior_dict = {'m_mu':xst.norm(loc=0.6,scale=0.05),
                               'm_sigma':xst.invgamma(5),
                               'd_gamma_a':xst.uniform(loc=1,scale=9), ## these are pretty arbitrary
                               'd_gamma_b':xst.uniform(loc=1,scale=9), ## these are pretty arbitrary
                               'a_alpha':xst.uniform(0.25,1.0)
                              }
        self.hyperprior_dict = hyperprior_dict
        return

    def sample(self,N=1):
        return {key:self.hyperprior_dict[key].rvs(size=N) for key in self.hyperprior_dict.keys()}

class PopModel():
    '''
    Class to house the overall population model.
    '''

    def __init__(self,Ntot,rng,fbins='default',Tobs=4*u.yr,Nsamp=1):
        
        self.hyperprior = PopulationHyperPrior()

        self.gbprior = GalacticBinaryPrior(rng)

        self.N = int(Ntot)

        if type(fbins) is str and fbins == 'default':
            self.bin_width = 1e-5
            dur_eff = 1/bin_width
            self.fbins = xp.arange(1e-4,5e-3,bin_width)
        else:
            self.fbins = fbins
            self.bin_width = self.fbins[1] - self.fbins[0]

        self.Tobs = Tobs.to(u.s).value

        self.approx_lisa_psd = lw.psd.lisa_psd(self.fbins*u.Hz,t_obs=self.Tobs*u.s,confusion_noise=None).value

        self.approx_lisa_rx = lw.psd.approximate_response_function(self.fbins*u.Hz,19.09*u.mHz).value

        self.Nsamp = Nsamp

        ## wrap the per-frequency bin and sort in xp.vectorize
        self.vectorized_f_sort = np.vectorize(self.f_sort_wrapper,excluded=[1,2,4,5,6])   
        
        return

    def construct_likelihood(self,data):
        '''
        Wrapper to build all the likelihoods
        '''

        fg_data = data['fg']
        fg_sigma =data['fg_sigma']
        N_res_data = data['Nres']

        self.construct_fg_likelihood(fg_data,fg_sigma)
        self.construct_Nres_likelihood(N_res_data)

        return
    
    def construct_fg_likelihood(self,fg_data,fg_sigma):
        '''
        Method to attach the foreground likelihood to the PopModel,
        '''

        self.fg_like = FG_Likelihood(fg_data,fg_sigma)
        self.fg_ln_prob = self.fg_like.ln_prob

        return

    def construct_Nres_likelihood(self,N_res_obs):
        '''
        Method to attach the Poisson likelihood for the number of resolved binaries to the PopModel
        '''
        self.Nres_like = Nres_Likelihood(N_res_obs)
        self.N_res_ln_prob = self.Nres_like.ln_prob

        return
        
        
            
    @staticmethod
    def rebin_calc_Nij(A, noisePSD, lowamp_PSD, wts, duration, duration_eff):
        '''
        Make the per-frequency SNR vector (dim 1xN_dwd)
        
        Arguments
        ------------
        A (float array)      : Sorted (ascending) DWD amplitudes
        noisePSD (float)     : Level of the noise PSD in the relevant frequency bin (i.e., S_n(f))
        lowamp_PSD (float)   : Level of the low-amplitude contribution to the foreground PSD in the relevant frequency bin
        wts (float or array) : weights from fiducial population (1 for now)
        duration (float)     : duration in seconds of the observing run. Assume 4 years in general.
        duration_eff (float) : Effective duration in seconds given the frequency binning, i.e. 1/f_bin_width
        '''
        return xp.sqrt(duration*A**2/((noisePSD + lowamp_PSD + duration_eff * (xp.cumsum(wts*A**2) - wts*A**2) )))

    @staticmethod
    def f_sort_wrapper(i,f_idx,dwd_amps,LISA_rx,wts,snr_thresh,compute_frac,duration,duration_eff,noisePSD):
        '''
        Wrapper function for the iteration over frequencies of the rapid sort/threshold algorithm.

        Arguments
        ---------------


        Returns
        ---------------
        foreground_amp_i (1 x Nf array)
        N_res_i (int)
        
        '''

        fbin_mask_i = xp.array(f_idx == i)
        fbin_amps_i = dwd_amps[fbin_mask_i]*xp.sqrt(LISA_rx) ## sqrt because we square the amplitudes to get Sgw
        fbin_sort_i = xp.argsort(fbin_amps_i)
        re_sort_i = xp.argsort(fbin_sort_i) ## this will allow us to later return to the original order
        sorted_fbin_amps_i = fbin_amps_i[fbin_sort_i]
        if len(sorted_fbin_amps_i) != 0:
            hightail_filt = sorted_fbin_amps_i > sorted_fbin_amps_i[int((1-compute_frac)*len(sorted_fbin_amps_i))]
            # print(xp.sum(hightail_filt)/len(sorted_fbin_amps_i))
            hightail_idx = xp.where(hightail_filt)
            lowamp_idx = xp.where(xp.invert(hightail_filt))
            # bin_amps_i[fbin_sort_i] > xp.quantile(fbin_amps_i[fbin_sort_i],0.9)
            lowamp_PSD = duration_eff*xp.sum(wts*sorted_fbin_amps_i[xp.invert(hightail_filt)]**2)
            # print(lowamp_PSD,noisePSD[i])
            
            high_tail = sorted_fbin_amps_i[hightail_filt]
            
            fbin_Nij = PopModel.rebin_calc_Nij(high_tail,noisePSD,lowamp_PSD,wts,duration,duration_eff)
            # if fbin_Nij.size > 0:
                # print(xp.max(fbin_Nij))
            res_mask_i = xp.zeros(len(sorted_fbin_amps_i),dtype='bool')
            res_mask_i = res_mask_i.at[hightail_idx].set(fbin_Nij>=snr_thresh)

            ## get the number of resolved binaries in this bin
            N_res_i = xp.sum(fbin_Nij>=snr_thresh)
            
            # print(xp.sum(res_mask_i))
            res_mask_i_resort = res_mask_i[re_sort_i]
            # fbin_res_list.append(dwd_idx[fbin_mask_i][res_mask_i_resort])
            
            foreground_amp_i = xp.sum(fbin_amps_i[xp.invert(res_mask_i_resort)]**2)
        else:
            foreground_amp_i = 0.0
            N_res_i = 0

        return foreground_amp_i, N_res_i
    
    def rebin_sort_threshold(self,binaries,fs,noisePSD,duration,LISA_rx,wts=1,snr_thresh=7,compute_frac=0.1):
        '''
        Function to bin by frequency, then for the vector of binaries in each frequency bin, sort them by amplitude.
        
        Arguments
        -----------
        binaries (dataframe) : df with binary info. Will rephrase arguments in terms of the specific needed components later.
        fs (float array) : data frequencies
        noisePSD  (float)     : Level of the noise PSD in the relevant frequency bin (i.e., S_n(f))
        LISA_rx (float or array) : Approximate LISA response function evaluated at fs_full
        wts (float or array) : weights from fiducial population (1 for now)
        snr_thresh (float)    : the SNR threshold to condition resolved vs. unresolved on
        quantile (float : Percent (from bottom) of sources in a given bin to assume are unresolved. Must be 0 < q < 1.
        
        Returns
        -----------
        foreground_amp (array) : Stochastic foreground from unresolved sources, evaluated at fs_full.
        N_res (int)            : Number of resolved DWDs
        res_idx (array)        : Indices of the binaries dataframe for resolved DWDs.
        unres_idx (array)      : Indices of the binaries dataframe for unresolved DWDs.
        '''
        # dwd_fs = xp.array(binaries['fs'])
        # dwd_amps = xp.array(binaries['hs'])
    
        dwd_fs = binaries[0,:]
        dwd_amps = binaries[1,:]
        
        dwd_idx = xp.arange(len(dwd_amps))
        ## constrain to frequencies where we have a noise curve
        fs_noise = fs ## lazy
        fs_full = fs ## lazy
        if fs_noise[0] == 0:
            fs_noise = fs_noise[1:]
            noisePSD = noisePSD[1:]
        noise_f_mask = (fs_full>=fs_noise.min()) & (fs_full<=fs_noise.max())
        fs_full = fs_full[noise_f_mask]
        ## find which noise frequency corresponds to each frequency bin
    #     noise_f_idx = xp.digitize(fs_full,fs_noise-(fs_noise[1]-fs_noise[0])/2)
        
        ## bin the binaries by frequency
        ## first, find which frequency bin each binary is in
        delf = fs_full[1] - fs_full[0]
        f_idx = xp.digitize(dwd_fs,fs_full+0.5*delf)
        duration_eff = 1/delf ## effective duration for new frequency resolution
        
        ## now created a ragged list of arrays of varying sizes, corresponding to N_dwd(f_i)
        ## each entry is an array containing the indices of the DWDs in that bin, sorted by ascending amplitude*
        ##     * under the current assumption of uniform responses, this is equivalent to sorting by the naive SNR
        ##       (!! -- we will need to refine this in future)
        # fbin_res_list = []
        # foreground_amp = xp.zeros(len(fs_full))
        # iter_range = len(fs_full)

        
        ## these should be 1 x Nf arrays, where Nf is the number of frequency bins
        foreground_amp, N_res = self.vectorized_f_sort(xp.arange(fs_full.shape[0]),f_idx,dwd_amps,
                                                  LISA_rx,wts,snr_thresh,compute_frac,duration,duration_eff,noisePSD)


        
    
        # ##unpack the binned list
        # res_idx = xp.array([],dtype=int)
        # for i, arr in enumerate(fbin_res_list):
        #     res_idx = xp.append(res_idx,arr)
        # N_res = len(res_idx)
        # unres_idx = xp.isin(dwd_idx,res_idx,invert=True)
        
        return foreground_amp, N_res

    
    def run_model(self,pop_theta=None):

        ## draw pop hyperparameters
        if pop_theta is None:
            pop_theta = self.hyperprior.sample(1)

        ## condition the astro parameter distributions on the hyperprior draw
        self.gbprior.condition(pop_theta)

        ## draw a sample galaxy
        galaxy_draw = self.gbprior.sample_conditional(self.N)

        ## convert to phenomenological space
        amp_draws, fgw_draws = get_amp_freq(galaxy_draw)

        ## form array
        obs_draws = xp.array([fgw_draws,amp_draws])
        
        ## sort into resolved and unresolved binaries
        fg_sort, fs_sort, N_res, res_idx, unres_idx = self.rebin_sort_threshold(obs_draws,
                                                                           self.fbins,
                                                                           self.approx_lisa_psd,
                                                                           self.Tobs,
                                                                           self.approx_lisa_rx,
                                                                           wts=1,
                                                                           snr_thresh=7,
                                                                           compute_frac=0.2)
        fg_psd = (self.Tobs / self.bin_width**(-1))*fg_sort

        ## lowest bin is not accurate, discard
        return self.fbins[1:], fg_psd[1:], N_res, res_idx

    def fg_N_ln_prob(self,pop_theta,return_spec=False):
        '''
        Function to get the model probability conditioned on only 
        the per-bin foreground amplitude and the total number of resolved binaries

        Eventually we can extend this to per-bin N_res
        '''
        # ## unpack data
        # N_res_obs = data['N_res']
        # fg_obs = data['fg']

        ## call the population model
        fbins, fg_psd, N_res = self.run_model(pop_theta)

        ## call the fg likelihood
        ln_p_fg = self.fg_ln_prob(fg_psd)

        ln_p_Nres = self.N_res_ln_prob(N_res)

        if return_spec:
            return xp.append(xp.append(fg_psd,N_res),xp.array([ln_p_fg + ln_p_Nres]))
        else:
            return ln_p_fg + ln_p_Nres

    def sample_likelihood(self,save_spec=False):

        ## the chain dimension should be the number of pop params + 2x the number of frequency bins + 1
        ## so we can store the pop params, the foreground spectra, N_res(f_i), and the likelihood
        Npar = len(self.hyperprior.hyperprior_dict)
        Nf = len(self.fbins)
        chain_dim = Npar + 2*Nf + 1
        
        new_chain = xp.empty((chain_dim,self.Nsamp)) ## last column is for the likelihood
        if hasattr(self,'chain'):
            self.chain = xp.append(self.chain,new_chain,axis=1)
        else:
            self.chain = new_chain

        
        if save_spec:
            for ii in tqdm(range(self.Nsamp)):
                draw = self.hyperprior.sample(1)
                self.chain = self.chain.at[:Npar,ii].set(xp.array([draw[key] for key in draw.keys()]).flatten())
                self.chain = self.chain.at[Npar:,ii].set(self.fg_N_ln_prob(draw,return_spec=True,return_N=True))
            return self.chain
        
        else:
            for ii in tqdm(range(self.Nsamp)):
                draw = self.hyperprior.sample(1)
                self.chain[:-1,ii] = xp.array([draw[key] for key in draw.keys()]).flatten()
                self.chain[-1,ii] = self.fg_N_ln_prob(draw)
        
            
            return self.chain

In [10]:
# test_rng = jax.random.key(42)
test_rng = np.random.default_rng(42)

In [14]:
## test frequency bins
bin_width = 1e-5
dur_eff = 1/bin_width
f_bins = xp.arange(1e-4,1e-3,bin_width)

In [15]:
test_popmodel = PopModel(1e7,test_rng,fbins=f_bins)

In [16]:
test_fg, N_res = test_popmodel.run_model()

E0805 14:43:16.164726    4105 hlo_lexer.cc:443] Failed to parse int literal: 13328992393196351600448
E0805 14:43:16.164777    4105 hlo_lexer.cc:443] Failed to parse int literal: 13328992393196351600448


KeyboardInterrupt: 