In [1]:
import jax.numpy as jnp
from jax import config
import jax

config.update("jax_enable_x64", True)

In [32]:
class taylorf2:
    def __init__(self, injection):
        self.injection = injection 

        # Frequency grid
        self.fmin = 10
        self.fmax = 1000
        self.n_bins = 1000
        self.frequency = jnp.linspace(10, 1000, num=self.n_bins+1)
        self.deltaf = (self.fmax - self.fmin) / self.n_bins # Grid spacing

        # Constants
        self.m_sun_sec = 4.92549094830932e-6
        self.PSD = 1e-40 * jnp.ones_like(self.frequency)

        # Data
        self.data = self.strain(self.injection, self.frequency)

    def strain(self, x, frequencies):
        time_coalescence = x[0]
        phase_coalescence = x[1]
        chirp_mass = x[2]
        symmetric_mass_ratio = x[3]
        Amplitude = x[4]

        # Redefined for cleaner expression
        f = frequencies
        eta = symmetric_mass_ratio
        Mc = chirp_mass
        phi = phase_coalescence
        
        expr = (Amplitude * jnp.exp(-1j * (-(jnp.pi/4) + 2 * f * jnp.pi * time_coalescence + (3 * (1 + jnp.pi**(2/3) * ((f * Mc * self.m_sun_sec) / eta**(3/5))**(2/3) * (3715/756 + (55 * eta)/9))) / (128 * jnp.pi**(5/3) * ((f * Mc * self.m_sun_sec) / eta**(3/5))**(5/3) * eta) - phi))) / f**(7/6)
        
        return expr

    def gradient_strain(self, x, frequencies):
        time_coalescence = x[0]
        phase_coalescence = x[1]
        chirp_mass = x[2]
        symmetric_mass_ratio = x[3]
        Amplitude = x[4]

        # Redefined for cleaner expression
        f = frequencies
        eta = symmetric_mass_ratio
        Mc = chirp_mass
        S = self.strain(x, frequencies)
        
        expr1 = -2j * f * jnp.pi * S
        expr2 = 1j * S
        expr3 = (5j * S * (252 + jnp.pi**(2/3) * ((f * Mc * self.m_sun_sec) / eta**(3/5))**(2/3) * (743 + 924 * eta))) / (32256 * Mc * jnp.pi**(5/3) * ((f * Mc * self.m_sun_sec) / eta**(3/5))**(5/3) * eta)
        expr4 = -((1j * S * (-743 + 1386 * eta)) / (16128 * f * Mc * self.m_sun_sec * jnp.pi * eta**(7/5)))
        expr5 = S / Amplitude
        
        return jnp.array([expr1, expr2, expr3, expr4, expr5])

    def inner_product(self, a, b, power_spectral_density, frequency_spacing):
        return (4 * jnp.sum(a.conjugate()[..., :-1] * b[..., :-1] / power_spectral_density[..., :-1] * frequency_spacing, axis=-1)).T

    def potential_single(self, x):
        residual = self.strain(x, self.frequency) - self.data
        return 0.5 * self.inner_product(residual, residual, power_spectral_density=self.PSD, frequency_spacing=self.deltaf).real

    def gradient_potential_single(self, x):
        residual = self.strain(x, self.frequency) - self.data
        gradient_residual = self.gradient_strain(x, self.frequency)
        return self.inner_product(gradient_residual, residual, power_spectral_density=self.PSD, frequency_spacing=self.deltaf).real

In [33]:
injection = jnp.array([0, 0, 30.0, 0.24, 2e-22]) # For gradient testing

model = taylorf2(injection)

In [34]:
import numpy as np

x = model.injection + np.random.uniform(low=0, high=0.0001, size=5)

# The model and its derivative are calculated correctly
test1 = jax.jacfwd(model.strain)(x,10)
test2 = model.gradient_strain(x, 10)

print(test1)
print(test2)

test3 = jax.jacrev(model.potential_single)(x)
test4 = model.gradient_potential_single(x)

# Last component agrees, but all others disagree
print(test3)
print(test4)



[ 3.14164900e-04-1.59343031e-04j -5.00008968e-06+2.53602310e-06j
 -5.99001132e-05+3.03810692e-05j -2.02051009e-04+1.02479367e-04j
  3.08175699e-02+6.07607293e-02j]
[ 3.14164900e-04-1.59343031e-04j -5.00008968e-06+2.53602310e-06j
 -5.99001132e-05+3.03810692e-05j -2.02051009e-04+1.02479367e-04j
  3.08175699e-02+6.07607293e-02j]
[ 4.33276945e+15 -3.36628128e+12  6.51655926e+14  1.60442535e+15
  1.22195271e+35]
[-1.16566675e+16 -9.15201910e+07 -1.00170435e+15 -1.65605575e+15
  1.22195271e+35]


In [None]:
injection = jnp.array[1 + 1j, 2 + 2j, 3 + 3j, 4 + 4j, 5 + 5j]

grid = jnp.linspace(0, 1000, 1000)
def model(x):
    phase = 1 + x[0] + x[1]
    return jnp.exp(-1j * phase)

In [50]:
class taylorf2:
    def __init__(self, injection):
        self.injection = injection 

        # Frequency grid
        self.frequency = jnp.linspace(10, 1000, num=1000)

        # Constants
        self.m_sun_sec = 1

        # Data
        self.data = self.strain(self.injection, self.frequency)

    def strain(self, x, frequencies):
        time_coalescence = x[0]
        phase_coalescence = x[1]
        chirp_mass = x[2]
        symmetric_mass_ratio = x[3]
        Amplitude = x[4]

        # Redefined for cleaner expression
        f = frequencies
        eta = symmetric_mass_ratio
        Mc = chirp_mass
        phi = phase_coalescence
        
        expr = (Amplitude * jnp.exp(-1j * (-(jnp.pi/4) + 2 * f * jnp.pi * time_coalescence + (3 * (1 + jnp.pi**(2/3) * ((f * Mc * self.m_sun_sec) / eta**(3/5))**(2/3) * (3715/756 + (55 * eta)/9))) / (128 * jnp.pi**(5/3) * ((f * Mc * self.m_sun_sec) / eta**(3/5))**(5/3) * eta) - phi))) / f**(7/6)
        
        return expr

    def gradient_strain(self, x, frequencies):
        time_coalescence = x[0]
        phase_coalescence = x[1]
        chirp_mass = x[2]
        symmetric_mass_ratio = x[3]
        Amplitude = x[4]

        # Redefined for cleaner expression
        f = frequencies
        eta = symmetric_mass_ratio
        Mc = chirp_mass
        S = self.strain(x, frequencies)
        
        expr1 = -2j * f * jnp.pi * S
        expr2 = 1j * S
        expr3 = (5j * S * (252 + jnp.pi**(2/3) * ((f * Mc * self.m_sun_sec) / eta**(3/5))**(2/3) * (743 + 924 * eta))) / (32256 * Mc * jnp.pi**(5/3) * ((f * Mc * self.m_sun_sec) / eta**(3/5))**(5/3) * eta)
        expr4 = -((1j * S * (-743 + 1386 * eta)) / (16128 * f * Mc * self.m_sun_sec * jnp.pi * eta**(7/5)))
        expr5 = S / Amplitude
        
        return jnp.array([expr1, expr2, expr3, expr4, expr5])

    def inner_product(self, a, b):
        return jnp.sum(a.conjugate() * b, axis=-1).T

    def potential_single(self, x):
        residual = self.strain(x, self.frequency) - self.data
        return 0.5 * self.inner_product(residual, residual).real

    def gradient_potential_single(self, x):
        residual = self.strain(x, self.frequency) - self.data
        gradient_residual = self.gradient_strain(x, self.frequency)
        return self.inner_product(gradient_residual, residual).real

# Instantiate class
injection = jnp.array([0, 0, 30.0, 0.24, 2e-22])
model = taylorf2(injection)

# Get a point to evaluate at
import numpy as np
x = model.injection + np.random.uniform(low=0, high=0.0001, size=5)

# The model and its derivative are calculated correctly
test1 = jax.jacfwd(model.strain)(x, 10)
test2 = model.gradient_strain(x, 10)

print(test1)
print(test2)

test3 = jax.jacrev(model.potential_single)(x)
test4 = model.gradient_potential_single(x)

# Last component agrees, but all others disagree
print(test3)
print(test4)

[ 1.80947897e-04-1.82906972e-04j -2.87987522e-06+2.91105486e-06j
 -2.69984327e-11+2.72907375e-11j -5.73301800e-10+5.79508786e-10j
  4.84333049e-02+4.79145469e-02j]
[ 1.80947897e-04-1.82906972e-04j -2.87987522e-06+2.91105486e-06j
 -2.69984327e-11+2.72907375e-11j -5.73301800e-10+5.79508786e-10j
  4.84333049e-02+4.79145469e-02j]
[-1.38430797e-25  2.26594132e-27  1.88105680e-32  3.99505513e-31
  2.25167808e-06]
[-2.36178644e-25  0.00000000e+00  7.96164087e-33 -4.00639860e-31
  2.25167808e-06]


In [51]:
x

Array([8.12576959e-05, 2.47812873e-06, 3.00000972e+01, 2.40001650e-01,
       6.01044027e-05], dtype=float64)

In [48]:
test2 = model.gradient_strain(x, jnp.array([10., 20.]))

In [49]:
test2.shape

(5, 2)

In [52]:
class taylorf2:
    def __init__(self, injection):
        self.injection = injection 

        # Frequency grid
        self.frequency = jnp.linspace(10, 1000, num=1000)

        # Constants
        self.m_sun_sec = 1

        # Data
        self.data = self.strain(self.injection, self.frequency)

    def strain(self, x, frequencies):
        time_coalescence = x[0]
        phase_coalescence = x[1]
        chirp_mass = x[2]
        symmetric_mass_ratio = x[3]
        Amplitude = x[4]

        # Redefined for cleaner expression
        f = frequencies
        eta = symmetric_mass_ratio
        Mc = chirp_mass
        phi = phase_coalescence
        
        expr = (Amplitude * jnp.exp(-1j * (-(jnp.pi/4) + 2 * f * jnp.pi * time_coalescence + (3 * (1 + jnp.pi**(2/3) * ((f * Mc * self.m_sun_sec) / eta**(3/5))**(2/3) * (3715/756 + (55 * eta)/9))) / (128 * jnp.pi**(5/3) * ((f * Mc * self.m_sun_sec) / eta**(3/5))**(5/3) * eta) - phi))) / f**(7/6)
        
        return expr

    def gradient_strain(self, x, frequencies):
        time_coalescence = x[0]
        phase_coalescence = x[1]
        chirp_mass = x[2]
        symmetric_mass_ratio = x[3]
        Amplitude = x[4]

        # Redefined for cleaner expression
        f = frequencies
        eta = symmetric_mass_ratio
        Mc = chirp_mass
        S = self.strain(x, frequencies)
        
        expr1 = -2j * f * jnp.pi * S
        expr2 = 1j * S
        expr3 = (5j * S * (252 + jnp.pi**(2/3) * ((f * Mc * self.m_sun_sec) / eta**(3/5))**(2/3) * (743 + 924 * eta))) / (32256 * Mc * jnp.pi**(5/3) * ((f * Mc * self.m_sun_sec) / eta**(3/5))**(5/3) * eta)
        expr4 = -((1j * S * (-743 + 1386 * eta)) / (16128 * f * Mc * self.m_sun_sec * jnp.pi * eta**(7/5)))
        expr5 = S / Amplitude
        
        return jnp.array([expr1, expr2, expr3, expr4, expr5])

    def inner_product(self, a, b):
        return jnp.sum(a.conjugate() * b, axis=-1).T

    def potential_single(self, x):
        residual = self.strain(x, self.frequency) - self.data
        return 0.5 * self.inner_product(residual, residual).real

    def gradient_potential_single(self, x):
        residual = self.strain(x, self.frequency) - self.data
        gradient_residual = self.gradient_strain(x, self.frequency)
        return self.inner_product(gradient_residual, residual).real

# Instantiate class
injection = jnp.array([0, 0, 30.0, 0.24, 2e-22])
model = taylorf2(injection)

# Get a point to evaluate at
import numpy as np
x = model.injection + np.random.uniform(low=0, high=0.0001, size=5)

# The model and its derivative are calculated correctly
test1 = jax.jacfwd(model.strain)(x, 10)
test2 = model.gradient_strain(x, 10)

print(test1)
print(test2)

test3 = jax.jacrev(model.potential_single)(x)
test4 = model.gradient_potential_single(x)

# Last component agrees, but all others disagree
print(test3)
print(test4)

[ 1.55899046e-04-1.57640696e-04j -2.48121038e-06+2.50892960e-06j
 -2.32607483e-11+2.35206094e-11j -4.93860040e-10+4.99377273e-10j
  4.84414772e-02+4.79062848e-02j]
[ 1.55899046e-04-1.57640696e-04j -2.48121038e-06+2.50892960e-06j
 -2.32607483e-11+2.35206094e-11j -4.93860040e-10+4.99377273e-10j
  4.84414772e-02+4.79062848e-02j]
[ 9.56444786e-26 -4.32936726e-28 -7.86066186e-35 -1.71946360e-33
  1.94030995e-06]
[-5.93223401e-26  0.00000000e+00 -3.71773745e-33  9.37662926e-32
  1.94030995e-06]


In [53]:
# This yields overflow ever
import numdifftools as nd
nd.Gradient(model.potential_single)(x)

array([-5.01416622e-27,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        1.94030995e-06])