In [1]:
import numpy as np
from jax.config import config
config.update("jax_enable_x64", True)
config.update('jax_platform_name', 'gpu')
import jax.numpy as jnp
from jax import jit, vmap, lax, grad
from jax.test_util import check_grads
from jax.interpreters import ad, batching, xla
import jax

from caustics import ehrlich_aberth
from functools import partial

gpu_ops: <module 'caustics.gpu_ops' from '/home/fb90/miniforge3/envs/caustics/lib/python3.9/site-packages/caustics-0.0.0-py3.9-linux-x86_64.egg/caustics/gpu_ops.cpython-39-x86_64-linux-gnu.so'>


In [2]:
@jit
def min_zero_avoiding(x):
    x = jnp.sort(x)
    min_x = jnp.min(x)
    cond = min_x == 0.
    return jnp.where(cond, x[(x!=0).argmax(axis=0)], min_x)

@jit
def max_zero_avoiding(x):
    x = jnp.sort(x)
    max_x = jnp.max(x)
    cond = max_x == 0.
    return jnp.where(cond, -min_zero_avoiding(jnp.abs(x)), max_x)

@jit
def ang_dist(theta1, theta2):
    """Smallest angular distance between two angles."""
    diff1 = (theta1 - theta2) % (2*jnp.pi)
    diff2 = (theta2 - theta1) % (2*jnp.pi)
    return jnp.min(jnp.array([diff1, diff2]), axis=0)

@jit
def ang_diff(theta):
    """
    Angular distance between consecutive points. Last point of the output array
    is the distance between first and last point.
    """
    theta1 = theta
    theta2 = jnp.concatenate([theta[1:], jnp.atleast_1d(theta1[0])])
    return vmap(ang_dist)(theta1, theta2)

@jit
def add_angles(a, b):
    """a + b"""
    cos_apb = jnp.cos(a)*jnp.cos(b) - jnp.sin(a)*jnp.sin(b)
    sin_apb = jnp.sin(a)*jnp.cos(b) + jnp.cos(a)*jnp.sin(b)
    return jnp.arctan2(sin_apb, cos_apb)

@jit
def subtract_angles(a, b):
    """a - b"""
    cos_amb = jnp.cos(a)*jnp.cos(b) + jnp.sin(a)*jnp.sin(b)
    sin_amb = jnp.sin(a)*jnp.cos(b) - jnp.cos(a)*jnp.sin(b)
    return jnp.arctan2(sin_amb, cos_amb)      


In [4]:
@partial(jit, static_argnames=('N',))
def compute_polynomial_coeffs(w, a, e1, N=2):    
    wbar = jnp.conjugate(w)
    
    p_0 = -a**2 + wbar**2
    p_1 = a**2*w - 2*a*e1 + a - w*wbar**2 + wbar
    p_2 = 2*a**4 - 2*a**2*wbar**2 + 4*a*wbar*e1 - 2*a*wbar - 2*w*wbar
    p_3 = -2*a**4*w + 4*a**3*e1 - 2*a**3 + 2*a**2*w*wbar**2 - 4*a*w*wbar*e1 +\
        2*a*w*wbar + 2*a*e1 - a - w
    p_4 = -a**6 + a**4*wbar**2 - 4*a**3*wbar*e1 + 2*a**3*wbar + 2*a**2*w*wbar +\
        4*a**2*e1**2 - 4*a**2*e1 + 2*a**2 - 4*a*w*e1 + 2*a*w
    p_5 = a**6*w - 2*a**5*e1 + a**5 - a**4*w*wbar**2 - a**4*wbar + 4*a**3*w*wbar*e1 -\
        2*a**3*w*wbar + 2*a**3*e1 - a**3 - 4*a**2*w*e1**2 + 4*a**2*w*e1 - a**2*w

    p = jnp.array([p_0, p_1, p_2, p_3, p_4, p_5])
        
    return p


# Lens postion
a = 0.5*0.9

# Lens mass ratio
e1 = 0.8
e2 = 1. - e1
ncoeffs = 6

# Compute complex polynomial coefficients for each source position
w_points = jnp.linspace(0.39, 0.4, 5000).astype(jnp.complex128)
wgrid = w_points[:, None]
wgrid = wgrid[10, :][:, None] # select just one polynomial
coeffs = vmap(vmap(lambda w: compute_polynomial_coeffs(w, a, e1)))(wgrid).reshape(1, -1)
coeffs.shape

(1, 6)

In [5]:
# test_fn = lambda p: ehrlich_aberth(p)[2]
# check_grads(test_fn, (coeffs,), 1)

In [6]:
delta = jnp.zeros_like(coeffs)
delta = delta.at[0, np.random.randint(0, len(coeffs))].set(1e-07 + 2e-07j) # perturbation in random coefficient

df_finite_diff = ehrlich_aberth(coeffs + delta) - ehrlich_aberth(coeffs)
df_finite_diff

DeviceArray([ 1.44348505e-10+2.88697021e-10j,
              5.98220016e-06-6.45226648e-06j,
             -8.75070133e-06+9.14325318e-07j,
             -1.14397014e-07-2.28794450e-07j,
              8.38463628e-06+1.67703210e-05j], dtype=complex128)

In [7]:
f, df = jax.jvp(ehrlich_aberth, (coeffs,), (delta,))
df

DeviceArray([-6.47938339e-07-1.29587668e-06j,
             -2.65091845e-07+3.31006422e-07j,
              4.23860244e-07-1.34696231e-08j,
              1.32588365e-06+2.65176730e-06j,
             -1.98405184e-07-3.96810368e-07j], dtype=complex128)

In [8]:
ehrlich_aberth(coeffs) - f

DeviceArray([0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j], dtype=complex128)

In [26]:
@jax.custom_jvp
def fn(p):
    return ehrlich_aberth(p)

@fn.defjvp
def fn_jvp(args, tangents):
    p = args[0]
    dp = tangents[0]
        
    size = p.shape[0]  # number of polynomials
    deg = p.shape[1] - 1  # degree of polynomials

    # Roots
    z = ehrlich_aberth(p)  # shape (size * deg,)
    z = z.reshape((size, deg))  # shape (size, deg)

    # Evaluate the derivative of the polynomials at the roots
    p_deriv = vmap(jnp.polyder)(p)
    df_dz = vmap(lambda coeffs, root: jnp.polyval(coeffs, root))(p_deriv, z)

    def zero_tangent(tan, val):
        return lax.zeros_like_array(val) if type(tan) is ad.Zero else tan

    # The Jacobian of f with respect to coefficient p evaluated at each of the
    # roots. Shape (size, deg, deg + 1).
    df_dp = vmap(vmap(lambda z: jnp.power(z, jnp.arange(deg + 1)[::-1])))(z)

    # Jacobian of the roots multiplied by the tangents, shape (size, deg)
    dz = (
        vmap(
            lambda df_dp_i: jnp.sum(df_dp_i * zero_tangent(dp, p), axis=1),
            in_axes=1,  # vmap over all roots
        )(df_dp).T
        / (-df_dz)
    )
    
    return (
        z.reshape(-1),
        dz.reshape(-1),
    )

In [27]:
df_finite_diff2 = fn(coeffs + delta) - fn(coeffs) # change in root with idx 2 (finite difference)
df_finite_diff2

DeviceArray([ 1.44348505e-10+2.88697021e-10j,
              5.98220016e-06-6.45226648e-06j,
             -8.75070133e-06+9.14325318e-07j,
             -1.14397014e-07-2.28794450e-07j,
              8.38463628e-06+1.67703210e-05j], dtype=complex128)

In [28]:
f2, df2 = jax.jvp(fn, (coeffs,), (delta,))
df2

DeviceArray([ 1.44348511e-10+2.88697023e-10j,
              5.98211575e-06-6.45203640e-06j,
             -8.75089857e-06+9.14470759e-07j,
             -1.14397141e-07-2.28794281e-07j,
              8.38495081e-06+1.67699016e-05j], dtype=complex128)

In [12]:
check_grads(lambda p: fn(p)[2], (coeffs,), 1)

AssertionError: 
Not equal to tolerance rtol=1e-05, atol=1e-05
JVP tangent
Mismatched elements: 1 / 1 (100%)
Max absolute difference: 0.4168858
Max relative difference: 0.00310512
 x: array(-125.717262+47.900205j)
 y: array(-125.348052+48.093797j)