In [2]:
import jax.numpy as jnp
from jax.scipy.special import factorial as fac

import numpy as np

JAX does factorial for JAX arrays on its own! No more need for `tens_factorial`.

To calculate binomial coefficients, `jax_binom_coeff` turns `tens_binom` into (basically) one line:

In [3]:
def jax_binom_coeff(n, k):
    """Returns binomial coefficient, `n choose k`.
    
    Arguments
    ---------
    n: int
      number of possibilities
    k: int
      number of unordered outcomes to choose
    """

    coeff = round(fac(n)) / round((fac(k)) * round(fac(n-k)))

    check_nk = jnp.isinf(coeff) # check if k > n, set answer to zero

    coeff = jnp.where(check_nk, 0, coeff)

    return coeff

In [131]:
def sp_w_ylm(s, l, m):
    """Returns spin-weighted (l, m) spherical harmonic with spin weight s, as a function of cosi.
     
    Arguments
    ---------
    s: int
        spin weight of the sYlm, -2 for GWs
    l: JAX array
        l index of each QNM to be included in the model
    m: JAX array
        l index of each QNM to be included in the model
    """
    
    ## with phi = 0
    ## th == np.arccos(cosi)
    
    r = l - s
    
    def _get_rs(r):
        shape = len(r)
        leng = max(r) + 1
        coll = jnp.broadcast_to(jnp.arange(leng), (shape, leng)).swapaxes(0,1)

        return coll
    
    rs = _get_rs(r)
    
    def sin_th_2(cosi):
        return jnp.sqrt((1-cosi)/2)
    
    def cos_th_2(cosi):
        return jnp.sqrt((1+cosi)/2)
    
    def cot_th_2(cosi):
        return cos_th_2(cosi)/sin_th_2(cosi)
    
    swsh = lambda cosi: ((-1)**(l+m-s) * jnp.sqrt(fac(l+m)*
                                                  fac(l-m)*
                                                  ((2*l)+1)/
                                                  (4*jnp.pi)/
                                                  fac(l+s)/
                                                  fac(l-s)) * (sin_th_2(cosi))**(2*l) * jnp.sum(jnp.array([(-1)**n * jax_binom_coeff(l-s, n) * jax_binom_coeff(l+s, n+s-m) * (cot_th_2(cosi))**((2*n)+s-m) 
                                                                                                                            for n in rs]), axis=0))
    
    return lambda cosi: swsh(cosi) * np.sqrt(1/0.1587577) # normalization constant

In [101]:
def sp_w_ylm_p(cosi, l, m):
    """Returns (+) spin-weighted angular factor"""
    return (sp_w_ylm(-2, l, m)(cosi) + sp_w_ylm(-2, l, m)(-cosi)) * np.sqrt(5/np.pi) # sqrt(5/pi) to match normalization in Isi & Farr (2021)

In [6]:
def sp_w_ylm_c(cosi, l, m):
    """Returns (x) spin-weighted angular factor"""
    return (sp_w_ylm(-2, l, m)(cosi) - sp_w_ylm(-2, l, m)(jnp.cos(jnp.pi-jnp.arccos(cosi)))) * np.sqrt(5/np.pi)

Within `model.py`, it is more convenient to implement `sp_w_ylm_p` and `sp_w_ylm_c` to return functions of `cosi`, similar to `sp_w_ylm` itself:

In [102]:
def calc_Yp(cosi, swsh):
    """Returns (+) angular factor for aligned model"""
    return (swsh(cosi) + swsh(-cosi)) * jnp.sqrt(5/jnp.pi) # sqrt(5/pi) to match normalization in Isi & Farr (2021)

def calc_Yc(cosi, swsh):
    """Returns (x) angular factor for aligned model"""
    return (swsh(cosi) - swsh(-cosi)) * jnp.sqrt(5/jnp.pi)

The above matches the the SWSH calculations that were written up using `aesara`! And below, agreement with the hardcoded  $l=m=2$  angular factors from the original aligned model code. Ready for implementation in `ringdown` jaxify branch.

In [134]:
cosi = 0.5

l = jnp.array([2])
m = jnp.array([2])
swsh = sp_w_ylm(-2, l, m)

p = calc_Yp(cosi, swsh)
c = calc_Yc(cosi, swsh)

hcp = (1 + (cosi)**2)
hcc = 2 * cosi

In [135]:
print(f'JAX (+) angular factor : hardcoded (+) angular factor = {(p/hcp)}')
print(f'JAX (x) angular factor : hardcoded (x) angular factor = {(c/hcc)}')

JAX (+) angular factor : hardcoded (+) angular factor = [0.99860257]
JAX (x) angular factor : hardcoded (x) angular factor = [0.99860257]


In [136]:
print('Percent error compared to hardcoded (+) angular factor = {:3f}%'.format(((hcp-p)[0]/hcp)*100))
print('Percent error compared to hardcoded (x) angular factor = {:3f}%'.format(((hcc-c)[0]/hcc)*100))

Percent error compared to hardcoded (+) angular factor = 0.139742%
Percent error compared to hardcoded (x) angular factor = 0.139743%
