In [None]:
import inspect
import matplotlib.pyplot as plt
import numpy as np
from L96_model import L96, L96s, L96_eq1_xdot, L96_2t_xdot_ydot, RK4
import time

rng=np.random.default_rng()

config=dict(K=40,J=10,obs_freq=10,F_truth=10,
            F_fcst=10,#+np.concatenate((np.linspace(-1.8,2,20),np.linspace(2,-1.8,20))),
            GCM_param=np.array([0,0,0,0]),ns_da=20000,
            ns=20000,ns_spinup=200,dt=0.005,si=0.005,B_loc=10,DA='EnKF',nens=100,
            inflate_opt="multiplicative",inflate_factor=0.01,hybrid_factor=0.1,
            param_DA=False,param_sd=[0.01,0.02,0.1,0.5],param_inflate='DA',
            obs_density=0.2,DA_freq=10,obs_sigma=0.5)

In [None]:
def s(k,K):
    """A non-dimension coordinate from -1..+1 corresponding to k=0..K"""
    return 2 * ( 0.5 + k ) / K - 1

def get_dist(i,j,K):
    return abs(i-j) if abs(i-j)<=K/2 else K-abs(i-j)

def GCM(X0, F, dt, nt, param=[0]):
    time, hist, X = dt*np.arange(nt+1), np.zeros((nt+1,len(X0)))*np.nan, X0.copy()
    hist[0] = X
    
    for n in range(nt):
        X = X + dt * ( L96_eq1_xdot(X, F) - np.polyval(param, X) )

        hist[n+1], time[n+1] = X, dt*(n+1)
    return hist, time

# Generate observation operator, assuming linearity and model space observations
def ObsOp(K,l_obs,t_obs,i_t):
    nobs=l_obs.shape[-1]
    H=np.zeros((nobs,K))
    H[range(nobs),l_obs[t_obs==i_t]]=1
    return H

# localize covariance matrix based on the Gaspari-Cohn function
def cov_loc(B,loc=0):
    M,N = B.shape
    X,Y = np.ix_(np.arange(M),np.arange(N))
    dist=np.vectorize(get_dist)(X,Y,M)
    W=np.vectorize(gaspari_cohn)(dist,loc)
    return B*W,W

def gaspari_cohn(distance,radius):
    if distance==0:
        weight=1
    else: 
        if radius==0:
            weight=0
        else:
            ratio=distance/radius
            weight=0
            if ratio<=1:
                weight=-ratio**5/4+ratio**4/2+5*ratio**3/8-5*ratio**2/3+1
            elif ratio<=2:
                weight=ratio**5/12-ratio**4/2+5*ratio**3/8+5*ratio**2/3-5*ratio+4-2/3/ratio
    return weight

def find_obs(loc,obs,t_obs,l_obs,period):
    t_period=np.where((t_obs[:,0]>period[0]) & (t_obs[:,0]<=period[1]))
    obs_period=np.zeros(t_period[0].shape)
    obs_period[:]=np.nan
    for i in np.arange(len(obs_period)):
        if np.any(l_obs[t_period[0][i]]==loc):
            obs_period[i]=obs[t_period[0][i]][l_obs[t_period[0][i]]==loc]
    return obs_period

def running_ave(X,N):
    if N%2==0:
        N1,N2=-N/2,N/2
    else:
        N1,N2=-(N-1)/2,(N+1)/2
        
    X_sum=np.zeros(X.shape)
    for i in np.arange(N1,N2):
        X_sum=X_sum+np.roll(X,int(i),axis=0)
    return X_sum/N   

In [None]:
# Set up the "truth" 2-scale L96 model and generate initial conditions from a short spinup
M_truth = L96(config['K'], config['J'], F=config['F_truth'], dt=config['dt'])
M_truth.set_state(rng.standard_normal((config['K'])), 0*M_truth.j)
X_spinup,Y_spinup,t_spinup = M_truth.run(config['si'], config['si']*config['ns_spinup'])
X_init=X_spinup[-1,:]
Y_init=Y_spinup[-1,:]

# Run L96 to generate the "truth"
M_truth.set_state(X_init, Y_init)
X_truth,Y_truth,t_truth = M_truth.run(config['si'], config['si']*config['ns'])

# # generate climatological background covariance for 2-scale L96 model
# B_clim = np.cov(X_truth.T)
# np.save('B_clim_L96.npy', B_clim)  

plt.figure(figsize=(12,10))
plt.subplot(221); # Snapshot of X[k]
plt.plot(M_truth.k, X_truth[-1,:], label='X');
plt.plot(M_truth.j/M_truth.J, Y_truth[-1,:], label='Y')
plt.legend(); plt.xlabel('k'); plt.title('$X,Y$ @ $t=N\Delta t$');
plt.plot(M_truth.k, X_truth[0,:], 'k:')
plt.plot(M_truth.j/M_truth.J, Y_truth[0,:], 'k:')
plt.subplot(222); # Sample time-series X[0](t), Y[0](t)
plt.plot(t_truth, X_truth[:,0], label='X');
plt.plot(t_truth, Y_truth[:,0], label='Y');
plt.xlabel('t'); plt.title('$X[0,t]$, $Y[0,t]$');
plt.subplot(223); # Full model history of X
plt.contourf(M_truth.k,t_truth,X_truth); plt.colorbar(orientation='horizontal'); plt.xlabel('k'); plt.ylabel('t'); plt.title('$X(k,t)$');
plt.subplot(224); # Full model history of Y
plt.contourf(M_truth.j/M_truth.J,t_truth,Y_truth); plt.colorbar(orientation='horizontal'); plt.xlabel('k'); plt.ylabel('t'); plt.title('$Y(k,t)$');

# # generate climatological background covariance for 1-scale L96 model
# M_1s = L96s(config['K'], F=config['F_truth'], dt=config['dt'], method=RK4)
# M_1s.set_state(X_init)
# X1_truth,t1_truth = M_1s.run(config['si']*config['ns'])
# B_clim_1s = np.cov(X1_truth.T)
# np.save('B_clim_1s.npy', B_clim_1s)  

In [None]:
# Sample the "truth" to generate observations at certain times (t_obs) and locations (l_obs)
t_obs=np.tile(np.arange(config['obs_freq'],config['ns_da']+config['obs_freq'],config['obs_freq']),[int(config['K']*config['obs_density']),1]).T
l_obs=np.zeros(t_obs.shape,dtype='int')
for i in range(l_obs.shape[0]):
    l_obs[i,:]=rng.choice(config['K'], int(config['K']*config['obs_density']), replace=False)
X_obs=X_truth[t_obs,l_obs]+config['obs_sigma']*rng.standard_normal(l_obs.shape)
# print(X_obs.shape)

# Calculated observation covariance matrix, assuming independent observations
R = config['obs_sigma']**2*np.eye(int(config['K']*config['obs_density']))

# plt.figure(figsize=[6,4])
# plt.scatter(t_obs,X_obs)

In [None]:
import DA_methods
import importlib
importlib.reload(DA_methods)

t0 = time.time()

# load pre-calculated climatological background covariance matrix from a long simulation
B_clim=np.load('B_clim_L96s.npy')
B_loc,W_clim=cov_loc(B_clim,loc=config['B_loc'])

# set up array to store DA increments
X_inc=np.zeros((int(config['ns_da']/config['DA_freq']),config['K'],config['nens']))
if config['DA']=='3DVar':
    X_inc=np.squeeze(X_inc)
t_DA=np.zeros(int(config['ns_da']/config['DA_freq']))

# initialize ensemble with perturbations
i_t=0
ensX=X_init[None,:,None]+rng.standard_normal((1,config['K'],config['nens']))
X_post=ensX[0,...]

if config['param_DA']:
    mean_param=np.zeros((int(config['ns_da']/config['DA_freq']),len(config['GCM_param'])))
    spread_param=np.zeros((int(config['ns_da']/config['DA_freq']),len(config['GCM_param'])))
    param_scale=config['param_sd']
    W=np.ones((config['K']+len(config['GCM_param']),config['K']+len(config['GCM_param'])))
    W[0:config['K'],0:config['K']]=W_clim
    
else: 
    W=W_clim
    param_scale=np.zeros(config['GCM_param'].shape)
    
ens_param=np.zeros((len(config['GCM_param']),config['nens']))
for i in range(len(config['GCM_param'])):
    ens_param[i,:]=config['GCM_param'][i]+rng.normal(scale=param_scale[i],size=config['nens'])
    
# DA cycles
for cycle in np.arange(0,config['ns_da']/config['DA_freq'],dtype='int'):
# for cycle in np.arange(0,1,dtype='int'):
    
    # set up array to store model forecast for each DA cycle
    ensX_fcst=np.zeros((config['DA_freq']+1,config['K'],config['nens']))

    for n in range(config['nens']):
        ensX_fcst[...,n] = GCM(X_post[0:config['K'],n], config['F_fcst'], config['dt'], config['DA_freq'], ens_param[:,n])[0]
    i_t=i_t+config['DA_freq']

    X_prior=ensX_fcst[-1,...]  # get prior from model integration
    
    # call DA
    t_DA[cycle]=t_truth[i_t]
    if config['DA']=='EnKF':
        H=ObsOp(config['K'],l_obs,t_obs,i_t)
        # augment state vector with parameters when doing parameter estimation
        if config['param_DA']:
            H=np.concatenate((H,np.zeros((H.shape[0],len(config['GCM_param'])))),axis=-1)
            X_prior=np.concatenate((X_prior,ens_param))
        B_ens = np.cov(X_prior)
        B_ens_loc = B_ens*W
        X_post=DA_methods.EnKF(X_prior,X_obs[t_obs==i_t],H,R,B_ens_loc)
        X_post=DA_methods.ens_inflate(X_post,X_prior,config['inflate_opt'],config['inflate_factor'])
        if config['param_DA']:
            param_post=X_post[-len(config['GCM_param']):None,:]
    elif config['DA']=='HyEnKF':
        H=ObsOp(config['K'],l_obs,t_obs,i_t)
        B_ens = np.cov(X_prior)*(1-config['hybrid_factor'])+B_clim*config['hybrid_factor']
        B_ens_loc = B_ens*W
        X_post=DA_methods.EnKF(X_prior,X_obs[t_obs==i_t],H,R,B_ens_loc)
        X_post=DA_methods.ens_inflate(X_post,X_prior,config['inflate_opt'],config['inflate_factor'])
    elif config['DA']=='3DVar':
        X_prior=np.squeeze(X_prior)
        H=ObsOp(config['K'],l_obs,t_obs,i_t)
        X_post=DA_methods.Lin3dvar(X_prior,X_obs[t_obs==i_t],H,R,B_loc,3)
        X_post=X_post[:,None]
    elif config['DA']=='Replace':
        X_post=X_prior
        X_post[l_obs[t_obs==i_t]]=X_obs[t_obs==i_t]
    elif config['DA']=='None':
        X_post=X_prior
    
    X_inc[cycle,:]=np.squeeze(X_post[0:config['K'],:])-X_prior[0:config['K'],:]  # get current increments
    # get posterior info about the estimated parameters
    if config['param_DA']:
        mean_param[cycle,:]=param_post.mean(axis=-1)
        spread_param[cycle,:]=param_post.std(axis=-1)
        # prepare parameter ensemble for next DA cycle
        if config['param_inflate']=='fixed':
            for i in range(len(config['GCM_param'])):
                ens_param[i,:]=param_post.mean(axis=-1)[i]+rng.normal(scale=param_scale[i],size=config['nens'])
        elif config['param_inflate']=='DA':
            ens_param=param_post
            
    # reset initial conditions for next DA cycle
    ensX_fcst[-1,:,:]=X_post[0:config['K'],:]
    ensX=np.concatenate((ensX,ensX_fcst[1:None,...]))
    
if config['DA']=='3DVar':
    X_inc=X_inc[...,None]

t1 = time.time()
print(t1-t0)

In [None]:
# post processing and visualization
meanX=np.mean(ensX,axis=-1)
clim=np.max(np.abs(meanX-X_truth[0:(config['ns_da']+1),:]))

fig, axes=plt.subplots(3,2,figsize=(12,15))
ch=axes[0,0].contourf(M_truth.k,t_truth[0:(config['ns_da']+1)],meanX-X_truth[0:(config['ns_da']+1),:],
                           cmap='bwr',vmin=-clim,vmax=clim,extend='neither');
plt.colorbar(ch,ax=axes[0,0],orientation='horizontal'); 
axes[0,0].set_xlabel('s'); axes[0,0].set_ylabel('t'); axes[0,0].set_title('X - X_truth');
axes[0,1].plot(t_truth[0:(config['ns_da']+1)], np.sqrt(((meanX-X_truth[0:(config['ns_da']+1),:])**2).mean(axis=-1)),label='RMSE'); 
axes[0,1].plot(t_truth[0:(config['ns_da']+1)], np.mean(np.std(ensX,axis=-1),axis=-1),label='Spread'); 
axes[0,1].legend()
axes[0,1].set_xlabel('t'); axes[0,1].set_title('RMSE (X - X_truth)');
axes[1,0].plot(M_truth.k, np.sqrt(((meanX-X_truth[0:(config['ns_da']+1),:])**2).mean(axis=0)),label='RMSE'); 
X_inc_ave=X_inc/config['DA_freq']/config['si']
axes[1,0].plot(M_truth.k, X_inc_ave.mean(axis=(0,-1)),label='Inc'); 
axes[1,0].plot(M_truth.k, running_ave(X_inc_ave.mean(axis=(0,-1)),7),label='Inc Ave'); 
axes[1,0].plot(M_truth.k, np.ones(M_truth.k.shape)*(config['F_fcst']-config['F_truth']),label='F_bias'); 
axes[1,0].plot(M_truth.k, np.ones(M_truth.k.shape)*(X_inc/config['DA_freq']/config['si']).mean(),'k:',label='Ave Inc'); 
axes[1,0].legend()
axes[1,0].set_xlabel('s'); axes[1,0].set_title('RMSE (X - X_truth)');

plot_start,plot_end=1000, 1400
plot_start_DA, plot_end_DA=int(plot_start/config['DA_freq']), int(plot_end/config['DA_freq'])
plot_x=0
axes[1,1].plot(t_truth[plot_start:plot_end],X_truth[plot_start:plot_end,plot_x],label='truth')
axes[1,1].plot(t_truth[plot_start:plot_end],meanX[plot_start:plot_end,plot_x],label='forecast')
axes[1,1].scatter(t_DA[plot_start_DA:plot_end_DA],find_obs(plot_x,X_obs,t_obs,l_obs,[plot_start,plot_end]),label='obs')
axes[1,1].legend()

if config['param_DA']:
    for i,c in zip(np.arange(len(config['GCM_param'])),['r','b','g','k']):
        axes[2,0].plot(t_DA,mean_param[:,i],c+'-',label='C{} {:3f}'.format(i,mean_param[int(len(t_DA)/2):None,i].mean()))
        axes[2,0].plot(t_DA,spread_param[:,i],c+'--',label='SD {:3f}'.format(spread_param[int(len(t_DA)/2):None,i].mean()))
    axes[2,0].legend()

axes[2,1].text(0.1,0.1,'GCM param={}\nRMSE={:3f}\nSpread={:3f}\nDA={},{},{}\nDA_freq={}\nB_loc={}\ninflation={},{}\nobs_density={}\nobs_sigma={}\nF_truth={}\nobs_freq={}'.\
               format(config['GCM_param'],np.sqrt(((meanX-X_truth[0:(config['ns_da']+1),:])**2).mean()),
                      np.mean(np.std(ensX,axis=-1)),config['DA'],
                      config['nens'],config['hybrid_factor'],config['DA_freq'],config['B_loc'],
                      config['inflate_opt'],config['inflate_factor'],config['obs_density'],config['obs_sigma'],
                      config['F_truth'],config['obs_freq']),
               fontsize=15)

exp_number=np.random.randint(1,10000)
f = open('config_{0}.txt'.format(exp_number),"w")
f.write( str(config) )
f.close()
plt.savefig('fig_{0}.jpg'.format(exp_number))