## Imports

In [None]:
import xarray as xr
import pathlib
import numpy as np
import pandas as pd
import matplotlib as mpl
import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import os
import xesmf
import time
import src.utils
import copy

## specify filepath for data
DATA_FP = pathlib.Path(os.environ["DATA_FP"])

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI for presentation
mpl.rcParams["figure.dpi"] = 100

## Shared functions

In [None]:
def trim(
    data, lon_range=[130, 290], lat_range=[-5, 5], lon_name="TLONG", lat_name="TLAT"
):
    """select part of data in given longitude/latitude range"""

    ## helper function to check if 'x' is in 'x_range'
    isin_range = lambda x, x_range: (x_range[0] <= x) & (x <= x_range[1])

    ## get mask for data in given lon/lat range
    in_lon_range = isin_range(data[lon_name], lon_range)
    in_lat_range = isin_range(data[lat_name], lat_range)
    in_lonlat_range = in_lon_range & in_lat_range

    ## load to memory
    in_lonlat_range.load()

    ## Retain all points with at least one valid grid cell
    x_idx = in_lonlat_range.any("nlat")
    y_idx = in_lonlat_range.any("nlon")

    ## select given points
    return data.isel(nlon=x_idx, nlat=y_idx)

## 3D ocean data

In [None]:
def get_ensemble_ids():
    """get files for given variable name"""

    ## path to cesm2 lens data
    cesm2_fp = pathlib.Path(
        "/glade/campaign/collections/rda/data/d651056/CESM2-LE/ocn/proc/tseries/month_1"
    )

    ## path to FSNS (arbitrary, just want the ids)
    data_fp = cesm2_fp / "WVEL"

    ## get list of ensemble ids
    ensemble_ids = []
    for f in data_fp.glob("*.nc"):
        ensemble_ids.append(str(f)[-53:-28])

    ## get unique values and sort
    ensemble_ids = sorted(list(set(ensemble_ids)))

    return ensemble_ids


def load_grid(lon_range, lat_range):
    """Create mask from OISST data on cloud"""

    ## load sst data
    sst = xr.open_dataset(
        r"http://psl.noaa.gov/thredds/dodsC/Datasets/noaa.oisst.v2/new/sst.oisst.mon.ltm.1991-2020.nc",
        decode_times=False,
    )
    sst = sst["sst"].isel(time=0).drop_vars("time")

    ## convert to lsm (fill ones over ocean)
    lsm = sst.where(np.isnan(sst), other=1.0)

    ## sel lon/lat range
    lsm = lsm.sel(lon=slice(*lon_range), lat=slice(*lat_range))

    # ## add binary mask for regridding
    lsm["mask"] = ~np.isnan(lsm)

    return lsm


def load_var(varname, ensemble_id):
    """Load variable for given ensemble ID"""

    ## get path to data
    cesm2_fp = pathlib.Path(
        "/glade/campaign/collections/rda/data/d651056/CESM2-LE/ocn/proc/tseries/month_1"
    )

    ## path to data
    data_fp = cesm2_fp / varname

    ## get z-coordinate name
    if varname == "WVEL":
        z_coord = "z_w_top"
    else:
        z_coord = "z_t"

    ## open data for ensemble member
    data = xr.open_mfdataset(
        data_fp.glob(f"*{ensemble_id}*.nc"),
        decode_timedelta=True,
        chunks={"time": 12, z_coord: 60, "nlat": 384, "nlon": 320},
        parallel=True,
    )[varname]

    ## trim to eq Pac
    data = trim(data, lat_range=[-5.3, 5.3], lon_range=[130, 290])

    ## update z-coord units
    if "HMXL" in varname:

        ## convert from cm to m
        data = data / 100

    else:

        ## subset longitude and get top 150 m
        data = data.isel({z_coord: slice(None, 13)})

        ## convert vertical coord from cm to m
        data = data.assign_coords({z_coord: data[z_coord].values / 100})

    ## update velocity units (cm/s -> m / month)
    if "VEL" in varname:
        m_per_cm = 1e-2
        s_per_mo = 3600 * 24 * 30
        data = data * m_per_cm * s_per_mo

    ## rename coords for regridding
    data = data.rename({"TLONG": "lon", "TLAT": "lat"})
    data = data.drop_vars(["ULONG", "ULAT"])

    return data

    # ## get target grid for regridding
    # lat_ = data.lat[:,71]
    # lon_ = data.lon[24,:]
    # grid = xr.Dataset(coords=dict(lat=lat_, lon=lon_))

    ## grid to regular lon/lat
    # regridder = xesmf.Regridder(data, grid, "bilinear")
    # data_regrid = regridder(data)

    # return data_regrid


def preprocess_ensemble(varname, temp_dir):
    """compute net heat flux for full ensemble. Save to temp directory"""

    ## get ensemble ids
    ensemble_ids = get_ensemble_ids()

    ## loop through members
    # idx 60 throwing "NetCDF: HDF error" for WVEL
    # for i in tqdm.tqdm(ensemble_ids[61:]):
    for i in tqdm.tqdm(ensemble_ids):

        ## save filepath
        save_fp = pathlib.Path(temp_dir, f"{varname}_{i}.nc")

        if save_fp.is_file():
            pass

        else:
            data = load_var(varname=varname, ensemble_id=i)
            data.to_netcdf(save_fp)

    return


def get_dy(dlat_deg):
    """get spacing between latitudes in meters"""

    ## convert from degrees to radians
    dlat_rad = dlat / 180.0 * np.pi

    ## multiply by radius of earth
    R = 6.378e8  # earth radius (centimeters)
    dlat_meters = R * dlat_rad

    return dlat_meters


def get_dx(lat_deg, dlon_deg):
    """get spacing between longitudes in meters"""

    ## convert from degrees to radians
    dlon_rad = dlon_deg / 180.0 * np.pi
    lat_rad = lat_deg / 180 * np.pi

    ## multiply by radius of earth
    R = 6.378e6  # earth radius (meters)
    dlon_meters = R * np.cos(lat_rad) * dlon_rad

    return dlon_meters

#### Initialize cluster

In [None]:
from dask.distributed import LocalCluster, Client

cluster = LocalCluster(n_workers=16)
client = Client(cluster)
client

#### test

In [None]:
ids = get_ensemble_ids()

In [None]:
# W = load_var(varname="WVEL", ensemble_id=ids[0])
U = load_var(varname="UVEL", ensemble_id=ids[0])
# V = load_var(varname="VVEL", ensemble_id=ids[0])
T = load_var(varname="TEMP", ensemble_id=ids[0])

In [None]:
x_c = dict(nlon=slice(1, -1))
x_p = dict(nlon=slice(2, None))
x_m = dict(nlon=slice(None, -2))

lon_c = T.lon.isel(x_c)
lon_p = T.lon.isel(x_p)
lon_m = T.lon.isel(x_m)

## get finite difference for T
reset_lon = lambda x: x.assign_coords({"lon": lon_c})
dT = reset_lon(T.isel(x_p)) - reset_lon(T.isel(x_m))

## Get U * dT
UdT = U.isel(x_c) * dT

## normalize by x distance
dlon = lon_p.values - lon_m.values
lat = T.lat.isel(x_c).values
dx = get_dx(lat_deg=lat, dlon_deg=dlon)
UdT_dx = UdT / dx

## get mixed layer
mld = load_var(varname="HMXL_DR", ensemble_id=ids[0])
mld_clim = mld.mean("time").isel(x_c)

## average over mixed layer (go to 10m below)
UdT_dx_ml = UdT_dx.where(UdT_dx.z_t <= (mld_clim + 11))
UdT_dx_ml_avg = UdT_dx_ml.mean("z_t")

## integrate over latitudes
UdT_dx_ml_avg = UdT_dx_ml_avg.where(~np.isnan(UdT_dx_ml_avg), other=0).integrate(
    "nlat"
) / (len(UdT_dx_ml.nlat) - 1)

## assign coords
UdT_dx_ml_avg = UdT_dx_ml_avg.assign_coords(dict(lon=T.lon.isel(x_c).mean("nlat")))

In [None]:
UdT_dx_ml_avg.load()

In [None]:
U_ = UdT_dx_ml.mean("z_t")

## fill NaN values
# U_ = U_.where(~np.isnan(U_), other=0).integrate("nlat") / len(U_
# U___ = U_.isel(time=slice(None,600)).compute()

In [None]:
UdT_dx_ml_avg.to_netcdf(DATA_FP / "cesm" / "uadv_test2.nc")

In [None]:
plt.plot(UdT_dx_ml_avg.lon, UdT_dx_ml_avg.std("time"))
# plt.plot(U__.lon, U___.std("time"))

In [None]:
help(UdT_dx_ml.mean("z_t").integrate)