In [None]:
## import libraries

import torch

import numpy as np
import matplotlib.pyplot as plt

torch.backends.cudnn.deterministic = True
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.float32

arr_kwargs = {'dtype': dtype, 'device': device}

print(device,dtype)


In [None]:
## obs and truth

from generate_obs_and_truth import generate_truth, generate_obs, create_obs_op

# parameters
Force = 15          # external forcing
n = 960             # state dimension
K = 32              # kernel width
dt = 1/40           # time step, corresponding to 3 hours
T = 12*30*8*2       # 2 years spinup time

# generate truth run
K_spinup = 12*30*8*2              # 2 years of spinup time
K = 10000                         # length of truth run
np.random.seed(0)
x0 = np.random.uniform(0,1,n)     # initial condition

x_true = generate_truth(x0, K, K_spinup)

# generate obs
s_obs = 2.0             # std of observation error
seed = 0
choice = 'grid'         # gridded observations
k = 0                   # obsv locations do not change in time  
m = 40                  # 40 obs
delta_obs = int(n/m)

H, obs_loc = create_obs_op(n, choice, m, delta_obs, k)

m = H.shape[0]

y_obs = generate_obs(x_true,s_obs,H,seed)

print('Percentage observed: ' + str(m/n*100) + '%')
print('Number of observations: ' + str(m))


In [None]:
## MF-EnKF run settings

from MF_EnKF import MF_EnKF_run
from MF_EnKF_baseline import MF_EnKF_run_base
from localization import create_loc_mat

K_da = 1000             # length of DA experiments

k_obs = 2               # distance in time between obs

N_X_val = [2]           # number of HR runs
N_U_val = [50]          # number of ML runs

mode = 'ML'             # 'ML' for ML surrogate, 'LR' for low-res surrogate
r_low = 120             # dimension of low-res surrogate

s_mod_X = 0.0           # model error for M_X
s_mod_U = 0.0           # model error for M_U

adjust_corr = True      # Set A_Uhat = A_X? {True, False}
control = True          # Include control variate Uhat? {True, False}

enkf_type = 'det'       # 'det' for D-EnKF, 'pert' for perturbed obs EnKF

pert_option_val = [2]       # method of creating perturbations inside MF-EnKF
## 1: eta_Uhat = R, eta_U = (2-lambda_)/lambda_ R 
## 2: eta_Uhat = 1/lambda_^2 R, eta_U = 1/lambda_^2 R
## 3: eta_Uhat = R, eta_U = R

pert_option = 2             # method of creating perturbations inside MF-EnKF (unused if enkf_type = 'det')

# localization
loc = 'no'    # 'yes' if localization, 'no' if not

r = 100
#rho_X = create_loc_mat(r,n,'GC', 'periodic').to_dense().cpu().numpy()
rho_X = None           # can be set to 'None' if loc = 'no'

recenter_forecast = False       # recenter forecast ensemble?

# initial ensemble
np.random.seed(0)
s_mod_X_init = 2.0
X_init = x_true[0,:,None] + s_mod_X_init * np.random.randn(n,200)


recenter = 'ML'         # recenter analysis ensemble? one of {'ML', 'none', 'all', 'control'}

lambda_val = [0.5]      # lambda parameter in MF-EnKF

M = 1                   # average over M obs realizations

show = 'mfenkf'         # show 'mfenkf' or 'base' (baseline)
alpha_val = [1.01]      # inflation factor

rmse_1 = []
P_X_1 = []
P_Z_1 = []

rmse_2 = []
rmse_3 = []

rmse_means = []

P_X_val = []
P_Z_val = []
rmse_val = []

for N_X in N_X_val:
    for N_U in N_U_val:
        for lambda_ in lambda_val:
            for alpha_X in alpha_val:
                for m in range(M):

                    seed = m+1
                    y_obs = generate_obs(x_true,s_obs,H,seed)

                    #N_U = 10*(15-N_X)

                    alpha_Uhat = alpha_X
                    alpha_U = alpha_X

            
                    #print('Run number: ' + str(m+1) + '/' + str(M))
                    if show == 'mfenkf':
                        print('N_X = ' + str(N_X) + ', N_U = ' + str(N_U), ', lambda = ' + str(lambda_))
                        rmse_Z , Z , P_X, P_Uhat, P_U, P_Z, mu_X, mu_Uhat, mu_U, X, Uhat, U = MF_EnKF_run(lambda_,N_X,N_U,s_obs,rho_X, alpha_X, alpha_Uhat, alpha_U, k_obs,y_obs,
                        X_init, K_da, x_true, mode, H,r_low,s_mod_X,s_mod_U,recenter,adjust_corr,
                        control,enkf_type,recenter_forecast,loc,pert_option)

                        rmse_1.append(rmse_Z[-1])
                        P_X_1.append(P_X[-1])
                        P_Z_1.append(P_Z[-1])
                        print('RMSE MF-EnKF: ' + str(rmse_Z[-1]) + ', spread X: ' + str(P_X[-1]), ', spread Z: ' + str(P_Z[-1]))

                    if show == 'base':
                        recenter_base = False
                        rmse_Z, rmse_X, Z,P_Z, P_X, P_U, X, U = MF_EnKF_run_base(N_X,N_U,s_obs,rho_X, alpha_X, k_obs,y_obs,
                        X_init, K_da, x_true, mode, H, r_low, enkf_type, recenter_base,loc)

                        print('RMSE MF-EnKF base (no recenter): ' + str(rmse_Z[-1]))

                        rmse_3.append(rmse_Z[-1])


                print('Mean RMSE: for lambda = ' + str(lambda_) + ': ' + str(np.mean(rmse_1[-M:])) )
                P_X_val.append(np.mean(P_X_1[-M:]))
                P_Z_val.append(np.mean(P_Z_1[-M:]))
                rmse_val.append(np.mean(rmse_1[-M:]))


           # print('N_X = ' + str(N_X) + ', N_U = ' + str(N_U), ', lambda = ' + str(lambda_))

           # print('Mean RMSE (base): for alpha = ' + str(alpha_X) + ': ' + str(np.mean(rmse_3[-M:])) )

           # rmse_means.append(np.mean(rmse_3[-M:]))

   # recenter_base = True
   # rmse_Z, rmse_X, Z,P_Z, P_X, P_U, X, U = MF_EnKF_run_base(N_X,N_U,s_obs,rho_X, alpha_X, k_obs,y_obs,
   #             X_init, K_da, x_true, mode, H, r_low, enkf_type, recenter_base,loc)

#print('N_X is: ' + str(N_X) + ', r is: ' + str(r) + ', RMSE is: ' + str(rmse_Z[-1]))
          
  #  print('RMSE MF-EnKF base: ' + str(rmse_Z[-1]))

   # rmse_2.append(rmse_Z[-1])

        
