# Module and DASK setting

In [1]:
# DASK client set

import os
import sys
from dask.distributed import Client
# client = Client(scheduler_file='/proj/kimyy/Dropbox/source/python/all/mpi/scheduler.json', threads_per_worker=2, n_workers=6)
client = Client(scheduler_file='/proj/kimyy/Dropbox/source/python/all/mpi/scheduler.json')
# client = Client(scheduler_file='/proj/kimyy/Dropbox/source/python/all/mpi/scheduler_10.json')  

def setup_module_path():
    module_path = '/proj/kimyy/Dropbox/source/python/all/Modules/CESM2'
    if module_path not in sys.path:
        sys.path.append(module_path)

client.run(setup_module_path)

client

# get path for path changes in Jupyter notebook: File - Open from Path - insert relative_path
notebook_path = os.path.abspath(".")
_, _, relative_path = notebook_path.partition('/all/')
relative_path = '/all/' + relative_path
relative_path

'/all/Model/CESM2/Earth_System_Predictability/ASSM/Aleph'

In [2]:
# load public modules

import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.patches as patches
import matplotlib.ticker as mticker
import matplotlib.path as mpath
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from scipy import stats
from scipy.interpolate import griddata
import cmocean
from cmcrameri import cm
import warnings
warnings.simplefilter(action='ignore')
import pandas as pd
import cftime
import pop_tools
from pprint import pprint
import time
import subprocess
import re as re_mod
import cftime
import datetime
from scipy.stats import ttest_1samp
import xcesm
# from scipy.stats import pearsonr
from scipy.stats import t

In [3]:
# load private modules

import sys
sys.path.append('/proj/kimyy/Dropbox/source/python/all/Modules/CESM2')
from KYY_CESM2_preprocessing import CESM2_config

savefilepath = "/mnt/lustre/proj/kimyy/tmp_python/HCST_skills_autocorr"


In [4]:
# change variables by command+F, for S-ST, T-REFHT, T-WS, P-SL, P-RECT, G-PP, S-SH, p-hotoC_TOT_zint_100m, F-AREA_BURNED (not for N-O3). 

cfg_var_SST=CESM2_config()
cfg_var_SST.year_s=1960
cfg_var_SST.year_e=2020
cfg_var_SST.setvar('SST')

start_date = cftime.DatetimeNoLeap(cfg_var_SST.year_s, 2, 1)
end_date = cftime.DatetimeNoLeap(cfg_var_SST.year_e+1, 1, 1)

ds_grid = pop_tools.get_grid('POP_gx1v7')


In [5]:

def process_coords_2d(
    ds, sd, ed, varname, comp, drop=True,
    except_coord_vars=["time","lon","lat","TLONG","TLAT"]

):
    import xcesm
    import numpy as np
    import datetime
    except_coord_vars.append(varname)

    coord_vars = []
    for v in np.array(ds.coords):
        if v not in except_coord_vars:
            coord_vars.append(v)
    for v in np.array(ds.data_vars):
        if v not in except_coord_vars:
            coord_vars.append(v)

    if drop:
        ds = ds.drop(coord_vars)
        ds = ds.sel(time=slice(sd, ed))

        new_time = ds.time - np.array([datetime.timedelta(days=15)] * len(ds.time))
        ds = ds.assign_coords(time=new_time)      
        ds=ds.groupby('time.year').mean(dim='time', skipna=True)
        if comp == "atm" or comp == "lnd":
            ds['lat'] = ds['lat'].round(4)
            ds['lon'] = ds['lon'].round(4)
        if comp == "ocn" or comp == "ice":
            ds['TLAT'] = ds['TLAT'].round(4)
            ds['TLONG'] = ds['TLONG'].round(4)
        return ds
    else:
        return ds.set_coords(coord_vars)


def process_coords_2d_obs(
    ds, sd, ed, varname, comp, drop=True,
    except_coord_vars=["time","lon","lat","TLONG","TLAT"]

):
    import xcesm
    import numpy as np
    import datetime

    if drop:

        if 'T' in ds.coords or 'T' in ds.dims:
            ds = ds.rename({'T': 'time'})
        ds=ds.groupby('time.year').mean(dim='time', skipna=True)
        if comp == "atm" or comp == "lnd":
            ds['lat'] = ds['lat'].round(4)
            ds['lon'] = ds['lon'].round(4)
        if comp == "ocn" or comp == "ice":
            ds['TLAT'] = ds['TLAT'].round(4)
            ds['TLONG'] = ds['TLONG'].round(4)
        return ds
    else:
        return ds.set_coords(coord_vars)


def process_coords_2d_hcst(
    ds, sd, ed, varname, comp, drop=True,
    except_coord_vars=["time","lon","lat","TLONG","TLAT"]
):
    import xcesm
    import numpy as np
    import datetime
    except_coord_vars.append(varname)

    coord_vars = []
    for v in np.array(ds.coords):
        if v not in except_coord_vars:
            coord_vars.append(v)
    for v in np.array(ds.data_vars):
        if v not in except_coord_vars:
            coord_vars.append(v)

    if drop:
        ds = ds.drop(coord_vars)
        # ds_rgd = ds[varname].utils.regrid()
        # new_time = ds_rgd.time - np.array([datetime.timedelta(days=15)] * len(ds.time))
        # ds_rgd = ds_rgd.assign_coords(time=new_time)
        new_time = ds.time - np.array([datetime.timedelta(days=15)] * len(ds.time))
        ds = ds.assign_coords(time=new_time)
        ds=ds.groupby('time.year').mean(dim='time', skipna=True)
        if comp == "atm" or comp == "lnd":
            ds['lat'] = ds['lat'].round(4)
            ds['lon'] = ds['lon'].round(4)
        if comp == "ocn" or comp == "ice":
            ds['TLAT'] = ds['TLAT'].round(4)
            ds['TLONG'] = ds['TLONG'].round(4)
        return ds
    else:
        return ds.set_coords(coord_vars)

In [6]:
# Read Observation dataset

start_time = time.time()

tmp_comp=cfg_var_SST.comp
cfg_var_SST.OBS_path_load(cfg_var_SST.var)

cfg_var_SST.OBS_ds = xr.open_mfdataset(cfg_var_SST.OBS_file_list[0][0], 
                       chunks={'time': 12}, 
                       parallel=True,
                       preprocess=lambda ds: process_coords_2d_obs(ds, start_date, end_date, 'SST', tmp_comp),
                       decode_cf=True, 
                       decode_times=True,)

# global mean removal for S-SH
if cfg_var_SST.OBS_var=='sla':
    print('global mean is removed for sea level')
    cfg_var_SST.OBS_ds = cfg_var_SST.OBS_ds.rename({cfg_var_SST.OBS_var: cfg_var_SST.var})
    lat_mask = (ds_grid.TLAT >= -60) & (ds_grid.TLAT <= 60)

    area_selected = ds_grid.TAREA.where(lat_mask, drop=True)  # 선택된 지역의 면적
    SST_selected = cfg_var_SST.OBS_ds[cfg_var_SST.var].where(lat_mask, drop=True)  # 선택된 지역의 SST 데이터

    cfg_var_SST.OBS_ds['gm'] = (SST_selected * area_selected).sum(dim=['nlat', 'nlon']) / area_selected.sum(dim=['nlat', 'nlon'])    
    cfg_var_SST.OBS_ds['SST'] = cfg_var_SST.OBS_ds[cfg_var_SST.var] - cfg_var_SST.OBS_ds['gm']
else:
    cfg_var_SST.OBS_ds = cfg_var_SST.OBS_ds.rename({cfg_var_SST.OBS_var: cfg_var_SST.var})

# start_year = int(cfg_var_SST.OBS_ds.time.dt.year.values[0])
# end_year   = int(cfg_var_SST.OBS_ds.time.dt.year.values[-1])

# time_slice = slice(f"{start_year}-01-01", f"{end_year}-12-31")

# if cfg_var_SST.var == 'SST':
#     cfg_var_SST.OBS_ds = cfg_var_SST.OBS_ds.assign_coords(
#         time=cfg_var_SST.ODA_ds.sel(time=time_slice).time
#     )

# valid_data_count = (~cfg_var.OBS_ds[cfg_var.var].isnull()).sum(dim='time')
# total_time_steps = cfg_var.OBS_ds['time'].size

# threshold = 0.8
# mask = (valid_data_count / total_time_steps) >= threshold

# cfg_var.OBS_ds['mask_80_percent'] = mask
# cfg_var.OBS_ds['mask_80_percent'].compute()

end_time = time.time()
elapsed_time = end_time - start_time
print('elasped time for reading OBS: ' + str(elapsed_time))

elasped time for reading OBS: 9.437170028686523


In [7]:
# get rolling mean variables, observational period

cfg_var_SST.OBS_ds_4yr = cfg_var_SST.OBS_ds.rolling(year=4, min_periods=4).mean()
obs_rolling_time_mean = cfg_var_SST.OBS_ds['year'].rolling(year=4, min_periods=4).mean()
cfg_var_SST.OBS_ds_4yr = cfg_var_SST.OBS_ds_4yr.assign_coords(year=obs_rolling_time_mean)
valid_index = np.where(~np.isnan(cfg_var_SST.OBS_ds_4yr['year']))[0]
cfg_var_SST.OBS_ds_4yr = cfg_var_SST.OBS_ds_4yr.isel(year=valid_index)
if cfg_var_SST.OBS_ds_4yr.isel(year=0).year == 1961.5:
    if cfg_var_SST.OBS_ds_4yr.isel(year=-1).year == 2017.5:
        cfg_var_SST.OBS_ds_4yr = cfg_var_SST.OBS_ds_4yr.isel(year=range(1, 57))
    else:
        cfg_var_SST.OBS_ds_4yr = cfg_var_SST.OBS_ds_4yr.isel(year=range(1, 58))

In [8]:

# individual (OBS)
da = cfg_var_SST.OBS_ds['SST'].sel(year=slice(1965, 2020))

da = da.where(~np.isclose(da, 0), np.nan) if cfg_var_SST.OBS_var == 'SST' else da
da_lag1 = da.shift(year=-1)
valid = (~np.isnan(da)) & (~np.isnan(da_lag1))
da      = da.where(valid)
da_lag1 = da_lag1.where(valid)
autocorr_SST_OBS = xr.corr(da, da_lag1, dim='year')   # ens_OBS, lat, lon


# individual 4yr (OBS)
da = cfg_var_SST.OBS_ds_4yr['SST'].sel(year=slice(1965, 2020))

da = da.where(~np.isclose(da, 0), np.nan) if cfg_var_SST.OBS_var == 'SST' else da
da_lag1 = da.shift(year=-1)
valid = (~np.isnan(da)) & (~np.isnan(da_lag1))
da      = da.where(valid)
da_lag1 = da_lag1.where(valid)

autocorr_SST_OBS_ds_4yr = xr.corr(da, da_lag1, dim='year')   # ens_OBS, lat, lon


In [9]:
# save temporary file (HCST)
start_time = time.time()

autocorr_SST_OBS.to_netcdf(savefilepath + "/autocorr_SST_OBS" + ".nc")
autocorr_SST_OBS_ds_4yr.to_netcdf(savefilepath + "/autocorr_SST_OBS_ds_4yr" + ".nc")


end_time = time.time()
elapsed_time = end_time - start_time
print('elasped time for saving SST corr, '  + ': ' + str(elapsed_time))

elasped time for saving SST corr, : 1.6517376899719238


In [10]:
savefilepath

'/mnt/lustre/proj/kimyy/tmp_python/HCST_skills_autocorr'