In [None]:
# DASK client set

import os
import sys
from dask.distributed import Client
# client = Client(scheduler_file='/proj/kimyy/Dropbox/source/python/all/mpi/scheduler.json', threads_per_worker=2, n_workers=6)
client = Client(scheduler_file='/proj/kimyy/Dropbox/source/python/all/mpi/scheduler.json')
# client = Client(scheduler_file='/proj/kimyy/Dropbox/source/python/all/mpi/scheduler_10.json')  

# add private module path for workers
# client.run(lambda: os.environ.update({'PYTHONPATH': '/proj/kimyy/Dropbox/source/python/all/Modules/CESM2'}))
# def add_path():
#     if '/proj/kimyy/Dropbox/source/python/all/Modules/CESM2' not in sys.path:
#         sys.path.append('/proj/kimyy/Dropbox/source/python/all/Modules/CESM2')

# client.run(add_path)

def setup_module_path():
    module_path = '/proj/kimyy/Dropbox/source/python/all/Modules/CESM2'
    if module_path not in sys.path:
        sys.path.append(module_path)

client.run(setup_module_path)

client

# get path for path changes in Jupyter notebook: File - Open from Path - insert relative_path
notebook_path = os.path.abspath(".")
_, _, relative_path = notebook_path.partition('/all/')
relative_path = '/all/' + relative_path
relative_path

In [None]:
# load public modules

import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.patches as patches
import matplotlib.ticker as mticker
import matplotlib.path as mpath
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from scipy import stats
from scipy.interpolate import griddata
import cmocean
from cmcrameri import cm
import warnings
warnings.simplefilter(action='ignore')
import pandas as pd
import cftime
import pop_tools
from pprint import pprint
import time
import subprocess
import re as re_mod
import cftime
import datetime
from scipy.stats import ttest_1samp

In [None]:
# load private modules

import sys
sys.path.append('/proj/kimyy/Dropbox/source/python/all/Modules/CESM2')
from KYY_CESM2_NWP_preprocessing import CESM2_NWP_config
# import KYY_CESM2_preprocessing
# import importlib
# importlib.reload(KYY_CESM2_preprocessing)

In [None]:
# 1. DIC (Total)
cfg_var_DIC=CESM2_NWP_config()
cfg_var_DIC.year_s=1955
cfg_var_DIC.year_e=2020
cfg_var_DIC.setvar('DIC')

# 2. WVEL (Total)
cfg_var_WVEL=CESM2_NWP_config()
cfg_var_WVEL.year_s=1955
cfg_var_WVEL.year_e=2020
cfg_var_WVEL.setvar('WVEL')

if cfg_var_DIC.comp=='ocn':
    ds_grid = pop_tools.get_grid('POP_gx1v7')

In [None]:
# define preprocessing function

exceptcv=['time','lon','lat','lev', 'TAREA', 'TLONG', 'TLAT', 'z_t', 'z_t_2', cfg_var_DIC.var, cfg_var_TEMP.var]
# exceptcv=['time','lon','lat','lev', 'TAREA', 'TLONG', 'TLAT', 'z_t', 'dz', 'z_t_2', cfg_var_DIC.var, cfg_var_TEMP.var]

def process_coords_3d(ds, sd, ed, drop=True, except_coord_vars=exceptcv):
    """Preprocessor function to drop all non-dim coords, which slows down concatenation."""
    coord_vars = []
    for v in np.array(ds.coords) :
        if not v in except_coord_vars:
            coord_vars += [v]
    for v in np.array(ds.data_vars) :
        if not v in except_coord_vars:
            coord_vars += [v]

    if drop:
        ds= ds.drop(coord_vars)
        ds= ds.sel(time=slice(sd, ed))
        # ds = ds.isel(z_t=slice(0, 39)) # ~39 layer (1000m)
        # ds = (ds.isel(z_t=slice(1, 39)) * ds.dz).sum(dim='z_t') / ds.dz.sum(dim='z_t')
        return ds
    else:
        return ds.set_coords(coord_vars)



def process_coords_3d_LE(ds, sd, ed, drop=True, except_coord_vars=exceptcv):
    """
    Preprocessor function for CESM POP-style datasets.
    - Normalizes vertical coordinate: if z_t or z_t_2 exists, rename to 'depth'.
    - Replaces its values with z_t_new for consistency.
    - Optionally drops unnecessary coordinate variables for faster concatenation.
    """
    z_t_new = np.array([5.0000000e+00, 1.5000000e+01, 2.5000000e+01, 3.5000000e+01,
       4.5000000e+01, 5.5000000e+01, 6.5000000e+01, 7.5000000e+01,
       8.5000000e+01, 9.5000000e+01, 1.0500000e+02, 1.1500000e+02,
       1.2500000e+02, 1.3500000e+02, 1.4500000e+02, 1.5500000e+02,
       1.6509839e+02, 1.7547903e+02, 1.8629126e+02, 1.9766026e+02,
       2.0971138e+02, 2.2257828e+02, 2.3640883e+02, 2.5137015e+02,
       2.6765421e+02, 2.8548364e+02, 3.0511920e+02, 3.2686798e+02,
       3.5109348e+02, 3.7822760e+02, 4.0878464e+02, 4.4337769e+02,
       4.8273669e+02, 5.2772797e+02, 5.7937286e+02, 6.3886261e+02,
       7.0756329e+02, 7.8700250e+02, 8.7882520e+02, 9.8470581e+02,
       1.1062042e+03, 1.2445669e+03, 1.4004972e+03, 1.5739464e+03,
       1.7640033e+03, 1.9689442e+03, 2.1864565e+03, 2.4139714e+03,
       2.6490012e+03, 2.8893845e+03, 3.1334045e+03, 3.3797935e+03,
       3.6276702e+03, 3.8764519e+03, 4.1257681e+03, 4.3753926e+03,
       4.6251904e+03, 4.8750835e+03, 5.1250278e+03, 5.3750000e+03])
    
    # ------------------------------------------------------
    # 1️⃣ Normalize vertical coordinate name
    # ------------------------------------------------------
    if "z_t_2" in ds.dims:
        ds = ds.rename({"z_t_2": "depth"})
    elif "z_t" in ds.dims:
        ds = ds.rename({"z_t": "depth"})
    else:
        print("[Warning] No vertical coordinate (z_t or z_t_2) found — skipped.")
        return ds

    # Drop any leftover z_t/z_t_2 coordinate variable if it exists
    ds = ds.drop_vars(["z_t", "z_t_2"], errors="ignore")

    # ------------------------------------------------------
    # 2️⃣ Replace coordinate values with z_t_new
    # ------------------------------------------------------
    if "depth" in ds.coords:
        if len(ds["depth"]) == len(z_t_new):
            ds = ds.assign_coords(depth=z_t_new)
        else:
            print(f"[Warning] depth length mismatch: {len(ds['depth'])} vs {len(z_t_new)}")
    else:
        print("[Warning] depth coordinate missing after renaming.")

    # ------------------------------------------------------
    # 3️⃣ Clean up coordinate references inside variable attributes
    # ------------------------------------------------------
    for v in ds.data_vars:
        if "coordinates" in ds[v].attrs:
            ds[v].attrs["coordinates"] = (
                ds[v].attrs["coordinates"]
                .replace("z_t_2", "depth")
                .replace("z_t", "depth")
            )

    # ------------------------------------------------------
    # 4️⃣ Drop unnecessary coordinate variables and slice time
    # ------------------------------------------------------
    coord_vars = []
    for v in np.array(ds.coords):
        if v not in except_coord_vars:
            coord_vars.append(v)
    for v in np.array(ds.data_vars):
        if v not in except_coord_vars:
            coord_vars.append(v)

    if drop:
        ds = ds.drop(coord_vars, errors="ignore")
        ds = ds.sel(time=slice(sd, ed))
    else:
        ds = ds.set_coords(coord_vars)

    return ds

start_date = cftime.DatetimeNoLeap(cfg_var_DIC.year_s, 2, 1)
end_date = cftime.DatetimeNoLeap(cfg_var_DIC.year_e+1, 1, 1)


# ds = ds.isel(lev=slice(1, 11))

In [None]:
# Read LE dataset

# Quicker test for 2 ensembles only

start_time = time.time()

#DIC
cfg_var_DIC.LE_path_load(cfg_var_DIC.var)
cfg_var_DIC.LE_ds = xr.open_mfdataset(cfg_var_DIC.LE_file_list[0][11:12], 
                       chunks={'time': 12}, 
                       combine='nested', 
                       concat_dim=[[*cfg_var_DIC.LE_ensembles][11:12], 'time'], 
                       parallel=True,
                       preprocess=lambda ds: process_coords_3d_LE(ds, start_date, end_date),
                       decode_cf=True, 
                       decode_times=True)         

cfg_var_DIC.LE_ds = cfg_var_DIC.LE_ds.rename({"concat_dim": "ens_LE"})
new_time = cfg_var_DIC.LE_ds.time - np.array([datetime.timedelta(days=15)] * len(cfg_var_DIC.LE_ds.time))
cfg_var_DIC.LE_ds = cfg_var_DIC.LE_ds.assign_coords(time=new_time)
cfg_var_DIC.LE_ds = cfg_var_DIC.LE_ds.rename({"depth": "z_t"})

end_time = time.time()
elapsed_time = end_time - start_time
print('elasped time for reading LE: ' + str(elapsed_time))

# Read ADA dataset

# Quicker test for 2 ensembADAs only

start_time = time.time()

#DIC
cfg_var_DIC.ADA_path_load(cfg_var_DIC.var)
cfg_var_DIC.ADA_ds = xr.open_mfdataset(cfg_var_DIC.ADA_file_list[0][5:6], 
                       chunks={'time': 12}, 
                       combine='nested', 
                       concat_dim=[[*cfg_var_DIC.ADA_ensembles][5:6], 'time'], 
                       parallel=True,
                       preprocess=lambda ds: process_coords_3d(ds, start_date, end_date),
                       decode_cf=True, 
                       decode_times=True)         

cfg_var_DIC.ADA_ds = cfg_var_DIC.ADA_ds.rename({"concat_dim": "ens_ADA"})
new_time = cfg_var_DIC.ADA_ds.time - np.array([datetime.timedelta(days=15)] * len(cfg_var_DIC.ADA_ds.time))
cfg_var_DIC.ADA_ds = cfg_var_DIC.ADA_ds.assign_coords(time=new_time)

end_time = time.time()
elapsed_time = end_time - start_time
print('elasped time for reading ADA: ' + str(elapsed_time))

# Read WDA dataset

# Quicker test for 2 ensembles only

start_time = time.time()

#DIC
cfg_var_DIC.WDA_path_load(cfg_var_DIC.var)
cfg_var_DIC.WDA_ds = xr.open_mfdataset(cfg_var_DIC.WDA_file_list[0], 
                       chunks={'time': 12}, 
                       combine='nested', 
                       concat_dim=[[*cfg_var_DIC.WDA_ensembles], 'time'], 
                       parallel=True,
                       preprocess=lambda ds: process_coords_3d(ds, start_date, end_date),
                       decode_cf=True, 
                       decode_times=True)         

cfg_var_DIC.WDA_ds = cfg_var_DIC.WDA_ds.rename({"concat_dim": "ens_WDA"})
new_time = cfg_var_DIC.WDA_ds.time - np.array([datetime.timedelta(days=15)] * len(cfg_var_DIC.WDA_ds.time))
cfg_var_DIC.WDA_ds = cfg_var_DIC.WDA_ds.assign_coords(time=new_time)

end_time = time.time()
elapsed_time = end_time - start_time
print('elasped time for reading WDA: ' + str(elapsed_time))

# Read ODA dataset

# Quicker test for 2 ensembles only

start_time = time.time()

#DIC
cfg_var_DIC.ODA_path_load(cfg_var_DIC.var)
cfg_var_DIC.ODA_ds = xr.open_mfdataset(cfg_var_DIC.ODA_file_list[0][15:16], 
                       chunks={'time': 12}, 
                       combine='nested', 
                       concat_dim=[[*cfg_var_DIC.ODA_ensembles][15:16], 'time'], 
                       parallel=True,
                       preprocess=lambda ds: process_coords_3d(ds, start_date, end_date),
                       decode_cf=True, 
                       decode_times=True)      

cfg_var_DIC.ODA_ds = cfg_var_DIC.ODA_ds.rename({"concat_dim": "ens_ODA"})
new_time = cfg_var_DIC.ODA_ds.time - np.array([datetime.timedelta(days=15)] * len(cfg_var_DIC.ODA_ds.time))
cfg_var_DIC.ODA_ds = cfg_var_DIC.ODA_ds.assign_coords(time=new_time)


# cfg_var_DIC.ODA_path_load(cfg_var_DIC.var)
# cfg_var_DIC.ODA_file_list[0]

end_time = time.time()
elapsed_time = end_time - start_time
print('elasped time for reading ODA: ' + str(elapsed_time))



start_time = time.time()

#regrids
import xcesm

lat_range = slice(10, 60)
lon_range = slice(110, 190)

cfg_var_DIC.LE_ds_rgd = (
    cfg_var_DIC.LE_ds['DIC']
    .isel(ens_LE=0)
    .utils.regrid()
    .sel(lat=lat_range, lon=lon_range)
)

cfg_var_DIC.ODA_ds_rgd = (
    cfg_var_DIC.ODA_ds['DIC']
    .isel(ens_ODA=0)
    .utils.regrid()
    .sel(lat=lat_range, lon=lon_range)
)

cfg_var_DIC.ADA_ds_rgd = (
    cfg_var_DIC.ADA_ds['DIC']
    .isel(ens_ADA=0)
    .utils.regrid()
    .sel(lat=lat_range, lon=lon_range)
)

cfg_var_DIC.WDA_ds_rgd = (
    cfg_var_DIC.WDA_ds['DIC']
    .isel(ens_WDA=0)
    .utils.regrid()
    .sel(lat=lat_range, lon=lon_range)
)

cfg_var_DIC.LE_ds_rgd = cfg_var_DIC.LE_ds_rgd.sortby("time")
cfg_var_DIC.WDA_ds_rgd = cfg_var_DIC.WDA_ds_rgd.sortby("time")
cfg_var_DIC.ADA_ds_rgd = cfg_var_DIC.ADA_ds_rgd.sortby("time")
cfg_var_DIC.ODA_ds_rgd = cfg_var_DIC.ODA_ds_rgd.sortby("time")


cfg_var_DIC.LE_ds_rgd = cfg_var_DIC.LE_ds_rgd.assign_coords(z_t = cfg_var_DIC.ODA_ds_rgd.z_t)


end_time = time.time()
elapsed_time = end_time - start_time
print('elasped time for regriding: ' + str(elapsed_time))

In [None]:
# Read LE dataset

# Quicker test for 2 ensembles only

start_time = time.time()

#WVEL
cfg_var_WVEL.LE_path_load(cfg_var_WVEL.var)
cfg_var_WVEL.LE_ds = xr.open_mfdataset(cfg_var_WVEL.LE_file_list[0][11:12], 
                       chunks={'time': 12}, 
                       combine='nested', 
                       concat_dim=[[*cfg_var_WVEL.LE_ensembles][11:12], 'time'], 
                       parallel=True,
                       preprocess=lambda ds: process_coords_3d_LE(ds, start_date, end_date),
                       decode_cf=True, 
                       decode_times=True)         

cfg_var_WVEL.LE_ds = cfg_var_WVEL.LE_ds.rename({"concat_dim": "ens_LE"})
new_time = cfg_var_WVEL.LE_ds.time - np.array([datetime.timedelta(days=15)] * len(cfg_var_WVEL.LE_ds.time))
cfg_var_WVEL.LE_ds = cfg_var_WVEL.LE_ds.assign_coords(time=new_time)
cfg_var_WVEL.LE_ds = cfg_var_WVEL.LE_ds.rename({"depth": "z_t"})

end_time = time.time()
elapsed_time = end_time - start_time
print('elasped time for reading LE: ' + str(elapsed_time))

# Read ADA dataset

# Quicker test for 2 ensembADAs only

start_time = time.time()

#WVEL
cfg_var_WVEL.ADA_path_load(cfg_var_WVEL.var)
cfg_var_WVEL.ADA_ds = xr.open_mfdataset(cfg_var_WVEL.ADA_file_list[0][5:6], 
                       chunks={'time': 12}, 
                       combine='nested', 
                       concat_dim=[[*cfg_var_WVEL.ADA_ensembles][5:6], 'time'], 
                       parallel=True,
                       preprocess=lambda ds: process_coords_3d(ds, start_date, end_date),
                       decode_cf=True, 
                       decode_times=True)         

cfg_var_WVEL.ADA_ds = cfg_var_WVEL.ADA_ds.rename({"concat_dim": "ens_ADA"})
new_time = cfg_var_WVEL.ADA_ds.time - np.array([datetime.timedelta(days=15)] * len(cfg_var_WVEL.ADA_ds.time))
cfg_var_WVEL.ADA_ds = cfg_var_WVEL.ADA_ds.assign_coords(time=new_time)

end_time = time.time()
elapsed_time = end_time - start_time
print('elasped time for reading ADA: ' + str(elapsed_time))

# Read WDA dataset

# Quicker test for 2 ensembles only

start_time = time.time()

#WVEL
cfg_var_WVEL.WDA_path_load(cfg_var_WVEL.var)
cfg_var_WVEL.WDA_ds = xr.open_mfdataset(cfg_var_WVEL.WDA_file_list[0], 
                       chunks={'time': 12}, 
                       combine='nested', 
                       concat_dim=[[*cfg_var_WVEL.WDA_ensembles], 'time'], 
                       parallel=True,
                       preprocess=lambda ds: process_coords_3d(ds, start_date, end_date),
                       decode_cf=True, 
                       decode_times=True)         

cfg_var_WVEL.WDA_ds = cfg_var_WVEL.WDA_ds.rename({"concat_dim": "ens_WDA"})
new_time = cfg_var_WVEL.WDA_ds.time - np.array([datetime.timedelta(days=15)] * len(cfg_var_WVEL.WDA_ds.time))
cfg_var_WVEL.WDA_ds = cfg_var_WVEL.WDA_ds.assign_coords(time=new_time)

end_time = time.time()
elapsed_time = end_time - start_time
print('elasped time for reading WDA: ' + str(elapsed_time))

# Read ODA dataset

# Quicker test for 2 ensembles only

start_time = time.time()

#WVEL
cfg_var_WVEL.ODA_path_load(cfg_var_WVEL.var)
cfg_var_WVEL.ODA_ds = xr.open_mfdataset(cfg_var_WVEL.ODA_file_list[0][15:16], 
                       chunks={'time': 12}, 
                       combine='nested', 
                       concat_dim=[[*cfg_var_WVEL.ODA_ensembles][15:16], 'time'], 
                       parallel=True,
                       preprocess=lambda ds: process_coords_3d(ds, start_date, end_date),
                       decode_cf=True, 
                       decode_times=True)      

cfg_var_WVEL.ODA_ds = cfg_var_WVEL.ODA_ds.rename({"concat_dim": "ens_ODA"})
new_time = cfg_var_WVEL.ODA_ds.time - np.array([datetime.timedelta(days=15)] * len(cfg_var_WVEL.ODA_ds.time))
cfg_var_WVEL.ODA_ds = cfg_var_WVEL.ODA_ds.assign_coords(time=new_time)


# cfg_var_WVEL.ODA_path_load(cfg_var_WVEL.var)
# cfg_var_WVEL.ODA_file_list[0]

end_time = time.time()
elapsed_time = end_time - start_time
print('elasped time for reading ODA: ' + str(elapsed_time))



start_time = time.time()

#regrids
import xcesm

lat_range = slice(10, 60)
lon_range = slice(110, 190)

cfg_var_WVEL.LE_ds_rgd = (
    cfg_var_WVEL.LE_ds['WVEL']
    .isel(ens_LE=0)
    .utils.regrid()
    .sel(lat=lat_range, lon=lon_range)
)

cfg_var_WVEL.ODA_ds_rgd = (
    cfg_var_WVEL.ODA_ds['WVEL']
    .isel(ens_ODA=0)
    .utils.regrid()
    .sel(lat=lat_range, lon=lon_range)
)

cfg_var_WVEL.ADA_ds_rgd = (
    cfg_var_WVEL.ADA_ds['WVEL']
    .isel(ens_ADA=0)
    .utils.regrid()
    .sel(lat=lat_range, lon=lon_range)
)

cfg_var_WVEL.WDA_ds_rgd = (
    cfg_var_WVEL.WDA_ds['WVEL']
    .isel(ens_WDA=0)
    .utils.regrid()
    .sel(lat=lat_range, lon=lon_range)
)

cfg_var_WVEL.LE_ds_rgd = cfg_var_WVEL.LE_ds_rgd.sortby("time")
cfg_var_WVEL.WDA_ds_rgd = cfg_var_WVEL.WDA_ds_rgd.sortby("time")
cfg_var_WVEL.ADA_ds_rgd = cfg_var_WVEL.ADA_ds_rgd.sortby("time")
cfg_var_WVEL.ODA_ds_rgd = cfg_var_WVEL.ODA_ds_rgd.sortby("time")


cfg_var_WVEL.LE_ds_rgd = cfg_var_WVEL.LE_ds_rgd.assign_coords(z_t = cfg_var_WVEL.ODA_ds_rgd.z_t)


end_time = time.time()
elapsed_time = end_time - start_time
print('elasped time for regriding: ' + str(elapsed_time))