## Imports

In [None]:
import xarray as xr
import pathlib
import numpy as np
import pandas as pd
import matplotlib as mpl
import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import os
import xeofs as xe
import time
import src.utils

## specify filepath for data
DATA_FP = pathlib.Path(os.environ["DATA_FP"])

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI for presentation
mpl.rcParams["figure.dpi"] = 100

## Functions

In [None]:
def get_files(varname):
    """get files for given variable name"""
    mmlea_fp = pathlib.Path("/glade/campaign/collections/rda/data/d651039")
    cesm2_fp = mmlea_fp / pathlib.Path("cesm2_lens/Omon")

    return sorted((cesm2_fp / varname).glob("*.nc"))


def check_member_file(i):
    """check zos and tos files match for given member idx"""
    return str(get_files("tos")[i])[-71:] == str(get_files("zos")[i])[-71:]


def check_member_files():
    """check all files match"""

    checks = np.array([check_member_file(i) for i in range(100)])
    return np.all(checks)


def load_member_varname(member_idx, varname):
    """load ensemble member. Args:
    - member_idx: integer in [0,99]
    """

    ## Get list of files
    files = get_files(varname)

    ## open data
    data = xr.open_dataset(files[member_idx])

    ## remove un-needed coords
    data = data[varname].squeeze(drop=True)

    ## rename lon/lat
    data = data.rename({"lat": "latitude", "lon": "longitude"})

    return data


def load_member_Th(member_idx):
    """Load T and h data for given member index"""

    ## compute indices
    T_idxs = src.utils.get_RO_T_indices(load_member_varname(member_idx, "tos"))
    h_idxs = src.utils.get_RO_h_indices(load_member_varname(member_idx, "zos"))

    ## compute indices
    return xr.merge([T_idxs, h_idxs])


def load_ensemble_Th(save_fp):
    """Load all ensemble members"""

    ## check if file exists
    if save_fp.is_file():

        data = xr.open_dataset(save_fp)

    else:

        ## new dimension: ensemble member
        member_dim = pd.Index(np.arange(100), name="member")

        ## do computation
        data = xr.concat(
            [load_member_Th(i) for i in tqdm.tqdm(member_dim)],
            dim=member_dim,
        )

        ## save to file
        data.to_netcdf(save_fp)

    return data


def preprocess_Th(Th, save_dir):
    """pre-process Th data (compute ensemble mean and anomalies"""

    ## define filepaths for saving
    save_fp_emean = pathlib.Path(save_dir, "Th_emean.nc")
    save_fp_anom = pathlib.Path(save_dir, "Th_anom.nc")

    ## compute ensemble mean and anomalies
    Th_emean = Th.mean("member")
    Th_anom = Th - Th_emean

    ## save to file if not already
    if not save_fp_emean.is_file():
        Th_emean.to_netcdf(save_fp_emean)

    if not save_fp_anom.is_file():
        Th_anom.to_netcdf(save_fp_anom)

    return Th_emean, Th_anom

## $T$, $h$

In [None]:
## specify save file paths
save_dir = DATA_FP / "cesm"

## load data
Th = load_ensemble_Th(save_dir / "Th.nc")

## compute ensemble stats/anomalies
Th_emean, Th_anom = preprocess_Th(Th, save_dir=save_dir)

## EOFs

In [None]:
def trim_to_eq_pac(data):
    """trim data to eq. Pac"""

    ## specfy
    lonlat_idx = dict(longitude=slice(100, 300), latitude=slice(-30, 30))

    return data.sel(lonlat_idx)


def load_ensemble(varname, trim_fn=None):
    """load spatial data for given variable"""

    ## specify loading function
    if trim_fn is None:
        load = lambda i: load_member_varname(i, varname)

    else:
        load = lambda i: trim_fn(load_member_varname(i, varname))

    ## new dimension: ensemble member
    member_dim = pd.Index(np.arange(100), name="member")

    ## load data
    data = xr.concat([load(i) for i in tqdm.tqdm(member_dim)], dim=member_dim)

    return data


def compute_eofs(varname):
    """compute/load eofs for given variable"""

    ## get filename
    filename = DATA_FP / pathlib.Path(f"cesm/eofs_{varname}.nc")

    ## try to load pre-computed EOFs
    if filename.is_file():
        eofs = src.utils.load_eofs(filename)

    ## if not pre-computed, do the computation here...
    else:
        data = load_ensemble(varname, trim_fn=trim_to_eq_pac)

        ## specs for EOFs
        eofs_kwargs = dict(
            n_modes=300, standardize=False, use_coslat=True, center=False
        )

        ## initialize EOF model
        eofs = xe.single.EOF(**eofs_kwargs)

        ## compute
        eofs.fit(data, dim=["time", "member"])

        ## save to file
        eofs.save(filename, engine="netcdf4")

    return eofs

In [None]:
print("loading TOS EOFs...")
eofs_tos = compute_eofs("tos")
print(f"\nloading ZOS EOFs...")
eofs_tos = compute_eofs("zos")

## CVDP indices

In [None]:
def get_cvdp_file(member_id):
    """Get filename corresponding to given ensemble id"""

    ## Get filename for corresponding spatial data
    orig_filename = str(get_files("tos")[member_id])

    ## get year initialization and member idx
    year_init = orig_filename[-36:-32]
    idx = orig_filename[-31:-28]

    ## get updated filename
    filename = f"CESM2-LENS_{year_init}.{idx}.cvdp_data.1850-2100.nc"

    ## get path to data
    cvdp_fp = DATA_FP / pathlib.Path("cesm/cvdp_output")

    return cvdp_fp / filename


def load_member_cvdp_idxs(member_id):
    """Load CVDP indices for given member"""

    ## get filename
    filename = get_cvdp_file(member_id)

    ## open data
    data = xr.open_dataset(filename, decode_times=False)

    ## extract given variable names
    names = [
        "indian_ocean_dipole",
        "nino34",
        "north_pacific_meridional_mode",
        "south_pacific_meridional_mode",
        "tropical_indian_ocean",
        "north_tropical_atlantic",
        "atlantic_nino",
    ]

    return data[names]


def load_cvdp_idxs():
    """Load CVDP data for all members"""

    ## ensemble member index
    member_idx = pd.Index(np.arange(100), name="member")

    ## load indices and concatenate
    data = xr.concat(
        [load_member_cvdp_idxs(i) for i in tqdm.tqdm(member_idx)], dim=member_idx
    )

    return data

In [None]:
## specify save file paths
save_dir = DATA_FP / "cesm"

## load data
cvdp_total = load_cvdp_idxs()

## compute anomalies
cvdp_anom = cvdp_total - cvdp_total.mean("member")

## save to file
cvdp_anom.to_netcdf(save_dir / "cvdp_anom.nc")