# EnKF Analysis Update:

$$X_t^n = \hat X_t^n + \sqrt{\frac{1+\zeta}{N-1}}(\hat X_t^n - \hat m_t) \tilde Z_t^T [R^{-1}-R^{-1}Z_t(I_{N\times N} + \tilde Z_t^TR^{-1} \tilde Z_t)^{-1}\tilde Z_t^TR^{-1}](Y_t + \gamma_t^n - H\hat X_t^n), \quad \textrm{ where } \tilde Z_t^n = \sqrt{\frac{1+\zeta}{N-1}}(H\hat X_t^n - H \hat m_t)$$

The analysis step is decomposed into part1 and part2 in the code. I distributed the multiplication across the subtraction in the brackets above.

part1 = $\sqrt{\frac{1+ \zeta}{N-1}}(\hat X_{t}^n-\hat m_t)\tilde Z_t R^{-1}(Y_t + \gamma_t^n - H\hat X_t^n)$

part2 = $-\sqrt{\frac{1+ \zeta}{N-1}}(\hat X_{t}^n-\hat m_t)\tilde Z_t R^{-1}Z_t(I_{N\times N} + \tilde Z_t^TR^{-1} \tilde Z_t)^{-1}\tilde Z_t^TR^{-1}(Y_t + \gamma_t^n - H\hat X_t^n)$




Since $\hat X_t \in \mathbb{R}^{N\times d_x}$ and $Y_t \in \mathbb{R}^{N\times d_y}$ in the code, and $\hat X_t \in \mathbb{R}^{d_x \times N}$ and $Y_t \in \mathbb{R}^{d_y\times N}$ in the expression above (I should really change this), we can rewrite the above expressions to account for the swapping of $\hat X_t$ and $Y_t$ dimensions.

Code:

part1 = $\sqrt{\frac{1+ \zeta}{N-1}}(Y_t + \gamma_t^n - H\hat X_t^n)R^{-1}\tilde Z_t(\hat X_{t}^n-\hat m_t) $

part2 = $-\sqrt{\frac{1+ \zeta}{N-1}}(Y_t + \gamma_t^n - H\hat X_t^n)R^{-1}\tilde Z_t(I_{N\times N} + \tilde Z_t^TR^{-1} \tilde Z_t)^{-1}\tilde Z_t^TR^{-1}Z_t(\hat X_{t}^n-\hat m_t)$


Here are what each named variable means in the code, and what its dimension is:

obs_perturb: $Y_t + \gamma^n_t \in \mathbb{R}^{N\times d_y}$

R: $R\in \mathbb{R}^{d_y \times d_y}$

Z: $\tilde Z_t^n =\sqrt{\frac{1+\zeta}{N-1}}(H\hat X_t^n - H \hat m_t) \in \mathbb{R}^{d_y\times N}$

RZ: $R^{-1}\tilde Z_t \in \mathbb{R}^{d_y \times N}$

nxn_inv: $(I_{N\times N} + \tilde Z_t^T R^{-1}Z_t)^{-1} \in \mathbb{R}^{N\times N}$

X_ct: $(\hat X_t^n - \hat m_t) \in \mathbb{R}^{N\times d_x}$

This code assumes that $H=I_{d_y\times d_x}$, $d_y=d_x$, $R=rI_{d_y\times d_y}$, $r\in \mathbb{R}^1$. The equations have a subscript $t$, but we can ignore this here. There is also an additional dimension for "bs" at the beginning of each tensor, but I set this dimension to 1 for all variables (it's an artifact from the original code source), so it's safe to ignore.

In [1]:
import torch
import math
def test_analysis(X, Y, r, N_ensem, zeta, y_perturb, float_type = torch.double):
    # X: initial condition
    # Y: observational data
    # r: observation variance scalar
    # N_ensem: number of ensemble members
    # zeta: covariance inflation constant
    # y_perturb: N_ensem random draws from N(0,R)
    # float_type = torch.double : double precision computations
    
    ########################################
    #               Analysis               #
    ########################################
    device = 'cpu'
    bs = 1
    y_dim = Y.shape[1]

    HX = X # assume H is the identity matrix 

    X_m = X.mean(dim=-2).unsqueeze(-2)
    X_ct = X - X_m
    
    obs_perturb = Y + y_perturb
    
    Z = torch.empty((bs, y_dim, N_ensem), dtype=float_type).to(device=device)
            
    # populate Z for each bs
    for i in range(bs):
        Z[i,:,:] = math.sqrt(1+zeta)/math.sqrt(N_ensem-1) * X_ct.transpose(-1,-2)[i,:,:] # (*bs, y_dim, N_ensem)
            

    RZ =  1/r * Z # this is R^{-1}Z 

    nxn_chol = torch.linalg.cholesky(torch.eye(N_ensem, device = device).expand(bs, N_ensem, N_ensem)  + RZ.transpose(1,2) @ Z)
    nxn_inv = torch.cholesky_inverse(nxn_chol)            

    # Update
    part1 = 1/r*math.sqrt(1+zeta)/math.sqrt(N_ensem-1)*( (obs_perturb-HX)@Z.to(dtype=float_type) )@ X_ct.to(dtype=float_type)
    #                                                           ( N x dy       dy x N  )                     N x dx
    #                                                                    N x N                               N x dx
    #                                                                                     N x dx


    part2 = -math.sqrt(1+zeta)/math.sqrt(N_ensem-1)*( (obs_perturb-HX) @ RZ) @ nxn_inv @ ( RZ.transpose(-1,-2) @Z.to(dtype=float_type) )@ X_ct.to(dtype=float_type)
    #                                                    ( N x dy    dy x N )     N x N      ( N x dy                   dy x N )             N x dx
    #                                                         N x N               N x N                   N x N                              N x dx
    #                                                                                        N x dx
    X = X + part1 + part2   
    return X

# Example setup

In [2]:
torch.manual_seed(20)

N_ensem = 3
dx = dy = 2
r = 0.0001 # small observation noise, which means that the analysis update should return a quantity very close to the original Y
zeta = 0

X0 = torch.Tensor([[1,1], [1,1],[1,1]]).to(dtype = torch.double)
x_perturb = torch.round(torch.normal(0,0.1,size = (N_ensem,dx))*10).to(dtype = torch.double)/10
print(x_perturb)

tensor([[-0.1000,  0.0000],
        [ 0.1000, -0.2000],
        [-0.2000, -0.0000]], dtype=torch.float64)


In [3]:
# latent space ensemble
X = (X0 + x_perturb).unsqueeze(0)
X

tensor([[[0.9000, 1.0000],
         [1.1000, 0.8000],
         [0.8000, 1.0000]]], dtype=torch.float64)

In [4]:
# this is the tensor we want to recover
Y = torch.Tensor([[1,1],[1,1],[1,1]]).to(dtype = torch.double)
Y

tensor([[1., 1.],
        [1., 1.],
        [1., 1.]], dtype=torch.float64)

In [5]:
# perturb the observational data
y_perturb = torch.round(torch.normal(0,math.sqrt(r),size = (N_ensem,dy))*1000).to(dtype = torch.double)/1000
Y + y_perturb

tensor([[0.9790, 0.9950],
        [0.9990, 1.0000],
        [0.9960, 0.9850]], dtype=torch.float64)

# Results from test_analysis function above

In [6]:
results = test_analysis(X, Y, r, N_ensem, zeta, y_perturb, float_type = torch.double)
results

tensor([[[0.9764, 0.9918],
         [0.9937, 0.9919],
         [0.9896, 0.9771]]], dtype=torch.float64)

# Computations based on original formula: dx=dy=2, N=3

$Z = \sqrt{\frac{1+\zeta}{N-1}}(\hat X - \hat m), \quad \hat m = \frac{1}{N_{ensem}}\sum_{n=1}^{N_{ensem}}\hat X^n$

In [7]:
# convert X and Y to transposes to match the analytic expression
X = X.transpose(-1,-2)
Y = Y.transpose(-1,-2)
y_perturb = y_perturb.transpose(-1,-2)

In [8]:
X.shape

torch.Size([1, 2, 3])

In [9]:
Y.shape

torch.Size([2, 3])

In [10]:
Z = math.sqrt(1+zeta)/math.sqrt(N_ensem-1) * (X - X.mean(2, keepdims=True))
Z

tensor([[[-0.0236,  0.1179, -0.0943],
         [ 0.0471, -0.0943,  0.0471]]], dtype=torch.float64)

$\sqrt{\frac{1+\zeta}{N_{ensem} - 1}}(\hat X - \hat m) \tilde Z^T$

In [11]:
A = math.sqrt(1+zeta)/math.sqrt(N_ensem -1) * (X - X.mean(2, keepdims=True)) @ Z.transpose(-1,-2)
A

tensor([[[ 0.0233, -0.0167],
         [-0.0167,  0.0133]]], dtype=torch.float64)

$(I + Z^T R^{-1}Z)^{-1}$

In [12]:
B = torch.linalg.inv(torch.eye(N_ensem, dtype= torch.double) + 1/r* Z.transpose(-1,-2)@ Z)
B

tensor([[[0.3921, 0.3197, 0.2882],
         [0.3197, 0.3380, 0.3422],
         [0.2882, 0.3422, 0.3695]]], dtype=torch.float64)

$Y + \gamma - X$

In [13]:
C = Y + y_perturb - X

$R^{-1} - R^{-1}Z(I+Z^TR^{-1}Z)^{-1}Z^TR^{-1}$

In [15]:
D = 1/r*torch.eye(dy,dtype= torch.double).unsqueeze(0) - 1/r*Z @ B @ Z.transpose(-1,-2)*1/r

In [17]:
(X + A @ D @ C).transpose(-1,-2) # transpose the results to match the results from test_analysis

tensor([[[0.9764, 0.9918],
         [0.9937, 0.9919],
         [0.9896, 0.9771]]], dtype=torch.float64)

Comparing these hand calculated results and the results from the test_analysis code:

In [18]:
results

tensor([[[0.9764, 0.9918],
         [0.9937, 0.9919],
         [0.9896, 0.9771]]], dtype=torch.float64)