## Imports

In [None]:
import xarray as xr
import pathlib
import numpy as np
import pandas as pd
import matplotlib as mpl
import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import os
import xesmf
import time
import src.utils
import copy

## specify filepath for data
DATA_FP = pathlib.Path(os.environ["DATA_FP"])

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI for presentation
mpl.rcParams["figure.dpi"] = 100

## Shared functions

In [None]:
def trim(
    data, lon_range=[130, 290], lat_range=[-5, 5], lon_name="TLONG", lat_name="TLAT"
):
    """select part of data in given longitude/latitude range"""

    ## helper function to check if 'x' is in 'x_range'
    isin_range = lambda x, x_range: (x_range[0] <= x) & (x <= x_range[1])

    ## get mask for data in given lon/lat range
    in_lon_range = isin_range(data[lon_name], lon_range)
    in_lat_range = isin_range(data[lat_name], lat_range)
    in_lonlat_range = in_lon_range & in_lat_range

    ## load to memory
    in_lonlat_range.load()

    ## Retain all points with at least one valid grid cell
    x_idx = in_lonlat_range.any("nlat")
    y_idx = in_lonlat_range.any("nlon")

    ## select given points
    return data.isel(nlon=x_idx, nlat=y_idx)

## Subsurface ocean data

In [None]:
def get_ensemble_ids():
    """get files for given variable name"""

    ## path to cesm2 lens data
    cesm2_fp = pathlib.Path(
        "/glade/campaign/collections/rda/data/d651056/CESM2-LE/ocn/proc/tseries/month_1"
    )

    ## path to FSNS (arbitrary, just want the ids)
    data_fp = cesm2_fp / "WVEL"

    ## get list of ensemble ids
    ensemble_ids = []
    for f in data_fp.glob("*.nc"):
        ensemble_ids.append(str(f)[-53:-28])

    ## get unique values and sort
    ensemble_ids = sorted(list(set(ensemble_ids)))

    return ensemble_ids


def load_var(varname, ensemble_id):
    """Load variable for given ensemble ID"""

    ## get path to data
    cesm2_fp = pathlib.Path(
        "/glade/campaign/collections/rda/data/d651056/CESM2-LE/ocn/proc/tseries/month_1"
    )

    ## path to data
    data_fp = cesm2_fp / varname

    ## get z-coordinate name
    if varname in ["WVEL", "WTT"]:
        z_coord = "z_w_top"

    else:
        z_coord = "z_t"

    ## open data for ensemble member
    data = xr.open_mfdataset(
        data_fp.glob(f"*{ensemble_id}*.nc"),
        decode_timedelta=True,
        chunks={"time": 18, z_coord: 60, "nlat": 384, "nlon": 320},
        parallel=True,
    )[varname]

    ## trim to eq Pac
    data = trim(data, lat_range=[-1.5, 1.5], lon_range=[130, 290])

    ## subset longitude and get top 300 m
    data = data.isel({z_coord: slice(None, 27), "nlon": slice(None, None, 2)})

    ## average over latitudes
    data = data.assign_coords({"lon": data["TLONG"].mean("nlat")})
    data = data.mean("nlat")

    return data


def preprocess_ensemble(varname, temp_dir):
    """compute net heat flux for full ensemble. Save to temp directory"""

    ## get ensemble ids
    ensemble_ids = get_ensemble_ids()

    ## loop through members
    # idx 60 throwing "NetCDF: HDF error" for WVEL
    # for i in tqdm.tqdm(ensemble_ids[61:]):
    for i in tqdm.tqdm(ensemble_ids):

        ## save filepath
        save_fp = pathlib.Path(temp_dir, f"{varname}_{i}.nc")

        if save_fp.is_file():
            pass

        else:
            data = load_var(varname=varname, ensemble_id=i)
            data.to_netcdf(save_fp)

    return

#### Initialize cluster

In [None]:
from dask.distributed import LocalCluster, Client

cluster = LocalCluster(n_workers=24)
client = Client(cluster)
client

In [None]:
## compute wvel for each file
# preprocess_ensemble(
# varname="TEMP", temp_dir=pathlib.Path(DATA_FP, "cesm", "temp_temp_v2")
# )

# preprocess_ensemble(
#     varname="WVEL", temp_dir=pathlib.Path(DATA_FP, "cesm", "wvel_temp_v2")
# )

# preprocess_ensemble(
#     varname="WTT", temp_dir=pathlib.Path(DATA_FP, "cesm", "wtt_temp")
# )

# preprocess_ensemble(
#     varname="UET", temp_dir=pathlib.Path(DATA_FP, "cesm", "uet_temp")
# )

# preprocess_ensemble(
#     varname="UVEL", temp_dir=pathlib.Path(DATA_FP, "cesm", "uvel_sub_temp")
# )

# preprocess_ensemble(
#     varname="VNT", temp_dir=pathlib.Path(DATA_FP, "cesm", "vnt_temp")
# )

# preprocess_ensemble(
#     varname="ADV_3D_TEMP", temp_dir=pathlib.Path(DATA_FP, "cesm", "adv_temp")
# )

# preprocess_ensemble(
#     varname="TEND_TEMP", temp_dir=pathlib.Path(DATA_FP, "cesm", "ddt_T_temp")
# )

### can't open one of the files; not sure why

In [None]:
# ids = get_ensemble_ids()
# temp_dir = "wvel_temp"
# save_fp = pathlib.Path(DATA_FP, "cesm", "wvel_temp", f"WVEL_{ids[60]}.nc")

# ## load data
# d = load_var("WVEL", ids[60])

## Test off-by-one error with SST

In [None]:
def load_var_test(varname, ensemble_id):
    """Load variable for given ensemble ID"""

    ## get path to data
    cesm2_fp = pathlib.Path(
        "/glade/campaign/collections/rda/data/d651056/CESM2-LE/ocn/proc/tseries/month_1"
    )

    ## path to data
    data_fp = cesm2_fp / varname

    ## open data for ensemble member
    data = xr.open_mfdataset(
        data_fp.glob(f"*{ensemble_id}*.nc"),
        decode_timedelta=True,
        # chunks={"time": 360, "z_t": 60, "nlat": 384, "nlon": 320},
        parallel=True,
    )[varname]

    ## trim to eq Pac
    data = trim(data, lat_range=[-5,5], lon_range=[210, 270])

    return data.isel(z_t=0).mean(["nlat","nlon"])

In [None]:
T = load_var_test("TEMP", get_ensemble_ids()[0])
T.load();

In [None]:
T_ = T.isel(time=slice(None,-1))

Th = xr.open_dataset(pathlib.Path(os.environ["DATA_FP"], "cesm", "Th.nc"))
T2 = Th["T_3"].sel(member=0, time=T_.time).compute()

Th = xr.open_dataset(pathlib.Path(os.environ["DATA_FP"], "cesm", "Th_anom.nc"))
T3 = Th["T_3"].sel(member=0, time=T_.time).compute()

In [None]:
T_prime = T_.groupby("time.month") - T_.groupby("time.month").mean()
T2_prime = T2.groupby("time.month") - T2.groupby("time.month").mean()

In [None]:
from scipy.stats import pearsonr as r
print(r(T_.values, T2.values)[0])
print(r(T_.values[1:], T2.values[:-1])[0])
print(r(T_.values[:-1], T2.values[1:])[0])
print()
print(r(T_prime.values, T2_prime.values)[0])
print(r(T_prime.values[1:], T2_prime.values[:-1])[0])
print(r(T_prime.values[:-1], T2_prime.values[1:])[0])

In [None]:
shift = lambda x : x.isel(time=slice(1,None))

t = dict(time=slice("1860","1865"))
fig,ax = plt.subplots(figsize=(7,3))
ax.plot(T.sel(t))
ax.plot(T2.sel(t))
ax.plot(shift(T.sel(t)), ls="--")
plt.show()
# ax.plot(T3.sel(t))

In [None]:
U = load_var_test("UVEL", get_ensemble_ids()[0])
U.load();

## Surface data ($u$, $v$)

In [None]:
def get_ensemble_ids():
    """get files for given variable name"""

    ## path to cesm2 lens data
    cesm2_fp = pathlib.Path(
        "/glade/campaign/collections/rda/data/d651056/CESM2-LE/ocn/proc/tseries/month_1"
    )

    ## path to FSNS (arbitrary, just want the ids)
    data_fp = cesm2_fp / "UVEL"

    ## get list of ensemble ids
    ensemble_ids = []
    for f in data_fp.glob("*.nc"):
        ensemble_ids.append(str(f)[-53:-28])

    ## get unique values and sort
    ensemble_ids = sorted(list(set(ensemble_ids)))

    return ensemble_ids


def load_var(varname, ensemble_id):
    """Load variable for given ensemble ID"""

    ## get path to data
    cesm2_fp = pathlib.Path(
        "/glade/campaign/collections/rda/data/d651056/CESM2-LE/ocn/proc/tseries/month_1"
    )

    ## path to data
    data_fp = cesm2_fp / varname

    ## open data for ensemble member
    data = xr.open_mfdataset(
        data_fp.glob(f"*{ensemble_id}*.nc"),
        decode_timedelta=True,
        chunks={"time": 120, "nlat": 384, "nlon": 320},
        parallel=True,
    )[varname]

    ## get top layer
    data = data.isel(z_t=0)

    ## trim to eq Pac
    data = trim(
        data,
        lat_range=[-15, 15],
        lon_range=[120, 300],
        lon_name="ULONG",
        lat_name="ULAT",
    )

    return data


def preprocess_ensemble(varname, temp_dir):
    """compute net heat flux for full ensemble. Save to temp directory"""

    ## get ensemble ids
    ensemble_ids = get_ensemble_ids()

    ## loop through members
    for i in tqdm.tqdm(ensemble_ids):

        ## save filepath
        save_fp = pathlib.Path(temp_dir, f"{varname}_{i}.nc")

        if save_fp.is_file():
            pass

        else:
            data = load_var(varname=varname, ensemble_id=i)
            data.to_netcdf(save_fp)

    return

### Initialize cluster

In [None]:
from dask.distributed import LocalCluster, Client

cluster = LocalCluster(n_workers=12)
client = Client(cluster)
client

### Compute

In [None]:
preprocess_ensemble(varname="UVEL", temp_dir=pathlib.Path(DATA_FP, "cesm", "uvel_temp"))

preprocess_ensemble(varname="VVEL", temp_dir=pathlib.Path(DATA_FP, "cesm", "vvel_temp"))

## 3D ocean data

In [None]:
def get_ensemble_ids():
    """get files for given variable name"""

    ## path to cesm2 lens data
    cesm2_fp = pathlib.Path(
        "/glade/campaign/collections/rda/data/d651056/CESM2-LE/ocn/proc/tseries/month_1"
    )

    ## path to FSNS (arbitrary, just want the ids)
    data_fp = cesm2_fp / "WVEL"

    ## get list of ensemble ids
    ensemble_ids = []
    for f in data_fp.glob("*.nc"):
        ensemble_ids.append(str(f)[-53:-28])

    ## get unique values and sort
    ensemble_ids = sorted(list(set(ensemble_ids)))

    return ensemble_ids

def load_grid(lon_range, lat_range):
    """Create mask from OISST data on cloud"""

    ## load sst data
    sst = xr.open_dataset(
        r"http://psl.noaa.gov/thredds/dodsC/Datasets/noaa.oisst.v2/new/sst.oisst.mon.ltm.1991-2020.nc",
        decode_times=False,
    )
    sst = sst["sst"].isel(time=0).drop_vars("time")

    ## convert to lsm (fill ones over ocean)
    lsm = sst.where(np.isnan(sst), other=1.0)

    ## sel lon/lat range
    lsm = lsm.sel(lon=slice(*lon_range), lat=slice(*lat_range))

    # ## add binary mask for regridding
    lsm["mask"] = ~np.isnan(lsm)

    return lsm


def load_var(varname, ensemble_id):
    """Load variable for given ensemble ID"""

    ## get path to data
    cesm2_fp = pathlib.Path(
        "/glade/campaign/collections/rda/data/d651056/CESM2-LE/ocn/proc/tseries/month_1"
    )

    ## path to data
    data_fp = cesm2_fp / varname

    ## get z-coordinate name
    if varname == "WVEL":
        z_coord = "z_w_top"
    else:
        z_coord = "z_t"

    ## open data for ensemble member
    data = xr.open_mfdataset(
        data_fp.glob(f"*{ensemble_id}*.nc"),
        decode_timedelta=True,
        chunks={"time": 12, z_coord: 60, "nlat": 384, "nlon": 320},
        parallel=True,
    )[varname]

    ## trim to eq Pac
    data = trim(data, lat_range=[-6.5, 6.5], lon_range=[130, 290])

    ## subset longitude and get top 300 m
    data = data.isel({z_coord: slice(None, 27), "nlon": slice(None, None, 2)})

    ## rename coords for regridding
    data = data.rename({"TLONG": "lon", "TLAT": "lat"})
    
    ## get regridding operator
    grid = load_grid(lon_range=[130, 290], lat_range=[-6, 6])

    ## downsample in longitude (consistent with data)
    grid = grid.isel(lon=slice(None,None,2))

    ## get regridding object
    regridder = xesmf.Regridder(data, grid, "bilinear")

    ## do the regridding
    data_regrid = regridder(data).drop_vars("mask")

    return data_regrid


def preprocess_ensemble(varname, temp_dir):
    """compute net heat flux for full ensemble. Save to temp directory"""

    ## get ensemble ids
    ensemble_ids = get_ensemble_ids()

    ## loop through members
    # idx 60 throwing "NetCDF: HDF error" for WVEL
    # for i in tqdm.tqdm(ensemble_ids[61:]):
    for i in tqdm.tqdm(ensemble_ids):

        ## save filepath
        save_fp = pathlib.Path(temp_dir, f"{varname}_{i}.nc")

        if save_fp.is_file():
            pass

        else:
            data = load_var(varname=varname, ensemble_id=i)
            data.to_netcdf(save_fp)

    return

#### Initialize cluster

In [None]:
from dask.distributed import LocalCluster, Client

cluster = LocalCluster(n_workers=16)
client = Client(cluster)
client

### compute

In [None]:
# preprocess_ensemble(
#     varname="VVEL", temp_dir=pathlib.Path(DATA_FP, "cesm", "vvel_3d_temp")
# )

# preprocess_ensemble(
#     varname="TEMP", temp_dir=pathlib.Path(DATA_FP, "cesm", "temp_3d_temp")
# )

# preprocess_ensemble(
#     varname="UVEL", temp_dir=pathlib.Path(DATA_FP, "cesm", "uvel_3d_temp")
# )

preprocess_ensemble(
    varname="WVEL", temp_dir=pathlib.Path(DATA_FP, "cesm", "wvel_3d_temp")
)