In [2]:
import jax.numpy as jnp
import jax.scipy as jsp
from jax.scipy.special import factorial as fac

JAX does factorial for tensors, or jax arrays, on its own! No more need for `tens_factorial`.

JAX vectorizes everything, so functions using `jnp` arrays are automatically tensorized. `jax_binom_coeff` turns `tens_binom` into (basically) one line:

In [3]:
def jax_binom_coeff(n, k):
    """Returns binomial coefficient, `n choose k`.
    
    Arguments
    ---------
    n: int
      number of possibilities
    k: int
      number of unordered outcomes to choose
    """

    coeff = fac(n) / (fac(k) * fac(n-k)) 

    check_nk = jnp.isinf(coeff) # check if k > n, set answer to zero

    coeff = coeff.at[jnp.array(check_nk)].set(0)

    return coeff

In [4]:
def sp_w_ylm(s, l, m):
    """Returns spin-weighted (l, m) spherical harmonic with spin weight s, as a function of cosi.
     
    Arguments
    ---------
    s: int
        spin weight of the sYlm, -2 for GWs
    l: JAX array
        l index of each QNM to be included in the model
    m: JAX array
        l index of each QNM to be included in the model
    """
    
    ## with phi = 0
    ## th == np.arccos(cosi)
    
    r = l - s
    
    def _get_rs(r):
        shape = len(r)
        leng = max(r) + 1
        coll = jnp.broadcast_to(jnp.arange(leng), (shape, leng)).swapaxes(0,1)

        return coll
    
    rs = _get_rs(r)
    
    def sin_th_2(cosi):
        return jnp.sqrt((1-cosi)/2)
    
    def cos_th_2(cosi):
        return jnp.sqrt((1+cosi)/2)
    
    def cot_th_2(cosi):
        return cos_th_2(cosi)/sin_th_2(cosi)

    ## normalization constant is \sqrt{1/0.159)} ##
    
    return lambda cosi: ((-1)**(l+m-s) * jnp.sqrt(fac(l+m)*
                                                  fac(l-m)*
                                                  ((2*l)+1)/
                                                  (4*jnp.pi)/
                                                  fac(l+s)/
                                                  fac(l-s)) * (sin_th_2(cosi))**(2*l) * jnp.sum(jnp.array([(-1)**n * jax_binom_coeff(l-s, n) * jax_binom_coeff(l+s, n+s-m) * (cot_th_2(cosi))**((2*n)+s-m) 
                                                                                                                            for n in rs]), axis=0)* jnp.sqrt(1/0.159))

In [5]:
print(sp_w_ylm(-2, jnp.array([3]), jnp.array([2]))(0.5))
print(sp_w_ylm(-2, jnp.array([2]), jnp.array([1]))(0.5))
print(sp_w_ylm(-2, jnp.array([3, 2]), jnp.array([2, 1]))(0.5))

[-0.52642703]
[1.0274793]
[-0.52642703  1.0274793 ]


In [9]:
def sp_w_ylm_p(cosi, l, m):
    """Returns (+) spin-weighted angular factor"""
    return (sp_w_ylm(-2, l, m)(cosi) + sp_w_ylm(-2, l, m)(-cosi)) * jnp.sqrt(5/jnp.pi) # sqrt(5/pi) to match normalization in Isi & Farr (2021)

In [10]:
def sp_w_ylm_c(cosi, l, m):
    """Returns (x) spin-weighted angular factor"""
    return (sp_w_ylm(-2, l, m)(cosi) - sp_w_ylm(-2, l, m)(jnp.cos(jnp.pi-jnp.arccos(cosi)))) * jnp.sqrt(5/jnp.pi)

The above matches the the SWSH calculations that were written up using `aesara`! Ready for implementation in `ringdown` jaxify branch.

In [8]:
cosi = 0.5
l = jnp.array([2])
m = jnp.array([2])

p = sp_w_ylm_p(cosi, l, m)
c = sp_w_ylm_c(cosi, l, m)

hcp = (1 + (cosi)**2)
hcc = 2 * cosi

In [9]:
print(f'JAX (+) angular factor : hardcoded (+) angular factor = {(p/hcp)[0]}')
print(f'JAX (x) angular factor : hardcoded (x) angular factors = {(c/hcc)[0]}')

JAX (+) angular factor : hardcoded (+) angular factor = 0.997840404510498
JAX (x) angular factor : hardcoded (x) angular factors = 0.9978403449058533


In [22]:
num = jnp.abs(sp_w_ylm(-2, jnp.array([4, 3, 2, 2]), jnp.array([4, 2, 1, 0]))(0.5)) - jnp.abs(jnp.conjugate(sp_w_ylm(-2, jnp.array([4, 3, 2, 2]), jnp.array([4, 2, 1, 0]))(-0.5)))
den = jnp.abs(sp_w_ylm(-2, jnp.array([4, 3, 2, 2]), jnp.array([4, 2, 1, 0]))(0.5)) + jnp.abs(jnp.conjugate(sp_w_ylm(-2, jnp.array([4, 3, 2, 2]), jnp.array([4, 2, 1, 0]))(-0.5)))

In [23]:
num/den

Array([8.0000001e-01, 1.2500052e-01, 5.0000006e-01, 8.2039314e-08],      dtype=float32)

In [24]:
sp_w_ylm_c(0.5, jnp.array([4, 3, 2, 2]), jnp.array([4, 2, 1, 0]))/sp_w_ylm_p(0.5, jnp.array([4, 3, 2, 2]), jnp.array([4, 2, 1, 0]))

Array([8.0000001e-01, 1.2500054e-01, 5.0000006e-01, 8.2039314e-08],      dtype=float32)

In [5]:
swsh = sp_w_ylm(-2, jnp.array([3, 2]), jnp.array([2, 1]))

In [8]:
(swsh(0.3) - swsh(-0.3)) * jnp.sqrt(5/jnp.pi)

Array([-0.25856543,  0.5711278 ], dtype=float32)

In [11]:
sp_w_ylm_c(0.3, jnp.array([3, 2]), jnp.array([2, 1]))

Array([-0.25856543,  0.5711278 ], dtype=float32)

In [37]:
cosi_min = 0.3
cosi_max = None

if cosi_min is None and cosi_max is not None:
    cosi_min=-0.99
if cosi_min is not None and cosi_max is None:
    cosi_max = 0.99

print(f'cosi_min: {cosi_min}, cosi_max: {cosi_max}')

cosi_min: 0.3, cosi_max: 0.99
