## Imports

In [None]:
import xarray as xr
import pathlib
import numpy as np
import pandas as pd
import matplotlib as mpl
import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import os
import xeofs as xe
import time
import src.utils

## specify filepath for data
DATA_FP = pathlib.Path(os.environ["DATA_FP"])

## set plotting specs
sns.set(rc={"axes.facecolor": "white", "axes.grid": False})

## bump up DPI for presentation
mpl.rcParams["figure.dpi"] = 100

## Functions

In [None]:
def get_files(varname):
    """get files for given variable name"""
    mmlea_fp = pathlib.Path("/glade/campaign/collections/rda/data/d651039")
    cesm2_fp = mmlea_fp / pathlib.Path("cesm2_lens/Omon")

    return sorted((cesm2_fp / varname).glob("*.nc"))


def check_member_file(i):
    """check zos and tos files match for given member idx"""
    return str(get_files("tos")[i])[-71:] == str(get_files("zos")[i])[-71:]


def check_member_files():
    """check all files match"""

    checks = np.array([check_member_file(i) for i in range(100)])
    return np.all(checks)


def load_member_varname(member_idx, varname):
    """load ensemble member. Args:
    - member_idx: integer in [0,99]
    """

    ## Get list of files
    files = get_files(varname)

    ## open data
    data = xr.open_dataset(files[member_idx])

    ## remove un-needed coords
    data = data[varname].squeeze(drop=True)

    ## rename lon/lat
    data = data.rename({"lat": "latitude", "lon": "longitude"})

    return data


def load_member_Th(member_idx):
    """Load T and h data for given member index"""

    ## compute indices
    T_idxs = src.utils.get_RO_T_indices(load_member_varname(member_idx, "tos"))
    h_idxs = src.utils.get_RO_h_indices(load_member_varname(member_idx, "zos"))

    ## compute indices
    return xr.merge([T_idxs, h_idxs])


def load_ensemble_Th(save_fp):
    """Load all ensemble members"""

    ## check if file exists
    if save_fp.is_file():

        data = xr.open_dataset(save_fp)

    else:

        ## new dimension: ensemble member
        member_dim = pd.Index(np.arange(100), name="member")

        ## do computation
        data = xr.concat(
            [load_member_Th(i) for i in tqdm.tqdm(member_dim)],
            dim=member_dim,
        )

        ## save to file
        data.to_netcdf(save_fp)

    return data


def preprocess_Th(Th, save_dir):
    """pre-process Th data (compute ensemble mean and anomalies"""

    ## define filepaths for saving
    save_fp_emean = pathlib.Path(save_dir, "Th_emean.nc")
    save_fp_anom = pathlib.Path(save_dir, "Th_anom.nc")

    ## compute ensemble mean and anomalies
    Th_emean = Th.mean("member")
    Th_anom = Th - Th_emean

    ## save to file if not already
    if not save_fp_emean.is_file():
        Th_emean.to_netcdf(save_fp_emean)

    if not save_fp_anom.is_file():
        Th_anom.to_netcdf(save_fp_anom)

    return Th_emean, Th_anom

## Preprocessing

In [None]:
## specify save file paths
save_dir = DATA_FP / "cesm"

## load data
Th = load_ensemble_Th(save_dir / "Th.nc")

## compute ensemble stats/anomalies
Th_emean, Th_anom = preprocess_Th(Th, save_dir=save_dir)

In [None]:
def load_member_spatial(name, n):
    """Load data for given variable and ensemble member"""

    ## shared arguments
    kwargs = dict(name=name, n=n)

    ## load data
    hist = load_member(sim="historical", **kwargs)
    ssp585 = load_member(sim="ssp585", **kwargs)

    ## update varname for sst
    varname = "sst" if (name == "ts") else name

    ## compute indices
    return xr.concat([hist, ssp585], dim="time")[varname]


def load_ensemble_spatial(name, n_array=np.arange(1, 51)):
    """Load data for given variable and ensemble members"""

    ## new dimension: ensemble member
    member_dim = pd.Index(n_array, name="member")

    return xr.concat(
        [load_member_spatial(name, n).compute() for n in tqdm.tqdm(n_array)],
        dim=member_dim,
    )


def merge_ensemble_Th(load_dir, save_dir, sim):
    """merge files from individual ensemble members into one"""

    ## get filepath for saving
    save_fp = save_dir / f"Th_{sim}.nc"

    ## check if file exists
    if save_fp.is_dir():
        pass

    else:

        ## Load data
        Th = xr.open_mfdataset(
            sorted(list(load_dir.glob(f"*{sim}*.nc"))),
            combine="nested",
            concat_dim="member",
        )

        ## assign values to member dimension
        Th = Th.assign_coords({"member": pd.Index(np.arange(1, 51))})

        ## remove encoding setting which prevents saving
        for name in ["sst_trop", "ts_glob", "ssh_trop", "ssh_glob"]:
            del Th[name].encoding["_FillValue"]

        ## save to file
        Th.to_netcdf(save_fp)

    return Th


def preprocess(Th, save_dir=DATA_FP / "mpi_Th"):
    """preprocess T and h data"""

    ## compute ensemble mean and anomalies
    Th_emean = Th.mean("member")
    Th_anom = Th - Th_emean

    ## define filepaths for saving
    save_fp_emean = pathlib.Path(save_dir, "Th_ensemble_mean.nc")
    save_fp_anom = pathlib.Path(save_dir, "Th.nc")

    ## save to file if not already
    if not save_fp_emean.is_file():
        Th_emean.to_netcdf(save_fp_emean)

    if not save_fp_anom.is_file():
        Th_anom.to_netcdf(save_fp_anom)

    return Th_emean, Th_anom

## Compute $T$, $h$ variables

In [None]:
## get directory for saving temporary results
temp_dir = DATA_FP / "temp"

## compute T & h for each ensemble member
compute_ensemble_Th(save_dir=temp_dir)

## merge data into single ensemble
kwargs = dict(load_dir=temp_dir, save_dir=DATA_FP / "mpi_Th")
Th_hist = merge_ensemble_Th(sim="historical", **kwargs)
Th_ssp = merge_ensemble_Th(sim="ssp585", **kwargs)

## merge hist. and ssp simulations
Th = xr.concat([Th_hist, Th_ssp], dim="time")

## preprocess
Th_emean, Th = preprocess(Th)

### Look at output

In [None]:
## load data into memory
Th_emean.load()

## function to deseason data
deseason = lambda x: x.groupby("time.month") - x.groupby("time.month").mean()

## remove seasonal cycle but add back time-mean
Th_emean_deseasoned = deseason(Th_emean) + Th_emean.mean("time")

## set up plot
fig, axs = plt.subplots(2, 1, figsize=(6, 4.5), layout="constrained")

## plot data
axs[0].plot(Th_emean.time, Th_emean["T_3"], lw=2, c="k", label="Ens. mean")
axs[0].plot(
    Th_emean.time,
    Th_emean_deseasoned["T_3"],
    lw=2,
    c="r",
    label="Ens. mean (de-seasoned)",
)

## label
axs[0].legend()
axs[0].set_title("Niño 3 (Hist. & SSP585)")
axs[0].set_ylabel(r"$^{\circ}C$")

## plot 5 ensemble members
for i in range(1, 6):
    axs[1].plot(Th.time, Th["T_3"].sel(member=i), lw=1, c="k", alpha=0.5)

axs[1].set_title("Internal variability (5 ensemble members)")
plt.show()

## Data compression with EOFs (**to-do**)

Steps
1. Compute spatial patterns
2. Compress data
3. Check compression (compare recon of random ensemble to actual)
4. Remove ensemble mean (external forcing)

### Function to compute/save/load EOFs

In [None]:
def load_eofs(eofs_fp, varname, **eofs_kwargs):

    ## get filename
    filename = eofs_fp / f"{varname}.nc"

    ## initialize EOF model
    eofs = xe.single.EOF(**eofs_kwargs)

    ## Load pre-computed model if it exists
    if pathlib.Path(filename).is_file():
        eofs = eofs.load(filename, engine="netcdf4")

    else:
        ## load data
        data = load_ensemble_spatial(name=varname, n_array=np.arange(1, 51))

        ## fit model
        eofs.fit(data, dim=["time", "member"])

        ## save to file
        eofs.save(filename, engine="netcdf4")

    return eofs

### Do the computation

In [None]:
## specify save filepaths for saving EOF results
eofs_fp = DATA_FP / "mpi" / "eofs300"

## specify EOF specs
eofs_kwargs = dict(n_modes=300, standardize=False, use_coslat=True, center=False)

print("Computing EOFs for SST...")
t0 = time.time()
eofs_sst = load_eofs(eofs_fp, "ts", **eofs_kwargs)
tf = time.time()
print(f"Elapsed time: {(tf-t0)/60:.1f} minutes\n")

print("Computing EOFs for SSH...")
t0 = time.time()
eofs_ssh = load_eofs(eofs_fp, "ssh", **eofs_kwargs)
tf = time.time()
print(f"Elapsed time: {(tf-t0)/60:.1f} minutes")