In [1]:
## import libraries and functions

import numpy as np
import matplotlib.pyplot as plt
import torch

device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'
dtype = torch.float64

arr_kwargs = {'dtype': dtype, 'device': device}

from generate_obs_and_truth import create_obs_op, generate_obs


In [None]:
from initializeQG import generate_snapshots, generate_initial_condition

from generate_obs_and_truth import generate_truth

## generate snapshots

# takes a long time to run (!), so run this once and then save the snapshots
K_long = 1000      # paper: int(5e5)
n_snaps = 100      # paper: 2000

dt = 6*3600

nl = 2
nx = 128
ny = 128

p = torch.zeros(1,nl,nx,ny, dtype = torch.float64).to(device)

p_temp = generate_snapshots(p,K_long, n_snaps,dt)  # after running, save these

torch.save(p_temp, 'p_temp')
psi_long = torch.load('p_temp', weights_only = True)


In [None]:
## generate initial condition and truth
p0_temp = generate_initial_condition(p_temp)

p_true_temp = generate_truth(p0_temp, 5000,dt,device)

torch.save(p_true_temp, 'p_true_temp')

psi_true = torch.load('p_true_temp', weights_only = True).to(device)

In [None]:
dt = 6*3600        # 6 hours


#psi_long = torch.load('psi_long_run', weights_only = True)
#psi_true = torch.load('psi_true', weights_only = True).to(device)

## observations
nl = 2
choice = 'satellite'
k = 0
delta_obs = 16
m = 300
s_obs = 2.0

seed = 0

nx = 128
ny = 128

H,obs_loc = create_obs_op(nx,ny,choice,m,delta_obs,k)


print('m is: ' + str(H.shape[0]))

## MF-EnKF experiments

In [None]:
from MF_EnKF import MF_EnKF_run
from localization import create_loc_mat
from initializeQG import generate_initial_ensemble
from baseline import baseline_run

torch.cuda.empty_cache()

K_da = 500      # paper: 1500   # length of DA cycle

# load truth and long run
psi_true = torch.load('psi_true', weights_only = True).to(device)
psi_long = torch.load('psi_long_run', weights_only = True)


remove_bound = True     # remove boundary in analysis step, True or False

s_obs = 2.0             # std of the observation error

# localization

r_loc = 5       # fastest is to save the localization matrix once, then rerun

if remove_bound == True:
    nx = 128 - 2
    ny = 128 - 2
    H,obs_loc = create_obs_op(nx,ny,choice,m,delta_obs,k)
    y_obs = generate_obs(psi_true[:,:,1:-1,1:-1],s_obs,H,seed)
    rho_X = create_loc_mat(r_loc,nx**2,'GC','euler')

else:
    nx = 128
    ny = 128
    H,obs_loc = create_obs_op(nx,ny,choice,m,delta_obs,k)
    y_obs = generate_obs(psi_true,s_obs,H,seed)
    rho_X = create_loc_mat(r_loc,nx**2,'GC', 'euler')


mode = 'ML'         # choice of surrogate model, one of {'LR', 'ML'}
r_low = 64          # dimension of lower-dimensional surrogate

s_mod_X = 0.0       # model error for M_X
s_mod_U = 0.0       # model error for M_U


recenter_analysis_list = ['ML']     # recenter analyis ensemble, one of {'none', 'all' ,'control', 'ML'}
adjust_corr_list = [True]           # adjust correlation matrix A_Uhat to equal A_X, True or False

enkf_type_list = ['det']            # type of EnKF, 'det' or 'pert'
control_list = [True]               # include control variate Uhat, one of {True, False}

recenter_forecast = 'none'          # recenter the forecast step

recenter = True                     # recenter option for baseline method

nl = 2
nx = 128
ny = 128

seed = 0
N_large = 200                       # amount of maximum initial ensemble members

device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'
X_init = generate_initial_ensemble(psi_long, N_large,seed,device)       # initial ensemble

alpha_X_val = [1.0]
N_X_val = [25]                    # number of HR runs
N_U_val = [50]                    # number of ML runs

k_obs = 4                         # distance between obs

rmse_i = []
rmse_i_base = []

P_Z_i = []
P_X_i = []


lambda_val = [0.5]              # lambda parameter in MF-EnKF

alpha_X = 1.0                   # inflation factor

show = 'mfenkf'                 # show either 'mfenkf' or 'base' results


M = 10                          # average over M observation realizations
seed_val = range(M)

rmse_1 = []

for lambda_ in lambda_val:
    for recenter_analysis in recenter_analysis_list:
        for adjust_corr in adjust_corr_list:
            for control in control_list:
                for enkf_type in enkf_type_list:
                    for N_X in N_X_val:
                        for N_U in N_U_val:

                            for seed in seed_val:

                                alpha_Uhat = alpha_X
                                alpha_U = alpha_X

                                y_obs = generate_obs(psi_true[:,:,1:-1,1:-1],s_obs,H,seed)

                                if show == 'mfenkf':
                                    try: 
                                        rmse_Z, Z, P_X, P_Uhat, P_U, P_Z, mu_X, mu_Uhat, mu_U, X, Uhat, U   = MF_EnKF_run(lambda_,N_X,N_U,s_obs,rho_X, alpha_X, alpha_Uhat, alpha_U, k_obs,y_obs,
                                    X_init, K_da, psi_true, mode, H,r_low,s_mod_X,s_mod_U,recenter_analysis, adjust_corr,control,enkf_type,recenter_forecast,
                                    dt, nl, nx, ny, remove_bound)
                                        rmse_i.append(rmse_Z[-1])


                                        A_X = 1/np.sqrt(N_X-1) * (X - mu_X[:,None] ).type(dtype)
                                        A_Uhat = 1/np.sqrt(N_X-1) * (Uhat - mu_Uhat[:,None] ).type(dtype)
                                        A_U = 1/np.sqrt(N_U-1) * (U - mu_U[:,None]).type(dtype)
                                        S_Z = A_X @ A_X.T + lambda_**2*(A_Uhat @ A_Uhat.T) + lambda_**2*(A_U @ A_U.T) - lambda_*(A_X @ A_Uhat.T) - lambda_*(A_Uhat @ A_X.T) 

                                        rmse_1.append(rmse_Z[-1])
                                        
                                        print('lambda is: ' + str(lambda_) + ', (N_X,N_U) is: (' + str(N_X) + ',' + str(N_U) + '), rmse MF-EnKF is: ' + str(rmse_Z[-1]), ', spread X: ' + str(P_X[-1]) + ', spread Z: ' + str(P_Z[-1].item()))
                                    except:
                                        print('Unstable')



                            print('Mean RMSE: for lambda = ' + str(lambda_) + ': ' + str(np.mean(rmse_1[-M:])) )

                          #  elif show == 'base':
                          #      try:
                           #       recenter = False
                            #      rmse_Z, rmse_X, Z,P_Z, P_X, P_U, X, U = baseline_run(N_X,N_U,s_obs,rho_X, alpha_X, k_obs,y_obs,
                             #       X_init, K_da, psi_true, mode, H, r_low, enkf_type, dt, remove_bound,recenter)
                            
                              #    rmse_i_basse.append(rmse_X[-1])

                               #   print('N_X,N_U is: ' + str(N_X) + ', ' + str(N_U) + ', rmse base is: ' + str(rmse_X[-1]))
                               # except:
                               #   rmse_i_base.append(np.nan)

                               #   print('N_X,N_U is: ' + str(N_X) + ', ' + str(N_U) + ', rmse base is: ' + str(np.nan))



## EnKF experiments

In [None]:
from EnKF import EnKF_run

remove_boundary = True      # remove boundary in analysis step
alpha_X_val = [1.0]         # inflation factor values

N_X = 30                    # number of HR runs

enkf_type = 'det'           # 'det' or 'pert' for D-EnKF for standard perturbed obs version

mode = 'HR'                 # 'HR' for M_X and 'ML' for M_U

# localization
r_loc = 5

nx = 128 - 2
ny = 128 - 2
H,obs_loc = create_obs_op(nx,ny,choice,m,delta_obs,k)
y_obs = generate_obs(psi_true[:,:,1:-1,1:-1],s_obs,H,seed)
rho_X = create_loc_mat(r_loc,nx**2,'GC','euler')

loc = True                  # Localization True or False?


seed_val = range(10)

for alpha_X in alpha_X_val:
    for seed in seed_val:
        y_obs = generate_obs(psi_true[:,:,1:-1,1:-1],s_obs,H,seed)

        try:
            rmse_Z, Z_mfenkf, P_mfenkf, X_mfenkf = EnKF_run(N_X,s_obs,rho_X,alpha_X, k_obs,y_obs, X_init,K_da, psi_true,enkf_type,H,dt,remove_boundary,
             nl, nx,ny,mode,loc)
            print('..............................')
            print('alpha = ' + str(alpha_X) + ', RMSE: ' + str(rmse_Z[-1]))

        except:
            print('Unstable run')