In [1]:
'''
DeepDA - Data assimilation framework for deep time paleoclimate projects

By :  Mingsong Li
      Penn State
      limingsonglms@gmail.com
     
Date: Feb 26, 2020

Updated

    Mar. 3. 2020
        partly clean the code; add MC for local_rad, withheld_rate, and scaled Rg
    June 2020
        include d18O of CESM by Zhu et al., 2019 Sci Adv
    August 2020
        Two options for the proxy order: all random & use the given list
    Oct. 12, 2020
        Add multi_seed for Monte Carlo simulations
    Nov. 2, 2020
        Add DeepMIP PSMs
        Add log_level

Note:
    if Mg/Ca proxy is included, need to run
        correct_cgenie_carb_ohm_cal.ipynb for the estimation of carb_ohm_cal_ben field
    if d13C proxy is included, need to run
        correct_cgenie_sed_caco3_13c.ipynb for the correction of d13C
'''

### Import packages
from DeepDA_lib import LMR_DA
from DeepDA_lib import modules_nc
from DeepDA_lib import DeepDA_psm
from DeepDA_lib import DeepDA_tools
import h5py
import time
import yaml
import numpy as np
import random
import pandas
import os
import shutil
from netCDF4 import Dataset
import numpy.ma as ma
import numpy.matlib as mat
from sys import platform as sys_pf
import matplotlib.pyplot as plt
if sys_pf == 'darwin':
    import matplotlib
    matplotlib.use("TkAgg")
    import matplotlib.pyplot as plt
%matplotlib inline
from mpl_toolkits.basemap import Basemap, shiftgrid, cm

try:
    import bayspar
except ImportError as e2:
    print('Warning:', e2)
try:
    import bayfox
except ImportError as e3:
    print('Warning:', e3)
try:
    import baymag
except ImportError as e4:
    print('Warning:', e4)


### Read config file

config_name = "DeepDA_config.yml"
#config_name = "petmproxy3slices_v0.0.10gt1.csvexp_petm78_og1_qc_obs_20200203_test2.yml"
f = open(config_name, 'r')
yml_dict = yaml.load(f, Loader=yaml.FullLoader)
f.close()
log_level = yml_dict['log_level']
if log_level > 1:
    print('>>  Import packages...  => Okay')
t = -1  # last time slice, for cGENIE
k = 0   # surface layer, SST

kcov_saving = 0 # save covariance??? 0=no, 1 = yes

# read config.yml settings
if log_level > 1:
    print('')
    print(' ########## Load yml config file ########## ')
    print('')
########## Proxy + PSM #########


MCn = yml_dict['MonteCarlo']['number']
multi_seed = yml_dict['MonteCarlo']['multi_seed']
dir_proxy         = yml_dict['core']['proxy_dir']
dir_proxy_data    = dir_proxy +'/'+ yml_dict['proxies'][yml_dict['proxies']['use_from'][0]]['dbversion']
dir_proxy_save_dir= yml_dict['core']['wrkdir'] + '/'
dir_proxy_save    = yml_dict['proxies'][yml_dict['proxies']['use_from'][0]]['dbversion']
proxy_psm_type    = yml_dict['proxies'][yml_dict['proxies']['use_from'][0]]['proxy_psm_type']
proxy_assim2      = yml_dict['proxies'][yml_dict['proxies']['use_from'][0]]['proxy_assim2']
proxy_order       = yml_dict['proxies'][yml_dict['proxies']['use_from'][0]]['proxy_order']
proxy_err_eval   = yml_dict['proxies'][yml_dict['proxies']['use_from'][0]]['proxy_err_eval']
proxy_blacklist   = yml_dict['proxies'][yml_dict['proxies']['use_from'][0]]['proxy_blacklist']
proxy_list = [item for item in proxy_order if item not in proxy_blacklist]
psm_d18osw_adjust = yml_dict['psm']['bayesreg_d18o_pooled']['psm_d18osw_adjust']
d18osw_local_choice = yml_dict['psm']['bayesreg_d18o_pooled']['d18osw_local_choice']
d18osw_icesm_pco2 = yml_dict['psm']['bayesreg_d18o_pooled']['d18osw_icesm_pco2']

proxy_qc          = yml_dict['proxies']['proxy_qc']
lon_label = yml_dict['proxies'][yml_dict['proxies']['use_from'][0]]['lon_label']
lat_label = yml_dict['proxies'][yml_dict['proxies']['use_from'][0]]['lat_label']

prior_source = yml_dict['prior']['prior_source'] #
prior_state_variable = yml_dict['prior'][prior_source]['state_variable']  # note: ['2d': xxx; '3d': xxx]
dum_lon_offset = yml_dict['prior'][prior_source]['dum_lon_offset'] # longitude offset
dir_prior = yml_dict['core']['prior_dir']
dir_prior_full = os.listdir(dir_prior)
prior_len = len(dir_prior_full)

nexp = yml_dict['core']['nexp']
data_period_id    = yml_dict['proxies'][yml_dict['proxies']['use_from'][0]]['data_period_id']
data_period_idstd = yml_dict['proxies'][yml_dict['proxies']['use_from'][0]]['data_period_idstd']
recon_period = yml_dict['core']['recon_period']
recon_timescale = yml_dict['core']['recon_timescale_interval']
recon_period_full = np.arange(recon_period[0],recon_period[1]+1,recon_timescale)
recon_period_len = recon_period_full.shape[0]
geologic_age = yml_dict['core']['geologic_age']
limit_hard_keys = list(yml_dict['prior'][prior_source]['limit_hard'].keys())

# save config
config_save_name = dir_proxy_save_dir + dir_proxy_save + nexp + '.yml'
configos = 'cp ' + config_name + ' ' +  config_save_name
os.system(configos)
if log_level > 0:
    print('')
    print('The config (yml) file saved : ')
    print('')
    print(config_save_name)
    print('')
if log_level > 2:
    print('Set limit for {}'.format(limit_hard_keys))

nens = yml_dict['core']['nens']
save_ens_full = yml_dict['core']['save_ens_full']

proxy_err_eval = yml_dict['proxies'][yml_dict['proxies']['use_from'][0]]['proxy_err_eval']
# glassy d18O blacklist
proxy_d18o_glassy  = yml_dict['proxies']['proxy_d18o_glassy']
proxy_assim3 = yml_dict['proxies'][yml_dict['proxies']['use_from'][0]]['proxy_assim3']
data_glassy_label_blacklist = proxy_assim3['Marine sediments_d18o_pooled_glassy']
# bayspar
search_tol_i = yml_dict['psm']['bayesreg_tex86']['search_tol']
nens_i = yml_dict['psm']['bayesreg_tex86']['nens']

# ========= dataset for plot =========
cGENIEGrid = yml_dict['core']['proj_dir'] + '/data_misc/cGENIEGrid.csv'
cGENIEGrid = pandas.read_csv(cGENIEGrid)
cGENIEGridB_lat36 = cGENIEGrid['lat']
cGENIEGridB_lon36 = cGENIEGrid['lon']
cGENIEGrid = cGENIEGrid.to_numpy()
if log_level > 2:
    print('>>  Load dataset for plot => Okay')

# ========= Monte Carlo =========
local_rad_list = yml_dict['core']['local_rad_list'] #
locRadn= len(local_rad_list)
local_rad_list = np.asarray(local_rad_list)

proxy_frac_list   = yml_dict['proxies']['proxy_frac']
proxy_fracn = len(proxy_frac_list)
proxy_frac_list = np.asarray(proxy_frac_list)
proxy_order_type = yml_dict['proxies'][yml_dict['proxies']['use_from'][0]]['proxy_order_type']

Rscale_list = yml_dict['core']['Rscale']
Rscalen = len(Rscale_list)
Rscale_list = np.asarray(Rscale_list)

if log_level > 1:
    print('>>  Prior member size: {}'.format(prior_len))
    print('>>  Recon_period {} - {}. '.format(recon_period[0], recon_period[1]))
    if log_level > 1:
        print('      List: {}'.format(recon_period_full))
    print('>>  Proxy error evaluation: {}'.format(proxy_err_eval))
    if log_level > 2:
        print('>>  Proxy full list:')
        print('      {}'.format(proxy_order))
        print('>>  Proxy blacklist:')
        print('      {}'.format(proxy_blacklist))
    print('>>  Proxy to be assimilated (some may not exist)')
    print('      {}'.format(proxy_list))
    print('>>  Proxy quality control selection: {}'.format(proxy_qc))
if 'Marine sediments_mgca_pooled_bcp' in proxy_list or 'Marine sediments_mgca_pooled_red' in proxy_list:
    data_psm_mgca_find = 1
    if log_level > 2:
        print('>>    Mg/Ca proxy found ')
else:
    data_psm_mgca_find = 0
    
if log_level > 1:
    print('')
    print('########## Read prior ######### ')
    print('')

########## Prior #########
# save prior variable list
prior_variable_dict = []  # variable list
prior_nc_file_list = []  # nc file list
prior_variable_dict_3d = []  # variable list
prior_nc_file_list_3d = []  # nc file list

for key, value in prior_state_variable.items():
    nc_keyvalue = prior_state_variable[key]['ncname']  # note: 2d or 3d dict
    
    for key1, value1 in nc_keyvalue.items():
        for i in range(len(prior_state_variable[key][value1])):
            if key in ['2d']:
                prior_variable_dict.append(prior_state_variable[key][value1][i])
                prior_nc_file_list.append(key1+'/'+value1+'.nc')
            elif key in ['3d']:
                prior_variable_dict_3d.append(prior_state_variable[key][value1][i])
                prior_nc_file_list_3d.append(key1+'/'+value1+'.nc')

# prepare variable list for Xb
prior_variable2d_len = len(prior_variable_dict)
prior_variable3d_len = len(prior_variable_dict_3d)
if log_level > 2:
    print('>>  Number of 2d prior variables is: {}.'.format(prior_variable2d_len))
if prior_variable2d_len>0:
    if log_level > 2:
        print('      List:')
        for i in range(prior_variable2d_len):
            print('        {}/{}'.format(prior_nc_file_list[i], prior_variable_dict[i]))
if log_level > 2:
    print('>>  Number of 3d prior variables is: {}'.format(prior_variable3d_len))
    if prior_variable3d_len>0:
        print('      List:')
        for i in range(prior_variable3d_len):
            print('        {}/{}'.format(prior_nc_file_list_3d[i], prior_variable_dict_3d[i]))
   
# If there is no field in the model, convert model unit to proxy unit

if log_level > 2:
    print('>>  Reading prior state variables')
# read first variable data, first time slice, to get the shape of prior grid
try:
    x1 = Dataset(dir_prior+'/'+dir_prior_full[0]+'/'+ prior_nc_file_list_3d[0]).variables[prior_variable_dict_3d[0]][0,:,:,:]
    dum_dmax = x1.shape[0] # depth
    dum_imax = x1.shape[1]  # lon
    dum_jmax = x1.shape[2]  # lat
except:
    try:
        x0 = Dataset(dir_prior+'/'+dir_prior_full[0]+'/'+ prior_nc_file_list[0]).variables[prior_variable_dict[0]][0,:,:]
        dum_imax = 36 #x1.shape[0]  # lon
        dum_jmax = 36 #x1.shape[1]  # lat
        dum_dmax = 16
    except:
        dum_dmax = 16
        dum_imax = 36
        dum_jmax = 36
# prepare 2d Xb for lon-lat state 
dum_ijmax = dum_imax*dum_jmax  # lonn * latn
if log_level > 3:
    print('>>  Shape of dum_dmax {}, dum_imax {}, dum_jmax {}, dum_ijmax {}'.format(dum_dmax,dum_imax,dum_jmax,dum_ijmax))
# save units of each variable
prior_variable_units = list()
prior_variable_units_init = 0

# nan matrix for storing 2d and 3d variables
if prior_variable2d_len>0:
    Xb_shape = (prior_variable2d_len*dum_jmax*dum_imax, prior_len)  # lonn * latn * varn
    Xb   = np.full(Xb_shape,np.nan)
# prep 3d version of Xb
if prior_variable3d_len > 0:
    Xb3d_shape = (prior_variable3d_len*dum_dmax*dum_jmax*dum_imax, prior_len)  # lonn * latn * varn
    Xb3d = np.full(Xb3d_shape,np.nan)
    # read prior and save Xb
    #Xb = np.full((dum_ijmax, prior_len),np.nan)
if log_level > 2:
    print('>>  Reading prior ...')
if data_psm_mgca_find == 1:
    if log_level > 3:
        print('>>  Prepare Mg/Ca related state variable ...')
    # for Mg/Ca SST proxy salinity, ph, omega
    Xb_sal       = np.full(Xb_shape,np.nan)
    Xb_ph        = np.full(Xb_shape,np.nan)
    Xb_omega     = np.full(Xb_shape,np.nan)
    spp = 'all'
    # ``1`` for reductive, ``0`` for BCP (Barker).
    cleaningr = np.tile(np.array([1]),prior_len).reshape((prior_len,1))
    cleaningb = np.tile(np.array([0]),prior_len).reshape((prior_len,1))
# read units of each variable from prior and save as prior_variable_units
if prior_variable2d_len > 0:
    for j in range(prior_variable2d_len):
        name_nc_2d = dir_prior+'/'+dir_prior_full[0]+'/'+ prior_nc_file_list[j]
        nc_field = prior_variable_dict[j]
        try:
            unit_j = Dataset(name_nc_2d).variables[nc_field].units
        except:
            unit_j ='unit'
        prior_variable_units.append((unit_j))
if prior_variable3d_len > 0:
    for j in range(prior_variable3d_len):
        name_nc_3d = dir_prior+'/'+dir_prior_full[0]+'/'+ prior_nc_file_list_3d[j]
        nc_field = prior_variable_dict_3d[j]
        try:
            try:
                unit_j = Dataset(name_nc_3d).variables[nc_field].units
            except:
                unit_j ='unit'
            prior_variable_units.append((unit_j))
        except:
            prior_variable_units.append((''))
    
# loop for each member of an ensemble
for i in range(prior_len):
    # loop for each variable of each member
    if data_psm_mgca_find == 1:
        water_saturation = yml_dict['psm']['bayesreg_mgca_pooled_red']['water_saturation']
        water_saturation_field = yml_dict['psm']['bayesreg_mgca_pooled_red']['water_saturation_field']
        psm_required_nc = yml_dict['psm']['bayesreg_mgca_pooled_red']['psm_required_nc']
        psm_required_nc_mg = yml_dict['psm']['bayesreg_mgca_pooled_red']['psm_required_nc_mg']
        name_nc_2d = dir_prior+'/'+dir_prior_full[i]+psm_required_nc
        name_nc_2d_mgca = dir_prior+'/'+dir_prior_full[i]+psm_required_nc_mg
        
        x00 = Dataset(name_nc_2d).variables['ocn_sur_sal'][t,:,:] # time-lat-lon
        x01 = Dataset(name_nc_2d).variables['misc_pH'][t,:,:] # time-lat-lon | core top pH
        
        if water_saturation in ['surface']:
            x02 = Dataset(name_nc_2d_mgca).variables[water_saturation_field][t,:,:] # time-lat-lon  | surface water ohmega calcite
        if water_saturation in ['bottom']:
            x02 = Dataset(name_nc_2d_mgca).variables[water_saturation_field][t,:,:]
            
    if prior_variable2d_len > 0:
        for j in range(prior_variable2d_len):
            # full directory of netcdf file
            name_nc_2d = dir_prior+'/'+dir_prior_full[i]+'/'+ prior_nc_file_list[j]
            j0 = dum_ijmax * j
            j1 = dum_ijmax * (j+1)
            nc_field = prior_variable_dict[j]
            x = Dataset(name_nc_2d).variables[nc_field][t,:,:]  # time-lat-lon
            # pCO2 from 1 to ppm
            if nc_field in ['atm_pCO2']:
                x = x * 1.0e+06
            Xb[j0:j1,i] = np.copy(x.reshape(dum_ijmax))  # var-lat-lon: Nx x 1
            
            if data_psm_mgca_find == 1:
                try:
                    Xb_sal[j0:j1,i] = np.copy(x00.reshape(dum_ijmax)) # var-lat-lon: Nx x 1  | surface water salinity
                    Xb_ph[j0:j1,i] = np.copy(x01.reshape(dum_ijmax)) # var-lat-lon: Nx x 1
                    Xb_omega[j0:j1,i] = np.copy(x02.reshape(dum_ijmax)) # var-lat-lon: Nx x 1
                except:
                    if i == 0:
                        # warning one time
                        if log_level > 1:
                            print('>>  Warning: reading state variable error. ocn_sur_sal, misc_pH, carb_ohm_cal')
            # print the last one data
            if log_level > 2:
                if i > prior_len-2:
                    print('    Last member: {}: {}: {}'.format(i, dir_prior_full[i], prior_variable_dict[j]))
        Xb = np.ma.MaskedArray(Xb, Xb >= 9.9692e+36)
    # if 3d variables are used
    if prior_variable3d_len > 0:
        for k in range(prior_variable3d_len):
            name_nc_3d = dir_prior+'/'+dir_prior_full[i]+'/'+ prior_nc_file_list_3d[k]
            nc_field = prior_variable_dict_3d[k]
            k0 = dum_ijmax*dum_dmax * k
            k1 = dum_ijmax*dum_dmax * (k+1)
            x = Dataset(name_nc_3d).variables[nc_field][t,:,:,:]  # time-depth-lat-lon
            Xb3d[k0:k1,i] = np.copy(x.reshape(dum_dmax*dum_ijmax)) # var-depth-lat-lon
        Xb3d = np.ma.MaskedArray(Xb3d, Xb3d >= 9.9692e+36)

if log_level > 1:
    print('>>  Units of state variables {}: {}'.format(prior_variable_dict+prior_variable_dict_3d,prior_variable_units))

Xb_prior = np.copy(Xb)
if prior_variable3d_len > 0:
    Xb3d_prior = np.copy(Xb3d)

if data_psm_mgca_find == 1:
    if log_level > 3:
        print('>>  Prepare Mg/Ca related state variable ...')
    # for Mg/Ca SST proxy salinity, ph, omega
    Xb_sal_prior       = np.copy(Xb_sal)
    Xb_ph_prior        = np.copy(Xb_ph)
    Xb_omega_prior     = np.copy(Xb_omega)
    # ``1`` for reductive, ``0`` for BCP (Barker).
    cleaningr_prior = np.copy(cleaningr)
    cleaningb_prior = np.copy(cleaningb)
if log_level > 1:
    print('>>  Reading Prior => Okay')
    print('')
    print(' ########## Read proxies database ########## ')
    print('')
### read proxies database ###
proxies = pandas.read_csv(dir_proxy_data)
proxies_len0 = len(proxies)
if log_level > 3:
    print('>>  All proxy: '.format(proxies))
proxy_select_0 = 0
### check proxy data in the blacklist or not ###
for j in range(proxies_len0):
    # Read proxy type from the database
    data_psm_type = proxies['Proxy'][j]
    # initial default 0 : this proxy is not included
    data_assimilate_i = 0
    for jlist in range(len(proxy_list)):
        if data_psm_type in proxy_assim2[proxy_list[jlist]]:
            # find and save this proxy
            data_assimilate_i = 1
    if data_assimilate_i == 1:
        if log_level > 3:
            print('>>    File {}, {} included'.format(proxies.loc[j,'File'], data_psm_type))
        if proxy_select_0 == 0:
            proxy_select0 = proxies.iloc[[j]]
            proxy_select0 = proxy_select0.reset_index(drop=True) # reset_index, avoid index error
            proxy_select_0 = 1
        else:
            proxy_select0 = proxy_select0.append(proxies.iloc[[j]], ignore_index=True)
proxies_select_len0 = len(proxy_select0)
if log_level > 1:
    print('>>  Proxy: selected proxy dataset number {}: those in blacklist removed!'.format(proxies_select_len0))

### check glassy only data or not
proxy_select_0 = 0
if proxy_d18o_glassy:
    for jj in range(proxies_select_len0):
        data_glassy_label = proxy_select0['Glassy'][jj]
        if data_glassy_label not in data_glassy_label_blacklist:
            if proxy_select_0 == 0:
                proxy_select = proxy_select0.iloc[[jj]]
                proxy_select = proxy_select.reset_index(drop=True) # reset_index, avoid index error
                proxy_select_0 = 1
            else:
                proxy_select = proxy_select.append(proxy_select0.iloc[[jj]], ignore_index=True)

    if log_level > 3:
        print(proxy_select)
    proxies_select_len0 = len(proxy_select)
    if log_level > 1:
        print('>>  Proxy: selected proxy dataset number {}: those unknown/frosty removed!'.format(proxies_select_len0))
else:
    proxy_select = proxy_select0.copy()
    
#######     ########     #######     ########     #######     ########     #######     ######## 
#######                  OKAY, setting read, now let's DA                              ######## 
#######     ########     #######     ########     #######     ########     #######     ######## 

for locRadi in range(locRadn):
    locRad = local_rad_list[locRadi]
    if log_level > 2:
        print('')
        print('>>  Localization id {} radius distance {} km'.format(locRadi, locRad))
    if locRad is None:
        locRadv = 0 # for filename only
    else:
        locRadv = locRad
    for proxy_fraci in range(proxy_fracn):
        proxy_frac = proxy_frac_list[proxy_fraci]
        
        for Rscalei in range(Rscalen):
            Rscale = Rscale_list[Rscalei]
            #######     ########     #######     ########     #######     ########     #######     ########     
            if log_level > 1:
                print('')
                print('>>  Starting Monte Carlo ... ')
            #######     ########     #######     ########     #######     ########     #######     ########
            for MCi in range(MCn):
                # copy back:
                Xb = np.copy(Xb_prior)
                if prior_variable3d_len > 0:
                    Xb3d = np.copy(Xb3d_prior)
                if data_psm_mgca_find == 1:
                    Xb_sal       = np.copy(Xb_sal_prior)
                    Xb_ph        = np.copy(Xb_ph_prior)
                    Xb_omega     = np.copy(Xb_omega_prior)
                    # ``1`` for reductive, ``0`` for BCP (Barker).
                    cleaningr = np.copy(cleaningr_prior)
                    cleaningb = np.copy(cleaningb_prior)
                    # maybe used in Mg/Ca PSM
                    Xb_sal1 = np.copy(Xb_sal)
                    Xb_sal1[Xb_sal1> 3.0e+36] = np.nan
                    Xb_sal_mean = np.nanmean(Xb_sal1)
                    Xb_ph1 = np.copy(Xb_ph)
                    Xb_ph1[Xb_ph1> 3.0e+36] = np.nan
                    Xb_ph_mean = np.nanmean(Xb_ph1)
                    Xb_omega1 = np.copy(Xb_omega)
                    Xb_omega1[Xb_omega1> 3.0e+36] = np.nan
                    Xb_omega_mean = np.nanmean(Xb_omega1)
                    if log_level > 3:
                        print('')
                        print('>>    mean of Xb_sal {}, Xb_ph {}, Xb_omega {}'.format(Xb_sal_mean, Xb_ph_mean, Xb_omega_mean))
                
                ### Select a fraction of proxy sites ###
                if proxy_frac <= 1.0:
                    if log_level > 2:
                        print('')
                        print('>>  Proxy fraction is {}'.format(proxy_frac))
                    ## Seed
                    
                    
                    = multi_seed[MCi]
                    if log_level > 2:
                        print('Setting current prior iteration seed: {}'.format(curr_seed))
                    random.seed(curr_seed)
                    sites_assim, sites_eval = DeepDA_psm.proxy_frac_4da_eval(proxy_select,proxy_frac,log_level)
                else:
                    sites_assim = proxy_select.copy()
                    sites_eval = []
                ###
                if log_level > 3:
                    print('>>  Selected proxy sties: ')
                    print(sites_assim)
                    print('>>  Un-selected proxy sties: ')
                    print(sites_eval)                
                
                ### sort proxy data using the given order ###
                proxies_frac_len = len(sites_assim)
                proxy_select_1 = 0
                if proxy_order_type in ['use_list']:
                    for i in range(len(proxy_order)):
                        proxy_order_i = proxy_assim2[proxy_order[i]]
                        for j in range(proxies_frac_len):
                            # Read proxy type from the database
                            data_psm_type = sites_assim['Proxy'][j]
                            if data_psm_type in proxy_order_i:
                                if proxy_select_1 == 0:
                                    # first element
                                    #proxy_select_sort = proxy_select.iloc[[j]]
                                    proxy_select_sort = sites_assim.iloc[[j]]
                                    proxy_select_1 = 1  # proxy included
                                else:
                                    # rest elements
                                    proxy_select_sort = proxy_select_sort.append(sites_assim.iloc[[j]], ignore_index=True)
                    if log_level > 2:
                        print('>>  Proxy order: use user-defined list.')
                else:
                    proxy_select_sort = sites_assim.sample(frac=1).reset_index(drop=True);
                    if log_level > 2:
                        print('>>  Proxy order: use random list.')
                
                ### update proxies using sorted proxy order ###
                proxies =   proxy_select_sort.copy()
                proxies_len = len(proxies)

                if proxies_len0 > proxies_len:
                    if log_level > 1:
                        print('>>  Selected proxy data length {}'.format(proxies_len))

                ######## Ye   ########
                # for saving proxy unit data Ye
                Ye       = np.full((proxies_len,prior_len),np.nan)
                obvalue  = np.full((proxies_len,recon_period_len),np.nan)
                ob_err   = np.full((proxies_len,recon_period_len),np.nan) # data obs error
                ob_err0  = np.full((proxies_len,recon_period_len),np.nan) # PSM obs error
                ob_err_comb  = np.full((proxies_len,recon_period_len),np.nan) # PSM obs error
                yo_all = np.full((proxies_len,2),np.nan) # PSM obs error
                if log_level > 2:
                    print('>>  OKAY.')
                    print('')
                # check the consistency of the config.yml file and proxy database
                # AND get obs R
                if log_level > 2:
                    print('########## Check the consistency of the config.yml file and proxy database ##########')
                    print('')
                
                proxy_psm_type_dict = {}
                for j in range(proxies_len):
                    # Read proxy type from the database
                    data_psm_type = proxies['Proxy'][j]
                    # Read allowed proxy from the DTDA-config.yml
                    data_psm_type_find = 0
                    for key, value in proxy_assim2.items():
                        if log_level > 5:
                            print(key,value)
                        # check this proxy type exist or not, and how many times it occurrs
                        if data_psm_type in proxy_assim2[key]:
                            data_psm_type_find = data_psm_type_find + 1
                    if data_psm_type_find == 1:
                        for key, value in proxy_psm_type.items():
                            if data_psm_type in proxy_assim2[key]:
                                data_psm_key = key
                        proxy_psm_type_i = proxy_psm_type[data_psm_key]

                        proxy_psm_type_dict[j] =proxy_psm_type_i

                        if log_level > 3:
                            print('>>  {}. PSM for {} is {}'.format(j, data_psm_type,proxy_psm_type_i))

                    elif data_psm_type_find == 0:
                        if log_level > 3:
                            print('>>  Warning, {} in database is not find in config.yml dictionary'.format(data_psm_type))
                    else:
                        if log_level > 3:
                            print('>>  Warning, {} in database appears more than one time in config.yml dictionary'.format(data_psm_type))

                    # Now PSM type has been found. Let's precal Ye

                    if proxy_psm_type_i in ['bayesreg_mgca_pooled_red','bayesreg_mgca_pooled_bcp','deepmip_mgca']:
                        data_psm_mgca_find = 1

                if log_level > 3:
                    print('>>  Proxy_psm_type_dict: ')
                    print(proxy_psm_type_dict)
                if log_level > 2:
                    print('')
                    print('>>  All looks good.')
                    print('')

                    ##### Ye calculation ####

                    print('##########  Ye calculation  ##########')
                    print('')
                # precal_Ye
                proi = 0
                for j in range(proxies_len):
                    # Read proxy type from the database
                    data_psm_type = proxies['Proxy'][j]
                    proxy_psm_type_i = proxy_psm_type_dict[j]
                    psm_required_variable_key = list(yml_dict['psm'][proxy_psm_type_i]['psm_required_variables'].keys())[0]
                    if log_level > 3:
                        print(psm_required_variable_key)
                    # ID-ID match: proxy type matches with the prior type. This allows assimilate multiple proxy types for multiple state variables
                    if psm_required_variable_key in prior_variable_dict:
                        psm_required_variable_key_index = prior_variable_dict.index(psm_required_variable_key)
                    elif psm_required_variable_key in prior_variable_dict_3d:
                        psm_required_variable_key_index = prior_variable_dict_3d.index(psm_required_variable_key)

                ######################## FOR 2D field ONLY TO DO: adjusted to include 3d proxies ##############
                
                    # read lon lat for each line of proxy
                    dum_lat = proxies[lat_label][j]  # (paleo)latitude of this site
                    dum_lon = proxies[lon_label][j]  # (paleo)longitude of this site
                    yo_all[proi,:] = np.array([dum_lon, dum_lat])  # save location of this site

                    lonlat = modules_nc.cal_find_ij(dum_lon,dum_lat,dum_lon_offset,dum_imax,dum_jmax)
                    # output [lon, lat], 
                    # lon ranges from 0 (-180) to 35 (180), lat ranges from 0 (-90) to 35 (90)

                    Filei = proxies['File'][j]
                    # find 1d grid location
                    lonlati = lonlat[1] * dum_jmax + lonlat[0] + psm_required_variable_key_index * dum_ijmax
                    if log_level > 3:
                        print('>>  lonlat id is {}'.format(lonlati))
                    # read prior
                    prior_1grid = np.copy(Xb[lonlati,:])   # prior
                    #print(prior_1grid.shape)
                    if log_level > 3:
                        print(prior_1grid)
                    
                ######################## FOR 2D field ONLY. TO DO: adjusted to include 3d proxies ##############
                    if log_level > 2:
                        print('')
                        print('>>  {}. File: {}, grid [lon lat] {}, index {}, PSM for {} is {}'.format(j,Filei,lonlat,lonlati,data_psm_type,proxy_psm_type_i))
                    if log_level > 3:
                        if psm_required_variable_key in prior_variable_dict:                     
                            print('>>    Key Found: {} in prior_variable_dict 2d list, index = {}'.format(psm_required_variable_key, psm_required_variable_key_index))
                        elif psm_required_variable_key in prior_variable_dict_3d:
                            print('>>    Key Found: {} in prior_variable_dict_3d list, index = {}'.format(psm_required_variable_key, psm_required_variable_key_index))
                    if log_level > 3:
                        print('>>      Mean of Prior is {:.6f}, variance is {:.6f}'.format(np.mean(prior_1grid), np.var(prior_1grid)))

                    # Now PSM type has been found. Let's precal Ye

                    if proxy_psm_type_i in ['bayesreg_d18o_pooled']:
                        if d18osw_local_choice in ['zachos94']:
                            # d18o_localsw using method by Zachos et al., 1994 PALEOCEANOGRAPHY
                            d18o_localsw = DeepDA_psm.d18o_localsw(abs(dum_lat))
                        else:
                            if d18osw_icesm_pco2 == 1.0:
                                proxy_col_d18osw = 'd18osw_1x'
                            elif d18osw_icesm_pco2 == 3.0:
                                proxy_col_d18osw = 'd18osw_3x'
                            elif d18osw_icesm_pco2 == 6.0:
                                proxy_col_d18osw = 'd18osw_6x'
                            elif d18osw_icesm_pco2 == 9.0:
                                proxy_col_d18osw = 'd18osw_9x'
                            else:
                                proxy_col_d18osw = 'd18osw_3x'
                            d18o_localsw = proxies[proxy_col_d18osw][j]
                            
                        # total d18osw = d18o_localsw + d18o_adj + psm_d18osw_adjust
                        # d18o_adj has been included in the bayfox model
                        #print('>>  Prior is {}'.format(prior_1grid))
                        if d18osw_local_choice in ['zachos94']:
                            prediction_d18O = bayfox.predict_d18oc(prior_1grid,d18o_localsw + psm_d18osw_adjust) # pool model for bayfox
                            if log_level > 3:
                                print('>>        Sea water d18O is {:.6f}, d18osw_adjust is {:.6f} '.format(d18o_localsw, psm_d18osw_adjust))
                        else:
                            prediction_d18O = bayfox.predict_d18oc(prior_1grid,d18o_localsw) # pool model for bayfox
                            if log_level > 3:
                                print('>>        Sea water d18O is {:.6f}'.format(d18o_localsw))
                        if log_level > 4:
                            print('>>  prediction_d18O.ensemble shape {}'.format(prediction_d18O.ensemble.shape))
                        
                        Ye[proi,:] = np.mean(prediction_d18O.ensemble, axis = 1)
                        if log_level > 4:
                            print('>>  Ye is {}'.format(Ye[proi,:]))
                        if log_level > 3:
                            print('>>      Mean of  Ye  is {:.6f}, variance is {:.6f} '.format(np.mean(Ye[proi,:]), np.var(Ye[proi,:],ddof=1)))
                        for reconi in range(recon_period_len):
                            reconid = recon_period_full[reconi]
                            obvalue[proi,reconi] = proxies[data_period_id[reconid]][j]
                            ob_err[proi,reconi] = proxies[data_period_idstd[reconid]][j] ** 2
                            if proxy_err_eval in ['proxy_err_psm']:
                                if d18osw_local_choice in ['zachos94']:
                                    ob_err0[proi,reconi]= DeepDA_psm.obs_estimate_r_d18o(obvalue[proi,reconi], d18o_localsw+psm_d18osw_adjust) * Rscale
                                else:
                                    ob_err0[proi,reconi]= DeepDA_psm.obs_estimate_r_d18o(obvalue[proi,reconi], d18o_localsw) * Rscale
                            else:
                                ob_err0[proi,reconi]= DeepDA_psm.obs_estimate_r_fixed_d18o(15) * Rscale
                            ob_err_comb[proi,reconi] = np.nansum([ob_err[proi,reconi], ob_err0[proi,reconi]])
                            if ob_err_comb[proi,reconi] == 0: ob_err_comb[proi,reconi] = np.nan
                            if log_level > 3:
                                print('>>   {}. Proxy variance from PSM is {:.6f} vs. from PSM + time variance is {:.6f} '.format(reconi,ob_err0[proi,reconi], ob_err_comb[proi,reconi]))

                            # Quality control
                            if log_level > 1:
                                if proxy_qc is not None:
                                    print('>>   Quality Control (QC) ...')
                            if proxy_err_eval in ['proxy_err_psm']:
                                qc_i = DeepDA_psm.obs_qc(Ye[proi,:], obvalue[proi,reconi], ob_err_comb[proi,reconi], proxy_qc)
                            else:
                                qc_i = DeepDA_psm.obs_qc(Ye[proi,:], obvalue[proi,reconi], ob_err0[proi,reconi], proxy_qc)
                            #print(qc_i)
                            if qc_i:
                                if proxy_qc is not None:
                                    if log_level > 1:
                                        print('    Pass QC. ye {}, obs {}, obs_var {}, qc {}'.format(np.mean(Ye[proi,:]), obvalue[proi,reconi], ob_err_comb[proi,reconi], proxy_qc))
                            else:
                                ob_err_comb[proi,reconi] = np.nan
                                if proxy_qc is not None:   
                                    if log_level > 1:
                                        print('    Failed QC. ye {}, obs {}, obs_var {}, qc {}'.format(np.mean(Ye[proi,:]), obvalue[proi,reconi], ob_err_comb[proi,reconi], proxy_qc))
                        proi = proi + 1  # increasement
                        
                        
                    elif proxy_psm_type_i in ['deepmip_d18o']:
                        if d18osw_local_choice in ['zachos94']:
                            # d18o_localsw using method by Zachos et al., 1994 PALEOCEANOGRAPHY
                            d18o_localsw = DeepDA_psm.d18o_localsw(abs(dum_lat))
                        else:
                            if d18osw_icesm_pco2 == 1.0:
                                proxy_col_d18osw = 'd18osw_1x'
                            elif d18osw_icesm_pco2 == 3.0:
                                proxy_col_d18osw = 'd18osw_3x'
                            elif d18osw_icesm_pco2 == 6.0:
                                proxy_col_d18osw = 'd18osw_6x'
                            elif d18osw_icesm_pco2 == 9.0:
                                proxy_col_d18osw = 'd18osw_9x'
                            else:
                                proxy_col_d18osw = 'd18osw_3x'
                            d18o_localsw = proxies[proxy_col_d18osw][j]
                            
                        # total d18osw = d18o_localsw + d18o_adj + psm_d18osw_adjust
                        # d18o_adj has been included in the bayfox model
                        #print('>>  Prior is {}'.format(prior_1grid))
                        
                        if d18osw_local_choice in ['zachos94']:
                            Ye[proi,:] = DeepDA_psm.d18oc_linear_forward(prior_1grid,d18o_localsw + psm_d18osw_adjust)
                            #prediction_d18O = bayfox.predict_d18oc(prior_1grid,d18o_localsw + psm_d18osw_adjust) # pool model for bayfox
                            if log_level > 4:
                                print('>>        Sea water d18O is {:.6f}, d18osw_adjust is {:.6f} '.format(d18o_localsw, psm_d18osw_adjust))
                        else:
                            Ye[proi,:] = DeepDA_psm.d18oc_linear_forward(prior_1grid,d18o_localsw)
                            #prediction_d18O = bayfox.predict_d18oc(prior_1grid,d18o_localsw) # pool model for bayfox
                            if log_level > 4:
                                print('>>        Sea water d18O is {:.6f}'.format(d18o_localsw))
                        if log_level > 4:
                            print('>>  prediction_d18O.ensemble shape {}'.format(prediction_d18O.ensemble.shape))
                        
                        if log_level > 4:
                            print('>>  Ye is {}'.format(Ye[proi,:]))
                        if log_level > 3:
                            print('>>      Mean of  Ye  is {:.6f}, variance is {:.6f} '.format(np.mean(Ye[proi,:]), np.var(Ye[proi,:],ddof=1)))
                        for reconi in range(recon_period_len):
                            reconid = recon_period_full[reconi]
                            obvalue[proi,reconi] = proxies[data_period_id[reconid]][j]
                            ob_err[proi,reconi] = proxies[data_period_idstd[reconid]][j] ** 2
                            if proxy_err_eval in ['proxy_err_psm']:
                                if d18osw_local_choice in ['zachos94']:
                                    ob_err0[proi,reconi]= DeepDA_psm.obs_estimate_r_d18o(obvalue[proi,reconi], d18o_localsw+psm_d18osw_adjust) * Rscale
                                else:
                                    ob_err0[proi,reconi]= DeepDA_psm.obs_estimate_r_d18o(obvalue[proi,reconi], d18o_localsw) * Rscale
                                    
                            elif proxy_err_eval in ['proxy_err_psm_fixed']:
                                ob_err0[proi,reconi]= DeepDA_psm.obs_estimate_r_fixed_d18o(15) * Rscale
                            else:
                                ob_err0[proi,reconi] = yml_dict['psm'][proxy_psm_type_i]['psm_error'] * Rscale
                                
                            ob_err_comb[proi,reconi] = np.nansum([ob_err[proi,reconi], ob_err0[proi,reconi]])
                            if ob_err_comb[proi,reconi] == 0: ob_err_comb[proi,reconi] = np.nan
                            if log_level > 3:
                                print('>>   {}. Proxy variance from PSM is {:.6f} vs. from PSM + time variance is {:.6f} '.format(reconi,ob_err0[proi,reconi], ob_err_comb[proi,reconi]))

                            # Quality control
                            if log_level > 1:
                                if proxy_qc is not None:
                                    print('>>   Quality Control (QC) ...')
                            if proxy_err_eval in ['proxy_err_psm']:
                                qc_i = DeepDA_psm.obs_qc(Ye[proi,:], obvalue[proi,reconi], ob_err_comb[proi,reconi], proxy_qc)
                            elif proxy_err_eval in ['proxy_err_psm_fixed']:
                                qc_i = DeepDA_psm.obs_qc(Ye[proi,:], obvalue[proi,reconi], ob_err0[proi,reconi], proxy_qc)
                            else:
                                qc_i = DeepDA_psm.obs_qc(Ye[proi,:], obvalue[proi,reconi], ob_err0[proi,reconi], proxy_qc)
                            #print(qc_i)
                            if qc_i:
                                if proxy_qc is not None:
                                    if log_level > 1:
                                        print('    Pass QC. ye {}, obs {}, obs_var {}, qc {}'.format(np.mean(Ye[proi,:]), obvalue[proi,reconi], ob_err_comb[proi,reconi], proxy_qc))
                            else:
                                ob_err_comb[proi,reconi] = np.nan
                                if proxy_qc is not None:      
                                    if log_level > 1:
                                        print('    Failed QC. ye {}, obs {}, obs_var {}, qc {}'.format(np.mean(Ye[proi,:]), obvalue[proi,reconi], ob_err_comb[proi,reconi], proxy_qc))
                        proi = proi + 1  # increasement
                        
                    elif proxy_psm_type_i in ['cgenie_caco3', 'cgenie_caco3_13c']:
                        Ye[proi,:] = np.mean(prior_1grid)
                        for reconi in range(recon_period_len):
                            reconid = recon_period_full[reconi]
                            obvalue[proi,reconi] = proxies[data_period_id[reconid]][j]
                            ob_err[proi,reconi] = proxies[data_period_idstd[reconid]][j] ** 2
                            ob_err0[proi,reconi] = yml_dict['psm'][proxy_psm_type_i]['psm_error'] * Rscale
                            ob_err_comb[proi,reconi] = np.nansum([ob_err[proi,reconi], ob_err0[proi,reconi]])
                            # Quality control
                            if proxy_err_eval in ['proxy_err_psm']:
                                qc_i = DeepDA_psm.obs_qc(Ye[proi,:], obvalue[proi,reconi], ob_err_comb[proi,reconi], proxy_qc)
                            else:
                                qc_i = DeepDA_psm.obs_qc(Ye[proi,:], obvalue[proi,reconi], ob_err0[proi,reconi], proxy_qc)
                            if qc_i:
                                if proxy_qc is not None:
                                    if log_level > 1:
                                        print('    Pass QC. ye {}, obs {}, obs_var {}, qc {}'.format(np.mean(Ye[proi,:]), obvalue[proi,reconi], ob_err_comb[proi,reconi], proxy_qc))
                            else:
                                ob_err_comb[proi,reconi] = np.nan
                                if proxy_qc is not None: 
                                    if log_level > 1:
                                        print('    Failed QC. ye {}, obs {}, obs_var {}, qc {}'.format(np.mean(Ye[proi,:]), obvalue[proi,reconi], ob_err_comb[proi,reconi], proxy_qc))
                        proi = proi + 1  # increasement
                        
                    elif proxy_psm_type_i in ['bayesreg_tex86']:
                        # bayspar
                        #try:
                        prediction = bayspar.predict_tex_analog(prior_1grid, temptype = 'sst', search_tol = search_tol_i, nens=nens_i)
                        Ye[proi,:] = np.mean(prediction.ensemble, axis = 1)
                        if log_level > 3:
                            print('>>      Mean of  Ye   is {:.6f}, variance is {:.6f} '.format(np.mean(Ye[proi,:]), np.var(Ye[proi,:],ddof=1)))
                        for reconi in range(recon_period_len):
                            reconid = recon_period_full[reconi]
                            obvalue[proi,reconi] = proxies[data_period_id[reconid]][j]
                            ob_err[proi,reconi] = proxies[data_period_idstd[reconid]][j] ** 2
                            if proxy_err_eval in ['proxy_err_psm']:
                                ob_err0[proi,reconi]= DeepDA_psm.obs_estimate_r_tex86(np.array([31]), 'sst', 15)  * Rscale
                            else:
                                ob_err0[proi,reconi]= DeepDA_psm.obs_estimate_r_fixed_tex86(31)  * Rscale
                            #obvalue[proi,] = proxies['Lat'][j]
                            ob_err_comb[proi,reconi] = np.nansum([ob_err[proi,reconi], ob_err0[proi,reconi]])
                            if ob_err_comb[proi,reconi] == 0: ob_err_comb[proi,reconi] = np.nan
                            if log_level > 3:
                                print('>>   {}. Proxy variance from PSM is {:.6f}, from PSM and selected interval is {:.6f} '.format(reconi,ob_err0[proi,reconi], ob_err_comb[proi,reconi]))
                            # Quality control
                            if proxy_err_eval in ['proxy_err_psm']:
                                qc_i = DeepDA_psm.obs_qc(Ye[proi,:], obvalue[proi,reconi], ob_err_comb[proi,reconi], proxy_qc)
                            else:
                                qc_i = DeepDA_psm.obs_qc(Ye[proi,:], obvalue[proi,reconi], ob_err0[proi,reconi], proxy_qc)
                            if qc_i:
                                if proxy_qc is not None:
                                    if log_level > 1:
                                        print('    Pass QC. ye {}, obs {}, obs_var {}, qc {}'.format(np.mean(Ye[proi,:]), obvalue[proi,reconi], ob_err_comb[proi,reconi], proxy_qc))
                            else:
                                ob_err_comb[proi,reconi] = np.nan
                                if proxy_qc is not None: 
                                    if log_level > 1:
                                        print('    Failed QC. ye {}, obs {}, obs_var {}, qc {}'.format(np.mean(Ye[proi,:]), obvalue[proi,reconi], ob_err_comb[proi,reconi], proxy_qc))
                        proi = proi + 1  # increasement
                        #except:
                        #    if log_level > 2:
                        #        print('>>  Warning {}'.format(proxy_psm_type_i))
                        #        print('>>  search_tol too small for {}: mean sst is {}'.format(j, np.mean(prior_1grid)))
                        
                    elif proxy_psm_type_i in ['tex86h_forward']:
                        
                        Ye[proi,:] = DeepDA_psm.tex86h_forward(prior_1grid)
                        if log_level > 3:
                            print('>>      Mean of  Ye   is {:.6f}, variance is {:.6f} '.format(np.mean(Ye[proi,:]), np.var(Ye[proi,:],ddof=1)))
                        for reconi in range(recon_period_len):
                            reconid = recon_period_full[reconi]
                            obvalue[proi,reconi] = proxies[data_period_id[reconid]][j]
                            ob_err[proi,reconi] = proxies[data_period_idstd[reconid]][j] ** 2
                            if proxy_err_eval in ['proxy_err_psm']:
                                ob_err0[proi,reconi]= DeepDA_psm.obs_estimate_r_tex86(np.array([31]), 'sst', 15)  * Rscale
                            elif proxy_err_eval in ['proxy_err_psm_fixed']:
                                ob_err0[proi,reconi]= DeepDA_psm.obs_estimate_r_fixed_tex86(31)  * Rscale
                            else:
                                ob_err0[proi,reconi] = yml_dict['psm'][proxy_psm_type_i]['psm_error'] * Rscale
                            #obvalue[proi,] = proxies['Lat'][j]
                            ob_err_comb[proi,reconi] = np.nansum([ob_err[proi,reconi], ob_err0[proi,reconi]])
                            if ob_err_comb[proi,reconi] == 0: ob_err_comb[proi,reconi] = np.nan
                            if log_level > 3:
                                print('>>   {}. Proxy variance from PSM is {:.6f}, from PSM and selected interval is {:.6f} '.format(reconi,ob_err0[proi,reconi], ob_err_comb[proi,reconi]))
                            # Quality control
                            if proxy_err_eval in ['proxy_err_psm']:
                                qc_i = DeepDA_psm.obs_qc(Ye[proi,:], obvalue[proi,reconi], ob_err_comb[proi,reconi], proxy_qc)
                            else:
                                qc_i = DeepDA_psm.obs_qc(Ye[proi,:], obvalue[proi,reconi], ob_err0[proi,reconi], proxy_qc)
                            if qc_i:
                                if proxy_qc is not None:
                                    if log_level > 1:
                                        print('    Pass QC. ye {}, obs {}, obs_var {}, qc {}'.format(np.mean(Ye[proi,:]), obvalue[proi,reconi], ob_err_comb[proi,reconi], proxy_qc))
                            else:
                                ob_err_comb[proi,reconi] = np.nan
                                if proxy_qc is not None:      
                                    if log_level > 1:
                                        print('    Failed QC. ye {}, obs {}, obs_var {}, qc {}'.format(np.mean(Ye[proi,:]), obvalue[proi,reconi], ob_err_comb[proi,reconi], proxy_qc))
                        proi = proi + 1  # increasement

                    #elif proxy_psm_type_i in ['bayesreg_uk37']:
                        # 
                        #print('... bayesreg_uk37: To be done ...')
                        
                    elif proxy_psm_type_i in ['deepmip_mgca']:
                        salinity =  np.copy(Xb_sal[lonlati,:])
                        ph       =  np.copy(Xb_ph[lonlati,:])
                        mgcasw = yml_dict['psm'][proxy_psm_type_i]['mgcasw']
                        mgcacorr = DeepDA_psm.mgca_evans18_forward(prior_1grid,ph,mgcasw)
                        Ye[proi,:] = DeepDA_psm.mgca_sal_corr_forward(mgcacorr,salinity)
                        
                        for reconi in range(recon_period_len):
                            reconid = recon_period_full[reconi]
                            obvalue[proi,reconi] = proxies[data_period_id[reconid]][j]
                            ob_err[proi,reconi]  = proxies[data_period_idstd[reconid]][j] ** 2
                            if proxy_err_eval in ['proxy_err_psm', 'proxy_err_psm_fixed']:
                                clearning_one = cleaningb  # use barker cleaning model
                                if log_level > 3:
                                    print('>>    mean of Xb_sal {}, Xb_ph {}, Xb_omega {}, cleaning {}'.format(Xb_sal_mean, Xb_ph_mean, Xb_omega_mean, clearning_one[0]))
                            if proxy_err_eval in ['proxy_err_psm']:
                                ob_err0[proi,reconi] = DeepDA_psm.obs_estimate_r_mgca_pooled(obvalue[proi,reconi], clearning_one[0], np.mean(salinity), np.mean(ph), np.mean(omega), spp, geologic_age) * Rscale
                            elif proxy_err_eval in ['proxy_err_psm_fixed']:
                                ob_err0[proi,reconi] = DeepDA_psm.obs_estimate_r_fixed_mgca_pooled((15, 16), clearning_one[0], Xb_sal_mean, Xb_ph_mean, Xb_omega_mean, spp, geologic_age) * Rscale
                            else:
                                ob_err0[proi,reconi] = yml_dict['psm'][proxy_psm_type_i]['psm_error'] * Rscale
                                
                            ob_err_comb[proi,reconi] = np.nansum([ob_err[proi,reconi], ob_err0[proi,reconi]])
                            if ob_err_comb[proi,reconi] == 0: ob_err_comb[proi,reconi] = np.nan
                            if log_level > 3:
                                print('>>   {}. Proxy variance from PSM is {:.6f}, from PSM and selected interval is {:.6f} '.format(reconi,ob_err0[proi,reconi], ob_err_comb[proi,reconi]))
                            # Quality control
                            if proxy_err_eval in ['proxy_err_psm']:
                                qc_i = DeepDA_psm.obs_qc(Ye[proi,:], obvalue[proi,reconi], ob_err_comb[proi,reconi], proxy_qc)
                            else:
                                qc_i = DeepDA_psm.obs_qc(Ye[proi,:], obvalue[proi,reconi], ob_err0[proi,reconi], proxy_qc)
                            if qc_i:
                                if proxy_qc is not None:
                                    if log_level > 1:
                                        print('      Pass QC. ye {}, obs {}, obs_var {}, qc {}'.format(np.mean(Ye[proi,:]), obvalue[proi,reconi], ob_err_comb[proi,reconi], proxy_qc))
                            else:
                                ob_err_comb[proi,reconi] = np.nan
                                if proxy_qc is not None:  
                                    if log_level > 1:
                                        print('      Failed QC. ye {}, obs {}, obs_var {}, qc {}'.format(np.mean(Ye[proi,:]), obvalue[proi,reconi], ob_err_comb[proi,reconi], proxy_qc))
                        proi = proi + 1  # increasement

                    elif proxy_psm_type_i in ['bayesreg_mgca_pooled_red', 'bayesreg_mgca_pooled_bcp']:
                        if proxy_psm_type_i in ['bayesreg_mgca_pooled_red']:
                            clearning_one = cleaningr
                            proxy_explain = 'reductive'
                        elif proxy_psm_type_i in ['bayesreg_mgca_pooled_bcp']:
                            clearning_one = cleaningb
                            proxy_explain = 'barker'
                        #try:
                        # prior_1grid = np.copy(Xb[lonlati,:])   # prior
                        salinity =  np.copy(Xb_sal[lonlati,:])#.reshape((prior_len,1))
                        ph       =  np.copy(Xb_ph[lonlati,:])#.reshape((prior_len,1))
                        omega    =  np.copy(Xb_omega[lonlati,:])#.reshape((prior_len,1))
                        prior_1grid = prior_1grid
                        if log_level > 3:
                            print('>>    mean of Xb_sal {}, Xb_ph {}, Xb_omega {}, cleaning {}'.format(Xb_sal_mean, Xb_ph_mean, Xb_omega_mean, clearning_one[0]))
                            
                            if log_level > 4:
                                if MCi == 0:
                                    print('shape of prior_1grid {}, clearning_one {}, salinity {}, ph {}, omega {}'
                                          .format(prior_1grid.shape, clearning_one.shape, salinity.shape, ph.shape,omega.shape))
                        prediction_mgca = baymag.predict_mgca(prior_1grid, clearning_one[0], salinity, ph, omega, spp) # pool model for baymag reductive
                        #prediction_mgca = baymag.predict_mgca(prior_1grid, cleaningr, salinity, ph, omega, spp) # pool model for baymag reductive
                        pred_mgca_adj = baymag.sw_correction(prediction_mgca, np.array([geologic_age]))
                        if log_level > 4:
                            if MCi == 0: 
                                print(prediction_mgca.ensemble.shape)
                                print(pred_mgca_adj.ensemble.shape)
                        Ye[proi,:] = np.mean(pred_mgca_adj.ensemble, axis = 1)
                        if log_level > 3:
                            print('>>      Mean of  Ye   is {:.6f}, variance is {:.6f} '.format(np.mean(Ye[proi,:]), np.var(Ye[proi,:],ddof=1)))

                        for reconi in range(recon_period_len):
                            reconid = recon_period_full[reconi]
                            obvalue[proi,reconi] = proxies[data_period_id[reconid]][j]
                            ob_err[proi,reconi]  = proxies[data_period_idstd[reconid]][j] ** 2
                            #obs_estimate_r_mgca_pooled(obs, cleaning, salinity, ph, omega, spp, age):
                            if proxy_err_eval in ['proxy_err_psm']:
                                ob_err0[proi,reconi] = DeepDA_psm.obs_estimate_r_mgca_pooled(obvalue[proi,reconi], clearning_one[0], np.mean(salinity), np.mean(ph), np.mean(omega), spp, geologic_age) * Rscale
                            else:
                                #ob_err0[proi,reconi] = DeepDA_psm.obs_estimate_r_fixed_mgca_pooled((15, 16), clearning_one[0], np.mean(salinity), np.mean(ph), np.mean(omega), spp, geologic_age)
                                ob_err0[proi,reconi] = DeepDA_psm.obs_estimate_r_fixed_mgca_pooled((15, 16), clearning_one[0], Xb_sal_mean, Xb_ph_mean, Xb_omega_mean, spp, geologic_age) * Rscale
                            ob_err_comb[proi,reconi] = np.nansum([ob_err[proi,reconi], ob_err0[proi,reconi]])
                            if ob_err_comb[proi,reconi] == 0: ob_err_comb[proi,reconi] = np.nan
                            if log_level > 3:
                                print('>>   {}. Proxy variance from PSM is {:.6f}, from PSM and selected interval is {:.6f} '.format(reconi,ob_err0[proi,reconi], ob_err_comb[proi,reconi]))
                            # Quality control
                            if proxy_err_eval in ['proxy_err_psm']:
                                qc_i = DeepDA_psm.obs_qc(Ye[proi,:], obvalue[proi,reconi], ob_err_comb[proi,reconi], proxy_qc)
                            else:
                                qc_i = DeepDA_psm.obs_qc(Ye[proi,:], obvalue[proi,reconi], ob_err0[proi,reconi], proxy_qc)
                            if qc_i:
                                if proxy_qc is not None:
                                    if log_level > 1:
                                        print('      Pass QC. ye {}, obs {}, obs_var {}, qc {}'.format(np.mean(Ye[proi,:]), obvalue[proi,reconi], ob_err_comb[proi,reconi], proxy_qc))
                            else:
                                ob_err_comb[proi,reconi] = np.nan
                                if proxy_qc is not None:                    
                                    if log_level > 1:
                                        print('      Failed QC. ye {}, obs {}, obs_var {}, qc {}'.format(np.mean(Ye[proi,:]), obvalue[proi,reconi], ob_err_comb[proi,reconi], proxy_qc))
                        if log_level > 3:
                            print('        {}: mean salinity {}, ph {}, omega {}'.format(proxy_explain,np.mean(salinity), np.mean(ph), np.mean(omega)))
                        proi = proi + 1  # increasement

                    else:
                        a = 1

                
                Ye_mean_print = (np.mean(Ye,axis=1))[np.newaxis]
                if log_level > 1:
                    print('')
                    print('>>  Summary of this Monte Carlo simulation')
                    if log_level > 2:
                        print('')
                        print('>>  Ye mean')
                    
                        print('>>  {}'.format(Ye_mean_print.T))
                        print('>>  Observation value')
                        print('>>  {}'.format(obvalue))
                        print('>>  Observation error (ob_err0)')
                        print('>>  {}'.format(ob_err0))
                        print('>>  Observation error (ob_err_comb, from PSM and time-variance)')
                        print('>>  {}'.format( ob_err_comb))
                    print('[Ye, Ob_value]')
                Yeobvalue = np.concatenate((Ye_mean_print.T, obvalue), axis=1)
                if log_level > 1:
                    print('>>  {}'.format(Yeobvalue))
                if log_level > 0:
                    print('')
                    print('##########  Monte Carlo {} / {} => okay   ##########'.format(MCi+1, MCn))
                    print('')

                MC_dir = dir_proxy_save_dir + dir_proxy_save + nexp +'/'
                
                if not os.path.exists(MC_dir):
                    os.makedirs(MC_dir)
                # NetCDF file name
                filename_short = '_loc_', str(locRadv),'_proxy_frac_', str(proxy_frac),'_Rscale_',str(Rscale),'_MC_' + str(MCi) 
                nc_filename = MC_dir + ''.join(filename_short) + '.nc'
                hdf5name    = MC_dir + ''.join(filename_short) + '.hdf5'
                #hdf5name_short    = '_loc_', str(locRadv),'_proxy_frac_', str(proxy_frac),'_Rscale_',str(Rscale),'_MC_' + str(MCi) +'.hdf5'
                proxy_psm_type_dict_df = pandas.DataFrame.from_dict(proxy_psm_type_dict, orient='index')

                with h5py.File(hdf5name, 'w') as f:
                    # if any 2d field selected
                    if prior_variable2d_len > 0:
                        f.create_dataset('Xb', data=Xb)
                    f.create_dataset('obvalue', data=obvalue)
                    f.create_dataset('Ye', data=np.transpose(Ye))
                    f.create_dataset('ob_err', data=ob_err)
                    f.create_dataset('ob_err0', data=ob_err0)
                    f.create_dataset('ob_err_comb', data=ob_err_comb)
                    f.create_dataset('yo_all', data=yo_all)
                    # If any 3d field saved
                    if prior_variable3d_len>0:
                        f.create_dataset('Xb3d', data=Xb3d)
                    # if Mg/Ca proxy are used
                    if data_psm_mgca_find == 1:
                        f.create_dataset('Xb_sal', data=Xb_sal)
                        f.create_dataset('Xb_ph', data=Xb_ph)
                        f.create_dataset('Xb_omega', data=Xb_omega)

                    metadata = {'Date': time.time(),
                                'proxy_dbversion':yml_dict['proxies'][yml_dict['proxies']['use_from'][0]]['dbversion'],
                                'exp_dir':dir_prior,
                                'Nens':str(prior_len)}

                    f.attrs.update(metadata)

                # append proxy to hdf5 file
                proxies.to_hdf(hdf5name, key='proxies')
                proxy_psm_type_dict_df.to_hdf(hdf5name, key='proxy_psm_type_dict_df')

                if proxy_frac < 1.0:
                    sites_eval.to_hdf(hdf5name, key='sites_eval')
                pandas.DataFrame(prior_variable_dict).to_hdf(hdf5name, key='prior_variable_dict')
                pandas.DataFrame(prior_variable_dict_3d).to_hdf(hdf5name, key='prior_variable_dict_3d')
                if log_level > 1:
                    print('>>  prior2proxyunit hdf5 file saved')
                    print('      {}'.format(hdf5name))
                    print('>>  Finished Step 1. Preparation')
                    print('>>  Now Run: Step 2. Data Assimilation ...')
                    print('')


                #######     ########     #######     ########     #######     ########     #######     ########     
                # STEP 2 Data Assimilation
                #######     ########     #######     ########     #######     ########     #######     ########


                ########## Prior #########
                # prior variable list
                prior_variable_dict = []  # variable list
                prior_nc_file_list = []  # nc file list
                prior_variable_dict_3d = []  # variable list
                prior_nc_file_list_3d = []  # nc file list

                for key, value in prior_state_variable.items():
                    nc_keyvalue = prior_state_variable[key]['ncname']  # note: 2d dict
                    if log_level > 3:
                        print('>>  nc_keyvalue {}...'.format(nc_keyvalue))
                    for key1, value1 in nc_keyvalue.items():
                        if log_level > 3:
                            print('>>  {}: {}'.format(key1,value1))

                        for i in range(len(prior_state_variable[key][value1])):
                            if key in ['2d']:
                                prior_variable_dict.append(prior_state_variable[key][value1][i])
                                prior_nc_file_list.append(key1+'/'+value1+'.nc')
                            elif key in ['3d']:
                                prior_variable_dict_3d.append(prior_state_variable[key][value1][i])
                                prior_nc_file_list_3d.append(key1+'/'+value1+'.nc')

                # variable list
                prior_variable_len = len(prior_variable_dict)
                prior_variable3d_len = len(prior_variable_dict_3d)
                if log_level > 2:
                    print('>>  Number of 2d prior variables is: {}. List:'.format(prior_variable_len))
                    print('      {}'.format(prior_variable_dict))
                    print('>>  Number of 3d prior variables is: {}. List:'.format(prior_variable3d_len))
                    print('      {}'.format(prior_variable_dict_3d))

                # for saving DA product Xa
                if prior_variable_len > 0:
                    Xa_output   = np.full((dum_ijmax * prior_variable_len, nens, recon_period_len),np.nan)
                    Xa_output_all = Xa_output
                    if prior_variable3d_len > 0:
                        Xa3d_output   = np.full((dum_ijmax * dum_dmax * prior_variable_len, nens, recon_period_len),np.nan)
                        Xa_output_all = np.concatenate((Xa_output, Xa3d_output), axis=0)
                    else:
                        if log_level > 1:
                            print('>>  No 3d variable listed in {}'.format(config_name))
                elif prior_variable_len == 0:
                    if prior_variable3d_len > 0:
                        Xa3d_output   = np.full((dum_ijmax * dum_dmax * prior_variable_len, nens, recon_period_len),np.nan)
                        Xa_output_all = Xa3d_output
                    if log_level > 1:
                        print('>>  No 2d variable listed in {}'.format(config_name))
                else:
                    if log_level > 0:
                        print('>>  Error! No 3d or 2d variables are listed in {}'.format(config_name))

                # DA core script

                proxies=pandas.read_hdf(hdf5name, 'proxies')
                proxy_psm_type_dict_df = pandas.read_hdf(hdf5name, 'proxy_psm_type_dict_df')
                proxy_psm_type_dict_list = proxy_psm_type_dict_df[0].values.tolist()

                with h5py.File(hdf5name, 'r') as f:
                    Xb = f.get('Xb')  # read Xb, background 2d field data
                    Xb3d = f.get('Xb3d')  # read Xb, background 3d field data order: lon-lat-depth
                    if Xb and Xb3d:
                        Xball = np.concatenate((Xb, Xb3d), axis=0)
                    elif Xb and Xb3d is None:
                        Xball = Xb
                    elif Xb is None and Xb3d:
                        Xball = Xb3d
                    else:
                        print('>>  Error! No 3d or 2d variables are listed in {}'.format(config_name))

                    Xb0 = np.copy(Xball)  # default Xb
                    obvalue_full = f.get('obvalue')
                    Ye_full = f.get('Ye')
                    ob_err_full = f.get('ob_err')
                    ob_err0_full = f.get('ob_err0')
                    ob_err_comb = f.get('ob_err_comb')
                    yo_all = f.get('yo_all')  # read location data
                    mgca_psm_list = ['bayesreg_mgca_pooled_bcp', 'bayesreg_mgca_pooled_red', 'deepmip_mgca']
                    if any(item in mgca_psm_list for item in proxy_psm_type_dict_list):
                        #if 'bayesreg_mgca_pooled_bcp' in proxy_psm_type_dict_list or 'bayesreg_mgca_pooled_red' in proxy_psm_type_dict_list:
                        Xb_sal = f.get('Xb_sal')
                        Xb_omega = f.get('Xb_omega')
                        Xb_ph = f.get('Xb_ph')
                        if log_level > 1:
                            print('')
                            print('>>  Mg/Ca proxy found. Loading salinity, pH and omega')

                    Xa_output_all = np.full((Xball.shape[0], Xball.shape[1], recon_period_len),np.nan)
                    ob_len = obvalue_full.shape[0]

                    if log_level > 1:
                        print('>>  Reconstruction time intervals: {}; Observation data set number: {}'.format(recon_period_len,ob_len))
                    for reconi in range(recon_period_len):
                        Xball = Xb0.copy()  # initialize Xball
                        for obi in range(ob_len):
                            if log_level > 3:
                                print('>>ID Recon: {}, obser: {}'.format(reconi,obi))
                            yo_loc = yo_all[obi,:]  # read location
                            obvalue  = obvalue_full[obi, reconi]  # read observation value
                            if proxy_err_eval in ['proxy_err_psm_mp']:
                                ob_err = ob_err_comb[obi, reconi]  # read observation error, use PSM model + interval data uncertainty
                            else:
                                ob_err = ob_err0_full[obi, reconi] # read observation error, use PSM model only

                            # proxy type
                            proxy_psm_type_i = proxy_psm_type_dict_df[0][obi]
                            if proxy_psm_type_i in ['bayesreg_tex86','tex86h_forward', 'bayesreg_d18o_pooled', 'deepmip_d18o','cgenie_caco3', 'cgenie_caco3_13c']:
                                Ye = DeepDA_psm.cal_ye_cgenie(yml_dict,proxies,obi,Xball,proxy_assim2,proxy_psm_type,dum_lon_offset,dum_imax,dum_jmax)
                            elif proxy_psm_type_i in ['bayesreg_mgca_pooled_bcp', 'bayesreg_mgca_pooled_red','deepmip_mgca']:
                                Ye = DeepDA_psm.cal_ye_cgenie_mgca(yml_dict,proxies,obi,Xball,proxy_psm_type_i,dum_lon_offset,dum_imax,dum_jmax,Xb_sal,Xb_ph,Xb_omega,geologic_age)

                            if ~np.isnan(obvalue) and ~np.isnan(ob_err_comb[obi, reconi]):
                                if log_level > 1:
                                    print('>>  ID Recon: {}, obser: {}. Loc: {}. Mean of Ye {:.6f}, var {:.6f}, obs {:.6f}, obs_err {:.6f}'.format(reconi,obi,yo_loc,np.mean(Ye),np.var(Ye,ddof=1), obvalue, ob_err))
                                if locRad:
                                    covloc = modules_nc.covloc_eval(locRad, yo_loc, dum_jmax, dum_imax, cGENIEGrid)
                                    covlocext = int(Xball.shape[0] / covloc.shape[0])
                                    covloc = np.matlib.repmat(covloc, covlocext, 1).reshape((Xball.shape[0],))
                                else:
                                    covloc = np.full((Xball.shape[0],),1)
                                if log_level > 4:
                                    print('>>  Shape of Xball {}, ye {}, ob_err {}, covloc {}'.format(Xball.shape, Ye.shape, ob_err.shape, covloc.shape))
                                Xa = LMR_DA.enkf_update_array(Xball, obvalue, Ye, ob_err, loc = covloc)
                                #XaMean = np.ma.MaskedArray(Xa, np.matlib.repmat(np.copy(xbm) >= 9.9692e+36, 150,1))
                                
                                # June 17, 2020 set hard limit for given variables
                                Xa = DeepDA_tools.deepda_hard_limit(Xa, yml_dict, prior_variable_dict, dum_ijmax,log_level)

                                if reconi == 0 and obi == 0:
                                    kcov_saving = 1
                                    ye = np.subtract(Ye, np.mean(Ye))
                                    xbm = np.mean(Xball,axis=1)
                                    Xbp = np.subtract(Xball,xbm[:,None])  # "None" means replicate in this dimension
                                    kcov = np.dot(Xbp,np.transpose(ye)) / (nens-1)
                                # update Xb using Xa, to assimilate next observation
                                Xball = np.copy(Xa)
                            else:
                                if log_level > 1:
                                    print('>>  ID Recon: {}, obser: {}. Skip invalid obs.'.format(reconi,obi))
                        if log_level > 2:
                            print('>>  ... global mean is {}'.format(np.nanmean(Xa)))
                        Xa_output_all[:,:,reconi] = np.copy(Xa) # for each reconi, all observations were assimilated. Save final result for this reconi

                    if Xb is not None:
                        lenn1 = f.get('Xb').shape[0]
                        Xa_output_2d = Xa_output_all[0:lenn1,:,:]
                        if Xb3d:
                            lenn2 = f.get('Xb3d').shape[0]
                            Xa_output_3d = Xa_output_all[lenn1:lenn2+lenn1,:,:]
                    elif Xb is None:
                        if Xb3d:
                            lenn2 = f.get('Xb3d').shape[0]
                            Xa_output_3d = Xa_output_all[0:lenn2,:,:]
                    else:
                        print('>>  Error! No 3d or 2d variables are listed in {}'.format(config_name))
                if log_level > 1:
                    print('>>')
                    print('>>  Finished Step 2. Data Assimilation')
                    print('>>  Now      Step 3. Save results')
                    print('')
                if log_level > 4:
                    print(Xa_output_all.shape)

                # DA save output in the netCDF file
                
                with h5py.File(hdf5name, 'r') as f:
                    if log_level > 1:
                        print('>>  Start writing netCDF ...')
                    # save netCDF file
                    nf = Dataset(nc_filename, 'w', format='NETCDF4')
                    nf.description = 'DeepDA' + nc_filename
                    #Specifying dimensions
                    nf.createDimension('lon', len(cGENIEGridB_lon36))
                    nf.createDimension('lat', len(cGENIEGridB_lat36))
                    z = np.arange(0,1,1) # level 2d
                    nf.createDimension('z', len(z))  # level
                    nf.createDimension('nens', nens)  # number of ens
                    nf.createDimension('time', recon_period_len)
                    # Building variables
                    longitude = nf.createVariable('Longitude', 'f4', 'lon')
                    # Passing data into variables
                    longitude[:] = cGENIEGridB_lon36.values

                    latitude = nf.createVariable('Latitude', 'f4', 'lat')
                    latitude[:] = cGENIEGridB_lat36.values

                    levels = nf.createVariable('Levels', 'i4', 'z')
                    levels[:] = z  # 2d level
                    if Xb3d is not None:
                        nf.createDimension('zt', len(zt))
                        levels = nf.createVariable('zt', 'f4', 'zt')
                        levels[:] = zt

                    if locRad:
                        #nf.createDimension('prior_var', prior_variable_len)  # level
                        covloc_nc = nf.createVariable('covloc', 'f4', ('lat', 'lon'))
                        covloc_nc[:,:] = np.copy(covloc[0:dum_ijmax].reshape(dum_jmax,dum_imax))

                    if Xb is not None:
                        if log_level > 2:
                            print('Writing 2d field.')
                        for nc_var_i in range(prior_variable_len):
                            nc_var_name = prior_variable_dict[nc_var_i]

                            j0 = dum_ijmax * nc_var_i
                            j1 = dum_ijmax * (nc_var_i+1)
                            if log_level > 0:
                                print('')                            
                                print('>>    ID from {} to {}: field is {}'.format(j0, j1,nc_var_name))

                            Xb0_i = np.copy(f.get('Xb')[j0:j1,:])

                            Xa_output_i = np.copy(Xa_output_2d[j0:j1,:,:])
                            Xa_outputi = Xa_output_i.reshape(dum_imax,dum_jmax,nens,recon_period_len)

                            XbNC_mean = nf.createVariable(nc_var_name+'_Xb_mean', 'f4', ('lat', 'lon','z'))
                            xbm = np.mean(Xb0_i,axis=1)
                            XbNC_mean[:,:,:] = np.copy(xbm.reshape(dum_jmax,dum_imax,1))

                            XbNC_variance = nf.createVariable(nc_var_name+'_Xb_variance', 'f4', ('lat', 'lon','z'))
                            Xb_temp = np.copy(np.var(Xb0_i,axis=1).reshape(dum_jmax,dum_imax,1))
                            Xb_temp = np.ma.MaskedArray(Xb_temp, np.copy(xbm.reshape(dum_jmax,dum_imax,1)) >= 9.9692e+36)
                            XbNC_variance[:,:,:] = Xb_temp
                            if log_level > 0:
                                print('>>               Xb mean is {:.8f}, std is {:.8f}, var is {:.8f}'.format(np.nanmean(XbNC_mean), np.sqrt(np.nanmean(Xb_temp)), np.nanmean(Xb_temp)))

                            XaNC_mean = nf.createVariable(nc_var_name+'_Xa_mean', 'f4', ('lat', 'lon','z','time'))
                            Xam_temp = np.copy(np.nanmean(Xa_outputi,axis=2).reshape(dum_jmax,dum_imax,1,recon_period_len))
                            XaNC_mean[:,:,:,:] = Xam_temp

                            XaNC_variance = nf.createVariable(nc_var_name+'_Xa_variance', 'f4', ('lat', 'lon','z','time'))
                            if log_level > 4:
                                print(Xa_outputi[0,0:36,0,0])
                            Xa_temp = np.copy(np.ma.var(Xa_outputi,axis=2).reshape(dum_jmax,dum_imax,1,recon_period_len))
                            Xa_temp = np.ma.MaskedArray(Xa_temp, Xam_temp >= 9.9692e+36)
                            XaNC_variance[:,:,:,:] = Xa_temp

                            for reconii in range(recon_period_len):
                                XaNC_mean_i = XaNC_mean[:,:,:,reconii]
                                XaNC_var_i = XaNC_variance[:,:,:,reconii]
                                if log_level > 0:
                                    print('>>      Recon {}. Xa mean is {:.8f}, std is {:.8f}, var is {:.8f}'.format(reconii, np.nanmean(XaNC_mean_i),np.sqrt(np.nanmean(XaNC_var_i)), np.nanmean(XaNC_var_i)))

                            if save_ens_full:
                                XaNC_full = nf.createVariable(nc_var_name+'_Xa_full', 'f4', ('lat', 'lon', 'nens', 'z','time'))
                                XaNC_full[:,:,:,:,:] = np.copy(Xa_outputi.reshape(dum_jmax,dum_imax,nens,1,recon_period_len))

                                XbNC_full = nf.createVariable(nc_var_name+'_Xb_full', 'f4', ('lat', 'lon', 'nens', 'z'))
                                XbNC_full[:,:,:,:] = np.copy(Xb0_i.reshape(dum_jmax,dum_imax,nens,1))

                            if kcov_saving > 0:
                                kcov_i = np.copy(kcov[j0:j1]).reshape(dum_imax,dum_jmax,1)
                                kcov_i = np.ma.MaskedArray(kcov_i, np.copy(xbm.reshape(dum_jmax,dum_imax,1)) >= 9.9692e+36)
                                cov_ob0 = nf.createVariable(nc_var_name+'_obs0'+'_cov', 'f4', ('lat', 'lon','z'))
                                cov_ob0[:,:,:] = kcov_i

                            #Add local attributes to variable instances
                            longitude.units = '°'
                            latitude.units = '°'
                            levels.units = 'm'
                            #XbNC_mean.units = '°C'
                            #XbNC_variance.units = '°C^2'
                            #if save_ens_full:
                            #    XaNC_full.units = '°C'
                            #    XbNC_full.units = '°C'

                            #variance.warning = 'test ...'
                    if Xb3d is not None:
                        if log_level > 1:
                            print('Writing 3d field.')
                        for nc_var_i in range(prior_variable3d_len):
                            nc_var_name = prior_variable_dict_3d[nc_var_i]

                            j0 = dum_ijmax * dum_dmax * nc_var_i
                            j1 = dum_ijmax * dum_dmax * (nc_var_i+1)
                            if log_level > 0:
                                print('')
                                print('>>    ID from {} to {}: field is {}'.format(j0, j1,nc_var_name))

                            Xb0_i = np.copy(f.get('Xb3d')[j0:j1,:])
                            Xa_output_i = np.copy(Xa_output_3d[j0:j1,:,:])
                            Xa_outputi = Xa_output_i.reshape(dum_imax, dum_jmax,dum_dmax, nens,recon_period_len)

                            XbNC_mean = nf.createVariable(nc_var_name+'_Xb_3d_mean', 'f4', ( 'zt', 'lat','lon'))
                            xbm = np.mean(Xb0_i,axis=1)
                            XbNC_mean[:,:,:] = np.copy(xbm.reshape(dum_dmax,dum_jmax,dum_imax))

                            XbNC_variance = nf.createVariable(nc_var_name+'_Xb_3d_variance', 'f4', ( 'zt', 'lat','lon'))
                            Xb_temp = np.copy(np.var(Xb0_i,axis= 1).reshape(dum_dmax,dum_jmax,dum_imax))
                            Xb_temp = np.ma.MaskedArray(Xb_temp, np.copy(xbm.reshape(dum_dmax,dum_jmax,dum_imax)) >= 9.9692e+36)
                            XbNC_variance[:,:,:] = Xb_temp
                            if log_level > 0:
                                print('>>               Xb mean is {:.8f}, std is {:.8f}, var is {:.8f}'.format(np.nanmean(XbNC_mean),np.sqrt(np.nanmean(Xb_temp)), np.nanmean(Xb_temp)))

                            XaNC_mean = nf.createVariable(nc_var_name+'_Xa_3d_mean', 'f4', ('zt','lat', 'lon','time'))
                            Xam_temp = np.copy(np.nanmean(Xa_outputi,axis=3).reshape(dum_dmax,dum_jmax,dum_imax,recon_period_len))
                            XaNC_mean[:,:,:,:] = Xam_temp

                            XaNC_variance = nf.createVariable(nc_var_name+'_Xa_3d_variance', 'f4', ('zt','lat', 'lon','time'))
                            Xa_temp = np.copy(np.ma.var(Xa_outputi,axis=3).reshape(dum_dmax,dum_jmax,dum_imax,recon_period_len))
                            Xa_temp = np.ma.MaskedArray(Xa_temp, Xam_temp >= 9.9692e+36)
                            XaNC_variance[:,:,:,:] = Xa_temp

                            for reconii in range(recon_period_len):
                                XaNC_mean_i = XaNC_mean[:,:,:,reconii]
                                XaNC_var_i = XaNC_variance[:,:,:,reconii]
                                if log_level > 0:
                                    print('>>      Recon {}. Xa mean is {:.8f}, std is {:.8f}, var is {:.8f}'.format(reconii, np.nanmean(XaNC_mean_i), np.sqrt(np.nanmean(XaNC_var_i)), np.nanmean(XaNC_var_i)))

                            if save_ens_full:
                                XaNC_full = nf.createVariable(nc_var_name+'_Xa_3d_full', 'f4', ('zt','lat', 'lon', 'nens', 'time'))
                                XaNC_full[:,:,:,:,:] = np.copy(Xa_outputi.reshape(dum_dmax,dum_jmax,dum_imax,nens,recon_period_len))

                                XbNC_full = nf.createVariable(nc_var_name+'_Xb_3d_full', 'f4', ('zt','lat', 'lon', 'nens'))
                                XbNC_full[:,:,:,:] = np.copy(Xb0_i.reshape(dum_dmax,dum_jmax,dum_imax,nens))

                            if kcov_saving > 0:
                                kcov_i = np.copy(kcov[lenn1:lenn1+dum_ijmax*dum_dmax]).reshape(dum_dmax,dum_jmax,dum_imax)
                                kcov_i = np.ma.MaskedArray(kcov_i, np.copy(xbm.reshape(dum_dmax,dum_jmax,dum_imax)) >= 9.9692e+36)
                                cov_ob0 = nf.createVariable(nc_var_name+'_3d_obs0'+'_cov', 'f4', ( 'zt', 'lat','lon'))
                                cov_ob0[:,:,:] = kcov_i

                            #Add local attributes to variable instances
                            longitude.units = '°'
                            latitude.units = '°'
                            levels.units = 'm'
                            #XbNC_mean.units = '°C'
                            #XbNC_variance.units = '°C^2'
                            #if save_ens_full:
                            #    XaNC_full.units = '°C'
                            #    XbNC_full.units = '°C'
                    # Closing the dataset
                    nf.close()  # close the new file
                    if log_level > 1:
                        print('')
                        print('>>  Data saved in netCDF file:')
                if log_level > 1:
                    print('')
                    print('netCDF file saved : ')
                    print('')
                    print(nc_filename)
                    print('')
                    print('##########                    ##########')
                    print('##########   This loop done   ##########')
                    print('##########                    ##########')
                    print('')
                    print('')
                
# export jupyter notebook as html, for reference
os.system('jupyter nbconvert --to html DeepDA_allMC.ipynb')
shutil.move("DeepDA_allMC.html", MC_dir+"DeepDA_allMC.html")
if log_level > 0:
    print('')
    print('########## All Done ##########')
    print('')
    print('This web page saved as DeepDA_allMC.html in the working directory : {}'.format(MC_dir))
########## Check the consistency of the config.yml file and proxy database ##########


Bad key "text.kerning_factor" on line 4 in
/Users/mingsongli/miniconda3/envs/deepda/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test_patch.mplstyle.
You probably need to get an updated matplotlibrc file from
http://github.com/matplotlib/matplotlib/blob/master/matplotlibrc.template
or from the matplotlib source distribution


>>  Import packages...  => Okay

 ########## Load yml config file ########## 


The config (yml) file saved : 

/volumes/DA/DeepDA/wrk/petmproxy3slices_v0.0.18.csv_petm18_v18_20210122_TOM_deepmip_MCsd100.yml

>>  Prior member size: 100
>>  Recon_period 1 - 2. 
      List: [1 2]
>>  Proxy error evaluation: proxy_err_psm_fixed
>>  Proxy to be assimilated (some may not exist)
      ['Marine sediments_uk37', 'Marine sediments_d18o_pooled', 'Marine sediments_tex86', 'Marine sediments_mgca_pooled_bcp', 'Marine sediments_mgca_pooled_red']
>>  Proxy quality control selection: None

########## Read prior ######### 





>>  Units of state variables ['ocn_sur_temp', 'atm_temp', 'atm_pCO2', 'ocn_sur_sal', 'misc_pH', 'carb_sur_ohm_cal', 'sed_CaCO3']: ['unit', 'degrees C', 'atm', 'unit', 'pH units (SWS)', 'unit', 'wt%']
>>  Reading Prior => Okay

 ########## Read proxies database ########## 

>>  Proxy: selected proxy dataset number 84: those in blacklist removed!
>>  Proxy: selected proxy dataset number 69: those unknown/frosty removed!

>>  Starting Monte Carlo ... 
>>  Selected index: [49, 53, 5, 33, 62, 51, 58, 50, 67, 19, 30, 22, 37, 13, 32, 8, 18, 60, 48, 6, 39, 16, 34, 45, 38, 9, 59, 68, 4, 21, 64, 35, 41, 57, 27, 20, 55, 44, 14, 46, 47, 63, 1, 25, 17, 0, 2, 12]
>>  Unselected index: [3, 7, 10, 11, 15, 23, 24, 26, 28, 29, 31, 36, 40, 42, 43, 52, 54, 56, 61, 65, 66]
>>  Selected proxy data length 48

>>  Summary of this Monte Carlo simulation
[Ye, Ob_value]
>>  [[ 4.03726263  5.85        3.00005161]
 [-2.17673401 -1.785898           nan]
 [ 0.69324747  0.85691489         nan]
 [-3.86405951 -2.478   

your performance may suffer as PyTables will pickle object types that it cannot
map directly to c-types [inferred_type->mixed,key->block1_values] [items->['File', 'Site', 'Type', 'Proxy', 'DepthOri', 'Glassy', 'mg_bcp_red']]

  pytables.to_hdf(path_or_buf, key, self, **kwargs)


>>  prior2proxyunit hdf5 file saved
      /volumes/DA/DeepDA/wrk/petmproxy3slices_v0.0.18.csv_petm18_v18_20210122_TOM_deepmip_MCsd100/_loc_0_proxy_frac_0.7_Rscale_0.7_MC_0.hdf5
>>  Finished Step 1. Preparation
>>  Now Run: Step 2. Data Assimilation ...

>>  No 3d variable listed in DeepDA_config.yml

>>  Mg/Ca proxy found. Loading salinity, pH and omega
>>  Reconstruction time intervals: 2; Observation data set number: 48
>>  ID Recon: 0, obser: 0. Loc: [-178.55  -56.15]. Mean of Ye 4.037263, var 3.132305, obs 5.850000, obs_err 0.373882
>>  ID Recon: 0, obser: 1. Loc: [  4.15 -70.  ]. Mean of Ye -3.362293, var 0.194197, obs -1.785898, obs_err 0.205473
>>  ID Recon: 0, obser: 2. Loc: [-161.    -48.75]. Mean of Ye 0.720844, var 0.000530, obs 0.856915, obs_err 0.004431
>>  ID Recon: 0, obser: 3. Loc: [-10.95  41.8 ]. Mean of Ye -4.345745, var 0.061585, obs -2.478000, obs_err 0.202299
>>  ID Recon: 0, obser: 4. Loc: [-155.45   28.35]. Mean of Ye 6.331565, var 0.769920, obs 5.420000, obs_er