In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import skew, lognorm

from model import lorenz63_fdm, M63, lorenz96_fdm, M96
from assimilation import EnKF, M3DVar, ExtendedKF

In [24]:
X_nature = np.load('./L96/data/X_nature.npy')
X_ini = np.load('./L96/data/X_ini.npy')
obs = np.load('./L96/data/obs_normal.npy')
ts = np.load('./L96/data/time_span.npy')
Pb = np.load('./L96/data/Pb_assum.npy')
R = np.load('./L96/data/R.npy')

dt = 0.01

# generate initial ensemble
N_ens = 30
rng = np.random.RandomState(42)
X_ens_ini = rng.multivariate_normal(X_ini.ravel(), Pb, size=N_ens).T  # (3, N_ens)

def RMSE(forecast, nature):
    return np.sqrt(np.mean((forecast-nature)**2, axis=0))

In [25]:
def gen_lognorm_obs(nature, s, obs_intv):
    ndim, size = nature.shape
    
    obserr = lognorm.rvs(s=s, size=(ndim,size))
    tmean = lognorm.mean(s)
    obserr_adjust = obserr - tmean
    obs = (nature + obserr_adjust)[:,::obs_intv]
    return obs

gen_lognorm_obs(X_nature, np.sqrt(np.log(2)), 8).shape

(40, 200)

In [26]:
obs = np.load('./L96/data/obs_normal.npy')
param = {
    'X_ini': X_ini,
    'obs': obs[::2,:],
    'obs_interv': 8,
    'Pb': Pb,
    'R': 2 * np.eye(20),
    'H_func': lambda arr: arr[::2,:]
}
tdv = M3DVar(lorenz96_fdm, dt)
tdv.set_params(**param)
tdv.cycle()
print('normal:', RMSE(tdv.analysis, X_nature).mean())


print('lognorm:')
for itime in range(10):
    print(itime)

    obs = gen_lognorm_obs(X_nature, np.sqrt(np.log(2)), 8)
    param = {
        'X_ini': X_ini,
        'obs': obs[::2,:],
        'obs_interv': 8,
        'Pb': Pb,
        'R': 2 * np.eye(20),
        'H_func': lambda arr: arr[::2,:]
    }
    tdv = M3DVar(lorenz96_fdm, dt)
    tdv.set_params(**param)
    tdv.cycle()
    print(RMSE(tdv.analysis, X_nature).mean())

normal: 2.5359512669968196
lognorm:
0
2.420281839892869
1
2.413148984273553
2
2.5971396959256525
3
2.230487456393766
4
2.4125592601013137
5
2.1660475073389427
6
2.3596853238936184
7
2.367619097466803
8
2.4352828468404044
9
2.234067019140366


In [27]:
obs = np.load('./L96/data/obs_normal.npy')
params = {
    'X_ens_ini': X_ens_ini, 
    'obs': obs[::2,:],
    'obs_interv': 8, 
    'R': 2 * np.eye(20), 
    'H_func': lambda arr: arr[::2,:], 
    'alpha': 0.3,
    'inflat': 1.4
}
enkf = EnKF(lorenz96_fdm, dt)
enkf.set_params(**params)
enkf.cycle()
rmse = RMSE(enkf.analysis.mean(axis=0), X_nature)
print('normal:', rmse.mean())


print('lognorm')
for itime in range(10):
    obs = gen_lognorm_obs(X_nature, np.sqrt(np.log(2)), 8)
    params = {
        'X_ens_ini': X_ens_ini, 
        'obs': obs[::2,:],
        'obs_interv': 8, 
        'R': 2 * np.eye(20), 
        'H_func': lambda arr: arr[::2,:], 
        'alpha': 0.3,
        'inflat': 1.4
    }
    enkf = EnKF(lorenz96_fdm, dt)
    enkf.set_params(**params)
    enkf.cycle()
    rmse = RMSE(enkf.analysis.mean(axis=0), X_nature)
    print(itime, rmse.mean())

normal: 2.319934862958723
lognorm
0 2.6105918174031264
1 2.3508061930609876
2 2.6641224090110263
3 2.6292354988003566
4 2.1832772082298004
5 2.568400576017441
6 2.1187456814180323
7 2.3991196543294926
8 2.380552118655066
9 2.1855279675016575
