# OPTIONAL practical 3: EnKF in L63

The objective of this practical is to perform state estimation and joined state/parameter estimation in the
Lorenz 1963 system. In this small (3-variable) model, localisation is not required. The practical "EnKF in L96" 
should be completed before this one. 

Let's start with importing all function we're using in this tutorial. 

In [None]:
import dataclasses
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import xarray as xr
from mpl_toolkits.mplot3d import Axes3D
from tools.obs import gen_obs, createH_L63
from tools.diag import rmse_spread
from tools.misc import createTime
from tools.L63_model import lorenz63
from tools.enkf import kfs
from tools.plots import plotL63, plotL63obs, plotL63DA_kf, plotpar, plotRMSP

## Nature run 

This section generates the nature run of the experiment, i.e. what we consider to be
the "truth". The Lorenz-63 model consists of 3 state variables,
$x_{0}(t)$, $x_{1}(t)$ and $x_{2}(t)$. The time-dependence of these variables is given by the differential equations

$$
\begin{eqnarray}
\frac{\partial x_{0}}{\partial t} &=& -\theta_{0} x_{0} + \theta_{0} x_{1} \\
\frac{\partial x_{1}}{\partial t} &=& -\theta_{1} x_{0} x_{2} + \theta_{1} x_{0} - x_{1} \\
\frac{\partial x_{2}}{\partial t} &=& x_{0} x_{1}  - \theta_{2} x_{2} \\
\end{eqnarray}
$$

In the model, you can change the initial conditions $x(0) \overset{def}{=} [x_{0}(0),x_{1}(0),x_{2}(0)]=[-10,-10,20]$, 
the final time $T$ (consider that the model time step is $\Delta t=0.01$ time units). You can also play with the 
parameters $\theta$ of the model to see how the behaviour of the system changes. However,
for the first part of the practical, you should leave this at its "true" value of $\theta = [10, 8/3, 28]$. 
In the following cell, we will set the model, the true initial state and model parameters. 

In [None]:
model = 'L63' # Name of the model to use.
x0 = [-10,-10,20] # True initial conditions X(0)
paramtrue = [10.0,8/3.0,28.0] # True parameters
tmax = 10 # The final time of the nature run simulation
deltat = 0.01
Nx = np.size(x0) # Number of state variables

The following cell will plot the values of the 3 state variables as function of time and the vector
$(x_{0}(t),x_{1}(t),x_{2}(t))$ in 3D phase space for different values of $0 \leq t \leq T$. 

In [None]:
t = createTime(0., tmax, deltat, 0)
xt = lorenz63(x0, tmax, deltat, 0, paramtrue) #Truth
plotL63(t,xt) #Plot state.

Next we create "artificial" observations from the truth we just generated. This is done by applying the 
linear operator $\mathbf{H}$ to the state $x(t)$
$\mathbf{H}$ can take on many shapes, e.g. $\mathbf{H}$ can observe all variables (`obs_grid='xyz'`),
just $x_{0}(t)$ and $x_{1}(t)$ (`obs_grid='xy'`), a single variable $x_{0}(t)$ (`obs_grid='x'`), etc.

The state will not necessarily have to be observed at every time step. You can set the number of time steps between
observations with the variable `period_obs`. As a rule of
thumb, taking and assimilating observations every 8 steps yields a quasi-linear problem, 
whereas assimilating every 25 steps yields a full non-linear problem.
    
Finally, we add some white random noise to our samples of the nature run. This noise represents the 
observational errors from e.g. instrument errors. The observational error covariance matrix $\mathbf{R}$ is 
assumed to be diagonal (common assumption), but you can choose the observational variance (i.e. the values on 
the diagonal of $\mathbf{R}$) with `var_obs`. 

In summary, the observations at time $t$, $y(t)$, are given as $y(t) = \mathbf{H}x(t) + \epsilon(t)$. Here
$\epsilon(t)$ is a random realisation from the normal distribution $\mathcal{N}(0,\mathbf{R})$. The settings
to generate observations together with a plot of them are created in the next cell. 

In [None]:
obs_grid = 'xyz' # options are 'x','y','z','xy','xz','yz','xyz'
period_obs = 10 # number of time steps between observtaions
var_obs = 2 # error variance of the observations
_, H = createH_L63(obs_grid, 3) #Observation operator
seed = 1
tobs, y, R = gen_obs(t, xt, period_obs, H, var_obs, seed, skip0=True) #observations obtained from truth

#Plot the truth together with observations. 
exp_title = 'ob freq:'+str(period_obs)+', density:'+str(obs_grid)+', err var:'+str(var_obs)
plotL63obs(t, xt, tobs, H, y, exp_title)

## State estimation

Since we have only 3 variables in the Lorenz-63 model, so we can easily afford to
run more ensemble members than variables (a luxury we seldom have in real life!). There is not a
well-defined ’physical distance’ between these variables, so the concept of 
localisation does not apply. In this section we will set the observational standard deviation 
$\sigma_{obs}=\sqrt{2}$, and look at the effect of varying the following: the observational frequency `period_obs`, 
the observation density `obs_grid`, the ensemble size `n_ens`. 

### Stochastic ensemble Kalman filter
First, we will try to assimilate the observations using the stochastic ensemble Kalman filter (`da_method='SEnKF'`). 
For reference, in the stochastic ensemble Kalman filter the analysis of the $n$th ensemble member, $x^{a,(n)}$, 
i.e. the ensemble member after the application of DA, is given by 

\begin{equation}
    x^{a,(n)} = x^{b,(n)} + \mathbf{K}(y - \mathbf{H}x^{b,(n)} + \mathbf{R}^{\frac{1}{2}}\epsilon)
\end{equation}

Here $x^{b,(n)}$ is the model state in the $n$th ensemble member at a time $t$ before DA, 
$\epsilon$ a realisation from the normal distribution $\mathcal{N}(0,\mathbf{R})$ and 

\begin{equation}
\mathbf{K} = \mathbf{B} \mathbf{H}^{\rm{T}} (\mathbf{H} \mathbf{B} \mathbf{H}^{\rm{T}} + \mathbf{R})^{-1}
\end{equation}

the Kalman gain matrix with background error covariance
\begin{equation}
\mathbf{B} = \frac{1}{N_{ens}-1} \sum_{n=1}^{N_{ens}} (x^{b,(n)}-\overline{x^{b}}) 
(x^{b,(n)}-\overline{x^{b}})^{\rm{T}}
\end{equation}

and $\overline{x^{b}} = \frac{1}{N_{ens}}\sum_{n=1}^{N_{ens}} x^{b,(n)}$ the forecast ensemble mean.
 
We construct and run all the experiments in this section using the Experiment class that is defined below. 

In [None]:
@dataclasses.dataclass
class Experiment:
    """
    Class that holds all settings for a single Lorenz-63 ensemble Kalman filter experiment. 
    
    Any setting can be overwritten by passing the new setting as key,value-pair to the constructor. 
    
    Methods 
    -------
    create_observations
        Sample observations from truth. 
    run
        Run the DA model and store the output. 
    calculate_metrics
        Calculate performance metrics.
    plot_metrics
        Plot the metrics produced by self.calculate_metrics as function of time. 
    plot_state
        Plot ensemble mean for forecast and analysis ensemble together with truth and observations 
        as function of time. 
        
    Attributes
    ----------
    x0 : np.ndarray 
        Initial model state. Default is true initial model state.
    param : np.ndarray 
        Model parameters. Default is true model parameters. 
    seed : int
        Seed for random number generator used to create observational errors. 
    period_obs : int>=1
        Number of time steps between observations.
    obs_grid: str
        Observation operator to be used. 
    var_obs: float>0
        Observational error variance. 
    n_ens: int>=1
        Number of ensemble members.
    da_method: 'SEnKF' | 'ETKF'
        Ensemble Kalman method to be used. 
    inflation: float
        Ensemble inflation. If inflation=0 no ensemble inflation is applied. 
    alpha: float>=0
        To be discovered. 
    Xb: 3D np.ndarray   
        Ensemble of background states with time along the 0th axis, grid position along the 1st axis and 
        ensemble member along the 2nd axis. 
    xb: 2D np.ndarray 
        Background ensemble mean with time along the 0th axis and grid position along the 1st axis. 
    Xa: 3D np.ndarray   
        Ensemble of analysis states with time along the 0th axis, grid position along the 1st axis and 
        ensemble member along the 2nd axis. 
    xa: 2D np.ndarray 
        Analysis ensemble mean with time along the 0th axis and grid position along the 1st axis. 
    
    """
    
    #Default model settings. 
    dt: float = 0.01 
    x0: np.ndarray = dataclasses.field(default_factory=lambda:np.array(x0))
    param: np.ndarray = dataclasses.field(default_factory=lambda:np.array(paramtrue))
        
    #Default observation operator settings.
    seed: int = 1 
    period_obs: int = 10 
    obs_grid: str = 'xyz' 
    var_obs: float = 2.0 
        
    #Default data assimilation system settings.
    n_ens: int = 10 
    da_method: str = 'SEnKF' 
    inflation: float = 0.01 
    alpha: float = 0.0
        
    def create_observations(self):
        """ Sample observations from truth. """
        _, self.H = createH_L63(self.obs_grid, 3)
        self.tobs, self.y, self.R = gen_obs(t, xt, self.period_obs, self.H, self.var_obs, self.seed, skip0=True)
    
    def run(self):
        """ Run the DA model and store the output. """
        #Create observations
        self.create_observations()
        
        #State background/analysis
        self.Xb, self.xb, self.Xa, self.xa, \
            _, _ = kfs(self.x0, self.param, lorenz63, t, self.tobs,
                       self.y, self.H, self.R, 
                       self.inflation, self.n_ens, self.da_method,
                       back0='fixed', desv=2.0, seed=self.seed)
        
        #Parameter background/analysis 
        n_time = len(self.tobs)
        self.Pa = np.ones((3, self.n_ens, n_time)) * self.param[:, None, None]
        self.pa = np.mean(self.Pa, axis=1)
        
    def calculate_metrics(self, step):
        """ 
        Calculate performance metrics.
        
        Parameters
        ----------
        step : int 
            Number of time steps between states for which metrics will be calculated. 
            
        Returns
        -------
        m : xr.Dataset object
            Dataset containing time series of RMSE and ensemble spread.
        
        """
        m = xr.Dataset(coords = {'DA':(['DA'],['background','analysis']), 
                                 'time':(['time'], t[::step])}
                      )
        
        #Initialise
        m['rmse'] = (['DA','time'], np.zeros((2,len(t[::step]))) )
        m['spread'] = (['DA','time'], np.zeros((2,len(t[::step]))) )
        
        #Background metrics
        m['rmse'][0], m['spread'][0] = rmse_spread(xt, self.xb, self.Xb, step)
        
        #Analysis metrics
        m['rmse'][1], m['spread'][1] = rmse_spread(xt, self.xa, self.Xa, step)

        return m
    
    def plot_metrics(self, step):
        """
        Plot the metrics produced by self.calculate_metrics as function of time. 
        
        Parameters
        ----------
        step : int 
            Number of time steps between states for which metrics will be calculated. 
            
        """
        m = self.calculate_metrics(step)
        plotRMSP(str(self), t, m['rmse'].sel(DA='background').data, m['rmse'].sel(DA='analysis').data,
                 m['spread'].sel(DA='background').data, m['spread'].sel(DA='analysis').data)
        
    def plot_state(self):
        """
        Plot ensemble mean for forecast and analysis ensemble together with truth and observations 
        as function of time. 
        """
        plotL63DA_kf(t, xt, self.tobs, self.H, self.y, self.Xb, self.xb, self.Xa, 
                    self.xa, str(self))
        
    def plot_parameters(self):
         """
         Plot value of the different model parameters. 
         """
         ptrue = np.ones((3, len(self.tobs))) * np.array(paramtrue)[:, None]
         plotpar(len(self.param), self.tobs, ptrue, self.Pa, self.pa)
        
    def __str__(self):
        """ Name of the experiment. 
        
        Returns
        -------
        String with name of experiment. 
        
        """
        return ('ob freq:'+str(self.period_obs)+', density:'+str(self.obs_grid)+
                ', err var:'+str(self.var_obs)+', N_ens='+str(self.n_ens)+', method='+str(self.da_method)+
                ', rho='+str(self.inflation))
    
def time_average(metrics):
    return np.sqrt((metrics**2).mean('time'))

As an example, we rerun the model with a initial condition that deviate from the the initial condition of the 
truth and assimilate every 10 time-steps all variables. We plot the truth (black) together with the output 
just before each DA correction is applied (blue) and just after the DA correction (purple). 
To see the difference between the latter two more distinctly zoom in 
on one of the observations. We also plot RMSE and ensemble spread. 

In [None]:
#Create experiment with a initial state that deviates from the truth. 
exp = Experiment(x0=np.array([-11,-12,15]))

#Run the experiment
exp.run()

#Plot model output
exp.plot_state()

#Plot metrics as function of time. 
exp.plot_metrics(1)

#Calculate and show as table the root-mean-square values over time. 
np.sqrt((exp.calculate_metrics(1)**2).mean(dim=['time'])).to_dataframe()

## Joint state-parameter estimation

DA can not only be used to correct the state with the aid of the observations, but also to correct the model 
parameters. In this section we will start the model using the incorrect parameters $\theta=[6.0,3.0,25.0]$, 
the true initial conditions, a observational error standard deviation of $\sigma_{obs}=1$ and observations of the 
values of $x_{1}$ and $x_{2}$ every 10 time-steps. We encode our experiment setting in the class `ParameterExperiment`
below and run it. 

In [None]:
class ParameterExperiment(Experiment):
    """
    Class representing DA experiments in which both state and model parameters are corrected. 
    """
    
    def run(self):
        #Create observations
        self.create_observations()
        
        #State and parameter  background/analysis
        self.Xb, self.xb, self.Xa, self.xa, \
            self.Pa, self.pa, _, _ = kfs(self.x0, self.param, lorenz63, t, self.tobs, 
                                         self.y, self.H, self.R, self.inflation,
                                         self.n_ens, self.da_method,
                                         back0='fixed', desv=1.0, alpha=self.alpha,
                                         param_estimate=True, seed=self.seed)
        
    def __str__(self):
        return ('ob freq:'+str(self.period_obs)+', density:'+str(self.obs_grid)+
                ', err var:'+str(self.var_obs)+', N_ens='+str(self.n_ens)+', method='+str(self.da_method)+
                ', rho='+str(self.inflation)+', param='+str(self.param))
    
#Create experiment with a initial state that deviates from the truth. 
exp = ParameterExperiment(param=np.array([6.0,3.0,25.0]), var_obs=1.0, obs_grid='yz', alpha=0.1, da_method='ETKF')

#Run the experiment
exp.run()

#Plot model state as function of time.
exp.plot_state()

#Plot estimated parameter value as function of time.
exp.plot_parameters()

The final estimates for $\theta_{1}$ is not much better than the initial guess. 
Let's see what can be done about that. Using `da_method='ETKF'`, `obs_grid='yz'`, `var_obs=1.0` 
run the `ParameterExperiment` with the combination of inputs in the table below. 

In [None]:
#Generate all possible argument combinations.
s = [(param1, inflation1, alpha1) for param1 in [np.array([6.,3.,25.]),np.array([3.,1.,-1.])] for inflation1 
    in np.array([0.01, 0.1]) for alpha1 in np.array([0.,0.1])]

#Put the combinations into a Pandas dataframe. 
settings = pd.DataFrame()
settings['param'] = [setting1[0] for setting1 in s]
settings['inflation'] = [setting1[1] for setting1 in s]
settings['alpha'] = [setting1[2] for setting1 in s]
settings.index.name='exp_no'

#Display combinations.
settings

1. What is the meaning is of the parameter `alpha`?

2. To what extend does the final estimate of the model parameters depend on the initial guess in `param`? 