# Data assimilation applied to a shallow water model

## Practical sessions

Alban Farchi, CEREA, [alban.farchi@enpc.fr](mailto:alban.farchi@enpc.fr)

During these sessions, you will apply two classical data assimilation methods to a shallow water model. The objective for you is to better understand these methods, figure out their practical implementations and identify their key parameters.

## The shallow water model

A shallow water model is well adapted to flows below a free surface when the depth is much smaller than the horizontal dimensions. It is commonly used to represent a lake or river. Its equations describe the time evolution of the water height $h(x)$ and the horizontal velocity $u(x)$ in a fixed-length domain.

A simplified unidimensional version reads
$$
    \frac{\partial h}{\partial t} + \frac{\partial(hu)}{\partial x} = 0,\\
    \frac{\partial(hu)}{\partial t} + \frac{\partial(hu^2)}{\partial x} + gh\frac{\partial h}{\partial x} = 0.
$$

Several boundary conditions may be defined. In our case, we rely on the following conditions:
1. on the left, a constant inflow $Q=hu$;
2. on the right, a homogeneous Neumann condition for $h$ and $u$, where fluxes are consequently determined by the state of the system along the boundary.

The equations are discretized and numerically solved using schemes detailed by, e.g., [Honnorat, 2007](https://tel.archives-ouvertes.fr/tel-00273318).

## The truth simulation

In this series of experiments, we will use twin simulations.
1. We run a reference simulation. The result is considered to be the **true situation** and is called the **truth**.
2. From the truth we extract **synthetic observations**.
3. Using the observations only, we try to reconstruct the truth using a dedicated **data assimilation** algorithm.

Let us start with the truth. The simulation domain is discretised using `Nx=101` grid points. At the initial time, the horizontal velocity $u(x)$ is null, and the water height $h(x)$ is a crenel: $h(x)=1$ everywhere but in the center of the domain, where $h(x)=1.05$. The simulation is run for `Nt=500` time iterations. The other model parameters are `dx=1` (the horizontal step), `dt=0.03` (the time step), `Q=0.1` (the constant inflow on the left), and `g=9.81` (the acceleration due to gravity).

In [None]:
# import standard modules
import numpy as np
import scipy
from matplotlib import pyplot as plt
from matplotlib import animation
import seaborn as sns
from tqdm import trange
from IPython.display import HTML

# for plot customisation
sns.set_context('notebook')
sns.set_style('darkgrid')
plt.rc('axes', linewidth=1)
plt.rc('axes', edgecolor='k')
plt.rc('figure', dpi=300)
palette = sns.color_palette('deep')

# import custom shallow water model
from shallow_water_model import ShallowWaterModel

In [None]:
# create the model
sw_model = ShallowWaterModel(Nx=101, dx=1, dt=0.03, Q=0.1, g=9.81)

# create a driver
def forecast_driver(model, state, Nt):
    """Run a simulation of Nt time steps of the given model starting from state.
    Returns the trajectory of h and u.
    """
    # allocate memory for h and u
    traj_h = np.empty((Nt+1, model.Nx))
    traj_u = np.empty((Nt+1, model.Nx))
    
    # initialise h and u
    traj_h[0] = state.h
    traj_u[0] = state.u
    
    # run the Nt time steps
    for t in trange(Nt, desc='running forward model'):
        model.forward(state)
        traj_h[t+1] = state.h
        traj_u[t+1] = state.u
        
    # return h and u
    return (traj_h, traj_u)

In [None]:
# initialise and run the truth simulation
Nt = 500
state = sw_model.new_state_crenel(h_anom=1.05)
truth_h, truth_u = forecast_driver(sw_model, state, Nt)

In [None]:
# make a fancy animation for water height
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
ax.set_xlim(0, 100)
ax.set_ylim(0.99, 1.08)
ax.set_xlabel('Domain')
ax.set_ylabel('Water height')
ax.set_title('Truth simulation')
line, = ax.plot([], [], c=palette[0])
x = np.arange(sw_model.Nx)
def animate(t):
    line.set_data(x, truth_h[t])
    return (line,)
freq = 10
anim = animation.FuncAnimation(fig, animate, frames=range(0, Nt+1, freq), interval=75, blit=True)
plt.close(fig)
HTML(anim.to_jshtml())

In [None]:
# make a fancy animation for horizontal velocity
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
ax.set_xlim(0, 100)
ax.set_ylim(-0.2, 0.2)
ax.set_xlabel('Domain')
ax.set_ylabel('Horizontal velocity')
ax.set_title('Truth simulation')
line, = ax.plot([], [], c=palette[0])
x = np.arange(sw_model.Nx)
def animate(t):
    line.set_data(x, truth_u[t])
    return (line,)
freq = 10
anim = animation.FuncAnimation(fig, animate, frames=range(0, Nt+1, freq), interval=75, blit=True)
plt.close(fig)
HTML(anim.to_jshtml())

## The observations

At each time step, `Ny=3` observations are available: the water height values at $x=79$, $x=80$, and $x=81$.

In [None]:
# observation function
def apply_observation_operator(h):
    """Apply the observation operator to the vector h."""
    # TODO: implement it! 
    return y

# extract the observations from the truth
Ny = 3
observations = np.empty((Nt+1, Ny))
for t in range(Nt+1):
    observations[t] = apply_observation_operator(truth_h)

In [None]:
# plot the time series of observations
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
ax.set_xlim(0, 15)
ax.set_ylim(0.99, 1.08)
ax.set_xlabel('Time')
ax.set_ylabel('Water height')
ax.set_title('Observations')
time = sw_model.dt * np.arange(Nt+1)
ax.plot(time, observations[:, 0], c=palette[0], label='$h(x=79)$')
ax.plot(time, observations[:, 1], c=palette[2], label='$h(x=80)$')
ax.plot(time, observations[:, 2], c=palette[3], label='$h(x=81)$')
plt.legend()
plt.show()

## Simulation without assimilation

Let us first run a simulation without data assimilation. For this perturbed simulation, we use a different initial condition: the water height is $h(x)=1$ everywhere.

In [None]:
# initialise and run the perturbed simulation
state = sw_model.new_state_crenel(h_anom=1)
perturbed_h, perturbed_u = forecast_driver(sw_model, state, Nt)

In [None]:
# make a fancy animation for water height
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
ax.set_xlim(0, 100)
ax.set_ylim(0.99, 1.08)
ax.set_xlabel('Domain')
ax.set_ylabel('Water height')
ax.set_title('Truth simulation')
line, = ax.plot([], [], c=palette[0])
x = np.arange(sw_model.Nx)
def animate(t):
    line.set_data(x, perturbed_h[t])
    return (line,)
freq = 10
anim = animation.FuncAnimation(fig, animate, frames=range(0, Nt+1, freq), interval=75, blit=True)
plt.close(fig)
HTML(anim.to_jshtml())

In [None]:
# compute MAE/RMSE for h
perturbed_mae = abs(perturbed_h-truth_h).mean(axis=1)
perturbed_rmse = np.sqrt(((perturbed_h-truth_h)**2).mean(axis=1))

In [None]:
# plot the time series of MAE
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
ax.set_xlim(0, 15)
ax.set_ylim(0, 0.01)
ax.set_xlabel('Time')
ax.set_ylabel('MAE')
ax.set_title('Error in water height')
time = sw_model.dt * np.arange(Nt+1)
ax.plot(time, perturbed_mae, c=palette[1], label='without assimilation')
plt.legend()
plt.show()

In [None]:
# plot the time series of RMSE
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
ax.set_xlim(0, 15)
ax.set_ylim(0, 0.02)
ax.set_xlabel('Time')
ax.set_ylabel('RMSE')
ax.set_title('Error in water height')
time = sw_model.dt * np.arange(Nt+1)
ax.plot(time, perturbed_rmse, c=palette[1], label='without assimilation')
plt.legend()
plt.show()

## Simulation with optimal interpolation

In [None]:
# background error covariance matrix
B = np.identity(sw_model.Nx)

# observation error covariance matrix
R = np.identity(Ny)

# observation operator
H = np.zeros((Ny, sw_model.Nx))
H[:, 79:82] = np.identity(3)

# BLUE analysis
def compute_analysis_blue(hb, B, y, R, H):
    """Compute the BLUE analysis."""
    # TODO: implement it!
    return ha

# create a driver
def blue_driver(model, state, Nt, observations, B, R, H):
    """Run a simulation of Nt time steps of the given model starting from state.
    At each time step, an analysis for h is performed using `compute_analysis_blue`.
    Returns the trajectory of h (forecast and analysis) and u.
    """
    # allocate memory for h and u
    traj_h_forecast = np.empty((Nt+1, model.Nx))
    traj_h_analysis = np.empty((Nt+1, model.Nx))
    traj_u = np.empty((Nt+1, model.Nx))
    
    # initialise h and u
    traj_u[0] = state.u
    traj_h_forecast[0] = state.h

    # run first analysis
    state.h = compute_analysis_blue(state.h, B, observations[0], R, H)
    traj_h_analysis[0] = state.h
    
    # run the Nt time steps
    for t in trange(Nt, desc='running BLUE'):
        
        # forecast
        model.forward(state)
        traj_h_forecast[t+1] = state.h        
        traj_u[t+1] = state.u
        
        # analysis
        state.h = compute_analysis_blue(state.h, B, observations[t+1], R, H)
        traj_h_analysis[t+1] = state.h
        
    # return h and u
    return (traj_h_forecast, traj_h_analysis, traj_u)

In [None]:
# initialise and run the BLUE simulation
state = sw_model.new_state_crenel(h_anom=1)
blue_h_forecast, blue_h_analysis, blue_u = blue_driver(sw_model, state, Nt, observations, B, R, H)

In [None]:
# make a fancy animation for water height
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
ax.set_xlim(0, 100)
ax.set_ylim(0.99, 1.08)
ax.set_xlabel('Domain')
ax.set_ylabel('Water height')
ax.set_title('Truth vs BLUE simulation')
truth_line, = ax.plot([], [], c=palette[0], label='truth')
blue_line, = ax.plot([], [], c=palette[2], label='BLUE analysis')
plt.legend()
x = np.arange(sw_model.Nx)
def animate(t):
    truth_line.set_data(x, truth_h[t])
    blue_line.set_data(x, blue_h_analysis[t])
    return (truth_line, blue_line)
freq = 10
anim = animation.FuncAnimation(fig, animate, frames=range(0, Nt+1, freq), interval=75, blit=True)
plt.close(fig)
HTML(anim.to_jshtml())

In [None]:
# make a fancy animation for horizontal velocity
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
ax.set_xlim(0, 100)
ax.set_ylim(-0.2, 0.2)
ax.set_xlabel('Domain')
ax.set_ylabel('Horizontal velocity')
ax.set_title('Truth vs BLUE simulation')
truth_line, = ax.plot([], [], c=palette[0], label='truth')
blue_line, = ax.plot([], [], c=palette[2], label='BLUE analysis')
plt.legend()
x = np.arange(sw_model.Nx)
def animate(t):
    truth_line.set_data(x, truth_u[t])
    blue_line.set_data(x, blue_u[t])
    return (truth_line, blue_line)
freq = 10
anim = animation.FuncAnimation(fig, animate, frames=range(0, Nt+1, freq), interval=75, blit=True)
plt.close(fig)
HTML(anim.to_jshtml())

In [None]:
# make a fancy animation for water height (zoom)
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
ax.set_xlim(77, 83)
ax.set_ylim(1, 1.02)
ax.set_xlabel('Domain')
ax.set_ylabel('Water height')
ax.set_title('Truth vs BLUE simulation')
truth_line, = ax.plot([], [], c=palette[0], label='truth')
blue_line_a, = ax.plot([], [], c=palette[2], label='BLUE analysis')
blue_line_f, = ax.plot([], [], c=palette[3], label='BLUE forecast')
plt.legend()
x = np.arange(sw_model.Nx)
def animate(t):
    truth_line.set_data(x, truth_h[t])
    blue_line_a.set_data(x, blue_h_analysis[t])
    blue_line_f.set_data(x, blue_h_forecast[t])
    return (truth_line, blue_line_a, blue_line_f)
freq = 10
anim = animation.FuncAnimation(fig, animate, frames=range(0, Nt+1, freq), interval=75, blit=True)
plt.close(fig)
HTML(anim.to_jshtml())

In [None]:
# compute MAE/RMSE for h
blue_analysis_mae = abs(blue_h_analysis-truth_h).mean(axis=1)
blue_analysis_rmse = np.sqrt(((blue_h_analysis-truth_h)**2).mean(axis=1))

In [None]:
# plot the time series of MAE
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
ax.set_xlim(0, 15)
ax.set_ylim(0, 0.01)
ax.set_xlabel('Time')
ax.set_ylabel('MAE')
ax.set_title('Error in water height')
time = sw_model.dt * np.arange(Nt+1)
ax.plot(time, perturbed_mae, c=palette[1], label='without assimilation')
ax.plot(time, blue_analysis_mae, c=palette[2], label='BLUE analysis')
plt.legend()
plt.show()

In [None]:
# plot the time series of RMSE
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
ax.set_xlim(0, 15)
ax.set_ylim(0, 0.02)
ax.set_xlabel('Time')
ax.set_ylabel('RMSE')
ax.set_title('Error in water height')
time = sw_model.dt * np.arange(Nt+1)
ax.plot(time, perturbed_rmse, c=palette[1], label='without assimilation')
ax.plot(time, blue_analysis_rmse, c=palette[2], label='BLUE analysis')
plt.legend()
plt.show()

## Simulation with ensemble Kalman filter

In [None]:
# observation error covariance matrix
R = 1e-4 * np.identity(Ny)

# observation operator
H = np.zeros((Ny, sw_model.Nx))
H[:, 79:82] = np.identity(3)

# covariance
def compute_covariance(E):
    """Compute the sample covariance matrix of ensemble E."""
    # TODO: implement it!
    return B

# EnKF analysis
def compute_analysis_enkf(Ef, y, R, H):
    """Compute the EnKF analysis."""
    Ne = Ef.shape[0]
    B = compute_covariance(Ef)
    Ea = np.zeros(Ef.shape)
    for i in range(Ne):
        Ea[i] = compute_analysis_blue(Ef[i], B, y, R, H)
    return (Ea, B)

# create a driver
def enkf_driver(model, ensemble, Nt, observations, R, H):
    """Run a simulation of Nt time steps of the given model starting from the ensemble.
    At each time step, an analysis for h is performed using `compute_analysis_enkf`.
    Returns the trajectory of h (forecast and analysis) and u.
    """
    # allocate memory for h and u
    traj_h_forecast = np.empty((Nt+1, ensemble.Ne, model.Nx))
    traj_h_analysis = np.empty((Nt+1, ensemble.Ne, model.Nx))
    traj_u = np.empty((Nt+1, ensemble.Ne, model.Nx))
    
    # allocate memory for B
    traj_B = np.empty((Nt+1, model.Nx, model.Nx))
    
    # initialise h and u
    traj_u[0] = ensemble.u
    traj_h_forecast[0] = ensemble.h

    # run first analysis
    ensemble.h, B = compute_analysis_enkf(ensemble.h, observations[0], R, H)
    traj_h_analysis[0] = ensemble.h
    traj_B[0] = B
    
    # run the Nt time steps
    for t in trange(Nt, desc='running EnKF'):
        
        # forecast
        model.forward_ensemble(ensemble)
        traj_h_forecast[t+1] = ensemble.h        
        traj_u[t+1] = ensemble.u
        
        # analysis
        ensemble.h, B = compute_analysis_enkf(ensemble.h, observations[t+1], R, H)
        traj_h_analysis[t+1] = ensemble.h
        traj_B[t+1] = B
        
    # return h and u
    return (traj_h_forecast, traj_h_analysis, traj_u, traj_B)

In [None]:
# initialise and run the EnKF simulation
ensemble = sw_model.new_ensemble_crenel(Ne=25, mean_h_anom=1, std_h_anom=0.02)
enkf_h_forecast, enkf_h_analysis, enkf_u, enkf_B = enkf_driver(sw_model, ensemble, Nt, observations, R, H)

In [None]:
# make a fancy animation for water height
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
ax.set_xlim(0, 100)
ax.set_ylim(0.99, 1.08)
ax.set_xlabel('Domain')
ax.set_ylabel('Water height')
ax.set_title('Truth vs EnKF simulation')
truth_line, = ax.plot([], [], c=palette[0], label='truth')
enkf_line, = ax.plot([], [], c=palette[2], label='EnKF analysis')
plt.legend()
x = np.arange(sw_model.Nx)
def animate(t):
    truth_line.set_data(x, truth_h[t])
    enkf_line.set_data(x, enkf_h_analysis[t].mean(axis=0))
    return (truth_line, enkf_line)
freq = 10
anim = animation.FuncAnimation(fig, animate, frames=range(0, Nt+1, freq), interval=75, blit=True)
plt.close(fig)
HTML(anim.to_jshtml())

In [None]:
# make a fancy animation for water height (with ensemble)
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
ax.set_xlim(0, 100)
ax.set_ylim(0.99, 1.08)
ax.set_xlabel('Domain')
ax.set_ylabel('Water height')
ax.set_title('Truth vs EnKF simulation')
enkf_lines = []
for i in range(enkf_h_analysis.shape[1]):
    mem_line, = ax.plot([], [], c=palette[3], lw=0.25)
    enkf_lines.append(mem_line)
truth_line, = ax.plot([], [], c=palette[0], label='truth')
enkf_line, = ax.plot([], [], c=palette[2], label='EnKF analysis')
plt.legend()
x = np.arange(sw_model.Nx)
def animate(t):
    truth_line.set_data(x, truth_h[t])
    enkf_line.set_data(x, enkf_h_analysis[t].mean(axis=0))
    for i in range(enkf_h_analysis[t].shape[0]):
        enkf_lines[i].set_data(x, enkf_h_analysis[t, i])
    return tuple(enkf_lines+[truth_line, enkf_line])
freq = 20
anim = animation.FuncAnimation(fig, animate, init_func=init, frames=range(0, Nt+1, freq), interval=75, blit=True)
plt.close(fig)
HTML(anim.to_jshtml())

In [None]:
# make a fancy animation for horizontal velocity
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
ax.set_xlim(0, 100)
ax.set_ylim(-0.2, 0.2)
ax.set_xlabel('Domain')
ax.set_ylabel('Horizontal velocity')
ax.set_title('Truth vs EnKF simulation')
truth_line, = ax.plot([], [], c=palette[0], label='truth')
enkf_line, = ax.plot([], [], c=palette[2], label='EnKF analysis')
plt.legend()
x = np.arange(sw_model.Nx)
def animate(t):
    truth_line.set_data(x, truth_u[t])
    enkf_line.set_data(x, enkf_u[t].mean(axis=0))
    return (truth_line, enkf_line)
freq = 10
anim = animation.FuncAnimation(fig, animate, frames=range(0, Nt+1, freq), interval=75, blit=True)
plt.close(fig)
HTML(anim.to_jshtml())

In [None]:
# compute the largest eigenvalue of B
largest_eigval = np.zeros(Nt+1)
for t in range(Nt+1):
    largest_eigval[t] = abs(scipy.linalg.eigvals(enkf_B[t])).max()

In [None]:
# plot the time evolution of the largest eigenvalue
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
ax.set_yscale('log')
ax.set_xlim(0, 15)
ax.set_ylim(1e-5, 1)
ax.set_xlabel('Time')
ax.set_ylabel('Module of largest B eigenvalue')
time = sw_model.dt * np.arange(Nt+1)
ax.plot(time, largest_eigval, c=palette[0])
plt.show()

In [None]:
# compute MAE/RMSE for h
enkf_analysis_mae = abs(enkf_h_analysis.mean(axis=1)-truth_h).mean(axis=1)
enkf_analysis_rmse = np.sqrt(((enkf_h_analysis.mean(axis=1)-truth_h)**2).mean(axis=1))

In [None]:
# plot the time series of MAE
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
ax.set_xlim(0, 15)
ax.set_ylim(0, 0.01)
ax.set_xlabel('Time')
ax.set_ylabel('MAE')
ax.set_title('Error in water height')
time = sw_model.dt * np.arange(Nt+1)
ax.plot(time, perturbed_mae, c=palette[1], label='without assimilation')
ax.plot(time, blue_analysis_mae, c=palette[2], label='BLUE analysis')
ax.plot(time, enkf_analysis_mae, c=palette[3], label='EnKF analysis')
plt.legend()
plt.show()

In [None]:
# plot the time series of RMSE
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
ax.set_xlim(0, 15)
ax.set_ylim(0, 0.02)
ax.set_xlabel('Time')
ax.set_ylabel('RMSE')
ax.set_title('Error in water height')
time = sw_model.dt * np.arange(Nt+1)
ax.plot(time, perturbed_rmse, c=palette[1], label='without assimilation')
ax.plot(time, blue_analysis_rmse, c=palette[2], label='BLUE analysis')
ax.plot(time, enkf_analysis_rmse, c=palette[3], label='EnKF analysis')
plt.legend()
plt.show()

In [None]:
# make a fancy animation for B
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
ax.grid(False)
ax.set_title('B matrix')
vmax = 5e-5
im = ax.imshow(enkf_B[0], origin='lower', extent=[0, sw_model.Nx, 0, sw_model.Nx], 
               vmin=-vmax, vmax=vmax, cmap='RdBu')
plt.colorbar(im)
def animate(t):
    im.set_array(enkf_B[t])
    return (im,)
freq = 10
anim = animation.FuncAnimation(fig, animate, frames=range(0, Nt+1, freq), interval=75, blit=True)
plt.close(fig)
HTML(anim.to_jshtml())

## Assimilation with observation errors

In [None]:
# extract the observations from the truth
Ny = 3
std_obs = 0.01
pert_observations = np.empty((Nt+1, Ny))
for t in range(Nt+1):
    pert_observations[t] = apply_observation_operator(truth_h) + std_obs * np.random.randn(3)

In [None]:
# plot the time series of observations
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
ax.set_xlim(0, 15)
ax.set_ylim(0.99, 1.08)
ax.set_xlabel('Time')
ax.set_ylabel('Water height')
ax.set_title('Observations')
time = sw_model.dt * np.arange(Nt+1)
ax.plot(time, pert_observations[:, 0], c=palette[0], label='$h(x=79)$')
ax.plot(time, pert_observations[:, 1], c=palette[2], label='$h(x=80)$')
ax.plot(time, pert_observations[:, 2], c=palette[3], label='$h(x=81)$')
plt.legend()
plt.show()

In [None]:
# background error covariance matrix
B = np.identity(sw_model.Nx)

# observation error covariance matrix
R = np.identity(Ny)

# observation operator
H = np.zeros((Ny, sw_model.Nx))
H[:, 79:82] = np.identity(3)

# initialise and run the BLUE simulation
state = sw_model.new_state_crenel(h_anom=1)
pert_blue_h_forecast, pert_blue_h_analysis, pert_blue_u = blue_driver(sw_model, state, Nt, pert_observations, B, R, H)

In [None]:
# observation error covariance matrix
R = 1e-4 * np.identity(Ny)

# observation operator
H = np.zeros((Ny, sw_model.Nx))
H[:, 79:82] = np.identity(3)

# initialise and run the EnKF simulation
ensemble = sw_model.new_ensemble_crenel(Ne=25, mean_h_anom=1, std_h_anom=0.02)
pert_enkf_h_forecast, pert_enkf_h_analysis, pert_enkf_u, pert_enkf_B = enkf_driver(sw_model, ensemble, Nt, pert_observations, R, H)

In [None]:
# make a fancy animation for water height
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
ax.set_xlim(0, 100)
ax.set_ylim(0.99, 1.08)
ax.set_xlabel('Domain')
ax.set_ylabel('Water height')
ax.set_title('Truth vs BLUE vs EnKF simulation')
truth_line, = ax.plot([], [], c=palette[0], label='truth')
blue_line, = ax.plot([], [], c=palette[1], label='BLUE analysis')
enkf_line, = ax.plot([], [], c=palette[2], label='EnKF analysis')
plt.legend()
x = np.arange(sw_model.Nx)
def animate(t):
    truth_line.set_data(x, truth_h[t])
    blue_line.set_data(x, pert_blue_h_analysis[t])
    enkf_line.set_data(x, pert_enkf_h_analysis[t].mean(axis=0))
    return (truth_line, blue_line, enkf_line)
freq = 10
anim = animation.FuncAnimation(fig, animate, frames=range(0, Nt+1, freq), interval=75, blit=True)
plt.close(fig)
HTML(anim.to_jshtml())

In [None]:
# compute MAE/RMSE for h
pert_blue_analysis_mae = abs(pert_blue_h_analysis-truth_h).mean(axis=1)
pert_enkf_analysis_mae = abs(pert_enkf_h_analysis.mean(axis=1)-truth_h).mean(axis=1)
pert_blue_analysis_rmse = np.sqrt(((pert_blue_h_analysis-truth_h)**2).mean(axis=1))
pert_enkf_analysis_rmse = np.sqrt(((pert_enkf_h_analysis.mean(axis=1)-truth_h)**2).mean(axis=1))

In [None]:
# plot the time series of MAE
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
ax.set_xlim(0, 15)
ax.set_ylim(0, 0.01)
ax.set_xlabel('Time')
ax.set_ylabel('MAE')
ax.set_title('Error in water height')
time = sw_model.dt * np.arange(Nt+1)
ax.plot(time, perturbed_mae, c=palette[1], label='without assimilation')
ax.plot(time, pert_blue_analysis_mae, c=palette[2], ls='-', label='BLUE analysis (pert. obs)')
ax.plot(time, pert_enkf_analysis_mae, c=palette[3], ls='-', label='EnKF analysis (pert. obs)')
ax.plot(time, blue_analysis_mae, c=palette[2], ls='--', label='BLUE analysis')
ax.plot(time, enkf_analysis_mae, c=palette[3], ls='--', label='EnKF analysis')
plt.legend()
plt.show()

In [None]:
# plot the time series of RMSE
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
ax.set_xlim(0, 15)
ax.set_ylim(0, 0.02)
ax.set_xlabel('Time')
ax.set_ylabel('RMSE')
ax.set_title('Error in water height')
time = sw_model.dt * np.arange(Nt+1)
ax.plot(time, perturbed_rmse, c=palette[1], label='without assimilation')
ax.plot(time, pert_blue_analysis_rmse, c=palette[2], ls='-', label='BLUE analysis (pert. obs)')
ax.plot(time, pert_enkf_analysis_rmse, c=palette[3], ls='-', label='EnKF analysis (pert. obs)')
ax.plot(time, blue_analysis_rmse, c=palette[2], ls='--', label='BLUE analysis')
ax.plot(time, enkf_analysis_rmse, c=palette[3], ls='--', label='EnKF analysis')
plt.legend()
plt.show()

## EnKF with (h, u) cross-covariances

In [None]:
# observation error covariance matrix
R = 1e-4 * np.identity(Ny)

# observation operator
H = np.zeros((Ny, 2*sw_model.Nx))
H[:, 79:82] = np.identity(3)

# create a driver
def enkf_driver_full(model, ensemble, Nt, observations, R, H):
    """Run a simulation of Nt time steps of the given model starting from the ensemble.
    At each time step, an analysis for h is performed using `compute_analysis_blue`.
    Returns the trajectory of h (forecast and analysis) and u.
    """
    # allocate memory for h and u
    traj_h_forecast = np.empty((Nt+1, ensemble.Ne, model.Nx))
    traj_h_analysis = np.empty((Nt+1, ensemble.Ne, model.Nx))
    traj_u_forecast = np.empty((Nt+1, ensemble.Ne, model.Nx))
    traj_u_analysis = np.empty((Nt+1, ensemble.Ne, model.Nx))
    
    # allocate memory for B
    traj_B = np.empty((Nt+1, 2*model.Nx, 2*model.Nx))
    
    # initialise h and u
    traj_h_forecast[0] = ensemble.h
    traj_u_forecast[0] = ensemble.u

    # run first analysis
    Ef = np.concatenate([ensemble.h, ensemble.u], axis=1)
    Ea, B = compute_analysis_enkf(Ef, observations[0], R, H)
    traj_h_analysis[0] = Ea[:, :sw_model.Nx]
    traj_u_analysis[0] = Ea[:, sw_model.Nx:]
    traj_B[0] = B
    ensemble.h[:] = traj_h_analysis[0]
    ensemble.u[:] = traj_u_analysis[0]
    
    # run the Nt time steps
    for t in trange(Nt, desc='running EnKF'):
        
        # forecast
        model.forward_ensemble(ensemble)
        traj_h_forecast[t+1] = ensemble.h        
        traj_u_forecast[t+1] = ensemble.u
        
        # analysis
        Ef = np.concatenate([ensemble.h, ensemble.u], axis=1)
        Ea, B = compute_analysis_enkf(Ef, observations[t+1], R, H)
        traj_h_analysis[t+1] = Ea[:, :sw_model.Nx]
        traj_u_analysis[t+1] = Ea[:, sw_model.Nx:]
        traj_B[t+1] = B
        ensemble.h[:] = traj_h_analysis[t+1]
        ensemble.u[:] = traj_u_analysis[t+1]
        
    # return h and u
    return (traj_h_forecast, traj_h_analysis, traj_u_forecast, traj_u_analysis, traj_B)

In [None]:
# initialise and run the EnKF simulation
ensemble = sw_model.new_ensemble_crenel(Ne=25, mean_h_anom=1, std_h_anom=0.02)
full_enkf_h_forecast, full_enkf_h_analysis, full_enkf_u_forecast, full_enkf_u_analysis, full_enkf_B = enkf_driver_full(sw_model, ensemble, Nt, pert_observations, R, H)

In [None]:
# make a fancy animation for water height
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
ax.set_xlim(0, 100)
ax.set_ylim(0.99, 1.08)
ax.set_xlabel('Domain')
ax.set_ylabel('Water height')
ax.set_title('Truth vs full EnKF simulation')
truth_line, = ax.plot([], [], c=palette[0], label='truth')
enkf_line, = ax.plot([], [], c=palette[2], label='EnKF analysis')
plt.legend()
x = np.arange(sw_model.Nx)
def animate(t):
    truth_line.set_data(x, truth_h[t])
    enkf_line.set_data(x, full_enkf_h_analysis[t].mean(axis=0))
    return (truth_line, enkf_line)
freq = 10
anim = animation.FuncAnimation(fig, animate, frames=range(0, Nt+1, freq), interval=75, blit=True)
plt.close(fig)
HTML(anim.to_jshtml())

In [None]:
# make a fancy animation for horizontal velocity
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
ax.set_xlim(0, 100)
ax.set_ylim(-0.2, 0.2)
ax.set_xlabel('Domain')
ax.set_ylabel('Horizontal velocity')
ax.set_title('Truth vs full EnKF simulation')
truth_line, = ax.plot([], [], c=palette[0], label='truth')
enkf_line, = ax.plot([], [], c=palette[2], label='EnKF analysis')
plt.legend()
x = np.arange(sw_model.Nx)
def animate(t):
    truth_line.set_data(x, truth_u[t])
    enkf_line.set_data(x, full_enkf_u_analysis[t].mean(axis=0))
    return (truth_line, enkf_line)
freq = 10
anim = animation.FuncAnimation(fig, animate, frames=range(0, Nt+1, freq), interval=75, blit=True)
plt.close(fig)
HTML(anim.to_jshtml())

In [None]:
# make a fancy animation for B
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
ax.grid(False)
ax.set_title('B matrix')
vmax = 5e-5
im = ax.imshow(full_enkf_B[0], origin='lower', extent=[0, 2*sw_model.Nx, 0, 2*sw_model.Nx], 
               vmin=-vmax, vmax=vmax, cmap='RdBu')
plt.colorbar(im)
def animate(t):
    im.set_array(full_enkf_B[t])
    return (im,)
freq = 10
anim = animation.FuncAnimation(fig, animate, frames=range(0, Nt+1, freq), interval=75, blit=True)
plt.close(fig)
HTML(anim.to_jshtml())

In [None]:
# compute MAE/RMSE for h
full_enkf_analysis_mae = abs(full_enkf_h_analysis.mean(axis=1)-truth_h).mean(axis=1)
full_enkf_analysis_rmse = np.sqrt(((full_enkf_h_analysis.mean(axis=1)-truth_h)**2).mean(axis=1))

In [None]:
# plot the time series of MAE
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
ax.set_xlim(0, 15)
ax.set_ylim(0, 0.01)
ax.set_xlabel('Time')
ax.set_ylabel('MAE')
ax.set_title('Error in water height')
time = sw_model.dt * np.arange(Nt+1)
ax.plot(time, perturbed_mae, c=palette[1], label='without assimilation')
ax.plot(time, pert_enkf_analysis_mae, c=palette[3], ls='--', label='EnKF analysis (only h)')
ax.plot(time, enkf_analysis_mae, c=palette[3], ls='-', label='EnKF analysis (full)')
plt.legend()
plt.show()

In [None]:
# plot the time series of RMSE
fig = plt.figure(figsize=(12, 6))
ax = plt.gca()
ax.set_xlim(0, 15)
ax.set_ylim(0, 0.02)
ax.set_xlabel('Time')
ax.set_ylabel('RMSE')
ax.set_title('Error in water height')
time = sw_model.dt * np.arange(Nt+1)
ax.plot(time, perturbed_rmse, c=palette[1], label='without assimilation')
ax.plot(time, pert_enkf_analysis_rmse, c=palette[3], ls='--', label='EnKF analysis (only h)')
ax.plot(time, enkf_analysis_rmse, c=palette[3], ls='-', label='EnKF analysis (full)')
plt.legend()
plt.show()