# Ensemble Kalman Filter

### Keiran Suchak

One of the issues faced by the traditional Kalman Filter (KF) is that as the number of state variables increases, the process becomes intractable due to the size of the state covariance matrices [1].
We can therefore make use of the Ensemble Kalman Filter (EnKF).
The ENKF is much like the traditional KF; however, instead of simply storing a single model state, it stores an *ensemble* of model states, each of which are evolved under the same model dynamics.

The aim of the ENKF is to make use of the fact that we have an ensemble of model states to then approximate the state covariance matrix.
Once again, this is achieved using a two-step process of

1. Forecast
2. Update

where the forecast is given by
$$ X_{t+1}^f = M_t X_t $$

and the update is given by
$$X_{t+1} = X_{t+1}^f + K_t (D_t - H_t X_t)^{-1} $$

## Notation

* State vector size: $n$

* Ensemble size: $N$

* Observation vector size: $m$

* Ensemble matrix ($n \times N$):
$$ X = (x_1 , ... , x_N) $$
for an ensemble of $N$ state vectors; here, each state vector $x_i \in X$ is a vector of size $n$ containing a single model state
$$ x_i = (x_{i1} , ... , x_{in})$$

* Ensemble mean ($n$):
$$ \bar{x} = \frac{1}{N} \sum_{i=1}^N x_i $$

* Perturbation from mean ($n$):
$$ x'_i = x_i - \bar{x} $$

* Ensemble of perturbations ($n \times N$):
$$ X' = (x'_1 , ... , x'_N) $$

* State covariance matrix: $P$

* Estimate of $P$ from finite ensemble:
$$ \hat{P} = \frac{1}{N-1} X' X'^T $$

* Observation matrix: $H$

* Data covariance matrix: $R$

* Kalman gain matrix:
$$ K = \hat{P} H^T ( H \hat{P} H^T + R )^{-1} $$

* Matrix of perturbed data ($m \times N$):
$$ D = (d_1 , ... , d_N) $$
where each $d_i$ consists of a replicate of the original observations, $d$, plus a normally distributed random vector:
$$ d_i = d + V_i $$
$$ V_i \sim \mathcal{N} (0, R) $$
$\forall i \in (1 , ... , N)$.

## General EnKF

In [1]:
# Imports
from numpy import array, identity, matmul, sum
from numpy.linalg import inv
from numpy.random import multivariate_normal

In [2]:
# Class
class EnKF(object):
    """
    Base implementation of an Ensemble Kalman Filter (EnKF).
    Contains an empty forecast method to be overloaded for specific implementations.
    
    Parameters
    ----------
    XInit : numpy.array(n x 1)
            Initial state vector
    
    H : numpy.array()
        Observation matrix
    
    R : numpy.array()
        Data covariance matrix
    
    ensembleSize : int
                   Ensemble size
    
    Attributes
    ----------
    X : numpy.array(n x N)
        Ensemble of state vectors
    
    H : numpy.array()
        Observation matrix
    
    Ht : numpy.array()
         Transpose of observation matrix
    
    N : int
        Ensemble size
    
    XBar : numpy.array(n x 1)
           Mean state vector
    
    C : numpy.array(n x n)
        State covariance matrix    
    """
    
    def __init__(self, XInit=None, H=None, R=None, ensembleSize=10):
        self.N = ensembleSize
        self.H = H
        self.Ht = self.H.T
        self.XBar = XInit
        self.C = identity(len(self.XBar))
        self.R = R
        self.X = self.makeStateEnsemble()
        self.forecastStates = []
        self.states = []
    
    # Main fuctions    
    def step(self, observations=None):
        """
        Step the model state forward one timestep.
        Includes updating of state with data if observations are provided.
        """
        if observations:
            self.update(observations)
            self.XBar = self.__makeStateMean()
        self.states.append(self.XBar)
        self.forecast()
        self.XBar = self.__makeStateMean()
        self.forecastStates.append(self.XBar)
    
    def forecast(self):
        """
        Empty forecast method.
        This should be overloaded for specific implementations.
        """
    
    def update(self, observations):
        """
        Analysis step to assimilate observations.
        
        Parameters
        ----------
        observations : numpy.array()
                       Vector of observations        
        """
        
        # Observation checking
        
        # Update
        K = self.__makeGainMatrix()
        D = self.__makePerturbedData(observations)
        ## Update each member of the ensemble
        for i in range(self.N):
            a = D[i] - H * self.X[i]
            self.X[i] = self.X[i] + K * a

    # Auxilliary Functions
    def __makeGainMatrix(self):
        """
        Calculate gain matrix:
            K = P H^T (H P H^T + R )^-1
        
        Returns
        -------
        Kalman gain matrix
        """
        
        self.C = self.__makeStateCovarianceMatrix()
        x = self.H * self.C * self.Ht
        y = inv(x + self.R)
        z = self.C * self.Ht * y
        return z
    
    def __makeStateMean(self):
        """
        Calculate mean state vector from ensemble of state vectors.
        
        Returns
        -------
        Mean state vector
        """
        
        return 1/self.N * sum(self.X, axis=0)
    
    def __makeStatePerturbations(self):
        """
        Calculate the perturbation of each member of the ensemble relative to the ensemble mean.
        
        Returns
        -------
        State perturbations
        """
        
        return self.X - self.XBar
    
    def __makeStateCovarianceMatrix(self):
        """
        Calculate state covariance matrix
        """
        
        A = self.__makeStatePerturbations()
        total = array()
        for i in range(self.N):
            total += A[i] * A[i].T
        return 1/self.N * total
    
    def __makePerturbedData(self, observations):
        """
        Create a matrix of data perturbed with normally distributed random noise.
        
        Parameters
        ----------
        observations : numpy.array(m)
        
        Returns
        -------
        Matrix of perturbed data
        """
        
        initialData = array([observations for _ in range(N)])
        noise = multivariate_normal(0, self.R, self.N)
        return initialData + noise
    
    def __makeStateEnsemble(self):
        """
        Create initial matrix of state ensemble by adding normally distributed random noise to copies of the initial state.
        
        Returns
        -------
        Intial matrix of state ensemble
        """
        
        return multivariate_normal(self.XBar, self.C, self.N)
        

## Toy Model

## Test Implementation