# Ensemble Kalman Filters

In [76]:
import autoroot
from pathlib import Path
import numpyro
import numpyro.distributions as dist
from jax import config
config.update("jax_enable_x64", True)
import einx
import jax
import jax.numpy as jnp
import numpy as np
import jax.random as jr
import xarray as xr
from jaxtyping import Float, Array
import cola
from oi_toolz._src.ops.kernels import kernel_rbf, gram
from cola.linalg import Auto
from oi_toolz._src.ops.enskf import analysis_etkf

key = jr.key(123)

import matplotlib.pyplot as plt
import seaborn as sns
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.75)


%matplotlib inline

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Data

In [103]:
save_dir = Path("/Users/eman/code_projects/data/enskf")
Y = np.load(save_dir.joinpath("Y.npy"), allow_pickle=True)
R = np.load(save_dir.joinpath("R.npy"), allow_pickle=True)
Xf = np.load(save_dir.joinpath("Xf.npz"), allow_pickle=True)["arr_0"]
Yf = np.load(save_dir.joinpath("HXf.npy"), allow_pickle=True)

Xf.shape, Yf.shape, Y.shape, R.shape

((55296, 100), (293, 100), (293,), (293,))

In [104]:
def ETKF(Xf, HXf, Y, R):
    """
    Implementation adapted from pseudocode description in
    "State-of-the-art stochastic data assimialation methods" by Vetra-Carvalho et al. (2018),
    algorithm 7, see section 5.4.
    Errors: Calculation of W1 prime, divide by square root of eigenvalues. The mathematical formula in the paper has an error already.
    
    Dimensions: N_e: ensemble size, N_y: Number of observations: N_x: State vector size (Gridboxes x assimilated variables)
    
    Input:
    - Xf:  the prior ensemble (N_x x N_y) 
    - R: Measurement Error (Variance of pseudoproxy timerseries) ($N_y$ x 1$) -> converted to Ny x Ny matrix
    - HX^f: Model value projected into observation space/at proxy locations ($N_y$ x $N_e$)
    - Y: Observation vector ($N_y$ x 1)

    Output:
    - Analysis ensemble (N_x, N_e)
    """
    # number of ensemble members
    Ne=np.shape(Xf)[1]

    #Obs error matrix
    #Rmat=np.diag(R)
    Rmat_inv=np.diag(1/R)
    #Mean of prior ensemble for each gridbox   
    mX = np.mean(Xf, axis=1)
    #Perturbations from ensemble mean
    Xfp=Xf-mX[:,None]
    #Mean and perturbations for model values in observation space
    mY = np.mean(HXf, axis=1)
    HXp = HXf-mY[:,None]
    

    C=Rmat_inv @ HXp
    A1=(Ne-1)*np.identity(Ne)
    A2=A1 + (HXp.T @ C)
    

    #eigenvalue decomposition of A2, A2 is symmetric
    eigs, ev = np.linalg.eigh(A2) 

    #compute perturbations
    Wp1 = np.diag(np.sqrt(1/eigs)) @ ev .T
    Wp = ev @ Wp1 * np.sqrt(Ne-1)


    #differing from pseudocode
    d=Y-mY
    D1 = Rmat_inv @ d
    D2 = HXp.T @ D1
    wm=ev @ np.diag(1/eigs) @ ev.T @ D2  #/ np.sqrt(Ne-1) 

    #adding pert and mean (!row-major formulation in Python!)
    W=Wp + wm[:,None]

    #final adding up (most costly operation)
    Xa=mX[:,None] + Xfp @ W

    return Xa

In [105]:
out1 = ETKF(Xf, Yf, Y, R)
np.mean(out1, axis=1)

array([225.77951468, 225.75510173, 225.75508779, ..., 252.08298075,
       252.08490147, 252.08658989])

In [106]:
out2 = analysis_etkf(Xf, Yf, Y, cola.ops.Diagonal(R))
np.mean(out2, axis=1)

Array([225.77951468, 225.75510173, 225.75508779, ..., 252.08298075,
       252.08490147, 252.08658989], dtype=float64)

In [107]:
np.testing.assert_array_almost_equal(out1, out2)

In [108]:
%%timeit
ETKF(Xf, Yf, Y, R)

33.9 ms ± 4.38 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [109]:
%%timeit
analysis_etkf(Xf, Yf, Y, cola.ops.Diagonal(R))

48.8 ms ± 351 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [97]:
# inverse R
R_inv: Float[Array, "Dy Dy"] = cola.inv(R)

# Mean of prior ensemble for each gridbox 
mu_Xf: Float[Array, "Dx"] = jnp.mean(Xf.to_dense(), axis=0)

# perturbations from ensuemble mean
Xfp: Float[Array, "Ne Dx"] = einx.multiply("Ne Dx, Dx -> Ne Dx", Xf.to_dense(), mu_Xf)

# mean of predictions
mu_Yf: Float[Array, "Dy"] = jnp.mean(Yf.to_dense(), axis=0)

# perturbations from ensemble mean
Yfp: Float[Array, "Ne Dy"] = cola.ops.Dense(einx.multiply("Ne Dy, Dy -> Ne Dy", Yf.to_dense(), mu_Yf))

# conditional covariance
A: Float[Array, "Ne Ne"] = Yfp @ R_inv @ Yfp.T
A += (num_ensembles - 1) * cola.ops.I_like(A)

# eigenvalue decomposition
eigvals, eigvecs = cola.linalg.eig(A, k=num_ensembles, alg=Auto())

# compute perturbation
Wp: Float[Array, "Ne Ne"] = jnp.sqrt(num_ensembles - 1) * eigvecs @ cola.ops.Diagonal(jnp.sqrt(1/eigvals)) @ eigvecs .T

# calculate innovation
innovation: Float[Array, "Ne Ne"] = Y - mu_Yf

# calculate weighted perturbation
Wm = eigvecs @ cola.ops.Diagonal(jnp.sqrt(1/eigvals)) @ eigvecs.T @ Yfp @ R_inv @ innovation
W: Float[Array, "Ne Ne"] = einx.multiply("Ne Nb, Nb -> Ne Nb", Wp.to_dense(), Wm)

# calculate analysis
# X_correction = einx.multiply("N D, N M -> N D", Xfp, W.astype(jnp.float64))
X_correction = jnp.einsum("ND,NM->ND", Xfp, W.real)

X_analysis = mu_Xf + X_correction

X_analysis.shape

(500, 250)

In [98]:
np.testing.assert_array_almost_equal(X_analysis.T, out)

AssertionError: 
Arrays are not almost equal to 6 decimals

Mismatched elements: 125000 / 125000 (100%)
Max absolute difference: 115.16062195
Max relative difference: 2242.9838023
 x: array([[-0.00812 , -0.008037, -0.008131, ..., -0.008075, -0.008288,
        -0.008056],
       [-0.095191, -0.096171, -0.095221, ..., -0.095374, -0.094894,...
 y: array([[  1.655195,  15.632597,   0.224739, ...,  -4.710851,  19.105776,
        -24.890383],
       [ -1.75347 ,  -0.349574,  -0.156994, ...,   1.281412,  -1.318782,...

In [111]:
Xfp.shape, W.shape

((100, 10), (100, 100))

In [108]:
W.shape

(100, 100)

In [82]:
out

Array([ 0.0411188 +0.j,  0.01590232+0.j,  0.03530396+0.j,  0.0186358 +0.j,
        0.03107024+0.j, -0.02780267+0.j,  0.01394665+0.j, -0.0398399 +0.j,
        0.03386795+0.j, -0.02623202+0.j,  0.01253782+0.j,  0.03394068+0.j,
       -0.007868  +0.j,  0.00223607+0.j,  0.01911372+0.j, -0.03309728+0.j,
       -0.00735682+0.j, -0.00864039+0.j,  0.00438894+0.j, -0.02099753+0.j,
       -0.0057542 +0.j, -0.00128632+0.j,  0.03433282+0.j,  0.00253788+0.j,
       -0.03130005+0.j,  0.01149849+0.j, -0.00278476+0.j,  0.01562224+0.j,
        0.03071425+0.j, -0.00507221+0.j, -0.0106824 +0.j,  0.02584798+0.j,
        0.02040042+0.j, -0.04686123+0.j,  0.0355799 +0.j, -0.00279811+0.j,
       -0.03848739+0.j,  0.02042115+0.j,  0.05159569+0.j, -0.00696575+0.j,
        0.0537002 +0.j, -0.00717838+0.j,  0.01022297+0.j,  0.00349617+0.j,
        0.04747664+0.j, -0.04910013+0.j, -0.0296434 +0.j, -0.00660738+0.j,
        0.00197204+0.j, -0.01283245+0.j,  0.00236846+0.j,  0.01343237+0.j,
       -0.01865395+0.j, -