# Consolidate
Merge EOF data for faster loading

## Imports

In [1]:
import xarray as xr
import pathlib
import src.utils
import os
import numpy as np

## Functions

In [2]:
def fix_lon_coord(data_sub):
    """fix longitude coordinate on subsurface data"""

    data_sub = data_sub.assign_coords({"nlon": data_sub.lon.isel(z_t=0).values})
    data_sub = data_sub.drop_vars("lon").rename({"nlon": "lon"})

    return data_sub


def convert_cm_to_m_helper(data, z_coord_name):
    """convert z-coord from cm to m"""
    return data.assign_coords({z_coord_name: data[z_coord_name].values / 100})


def convert_cm_to_m(data):
    """convert all z-coords from cm to m"""

    ## convert both z-coordinates
    for z_coord in ["z_t", "z_w_top"]:
        data = convert_cm_to_m_helper(data, z_coord_name=z_coord)

    return data


def load_sub():
    """load subsurface data"""

    ## path to EOF data
    eofs_fp = pathlib.Path(os.environ["DATA_FP"], "cesm")

    ## variables to load (and how to rename them)
    names = [
        "temp",
        "wvel",
        "uvel_sub",
    ]
    newnames = ["T", "w", "u"]

    ## load the EOFs
    load_var = lambda x: src.utils.load_eofs(pathlib.Path(eofs_fp, f"eofs_{x}.nc"))
    eofs_sub = {y: load_var(x) for (y, x) in zip(newnames, names)}

    ## for convenience, put spatial patterns / components in single dataset
    components_sub = xr.merge(
        [eofs_.components().rename(y) for (y, eofs_) in eofs_sub.items()]
    )

    ## fix longitude coord
    components_sub = fix_lon_coord(components_sub)

    ## rename w grid
    w_comp = components_sub["w"].rename({"z_w_top": "z_t"})
    w_comp = w_comp.assign_coords({"z_t": components_sub.z_t})

    ## update in xarray
    components_sub = components_sub.drop_vars("w")
    components_sub["w"] = w_comp

    # reset member dimension so they all match (NHF labeled differently...)
    member_coord = dict(member_id=np.arange(100))
    get_scores = lambda x, n: x.scores().assign_coords(member_coord).rename(n)
    scores_sub = xr.merge([get_scores(eofs_, n) for (n, eofs_) in eofs_sub.items()])

    ## convert z coords from cm to m
    components_sub = convert_cm_to_m(components_sub)

    ## convert u and w from cm/s to m/month

    # conversion factors
    m_per_cm = 1 / 100
    s_per_day = 86400
    s_per_month = s_per_day * 30

    # do conversion
    scores_sub["w"] = scores_sub["w"] * m_per_cm * s_per_month
    scores_sub["u"] = scores_sub["u"] * m_per_cm * s_per_month

    return components_sub, scores_sub


def load_surf():
    """Load surface -- or near-surface -- data"""

    ## path to EOF data
    eofs_fp = pathlib.Path(os.environ["DATA_FP"], "cesm")

    ## variables to load (and how to rename them)
    names = ["tos", "zos", "tauu", "tauv", "nhf", "mlotst", "pr"]
    newnames = ["sst", "ssh", "taux", "tauy", "nhf", "mld", "pr"]

    # ## load the EOFs
    load_var = lambda x: src.utils.load_eofs(pathlib.Path(eofs_fp, f"eofs_{x}.nc"))
    eofs = {y: load_var(x) for (y, x) in zip(newnames, names)}

    ## for convenience, put spatial patterns / components in single dataset
    components = xr.merge([eofs_.components().rename(y) for (y, eofs_) in eofs.items()])

    # reset member dimension so they all match (NHF labeled differently...)
    member_coord = dict(member=np.arange(100))
    get_scores = lambda x, n: x.scores().assign_coords(member_coord).rename(n)
    scores = xr.merge([get_scores(eofs_, n) for (n, eofs_) in eofs.items()])

    ## convert ssh from m to cm
    scores["ssh"].values *= 100

    ## convert from stress on atm to stress on ocn
    scores["taux"].values *= -1

    ## convert MLD from cm to m
    scores["mld"] = scores["mld"] / 100

    return components, scores


def load_vel():
    """Load (ocean) surface velocity data"""

    ## path to EOF data
    eofs_fp = pathlib.Path(os.environ["DATA_FP"], "cesm")

    ## load advection data
    uvel_eofs = xr.open_dataset(eofs_fp / "eofs_uvel.nc")
    vvel_eofs = xr.open_dataset(eofs_fp / "eofs_vvel.nc")

    ## func to merge u and v data
    merge = lambda u, v: xr.merge(
        [
            x.rename(n).drop_vars(["variable", "z_t"])
            for x, n in zip([u, v], ["uvel", "vvel"])
        ]
    )

    ## merge component data
    vel_comps = merge(uvel_eofs.components, vvel_eofs.components)

    ## merge scores
    vel_scores = merge(uvel_eofs.scores, vvel_eofs.scores)
    vel_scores = vel_scores.rename({"member_id": "member"})
    vel_scores = vel_scores.assign_coords(dict(member=np.arange(100)))

    return vel_comps, vel_scores


def regrid_pop(pop_data, target_grid):
    """regrid POP data to target grid"""

    ## get names of old/new coords
    if "lat" in pop_data:
        rename_dict = dict(lat="latitude", lon="longitude")
    else:
        rename_dict = dict(lon="longitude")

    ## rename coords
    pop_data = pop_data.rename(rename_dict)

    ## regrid
    return pop_data.interp_like(target_grid)


def load_and_merge():
    """load data components, merge, and save to specified filepath"""

    ## load data from different datasets
    components_sub, scores_sub = load_sub()
    components, scores = load_surf()
    components_vel, scores_vel = load_vel()

    ## regrid so everything matches MMLEA
    components_vel = regrid_pop(components_vel, components[["latitude", "longitude"]])
    components_sub = regrid_pop(components_sub, components["longitude"])

    ## rename subsurface scores to match other data
    scores_sub = scores_sub.rename({"member_id": "member"})

    ## merge variables
    components_all = xr.merge([components, components_vel, components_sub])
    scores_all = xr.merge([scores, scores_vel, scores_sub])

    return components_all, scores_all


def compute_forced_anom(save_dir):
    """Compute forced and anomaly components, and save to file"""

    ## get filepaths
    forced_fp = save_dir / "forced.nc"
    anom_fp = save_dir / "anom.nc"

    ## try loading pre-computed data
    try:
        forced = xr.open_dataset(forced_fp)
        anom = xr.open_dataset(anom_fp)

    except:

        ## load data
        components, scores = load_and_merge()

        ## separate scores into forced and anomalies
        forced, anom = src.utils.separate_forced(scores)

        ## new names for components to avoid merge conflict
        old_names = list(components)
        rename_dict = {n: f"{n}_comp" for n in old_names}

        ## add component information to scores
        forced = xr.merge([forced, components.rename(rename_dict)])
        anom = xr.merge([anom, components.rename(rename_dict)])

        ## drop attrs (to avoid issues saving)
        forced = forced.drop_attrs()
        anom = anom.drop_attrs()

        ## save to file
        forced.to_netcdf(forced_fp)
        anom.to_netcdf(anom_fp)

    return forced, anom

## Compute

In [3]:
SAVE_DIR = pathlib.Path(os.environ["DATA_FP"], "cesm", "consolidated")
forced, anom = compute_forced_anom(save_dir=SAVE_DIR)