In [30]:
from jax import grad as jgrad
import jax.numpy as jnp
from jax.numpy import linalg as jnpLA
import numpy as np

from jax import config
config.update("jax_enable_x64", True)

def fun(x):
    F=jnp.prod(jnp.unique(x,size=4,fill_value=1))
    # F=jnp.unique(x,size=4,fill_value=1)
    return F

FD=jgrad(fun)
x=np.array( [-2.,-2.] )
h=1e-4
print(FD(x))
# print( (fun(x+h) - fun(x-h)) / (2*h) )

[1. 0.]


In [7]:
import adaptive as adp

import jax.numpy as jnp
from jax import grad
from jax.scipy.special import erf as jax_erf
from jax.numpy import linalg as jnpLA
from parameters import D1_ND, Parameters, Initial_bigrating
from twobox_updated import TwoBox

def FD(grating: TwoBox) -> float:
    """
    Calculate the grating single-wavelength figure of merit FD.

    Parameters
    ----------
    grating :           TwoBox instance containing the grating parameters
    gaussian_width :    Width of gaussian beam
    """
    
    Q1,Q2,PD_Q1_angle,PD_Q2_angle,PD_Q1_wavelength,PD_Q2_wavelength=grating.return_Qs()
    w=grating.gaussian_width
    w_bar=w/L

    # Starting wavelength set to 1
    lam=grating.wavelength  # needs to be lambda'

    D=1/lam 
    g=(jnp.power(lam,2) + 1)/(2*lam) 

    # Set-up (not sure about whether left or right makes sense - constraints)
    Q1R=Q1; Q2R=Q2; PD_Q1R_angle=PD_Q1_angle;   PD_Q2R_angle=PD_Q2_angle
    PD_Q1R_omega=(lam/D)*PD_Q1_wavelength;   PD_Q2R_omega=(lam/D)*PD_Q2_wavelength

    # Symmetry
    Q1L=Q1R 
    Q2L= - Q2R

    PD_Q1L_angle= - PD_Q1R_angle
    PD_Q2L_angle=PD_Q2R_angle

    PD_Q1L_omega=PD_Q1R_omega
    PD_Q2L_omega= - PD_Q2R_omega


    ####################################
    # y acc
    fy_y= -     D**2 * (I0/(m*c)) * ( Q2R - Q2L) * ( 1 - jnp.exp( -1/(2*w_bar**2) ))
    fy_phi= -   D**2 * (I0/(m*c)) * ( PD_Q2R_angle + PD_Q2L_angle) * (w/2) * jnp.sqrt( jnp.pi/2 ) * jax_erf( 1/(w_bar*jnp.sqrt(2)) )
    fy_vy= -    D**2 * (I0/(m*c)) * (D+1)/(D* (g+1)) * ( Q1R + Q1L + PD_Q1R_angle + PD_Q1L_angle ) * (w/2) * jnp.sqrt( np.pi/2 ) * jax_erf( 1/(w_bar*jnp.sqrt(2)) )
    fy_vphi=    D**2 * (I0/(m*c)) * ( 2*( Q2R - Q2L ) - D*( PD_Q2R_omega - PD_Q2L_omega ) ) * (w/2)**2 * ( 1 - jnp.exp( -1/(2*w_bar**2) ))

    ####################################
    # phi acc
    fphi_y=     D**2 * (12*I0/( m*c*L**2)) * ( Q1R + Q1L ) * (  (w/2)*jnp.sqrt( jnp.pi/2 )  * jax_erf( 1/(w_bar*jnp.sqrt(2)))  - (L/2)* jnp.exp( -1/(2*w_bar**2) )  ) 
    fphi_phi=   D**2 * (12*I0/( m*c*L**2)) * ( PD_Q1R_angle - PD_Q1L_angle - ( Q2R - Q2L ) ) * (w/2)**2 * ( 1 - jnp.exp( -1/(2*w_bar**2) ))
    fphi_vy=    D**2 * (12*I0/( m*c*L**2)) * ( PD_Q1R_angle - PD_Q1L_angle - ( Q2R - Q2L ) ) * (w/2)**2 * ( 1 - jnp.exp( -1/(2*w_bar**2) )) * (D+1)/(D* (g+1))
    fphi_vphi= -D**2 * (12*I0/( m*c*L**2)) * ( 2*( Q1R + Q1L ) - D*( PD_Q1R_omega + PD_Q1L_omega ) ) * (w/2)**2 * (  (w/2)*jnp.sqrt( np.pi/2 )  * jax_erf( 1/(w_bar*jnp.sqrt(2)))  - (L/2)* jnp.exp( -1/(2*w_bar**2) )  ) 

    # Build the Jacobian matrix
    J00=fy_y;   J01=fy_phi;     J02=fy_vy/c;    J03=fy_vphi/c
    J10=fphi_y; J11=fphi_phi;   J12=fphi_vy/c;  J13=fphi_vphi/c
    J=jnp.array([[0,0,1,0],[0,0,0,1],[J00,J01,J02,J03],[J10,J11,J12,J13]])

    # Find the real part of eigenvalues    
    EIGVALVEC=jnpLA.eig(J)
    eig=EIGVALVEC[0]
    EIGreal=jnp.real(eig)
    EIGimag=jnp.imag(eig)
    
    unique_real=jnp.unique(EIGreal,size=4,fill_value=1)
    # unique_imag=jnp.unique(EIGimag,size=4,fill_value=1)
    FD = jnp.prod( unique_real)
    
    return FD

def FD_params_func(grating, params):    # , symboxes: bool=False, onebox: bool=False
    grating_pitch, grating_depth, box1_width, box2_width, box_centre_dist, box1_eps, box2_eps, gaussian_width, substrate_depth, substrate_eps = params
    
    grating.grating_pitch = grating_pitch
    grating.grating_depth = grating_depth
    grating.box1_width = box1_width
    grating.box2_width = box2_width
    grating.box_centre_dist = box_centre_dist
    grating.box1_eps = box1_eps
    grating.box2_eps = box2_eps
    
    grating.gaussian_width=gaussian_width
    grating.substrate_depth = substrate_depth
    grating.substrate_eps = substrate_eps

    return FD(grating)

    # Does this take the gradient w.r.t. gaussian-width as well ?

FD_grad = grad(FD_params_func, argnums=1)


def FOM_uniform(grating: TwoBox, final_speed: float=20., goal: float=0.1, return_grad: bool=True) -> float:
    """
    Calculate wavelength expectation of FD FOM (figure of merit) for the given grating over a fixed wavelength range.
    Assumes a uniform probability distribution for wavelength.

    Parameters
    ----------
    grating     :   TwoBox instance containing the grating parameters
    final_speed :   Final sail speed as percentage of light speed
    goal        :   Stopping goal for wavelength integration passed to adaptive runner. If int, use npoints_goal; if float, use loss_goal.
    return_grad :   Return [FOM, FOM gradient]
    """
    laser_wavelength = grating.wavelength # copy the starting wavelength
    Doppler = D1_ND([final_speed/100,0])
    l_min = 1 # l = grating frame wavelength normalised to laser frame wavelength
    l_max = l_min/Doppler    
    l_range = (l_min, l_max)

    # Perturbation probability density function (PDF)
    PDF_unif = 1/(l_max-l_min)
    
    # Define a one argument function to pass to learner
    def weighted_FD(l):
        grating.wavelength = l*laser_wavelength
        return PDF_unif*FD(grating)
    
    # Adaptive sample FD
    FD_learner = adp.Learner1D(weighted_FD, bounds=l_range)
    if isinstance(goal, int):
        FD_runner = adp.runner.simple(FD_learner, npoints_goal=goal)
    elif isinstance(goal, float):
        FD_runner = adp.runner.simple(FD_learner, loss_goal=goal)
    else: 
        raise ValueError("Sampling goal type not recognised. Must be int for npoints_goal or float for loss_goal.")
    
    FD_data = FD_learner.to_numpy()
    l_vals = FD_data[:,0]
    weighted_FDs = FD_data[:,1]
    
    FOM = np.trapz(weighted_FDs,l_vals)

    if return_grad:
        """
        Should return FOM (average FD over wavelength) gradient at the given grating parameters.

        Implemented by first calculating the gradient at the grating parameters then averaging the gradient over wavelength
        """
        
        # Need to copy the following immutable parameters to pass to FD_grad, otherwise get UFuncTypeError
        grating_pitch = grating.grating_pitch
        grating_depth = grating.grating_depth
        box1_width = grating.box1_width
        box2_width = grating.box2_width
        box_centre_dist = grating.box_centre_dist
        box1_eps = grating.box1_eps
        box2_eps = grating.box2_eps
        gaussian_width=grating.gaussian_width
        substrate_depth=grating.substrate_depth
        substrate_eps=grating.substrate_eps        

        params = [grating_pitch, grating_depth, box1_width, box2_width, box_centre_dist, box1_eps, box2_eps,
                  gaussian_width, substrate_depth, substrate_eps] 
        
        # Define a one argument function to pass to learner
        def weighted_FD_grad(l):
            grating.wavelength = l*laser_wavelength
            return PDF_unif*np.array(FD_grad(grating, params))

        # Adaptive sample FD_grad
        FD_grad_learner = adp.Learner1D(weighted_FD_grad, bounds=l_range)

        if isinstance(goal, int):
            FD_grad_runner = adp.runner.simple(FD_grad_learner, npoints_goal=goal)
        elif isinstance(goal, float):
            FD_grad_runner = adp.runner.simple(FD_grad_learner, loss_goal=goal)
        
        FD_grad_data = FD_grad_learner.to_numpy()
        l_vals = FD_grad_data[:,0]
        weighted_FD_grads = FD_grad_data[:,1:]
        
        FOM_grad = np.trapz(weighted_FD_grads,l_vals, axis=0)

        grating.wavelength = laser_wavelength # restore user-initialised wavelength
        return [FOM,FOM_grad]
    else:
        grating.wavelength = laser_wavelength # restore user-initialised wavelength
        return FOM


# GLOBAL OPTIMISATION ###########################################################################
## FIXED PARAMETERS ##
wavelength = 1. # Laser wavelength
angle = 0.
Nx = 100 # Number of grid points
nG = 25 # Number of Fourier components

# relaxation parameter, should be infinite unless you need to avoid singular matrix at grating cutoffs
# Also, optimiser finds large magnitude, noisy rNeg1 when Qabs = np.inf 
Qabs = 1e7


I0,L,m,c=Parameters()
grating_pitch, grating_depth, box1_width, box2_width, box_centre_dist, box1_eps, box2_eps, gaussian_width, substrate_depth, substrate_eps = Initial_bigrating()


# Initial twobox grating
grating = TwoBox(grating_pitch, grating_depth, box1_width, box2_width, box_centre_dist, box1_eps, box2_eps, 
                 gaussian_width, substrate_depth, substrate_eps,
                 wavelength, angle, Nx, nG, Qabs)

In [8]:
grating_pitch = grating.grating_pitch
grating_depth = grating.grating_depth
box1_width = grating.box1_width
box2_width = grating.box2_width
box_centre_dist = grating.box_centre_dist
box1_eps = grating.box1_eps
box2_eps = grating.box2_eps
gaussian_width=grating.gaussian_width
substrate_depth=grating.substrate_depth
substrate_eps=grating.substrate_eps        

params = [grating_pitch, grating_depth, box1_width, box2_width, box_centre_dist, box1_eps, box2_eps,
            gaussian_width, substrate_depth, substrate_eps] 

FD_grad(grating,params)

# FOM_uniform(grating,5,goal=0.1,return_grad=True)


TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape complex128[23]
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

In [15]:
import numpy as np
from autograd.extend import primitive, defvjp

@primitive
def unique_filled(x, filled_value):
    """
    Returns a 4-dimensional array with unique values from `x` and the remaining
    filled by `filled_value`.

    Parameters:
    x (np.ndarray): 4-dimensional input array.
    filled_value (float): Value to fill the remaining positions.

    Returns:
    np.ndarray: A 4-dimensional array of the same shape as `x`.
    """
    # if x.ndim != 4:
    #     raise ValueError("Input array x must be 4-dimensional")

    unique_values = np.unique(x)
    result = np.full_like(x, filled_value)
    result.flat[:len(unique_values)] = unique_values

    return result

# Define the vector-Jacobian product (VJP) for backpropagation.
def unique_filled_vjp(ans, x, filled_value):
    def vjp(g):
        grad = np.zeros_like(x)
        unique_values = np.unique(x)
        for val in unique_values:
            grad[x == val] += g.flat[np.flatnonzero(x == val)[0]]
        return grad, None  # No gradient for filled_value

    return vjp

defvjp(unique_filled, unique_filled_vjp)


In [116]:
def unique_filled(x, filled_value):
    """
    Returns a 4-dimensional array with unique values from `x` and the remaining
    filled by `filled_value`.

    Parameters:
    x (np.ndarray): 4-dimensional input array.
    filled_value (float): Value to fill the remaining positions.

    Returns:
    np.ndarray: A 4-dimensional array of the same shape as `x`.
    """
    
    # Sorting ensures differentiability of np.unique
    sorted_x = npa.sort(x.flatten())
    unique_values = sorted_x[np.concatenate(([True], npa.diff(sorted_x) != 0))]

    k=len(unique_values)
    for i in range(4-k):
        unique_values=npa.append(unique_values,filled_value)

    return unique_values

In [125]:
x=np.array( [1,2,4,3] )
unique_filled(x,1)

array([1, 2, 3, 4])

In [123]:
import autograd.numpy as npa
from autograd import grad

x=np.array([1.,1.,1.,1.])
filled_value=2

def GRAD(x,h):
    def diff(x,e,h):
        return ( jax_fun(x+h*e) - jax_fun(x-h*e) ) / (2*h)
    F=np.zeros(4)
    for i in range(4):
        e=np.zeros(4)
        e[i]=1
        F[i]=diff(x,e,h)
    return F

def fun(x):
    return npa.prod(unique_filled(x,filled_value))

print(fun(x))
fun_grad=grad(fun)
print(fun_grad(x))


from jax import numpy as jnp
from jax import grad as jgrad

def jax_fun(x):
    return jnp.prod(jnp.unique(x,size=4,fill_value=filled_value))
print(jax_fun(x))
jax_fun_grad=jgrad(jax_fun)
print(jax_fun_grad(x))

# print(GRAD(x,h))

8.0
[8. 0. 0. 0.]
8.0
[8. 0. 0. 0.]


In [45]:
h=1e-6


array([10.,  5., 10.,  2.])

In [10]:
import autograd.numpy as npa
from autograd import grad

def fun2(x):
    for i in range(4):
        print(i)

fun2(1)

# FD=grad(fun2)
# FD(2.)

0
1
2
3


In [19]:
import numpy as np
a=jnp.array( [-2,2,np.random.rand(),np.random.rand()] )

print(a)
print(jnp.unique(a))


[-2.          2.          0.17277419  0.63298563]
[-2.          0.17277419  0.63298563  2.        ]


In [20]:
# import autograd.numpy as npa
# from autograd import grad
from jax.scipy.special import erf as autograd_erf
# from autograd.numpy import linalg as npaLA