## Imports

In [None]:
import xarray as xr
import pathlib
import numpy as np
import pandas as pd
import matplotlib as mpl
import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import os
import xesmf as xe
import time
import src.utils
import copy

## specify filepath for data
DATA_FP = pathlib.Path(os.environ["DATA_FP"])

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI for presentation
mpl.rcParams["figure.dpi"] = 100

## Subsurface ocean data

In [None]:
def get_ensemble_ids():
    """get files for given variable name"""

    ## path to cesm2 lens data
    cesm2_fp = pathlib.Path(
        "/glade/campaign/collections/rda/data/d651056/CESM2-LE/ocn/proc/tseries/month_1"
    )

    ## path to FSNS (arbitrary, just want the ids)
    data_fp = cesm2_fp / "WVEL"

    ## get list of ensemble ids
    ensemble_ids = []
    for f in data_fp.glob("*.nc"):
        ensemble_ids.append(str(f)[-53:-28])

    ## get unique values and sort
    ensemble_ids = sorted(list(set(ensemble_ids)))

    return ensemble_ids


def load_var(varname, ensemble_id):
    """Load variable for given ensemble ID"""

    ## get path to data
    cesm2_fp = pathlib.Path(
        "/glade/campaign/collections/rda/data/d651056/CESM2-LE/ocn/proc/tseries/month_1"
    )

    ## path to data
    data_fp = cesm2_fp / varname

    ## get z-coordinate name
    if varname == "TEMP":
        z_coord = "z_t"

    elif varname == "WVEL":
        z_coord = "z_w_top"

    else:
        print("Not a valid variable")
        return

    ## open data for ensemble member
    data = xr.open_mfdataset(
        data_fp.glob(f"*{ensemble_id}*.nc"),
        decode_timedelta=True,
        chunks={"time": 12, z_coord: 60, "nlat": 384, "nlon": 320},
        parallel=True,
    )[varname]

    ## trim to eq Pac
    data = trim(data, lat_range=[-1.5, 1.5], lon_range=[130, 290])

    ## subset longitude and get top 300 m
    data = data.isel(z_t=slice(None, 27), nlon=slice(None, None, 2))

    ## get target grid (for regridding)
    target_grid = np.nan * copy.deepcopy(data.isel(time=0))
    target_grid = target_grid.expand_dims(
        dict(
            lat=data.TLAT.mean("nlon").values,
            lon=data.TLONG.mean("nlat").values,
        )
    )

    ## Get rid of TLONG/TLAT on target grid
    target_grid = target_grid.mean(["nlat", "nlon"])

    ## rename coords for regridding (on original grid)
    data = data.rename({"TLAT": "lat", "TLONG": "lon"})

    ## regrid
    regridder = xe.Regridder(data, target_grid, method="bilinear")
    data_regrid = regridder(data)

    ## average over latitudes
    data_mean = data_regrid.mean("lat")

    return data_mean


def trim(
    data, lon_range=[130, 290], lat_range=[-5, 5], lon_name="TLONG", lat_name="TLAT"
):
    """select part of data in given longitude/latitude range"""

    ## helper function to check if 'x' is in 'x_range'
    isin_range = lambda x, x_range: (x_range[0] <= x) & (x <= x_range[1])

    ## get mask for data in given lon/lat range
    in_lon_range = isin_range(data[lon_name], lon_range)
    in_lat_range = isin_range(data[lat_name], lat_range)
    in_lonlat_range = in_lon_range & in_lat_range

    ## load to memory
    in_lonlat_range.load()

    ## Retain all points with at least one valid grid cell
    x_idx = in_lonlat_range.any("nlat")
    y_idx = in_lonlat_range.any("nlon")

    ## select given points
    return data.isel(nlon=x_idx, nlat=y_idx)


def preprocess_ensemble(varname, temp_dir):
    """compute net heat flux for full ensemble. Save to temp directory"""

    ## get ensemble ids
    ensemble_ids = get_ensemble_ids()

    ## loop through members
    for i in tqdm.tqdm(ensemble_ids):

        ## save filepath
        save_fp = pathlib.Path(temp_dir, f"{varname}_{i}.nc")

        if save_fp.is_file():
            pass

        else:
            data = load_var(varname=varname, ensemble_id=i)
            data.to_netcdf(save_fp)

    return

#### Initialize cluster

In [None]:
from dask.distributed import LocalCluster, Client

cluster = LocalCluster(n_workers=4)
client = Client(cluster)
client

In [None]:
## compute wvel for each file
preprocess_ensemble(varname="TEMP", temp_dir=pathlib.Path(DATA_FP, "cesm", "temp_temp"))