In [1]:
from jax import grad,jit, jacfwd
from matplotlib import pyplot as plt
import jax.numpy as jnp


# We will implement two versions of calculating gradient of L. One using autodifferentiation with jax, the other following the analytic formula we derived. In this way we can compare them and have a sanity check on the correctness our implementation.

In [128]:
from jax import grad,jit, jacfwd
from matplotlib import pyplot as plt
import jax.numpy as jnp

from functools import partial

def single_meas_func(C1,C0,k,b,dist):
    """
        The small h function, for each individual measurement.
    """
    return k*jnp.power(dist-C1,b)+C0


def joint_meas_func(C1s,C0s,ks,bs,q,ps):
    """
        The big H function, the array of all individual measurements.
    """

    # Casting for the compatibility of jax.numpy

    C1s=jnp.array(C1s)
    C0s=jnp.array(C0s)
    ks=jnp.array(ks)
    bs=jnp.array(bs)
    ps=jnp.array(ps)

    # Keep in mind that x is a vector of [q,q'], thus only the first half of components are observable.    
    dists=jnp.linalg.norm(q-ps,axis=1)

    return single_meas_func(C1s,C0s,ks,bs,dists) 


def FIM(q,ps,sigma,C1s,C0s,ks,bs):
    """
       The computation of Fish Information Matrix.
    """
    
    H=partial(joint_meas_func, C1s,C0s,ks,bs)
    
    # Taking partial derivative of H w.r.t. the zeroth argument, which is q.
    dHdq=jit(jacfwd(H,argnums=0))
    return 1/(jnp.power(sigma,2)) *  dHdq(q,ps).T.dot(dHdq(q,ps))

def L(q,ps,sigma,C1s,C0s,ks,bs):
    """
        The reward function big L. It is just det(FIM)
    """
    
    return jnp.linalg.det(FIM(q,ps,sigma,C1s,C0s,ks,bs))

In [187]:
import numpy as np
def analytic_L(q,ps,sigma,C1s,C0s,ks,bs):
    n_p=len(ps)
    r=jnp.linalg.norm(ps-q,axis=1).reshape(-1,1)
    r_hat=(ps-q)/r

    L=0
    for i in range(n_p):
        for j in range(n_p):
                
            rkrj=jnp.min([r_hat[i,:].dot(r_hat[j,:]),1])
            
            L+=(bs[i]*bs[j]*ks[i]*ks[j])**2 * (r[i]-C1s[i])**(2*bs[i]-2) * (r[j]-C1s[j])**(2*bs[j]-2) * (1-rkrj**2)
            
    L/=2*sigma**2
    
    return L[0]
def analytic_dLdp(q,ps,sigma,C1s,C0s,ks,bs):

    n_p=len(ps)
    r=np.linalg.norm(ps-q,axis=1).reshape(-1,1)
    r_hat=(ps-q)/r
    t_hat=np.zeros(rhat.shape)
    t_hat[:,0]=-r_hat[:,1]
    t_hat[:,1]=r_hat[:,0]

    dLdeta=np.zeros(n_p).reshape(-1,1)
    dLdr=np.zeros(n_p).reshape(-1,1)


    for i in range(n_p):
        Keta=2*(ks[i]*bs[i])**2/(sigma**2) * (r[i]-C1s[i])**(2*bs[i]-2)
        Kr=2*(ks[i]*bs[i])**2/(sigma**2) * (bs[i]-1) * (r[i]-C1s[i])**(2*bs[i]-3)
        sum_eta=sum_kr=0
        for j in range(n_p):
                
            rkrj=np.min([r_hat[i,:].dot(r_hat[j,:]),1])
            
            direction=np.sign(np.linalg.det(r_hat[[j,i],:]))

            sum_eta += (ks[j]*bs[j])**2 * (r[j]-C1s[j])**(2*bs[j]-2) * rkrj * np.sqrt(1-rkrj**2) * direction
            sum_kr += (ks[j]*bs[j])**2 * (r[j]-C1s[j])**(2*bs[j]-2) * (1-rkrj**2)
        
        dLdeta[i]=Keta*sum_eta
        dLdr[i]=Kr*sum_kr
        
    dLdp = dLdr * r_hat  + (dLdeta/r) * t_hat
    
    
    return dLdp

# A sanity check is when b=1, and the sensors are distributed around the target equi-angularly, then the Jacobians should be 0.

In [2]:
from dLdp import dLdp,analytic_dLdp
import jax.numpy as jnp
sigma=1.
C0s=jnp.array([0.,0.,0.])
C1s=jnp.array([0.1,0.1,0.1])
ks=jnp.array([1.5,0.5,1.])
bs=jnp.array([1.,1.,1.])

q=jnp.array([0.,0.])
ps=jnp.array([[1.,0],[0,1.],[-1,0.0]])

# print(L(q,ps,sigma,C1s,C0s,ks,bs))
# print(analytic_L(q,ps,sigma,C1s,C0s,ks,bs))

# print(jit(grad(analytic_L,argnums=1))(q,ps,sigma,C1s,C0s,ks,bs))
# print(jit(grad(L,argnums=1))(q,ps,sigma,C1s,C0s,ks,bs))

print(analytic_dLdp(q,ps,sigma,C1s,C0s,ks,bs))
print(dLdp(q,ps,sigma,C1s,C0s,ks,bs))


[[ 0.  0.]
 [ 0.  0.]
 [-0.  0.]]
[[0. 0.]
 [0. 0.]
 [0. 0.]]


# When b=-2, and the sensors are equi-angularly located, the gradient should all point towards the target.

In [3]:
sigma=1.
C0s=jnp.array([0.,0.,0.])
C1s=jnp.array([0.1,0.1,0.1])
ks=jnp.array([1.5,0.5,1.])
bs=jnp.array([-2.,-2.,-2.])

q=jnp.array([0.,0.])
ps=jnp.array([[0,1.],[1.,0],[-1,0]])

# print(L(q,ps,sigma,C1s,C0s,ks,bs))
# print(analytic_L(q,ps,sigma,C1s,C0s,ks,bs))

# print(jit(grad(analytic_L,argnums=1))(q,ps,sigma,C1s,C0s,ks,bs))
# print(jit(grad(L,argnums=1))(q,ps,sigma,C1s,C0s,ks,bs))

print(analytic_dLdp(q,ps,sigma,C1s,C0s,ks,bs))
print(dLdp(q,ps,sigma,C1s,C0s,ks,bs))

[[   -0.      -1062.2124 ]
 [ -212.44246     0.     ]
 [  849.76984    -0.     ]]
[[    0.      -1062.2124 ]
 [ -212.44247     0.     ]
 [  849.7699      0.     ]]


# When b=-2 and the sensors are not well separated, the gradient should make them separated.

In [4]:
sigma=1.
C0s=jnp.array([0.,0.,0.])
C1s=jnp.array([0.1,0.1,0.1])
ks=jnp.array([1.5,0.5,1.])
bs=jnp.array([-2.,-2.,-2.])

q=jnp.array([0.,0.])
ps=jnp.array([[0,1.],[0.1,1.],[-0.1,1]])

# print(L(q,ps,sigma,C1s,C0s,ks,bs))
# print(analytic_L(q,ps,sigma,C1s,C0s,ks,bs))

# print(jit(grad(analytic_L,argnums=1))(q,ps,sigma,C1s,C0s,ks,bs))
# print(jit(grad(L,argnums=1))(q,ps,sigma,C1s,C0s,ks,bs))


print(analytic_dLdp(q,ps,sigma,C1s,C0s,ks,bs))
print(dLdp(q,ps,sigma,C1s,C0s,ks,bs))

[[ 18.31319   -10.174047 ]
 [ 10.593998   -6.5559745]
 [-28.121792  -14.409895 ]]
[[ 18.313099  -10.173941 ]
 [ 10.593956   -6.5559235]
 [-28.121672  -14.409692 ]]


# So far the FIM gradients computed by jax and analytic formula are consistent. We can use either of them to do gradient calculation.