## Regrid CESM2-WACCM TF
CESM2-WACCM is in a rotated polar grid with dimensions `nlat, nlon` rather than `lat, lon`.  Use rioxarray to regrid this to match EN4 grid before applying QDM bias correction.

Note: depth/`lev` variable may be in cm rather than m?

7 Jul 2025 | EHU
- Test with newly processed TF data, which accounts for depth expressed in cm rather than m in the original CESM2-WACCM data.  Note that the depth variable in the TF dataset is most likely still in cm.
- 11 Jul: Project directly to EN4 projection, rather than a rectilinear grid based on the CESM curvilinear grid in between
- 15 Jul: Fix `warp.reproject` command to get correct length of time dimension for CESM. Test `reproject_match` command -- still seems necessary.

In [None]:
import os
import sys
import copy
import csv
import time
import datetime
import math
import cartopy.crs as ccrs ## map projections
import pandas as pd
import numpy as np
import xarray as xr
import netCDF4 as nc
import matplotlib.pyplot as plt

import rioxarray
from rioxarray.rioxarray import affine_to_coords
from pyproj import CRS

In [None]:
## from cmethods.utils
import warnings
from typing import TYPE_CHECKING, Optional, Union, TypeVar

XRData_t = (xr.Dataset, xr.DataArray)
NPData_t = (list, np.ndarray, np.generic)
XRData = TypeVar("XRData", xr.Dataset, xr.DataArray)
NPData = TypeVar("NPData", list, np.ndarray, np.generic)
MAX_SCALING_FACTOR = 2 ## to allow multiplicative correction?


def check_xr_types(obs: XRData, simh: XRData, simp: XRData) -> None:
    """
    Checks if the parameters are in the correct type. **only used internally**
    """
    phrase: str = "must be type xarray.core.dataarray.Dataset or xarray.core.dataarray.DataArray"

    if not isinstance(obs, XRData_t):
        raise TypeError(f"'obs' {phrase}")
    if not isinstance(simh, XRData_t):
        raise TypeError(f"'simh' {phrase}")
    if not isinstance(simp, XRData_t):
        raise TypeError(f"'simp' {phrase}")

def check_np_types(
    obs: NPData,
    simh: NPData,
    simp: NPData,
) -> None:
    """
    Checks if the parameters are in the correct type. **only used internally**
    """
    phrase: str = "must be type list, np.ndarray or np.generic"

    if not isinstance(obs, NPData_t):
        raise TypeError(f"'obs' {phrase}")
    if not isinstance(simh, NPData_t):
        raise TypeError(f"'simh' {phrase}")
    if not isinstance(simp, NPData_t):
        raise TypeError(f"'simp' {phrase}")

def nan_or_equal(value1: float, value2: float) -> bool:
    """
    Returns True if the values are equal or at least one is NaN

    :param value1: First value to check
    :type value1: float
    :param value2: Second value to check
    :type value2: float
    :return: If any value is NaN or values are equal
    :rtype: bool
    """
    return np.isnan(value1) or np.isnan(value2) or value1 == value2
        
def ensure_dividable(
    numerator: Union[float, np.ndarray],
    denominator: Union[float, np.ndarray],
    max_scaling_factor: float,
) -> np.ndarray:
    """
    Ensures that the arrays can be divided. The numerator will be multiplied by
    the maximum scaling factor of the CMethods class if division by zero.

    :param numerator: Numerator to use
    :type numerator: np.ndarray
    :param denominator: Denominator that can be zero
    :type denominator: np.ndarray
    :return: Zero-ensured division
    :rtype: np.ndarray | float
    """
    with np.errstate(divide="ignore", invalid="ignore"):
        result = numerator / denominator

    if isinstance(numerator, np.ndarray):
        mask_inf = np.isinf(result)
        result[mask_inf] = numerator[mask_inf] * max_scaling_factor  # type: ignore[index]

        mask_nan = np.isnan(result)
        result[mask_nan] = 0  # type: ignore[index]
    elif np.isinf(result):
        result = numerator * max_scaling_factor
    elif np.isnan(result):
        result = 0.0

    return result

def get_pdf(
    x: Union[list, np.ndarray],
    xbins: Union[list, np.ndarray],
) -> np.ndarray:
    r"""
    Compuites and returns the the probability density function :math:`P(x)`
    of ``x`` based on ``xbins``.

    :param x: The vector to get :math:`P(x)` from
    :type x: list | np.ndarray
    :param xbins: The boundaries/bins of :math:`P(x)`
    :type xbins: list | np.ndarray
    :return: The probability densitiy function of ``x``
    :rtype: np.ndarray

    .. code-block:: python
        :linenos:
        :caption: Compute the probability density function :math:`P(x)`

        >>> from cmethods get_pdf

        >>> x = [1, 2, 3, 4, 5, 5, 5, 6, 7, 8, 9, 10]
        >>> xbins = [0, 3, 6, 10]
        >>> print(get_pdf(x=x, xbins=xbins))
        [2, 5, 5]
    """
    pdf, _ = np.histogram(x, xbins)
    return pdf


def get_cdf(
    x: Union[list, np.ndarray],
    xbins: Union[list, np.ndarray],
) -> np.ndarray:
    r"""
    Computes and returns returns the cumulative distribution function :math:`F(x)`
    of ``x`` based on ``xbins``.

    :param x: Vector to get :math:`F(x)` from
    :type x: list | np.ndarray
    :param xbins: The boundaries/bins of :math:`F(x)`
    :type xbins: list | np.ndarray
    :return: The cumulative distribution function of ``x``
    :rtype: np.ndarray


    .. code-block:: python
        :linenos:
        :caption: Compute the cumulative distribution function :math:`F(x)`

        >>> from cmethods.utils import get_cdf

        >>> x = [1, 2, 3, 4, 5, 5, 5, 6, 7, 8, 9, 10]
        >>> xbins = [0, 3, 6, 10]
        >>> print(get_cdf(x=x, xbins=xbins))
        [0.0, 0.16666667, 0.58333333, 1.]
    """
    pdf, _ = np.histogram(x, xbins)
    cdf = np.insert(np.cumsum(pdf), 0, 0.0)
    return cdf / cdf[-1]


def get_inverse_of_cdf(
    base_cdf: Union[list, np.ndarray],
    insert_cdf: Union[list, np.ndarray],
    xbins: Union[list, np.ndarray],
) -> np.ndarray:
    r"""
    Returns the inverse cumulative distribution function as:
    :math:`F^{-1}_{x}\left[y\right]` where :math:`x` represents ``base_cdf`` and
    ``insert_cdf`` is represented by :math:`y`.

    :param base_cdf: The basis
    :type base_cdf: list | np.ndarray
    :param insert_cdf: The CDF that gets inserted
    :type insert_cdf: list | np.ndarray
    :param xbins: Probability boundaries
    :type xbins: list | np.ndarray
    :return: The inverse CDF
    :rtype: np.ndarray
    """
    return np.interp(insert_cdf, base_cdf, xbins)


In [None]:
def quantile_delta_mapping(
    obs: NPData,
    simh: NPData,
    simp: NPData,
    n_quantiles: int,
    kind: str = "+",
    **kwargs,
    ) -> NPData:
    r"""
    Based on https://python-cmethods.readthedocs.io/en/latest/methods.html#quantile-delta-mapping

    kind: str, default + for additive, can be set to * for multiplicative
    """
    # check_adjust_called(
    #     function_name="quantile_delta_mapping",
    #     adjust_called=kwargs.get("adjust_called"),
    # )
    check_np_types(obs=obs, simh=simh, simp=simp)

    if not isinstance(n_quantiles, int):
        raise TypeError("'n_quantiles' must be type int")

    if kind=='+':
        obs, simh, simp = (
            np.array(obs),
            np.array(simh),
            np.array(simp),
        )  # to achieve higher accuracy
        global_max = kwargs.get("global_max", max(np.nanmax(obs), np.nanmax(simh)))
        global_min = kwargs.get("global_min", min(np.nanmin(obs), np.nanmin(simh)))

        if nan_or_equal(value1=global_max, value2=global_min):
            return simp

        wide = abs(global_max - global_min) / n_quantiles
        xbins = np.arange(global_min, global_max + wide, wide)

        cdf_obs = get_cdf(obs, xbins)
        cdf_simh = get_cdf(simh, xbins)
        cdf_simp = get_cdf(simp, xbins)

        # calculate exact CDF values of $F_{sim,p}[T_{sim,p}(t)]$
        epsilon = np.interp(simp, xbins, cdf_simp)  # Eq. 1.1
        QDM1 = get_inverse_of_cdf(cdf_obs, epsilon, xbins)  # Eq. 1.2
        delta = simp - get_inverse_of_cdf(cdf_simh, epsilon, xbins)  # Eq. 1.3
        return QDM1 + delta  # Eq. 1.4

    if kind=='*':
        obs, simh, simp = np.array(obs), np.array(simh), np.array(simp)
        global_max = kwargs.get("global_max", max(np.nanmax(obs), np.nanmax(simh)))
        global_min = kwargs.get("global_min", 0.0)
        if nan_or_equal(value1=global_max, value2=global_min):
            return simp

        wide = global_max / n_quantiles
        xbins = np.arange(global_min, global_max + wide, wide)

        cdf_obs = get_cdf(obs, xbins)
        cdf_simh = get_cdf(simh, xbins)
        cdf_simp = get_cdf(simp, xbins)

        epsilon = np.interp(simp, xbins, cdf_simp)  # Eq. 1.1
        QDM1 = get_inverse_of_cdf(cdf_obs, epsilon, xbins)  # Eq. 1.2

        delta = ensure_dividable(  # Eq. 2.3
            simp,
            get_inverse_of_cdf(cdf_simh, epsilon, xbins),
            max_scaling_factor=kwargs.get(
                "max_scaling_scaling",
                MAX_SCALING_FACTOR,
            ),
        )
        return QDM1 * delta  # Eq. 2.4
    raise NotImplementedError(
        f"{kind=} not available for quantile_delta_mapping. Use '+' or '*' instead.",
    )


def apply_cmfunc(
    method: str,
    obs: XRData,
    simh: XRData,
    simp: XRData,
    **kwargs: dict,
) -> XRData:
    """
    Internal function used to apply the bias correction technique to the
    passed input data.
    """
    ## hard-code the QDM method
    if method!='quantile_delta_mapping':
        raise UnknownMethodError('Not implemented for methods other than quantile_delta_mapping')
        ## give this a default for what we want to do
    else:
        method='quantile_delta_mapping' ## not actually going to use this
    
    check_xr_types(obs=obs, simh=simh, simp=simp)
    # if method not in __METHODS_FUNC__:
    #     raise UnknownMethodError(method, __METHODS_FUNC__.keys())

    if kwargs.get("input_core_dims"):
        if not isinstance(kwargs["input_core_dims"], dict):
            raise TypeError("input_core_dims must be an object of type 'dict'")
        if not len(kwargs["input_core_dims"]) == 3 or any(
            not isinstance(value, str) for value in kwargs["input_core_dims"].values()
        ):
            raise ValueError(
                'input_core_dims must have three key-value pairs like: {"obs": "time", "simh": "time", "simp": "time"}',
            )

        input_core_dims = kwargs.pop("input_core_dims")
    else:
        input_core_dims = {"obs": "time", "simh": "time", "simp": "time"}

    result: XRData = xr.apply_ufunc(
        quantile_delta_mapping,
        obs,
        simh,
        # Need to spoof a fake time axis since 'time' coord on full dataset is
        # different than 'time' coord on training dataset.
        simp.rename({input_core_dims["simp"]: "__t_simp__"}),
        dask="parallelized",
        vectorize=True,
        # This will vectorize over the time dimension, so will submit each grid
        # cell independently
        input_core_dims=[
            [input_core_dims["obs"]],
            [input_core_dims["simh"]],
            ["__t_simp__"],
        ],
        # Need to denote that the final output dataset will be labeled with the
        # spoofed time coordinate
        output_core_dims=[["__t_simp__"]],
        kwargs=dict(kwargs),
    )

    # Rename to proper coordinate name.
    result = result.rename({"__t_simp__": input_core_dims["simp"]})

    # ufunc will put the core dimension to the end (time), so want to preserve
    # original order where time is commonly first.
    return result.transpose(*obs.rename({input_core_dims["obs"]: input_core_dims["simp"]}).dims)


In [None]:
## Time utils from Bryan Riel
## pasting stuff from iceutils below.
#-*- coding: utf-8 -*-

def tdec2datestr(tdec_in, returndate=False):
    """
    Convert a decimaly year to an iso date string.
    """
    if isinstance(tdec_in, (list, np.ndarray)):
        tdec_list = copy.deepcopy(tdec_in)
    else:
        tdec_list = [tdec_in]
    current_list = []
    for tdec in tdec_list:
        year = int(tdec)
        yearStart = datetime.datetime(year, 1, 1)
        if year % 4 == 0:
            ndays_in_year = 366.0
        else:
            ndays_in_year = 365.0
        days = (tdec - year) * ndays_in_year
        seconds = (days - int(days)) * 86400
        tdelta = datetime.timedelta(days=int(days), seconds=int(seconds))
        current = yearStart + tdelta
        if not returndate:
            current = current.isoformat(' ').split()[0]
        current_list.append(current)

    if len(current_list) == 1:
        return current_list[0]
    else:
        return np.array(current_list)


def datestr2tdec(yy=0, mm=0, dd=0, hour=0, minute=0, sec=0, microsec=0, dateobj=None):
    """
    Convert year, month, day, hours, minutes, seconds to decimal year.
    """
    if dateobj is not None:
        if type(dateobj) == str:
            yy, mm, dd = [int(val) for val in dateobj.split('-')]
            hour, minute, sec = [0, 0, 0]
        elif type(dateobj) == datetime.datetime:
            attrs = ['year', 'month', 'day', 'hour', 'minute', 'second']
            yy, mm, dd, hour, minute, sec = [getattr(dateobj, attr) for attr in attrs]
        elif type(dateobj) == np.datetime64:
            yy = dateobj.astype('datetime64[Y]').astype(int) + 1970
            mm = dateobj.astype('datetime64[M]').astype(int) % 12 + 1
            days = (
                (dateobj - dateobj.astype('datetime64[M]')) / np.timedelta64(1, 'D')
            )
            dd = int(days) + 1
            hour, minute, sec = [0, 0, 0]
        else:
            raise NotImplementedError('dateobj must be str, datetime, or np.datetime64.')

    # Make datetime object for start of year
    yearStart = datetime.datetime(yy, 1, 1, 0, 0, 0)
    # Make datetime object for input time
    current = datetime.datetime(yy, mm, dd, hour, minute, sec, microsec)
    # Compute number of days elapsed since start of year
    tdelta = current - yearStart
    # Convert to decimal year and account for leap year
    if yy % 4 == 0:
        return float(yy) + tdelta.total_seconds() / (366.0 * 86400)
    else:
        return float(yy) + tdelta.total_seconds() / (365.0 * 86400)

### Load in data

In [None]:
DepthRange         = [0,500]
ShallowThreshold   = 100
PeriodObs0         = [1950,2015]
SelModel = 'CESM'

DirSave = f'/Users/eultee/Library/CloudStorage/OneDrive-NASA/Data/gris-iceocean-outfiles/Summer25Test'
DirIn = f'/Users/eultee/Library/CloudStorage/OneDrive-NASA/Data/gris-iceocean-outfiles/Summer25Test'

DirHadley = f'/Users/eultee/Library/CloudStorage/OneDrive-NASA/Data/gris-iceocean-outfiles'
HadleyFile = f'/tf-Hadley-1950_2020.nc'

In [None]:
## Load EN4 using xarray
ds1 = xr.open_dataset(DirHadley+HadleyFile, decode_times='timeDim')
ds1
# ds2 = ds.assign_coords({'timeDim': ds.time, 
#                   'latDim': ds.lat, 
#                   'lonDim': ds.lon,
#                   'depthDim': ds.depth})

# tfEN4 = ds2.tfdpavg0to500_bathymin100.rename({'timeDim': 'time',
#                                               'latDim': 'lat',
#                                               'lonDim': 'lon'})

In [None]:
## load in CESM TF for all time slices available, using multifile dataset
with xr.open_mfdataset(f'{DirIn}/tf-CESM2*.nc') as ds: 
    ds3 = ds.load()

ds3

In [None]:
ds_m = ds3.where(ds3.TF<1e20)
ds_m.mean()

Depth resampling seemed to be smearing fill value across depth levels, so that the mean of depth-resampled dataset was ~5e17, even when fill values masked out.  Created `ds_m`  before resampling to address this issue.

## Express depth in m, then resample to Hadley depths

In [None]:
## resample CESM to Hadley depths?
# ds1.depth
tf_CESM_inM = ds_m.assign_coords(new_depth = ('lev', ds3.indexes['lev'].values*0.01))
tf_CESM_inM = tf_CESM_inM.drop_indexes('lev')
tf_CESM_inM = tf_CESM_inM.set_xindex('new_depth').drop_vars('lev')
tf_CESM_inM = tf_CESM_inM.rename({'new_depth': 'lev'})
tf_CESM_inM

In [None]:
tfCESM_resampled = tf_CESM_inM.interp(lev=ds1.depth.values[0:30]).rename({'lev': 'depth'})

In [None]:
tfCESM_resampled

We had to do one extra step (above) to deal with re-scaling the native depth dimension from cm to m. Then we re-sampled to depth levels that match EN4. Now we proceed with applying a DateTimeIndex, and reprojecting from the rotated pole spatial grid to a regular grid matching EN4.

### Apply DateTimeIndex

In [None]:
test_ds_full = tfCESM_resampled.TF.sel( 
                    time=slice('1950', '2020'))

## aligning the time indices
test_ds_full = test_ds_full.assign_coords(new_time = ('time', test_ds_full.indexes['time'].to_datetimeindex().values))
test_ds_full = test_ds_full.drop_indexes('time')
test_ds_full = test_ds_full.set_xindex('new_time').drop_vars('time')

## aligning the names of the variables between obs and sim
test_ds_full = test_ds_full.to_dataset()
test_ds_full = test_ds_full.rename({'new_time': 'time'})
test_ds_full

In [None]:
tobs_ds_full = ds1.TF

### Reproject obs to match CESM grid
Because the grids are offset from each other, we will need to warp/resample using `rioxarray` before we run the QDM correction.

Note that the rotated-pole grid of CESM2-WACCM is different from the original test case.  Unlikely we can use reproject_match in this case.  Try [something else](https://gist.github.com/j08lue/e792b3c912c33e9191734af7e795b75c) with rasterio:

In [None]:
import rasterio
import rasterio.warp

In [None]:
## confirm that this version has src_geoloc_array option -- must be 1.4 or greater
rasterio.__version__

11 Jul: try going directly to the EN4 projection, instead of a default rectilinear and then from there to EN4.

In [None]:
tobs_ds_full.sel(depth=500, method='nearest')

In [None]:
lon2d = test_ds_full["lon"].values
lat2d = test_ds_full["lat"].values
src_height, src_width = lon2d.shape

WGS84 = rasterio.crs.CRS.from_epsg(4326)

# dst_transform, dst_width, dst_height = rasterio.warp.calculate_default_transform(
#     src_crs=WGS84,
#     dst_crs=WGS84,
#     width=src_width,
#     height=src_height,
#     src_geoloc_array=(lon2d, lat2d),
# )

lon_EN4 = tobs_ds_full['lon'].values
lat_EN4 = tobs_ds_full['lat'].values

dst_transform, dst_width, dst_height = rasterio.warp.calculate_default_transform(
    src_crs=WGS84,
    dst_crs=WGS84,
    width=len(lon_EN4),
    height=len(lat_EN4),
    src_geoloc_array=(lon2d, lat2d),
)

11 Jul: This reproj worked with no errors, but later cells show that of course this is going to cause problems, because we have time slices and so `len` shows the length of the time dimension rather than the data values at a given time step. 

15 Jul: Fixed problem with time dimension -- take the length of `destination` from the CESM time dimension `test_ds_full.time` rather than from EN4.

In [None]:
len(test_ds_full.sel(depth=500, method='nearest').time)

In [None]:
## make rectilinear
source = test_ds_full['TF'].sel(depth=500, method='nearest').values
# destination = np.full((len(source), dst_height, dst_width), np.nan)
## go straight to EN4
## first value in np.full((t,y,x)) is length of the time dimension
destination = np.full((len(test_ds_full.sel(depth=500, method='nearest').time),
                          dst_height, dst_width), np.nan)

data, transform = rasterio.warp.reproject(
    source,
    destination=destination,
    src_crs=WGS84,
    dst_crs=WGS84,
    dst_transform=dst_transform,
    dst_nodata=np.nan, ## previously had the Verjans fill value here, but now that we've done
    ## a `where` command to mask the dataset before depth resampling, the missing
    ## values have been replaced by NaNs
    src_geoloc_array=np.stack((lon2d, lat2d))
)

In [None]:
# sm = np.ma.masked_where(source>1e+20, source)
np.nanmean(source)

In [None]:
np.nanmean(data)

In [None]:
coords = affine_to_coords(transform, width=dst_width, height=dst_height, x_dim="x", y_dim="y")

In [None]:
coords.update(time=test_ds_full["time"])

In [None]:
filtered_attrs = test_ds_full['TF'].attrs.copy()
filtered_attrs.pop("grid_mapping", None)

In [None]:
da = xr.DataArray(data, coords=coords, dims=("time", "y", "x"), name='TF', attrs=filtered_attrs)


7 Jul: first tested the naive reprojection, writing the CRS with x_dim="lon" and y_dim="lat" even though the dims are nlon, nlat (two dimensional, rotated pole system).  This didn't work.  Try with the rectilinear array.

Note that it looks like we'll have to do this process at each level as a DataArray rather than on the whole Dataset at once.

In [None]:
da_masked = da.where(da<1.0e20) ## remove fill values

import matplotlib.pyplot as plt
ax = plt.axes(projection=ccrs.Robinson())
# ax.set_xlim(da_masked.x.max(), da_masked.x.min())
# ax.set_ylim(da_masked.y.min(), da_masked.y.max())
da_masked.sel(time='1990-01-01', method='nearest').plot(ax=ax, transform=ccrs.PlateCarree()) ## specify x and y coordinates
ax.set_extent([da_masked.x.min(), da_masked.x.max(),
               da_masked.y.min(), da_masked.y.max()], crs=ccrs.PlateCarree())
ax.coastlines(); ax.gridlines();

In [None]:
da_masked.x.values

In [None]:
da_masked.mean()

## Assign spatial reference
It seems that this `warp.reproject` sequence above has not brought the CESM test set (`da_masked`) into the same grid as the EN4 observations.  `tobs_ds_full` has very different spatial dimensions, as we can see below. Maybe we still need to do a `reproject_match` command with the spatial reference specified?

In [None]:
tobs_ds_full

In [None]:
da_masked

In [None]:
cc = CRS("EPSG:4326")

tobs_ds_full.rio.write_crs(cc, inplace=True).rio.set_spatial_dims(
    x_dim="lon",
    y_dim="lat",
    inplace=True,
).rio.write_coordinate_system(inplace=True)

da_masked.rio.write_crs(cc, inplace=True).rio.set_spatial_dims(
    x_dim="x",
    y_dim="y",
    inplace=True,
).rio.write_coordinate_system(inplace=True)

In [None]:
## test reproject_match
test_d = 500

# sim_level = da_masked.sel(depth=test_d, method='nearest')
sim_level = da_masked ## already a depth slice
obs_level = tobs_ds_full.sel(depth=test_d, method='nearest')
obs_repr = obs_level.rio.reproject_match(sim_level)
# obs_slices[d] = obs_repr
# # obs_slices[d] = {'dims': ("time", "lat", "lon"), 'data': obs_repr.data}
# sim_slices[d] = sim_level

# obtemp = xr.concat([obs_slices[d] for d in tobs_ds_full.depth.values[0:30]], dim=tobs_ds_full.depth.values[0:30])
# obtemp = obtemp.drop_vars('depth') ## drop a 1D depth variable that carried through
# simtemp = xr.concat([sim_slices[d] for d in tobs_ds_full.depth.values[0:30]], dim=tobs_ds_full.depth.values[0:30])
# simtemp = simtemp.drop_vars('depth')

# tobs_repr_match = obtemp.rename({'concat_dim': 'depth', 'x': 'lon', 'y': 'lat'})
# tsim_match = simtemp.rename({'concat_dim': 'depth'})

In [None]:
obs_repr

In [None]:
obs_repr

Okay, this will work fine on slices.  Clean up the process now so that we apply consistent spatial reference, perform slicing of the dataset only once, etc.