## DEnKF
In this project, we avoid the separation of modeling process noise and state transition by using recent insights in stochastic neural networks (SNNs). More specifically, a theoretical link between the Dropout training algorithm and Bayesian inference in deep Gaussian processes. Accordingly, after training a neural network with Dropout, it is possible to generate empirical samples from the predictive posterior via stochastic forward passes.

Hence, for the purposes of filtering, we can implicitly model the process noise by sampling state from a neural network trained on the transition dynamics, i.e., ${\bf x}_{t}  \thicksim  f_{\pmb {\theta}} ({\bf x}_{t-1})$. In contrast to previous approaches, the transition network $f_{\pmb {\theta}}(\cdot)$ models the system dynamics, as well as the inherent noise model in a consistent fashion without imposing diagonality. We formulate DEnKF as an extension of the EnKF while keeping the core algorithmic steps intact. 
In particular, we use an initial ensemble of $E$ members to represent the initial state distribution ${\bf X}_0 = [ {\bf x}^{1}_0, \dots, {\bf x}^{E}_0]$, $E \in \mathbb{Z}^+$.

### 1. Prediction Step
We leverage the stochastic forward passes from a trained state transition model to update each ensemble member: 

\begin{aligned}
    {\bf x}^{i}_{t|t-1} & \thicksim  f_{\pmb {\theta}} ({\bf x}^{i}_{t|t-1}|{\bf x}^{i}_{t-1|t-1}),\  \forall i \in E.
\end{aligned}

Matrix ${\bf X}_{t|t-1} = [{\bf x}^{1}_{t|t-1}, \cdots, {\bf x}^{E}_{t|t-1}]$ holds the updated ensemble members which are propagated one step forward through the state space. Note that sampling from the transition model $f_{\pmb {\theta}}(\cdot)$ (using the SNN methodology described above) implicitly introduces a process noise.

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from bayesian_torch.layers.flipout_layers.linear_flipout import LinearFlipout
import torchvision.models as models
from einops import rearrange, repeat
import numpy as np
import math
import warnings
warnings.filterwarnings('ignore')

class ProcessModel(nn.Module):
    """
    process model takes a state or a stack of states (t-n:t-1) and
    predict the next state t. the process model is flexiable, we can inject the known
    dynamics into it, we can also change the model architecture which takes sequential
    data as input

    input -> [batch_size, num_ensemble, dim_x]
    output ->  [batch_size, num_ensemble, dim_x]
    """

    def __init__(self, num_ensemble, dim_x):
        super(ProcessModel, self).__init__()
        self.num_ensemble = num_ensemble
        self.dim_x = dim_x

        self.bayes1 = LinearFlipout(in_features=self.dim_x, out_features=64)
        self.bayes2 = LinearFlipout(in_features=64, out_features=512)
        self.bayes3 = LinearFlipout(in_features=512, out_features=256)
        self.bayes4 = LinearFlipout(in_features=256, out_features=self.dim_x)

    def forward(self, last_state):
        batch_size = last_state.shape[0]
        last_state = rearrange(
            last_state, "bs k dim -> (bs k) dim", bs=batch_size, k=self.num_ensemble
        )
        x, _ = self.bayes1(last_state)
        x = F.relu(x)
        x, _ = self.bayes2(x)
        x = F.relu(x)
        x, _ = self.bayes3(x)
        x = F.relu(x)
        update, _ = self.bayes4(x)
        state = last_state + update
        state = rearrange(
            state, "(bs k) dim -> bs k dim", bs=batch_size, k=self.num_ensemble
        )
        return state

### 2. Update Step 
Given the updated ensemble members ${\bf X}_{t|t-1}$, a nonlinear observation model $h_{\pmb {\psi}}(\cdot)$ is applied to transform the ensemble members from the state space to observation space. Following our main rationale, the observation model is realized via a neural network with weights $\pmb {\psi}$. Accordingly, the update equations for the EnKF become:
    \begin{align}
    \label{eq:2}
        {\bf H}_t {\bf X}_{t|t-1} &= \left[ h_{\pmb {\psi}}({\bf x}^1_{t|t-1}), \cdots, h_{\pmb {\psi}}({\bf x}^E_{t|t-1}) \right],\\
        \label{eq:3}
        {\bf H}_t {\bf A}_{t} &=  {\bf H}_t {\bf X}_{t|t-1} \\
        &- \left[\frac{1}{E} \sum_{i=1}^E h_{\pmb {\psi}}({\bf x}^i_{t|t-1}),
        \cdots,
        \frac{1}{E} \sum_{i=1}^E h_{\pmb {\psi}}({\bf x}^i_{t|t-1})\right]. \nonumber
    \end{align}
${\bf H}_t {\bf X}_{t|t-1}$ is the predicted observation, and ${\bf H}_t {\bf A}_{t}$ is the sample mean of the predicted observation at $t$.

In [3]:
class ObservationModel(nn.Module):
    """
    observation model takes a predicted state at t-1 and
    predict the corresponding oberservations. typically, the observation is part of the
    state (H as an identity matrix), unless we are using some observations indirectly to
    update the state

    input -> [batch_size, num_ensemble, dim_x]
    output ->  [batch_size, num_ensemble, dim_z]
    """

    def __init__(self, num_ensemble, dim_x, dim_z):
        super(ObservationModel, self).__init__()
        self.num_ensemble = num_ensemble
        self.dim_x = dim_x
        self.dim_z = dim_z

        self.linear1 = torch.nn.Linear(self.dim_x, 64)
        self.linear2 = torch.nn.Linear(64, 128)
        self.linear3 = torch.nn.Linear(128, 128)
        self.linear4 = torch.nn.Linear(128, 64)
        self.linear5 = torch.nn.Linear(64, self.dim_z)

    def forward(self, state):
        batch_size = state.shape[0]
        state = rearrange(
            state, "bs k dim -> (bs k) dim", bs=batch_size, k=self.num_ensemble
        )
        x = self.linear1(state)
        x = F.relu(x)
        x = self.linear2(x)
        x = F.relu(x)
        x = self.linear3(x)
        x = F.relu(x)
        x = self.linear4(x)
        x = F.relu(x)
        z_pred = self.linear5(x)
        z_pred = rearrange(
            z_pred, "(bs k) dim -> bs k dim", bs=batch_size, k=self.num_ensemble
        )
        return z_pred

EnKF treats observations as random variables. Hence, the ensemble can incorporate a measurement perturbed by a small stochastic noise thereby accurately reflecting the error covariance of the best state estimate. 
In our differentiable version of the EnKF, we also incorporate a sensor model which can learn projections between a latent space and higher-dimensional observations spaces, i.e. images. To this end, we leverage the methodology from SNN to train a stochastic sensor model $s_{\pmb {\xi}}(\cdot)$:
\begin{aligned}\label{eq:sensor}
      \tilde{{\bf y}}^{i}_t & \thicksim  s_{\pmb {\xi}} (\tilde{{\bf y}}^{i}_t|{\bf y}_{t}),\  \forall i \in E.\\
\end{aligned}

where ${\bf y}_{t}$ represents the noisy observation. Sampling yields observations $\tilde{{\bf Y}}_t = [\tilde{{\bf y}}^{1}_t, \cdots, \tilde{{\bf y}}^{E}_t]$ and sample mean $\tilde{{\bf y}}_t = \frac{1}{E}\sum_{i=1}^i\tilde{{\bf y}}^i_t$.

In [4]:
class imgSensorModel(nn.Module):
    """
    latent sensor model takes the inputs stacks of images t-n:t-1
    and generate the latent state representations for the transformer
    process model, here we use resnet34 as the basic encoder to project
    down the vision inputs

    images -> [batch, channels, height, width]
    out -> [batch, ensemble, latent_dim_x]
    """

    def __init__(self, num_ensemble, dim_x):
        super(imgSensorModel, self).__init__()
        self.num_ensemble = num_ensemble
        self.dim_x = dim_x
        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(3, 16, kernel_size=5, stride=2, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            torch.nn.Dropout(p=0.1),
        )
        self.layer2 = torch.nn.Sequential(
            torch.nn.Conv2d(16, 32, kernel_size=5, stride=2, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            torch.nn.Dropout(p=0.1),
        )
        self.layer3 = torch.nn.Sequential(
            torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
            torch.nn.Dropout(p=0.1),
        )
        self.linear1 = torch.nn.Linear(64 * 7 * 7, 512)
        self.bayes1 = LinearFlipout(in_features=512, out_features=64)
        self.bayes2 = LinearFlipout(in_features=64, out_features=dim_x)

    def forward(self, images):
        batch_size = images.shape[0]
        x = images
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = x.view(x.size(0), -1)
        x = self.linear1(x)
        x = F.leaky_relu(x)
        x = repeat(x, "bs dim -> bs en dim", en=self.num_ensemble)
        x = rearrange(x, "bs k dim -> (bs k) dim")
        x, _ = self.bayes1(x)
        x = F.leaky_relu(x)
        encoding = x
        obs, _ = self.bayes2(x)
        obs = rearrange(
            obs, "(bs k) dim -> bs k dim", bs=batch_size, k=self.num_ensemble
        )
        obs_z = torch.mean(obs, axis=1)
        obs_z = rearrange(obs_z, "bs (k dim) -> bs k dim", k=1)
        encoding = rearrange(
            encoding, "(bs k) dim -> bs k dim", bs=batch_size, k=self.num_ensemble
        )
        encoding = torch.mean(encoding, axis=1)
        encoding = rearrange(encoding, "(bs k) dim -> bs k dim", bs=batch_size, k=1)

        return obs, obs_z, encoding

The innovation covariance ${\bf S}_t$ can then be calculated as:

\begin{aligned}
        {\bf S}_t &= \frac{1}{E-1}  ({\bf H}_t {\bf A}_t)  ({\bf H}_t {\bf A}_t)^T + r_{\pmb {\zeta}}(\tilde{{\bf y}_t}).
\end{aligned}

where $r_{\pmb {\zeta}}(\cdot)$ is the measurement noise model implemented using MLP. $r_{\pmb {\zeta}}(\cdot)$ takes an learned observation $\tilde{{\bf y}_t}$ in time $t$ and provides stochastic noise in the observation space by constructing the diagonal of the noise covariance matrix. The final estimate of the ensemble ${\bf X}_{t|t}$ can be obtained by performing the measurement update step:

\begin{align}
{\bf A}_t &= {\bf X}_{t|t-1} - \frac{1}{E}\sum_{i=1}^E{\bf x}^i_{t|t-1},\\
{\bf K}_t &= \frac{1}{E-1} {\bf A}_t ({\bf H}_t {\bf A}_t)^T {\bf S}_t^{-1},\\
{\bf X}_{t|t} &= {\bf X}_{t|t-1} + {\bf K}_t (\tilde{{\bf Y}}_t - {\bf H}_t {\bf X}_{t|t-1}),
\end{align}
    
where ${\bf K}_t$ is the Kalman gain. In inference, the ensemble mean ${\bf \bar{x}}_{t|t} = \frac{1}{E}\sum_{i=1}^E {\bf x}^i_{t|t}$ is used as the updated state. 

In [5]:
class ObservationNoise(nn.Module):
    def __init__(self, dim_z, r_diag):
        """
        observation noise model is used to learn the observation noise covariance matrix
        R from the learned observation, kalman filter require a explicit matrix for R
        therefore we construct the diag of R to model the noise here

        input -> [batch_size, 1, encoding/dim_z]
        output -> [batch_size, dim_z, dim_z]
        """
        super(ObservationNoise, self).__init__()
        self.dim_z = dim_z
        self.r_diag = r_diag

        self.fc1 = nn.Linear(64, 32)
        self.fc2 = nn.Linear(32, self.dim_z)

    def forward(self, inputs):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        batch_size = inputs.shape[0]
        constant = np.ones(self.dim_z) * 1e-3
        init = np.sqrt(np.square(self.r_diag) - constant)
        diag = self.fc1(inputs)
        diag = F.relu(diag)
        diag = self.fc2(diag)
        diag = torch.square(diag + torch.Tensor(constant).to(device)) + torch.Tensor(
            init
        ).to(device)
        diag = rearrange(diag, "bs k dim -> (bs k) dim")
        R = torch.diag_embed(diag)
        return R

### 3. DEnKF with all subclass
DEnKF contains four sub-modules: a state transition model, an observation model, an observation noise model, and a sensor model. The entire framework is implemented as a final class which is then used for end-to-end training.
The same implementation can also be found in `PyTorch/model/DEnKF.py`.

In [6]:
class DEnKF(nn.Module):
    def __init__(self, num_ensemble, dim_x, dim_z):
        super(DEnKF, self).__init__()
        self.num_ensemble = num_ensemble
        self.dim_x = dim_x
        self.dim_z = dim_z
        self.r_diag = np.ones((self.dim_z)).astype(np.float32) * 0.1
        self.r_diag = self.r_diag.astype(np.float32)

        # instantiate model
        self.process_model = ProcessModel(self.num_ensemble, self.dim_x)
        self.observation_model = ObservationModel(
            self.num_ensemble, self.dim_x, self.dim_z
        )
        self.observation_noise = ObservationNoise(self.dim_z, self.r_diag)
        self.sensor_model = imgSensorModel(self.num_ensemble, self.dim_z)

    def forward(self, inputs, states):
        # decompose inputs and states
        batch_size = inputs[0].shape[0]
        raw_obs = inputs
        state_old, m_state = states

        ##### prediction step #####
        state_pred = self.process_model(state_old)
        m_A = torch.mean(state_pred, axis=1)
        mean_A = repeat(m_A, "bs dim -> bs k dim", k=self.num_ensemble)
        A = state_pred - mean_A
        A = rearrange(A, "bs k dim -> bs dim k")

        ##### update step #####
        H_X = self.observation_model(state_pred)
        mean = torch.mean(H_X, axis=1)
        H_X_mean = rearrange(mean, "bs (k dim) -> bs k dim", k=1)
        m = repeat(mean, "bs dim -> bs k dim", k=self.num_ensemble)
        H_A = H_X - m
        # transpose operation
        H_XT = rearrange(H_X, "bs k dim -> bs dim k")
        H_AT = rearrange(H_A, "bs k dim -> bs dim k")

        # get learned observation
        ensemble_z, z, encoding = self.sensor_model(raw_obs)
        y = rearrange(ensemble_z, "bs k dim -> bs dim k")
        R = self.observation_noise(encoding)

        # measurement update
        innovation = (1 / (self.num_ensemble - 1)) * torch.matmul(H_AT, H_A) + R
        inv_innovation = torch.linalg.inv(innovation)
        K = (1 / (self.num_ensemble - 1)) * torch.matmul(
            torch.matmul(A, H_A), inv_innovation
        )
        gain = rearrange(torch.matmul(K, y - H_XT), "bs dim k -> bs k dim")
        state_new = state_pred + gain

        # gather output
        m_state_new = torch.mean(state_new, axis=1)
        m_state_new = rearrange(m_state_new, "bs (k dim) -> bs k dim", k=1)
        m_state_pred = rearrange(m_A, "bs (k dim) -> bs k dim", k=1)
        output = (
            state_new.to(dtype=torch.float32),
            m_state_new.to(dtype=torch.float32),
            m_state_pred.to(dtype=torch.float32),
            z.to(dtype=torch.float32),
            ensemble_z.to(dtype=torch.float32),
            H_X_mean.to(dtype=torch.float32),
        )
        return output