# Fied Alignment Data Assimilation

We define short code to illustrate the alignment of field in the context of Data Assimilation. We present the alignment phase. In the first part, the case of a gaussian discretize on a regular discretization. In the second part, an implementation for a gaussian discretize with a particle discretization. Finally, the goal would be to apply a force correction and penalize the interpenetration.

In [1]:
from dataclasses import dataclass

import jax.numpy as jnp
from jax import Array, grad, jacfwd, random
import matplotlib.pyplot as plt
import numpy as np

## Field alignment on a regular grid

We firt define the background state $y_b$, the observation $y_{obs}$ of a ground truth, $H$ the operator observation and the regular grid discretization $x$. All those quantities are store in the class `Data`. We store also the evolution of the field in $y_i$ during iterations.

In [None]:
@dataclass
class Data:
    x: np.array
    dx: np.double
    y_b: np.array
    y_i: np.array
    y_obs: np.array
    h: np.array

We define a fonction to apply a displacement `apply_displacement` given a vector $\bm{q}$ of deformation using linear interpolation on the grid $x$.

In [None]:
def apply_displacement(q: Array, x: Array, y: Array) -> Array:
    return jnp.interp(x - q, x, y,period=1.)

grad_apply_displacement = jacfwd(apply_displacement)

We define then the cost function for an incremental of deformation $\delta q$. We add a regularization term to avoid high gradient.

In [None]:
def cost_inc(dq:Array, data:Data, sigma_obs:Array, w:float) -> Array:
    # yq = apply_displacement(q, data.x, data.y_i)
    eta = data.y_obs - data.h @ data.y_i
    G = grad_apply_displacement(np.zeros(len(data.x)), data.x,data.y_i).T
    nabla_dq = jnp.gradient(dq, data.dx)

    return (1 / 2 * (data.h @ G @ dq - eta).T @ (data.h @ G @ dq - eta) / sigma_obs**2
        + w / 2 * nabla_dq.T @ nabla_dq
        + w / 2 * jnp.sum(jnp.trace(jnp.outer(nabla_dq, nabla_dq)))
    )
grad_cost_inc = grad(cost_inc)

We solve the minimization iteratively such that

In [None]:
def solve_alignment_inc(data:Data, sigma_obs:Array, itmax=50) -> Array:
    delta_q = jnp.zeros(len(data.x))
    q = jnp.zeros(len(data.x))
    loss_c = cost_inc(delta_q, data, sigma_obs)
    print("initial cost: ",loss_c)
    it = 1
    result = sp.optimize.minimize(cost_inc,delta_q,(data, sigma_obs), method="BFGS", jac=grad_cost_inc)
    q += result.x
    data.y_i = apply_displacement(result.x, data.x, data.y_i)
    loss = cost_inc(result.x, data, sigma_obs)
    while (it < itmax) or np.isclose(loss_c, loss, rtol = 0.1):
        loss_c = loss
        it += 1
        print(f"{it=}")
        result = sp.optimize.minimize(cost_inc,delta_q,(data, sigma_obs), method="BFGS", jac=grad_cost_inc)
        #mise à jour de y_i
        q += result.x
        loss = cost_inc(result.x, data, sigma_obs)
        print("new cost : ", loss)
        data.y_i = apply_displacement(result.x, data.x, data.y_i)

    return q, data.y_i