# Step 3: Quantile Delta Mapping with reimplemented cmethods tools
Apply reimplemented QDM functionality to Greenland thermal forcing. Streamline for applying on CCR.

26 Mar 2025 | EHU

In [None]:
import os
import sys
import copy
import csv
import time
import datetime
import math
import cartopy.crs as ccrs ## map projections
import pandas as pd
import numpy as np
import xarray as xr
import netCDF4 as nc
import matplotlib.pyplot as plt

import rioxarray
from pyproj import CRS

In [None]:
## from cmethods.utils
import warnings
from typing import TYPE_CHECKING, Optional, Union, TypeVar

XRData_t = (xr.Dataset, xr.DataArray)
NPData_t = (list, np.ndarray, np.generic)
XRData = TypeVar("XRData", xr.Dataset, xr.DataArray)
NPData = TypeVar("NPData", list, np.ndarray, np.generic)
MAX_SCALING_FACTOR = 2 ## to allow multiplicative correction?


def check_xr_types(obs: XRData, simh: XRData, simp: XRData) -> None:
    """
    Checks if the parameters are in the correct type. **only used internally**
    """
    phrase: str = "must be type xarray.core.dataarray.Dataset or xarray.core.dataarray.DataArray"

    if not isinstance(obs, XRData_t):
        raise TypeError(f"'obs' {phrase}")
    if not isinstance(simh, XRData_t):
        raise TypeError(f"'simh' {phrase}")
    if not isinstance(simp, XRData_t):
        raise TypeError(f"'simp' {phrase}")

def check_np_types(
    obs: NPData,
    simh: NPData,
    simp: NPData,
) -> None:
    """
    Checks if the parameters are in the correct type. **only used internally**
    """
    phrase: str = "must be type list, np.ndarray or np.generic"

    if not isinstance(obs, NPData_t):
        raise TypeError(f"'obs' {phrase}")
    if not isinstance(simh, NPData_t):
        raise TypeError(f"'simh' {phrase}")
    if not isinstance(simp, NPData_t):
        raise TypeError(f"'simp' {phrase}")

def nan_or_equal(value1: float, value2: float) -> bool:
    """
    Returns True if the values are equal or at least one is NaN

    :param value1: First value to check
    :type value1: float
    :param value2: Second value to check
    :type value2: float
    :return: If any value is NaN or values are equal
    :rtype: bool
    """
    return np.isnan(value1) or np.isnan(value2) or value1 == value2
        
def ensure_dividable(
    numerator: Union[float, np.ndarray],
    denominator: Union[float, np.ndarray],
    max_scaling_factor: float,
) -> np.ndarray:
    """
    Ensures that the arrays can be divided. The numerator will be multiplied by
    the maximum scaling factor of the CMethods class if division by zero.

    :param numerator: Numerator to use
    :type numerator: np.ndarray
    :param denominator: Denominator that can be zero
    :type denominator: np.ndarray
    :return: Zero-ensured division
    :rtype: np.ndarray | float
    """
    with np.errstate(divide="ignore", invalid="ignore"):
        result = numerator / denominator

    if isinstance(numerator, np.ndarray):
        mask_inf = np.isinf(result)
        result[mask_inf] = numerator[mask_inf] * max_scaling_factor  # type: ignore[index]

        mask_nan = np.isnan(result)
        result[mask_nan] = 0  # type: ignore[index]
    elif np.isinf(result):
        result = numerator * max_scaling_factor
    elif np.isnan(result):
        result = 0.0

    return result

def get_pdf(
    x: Union[list, np.ndarray],
    xbins: Union[list, np.ndarray],
) -> np.ndarray:
    r"""
    Compuites and returns the the probability density function :math:`P(x)`
    of ``x`` based on ``xbins``.

    :param x: The vector to get :math:`P(x)` from
    :type x: list | np.ndarray
    :param xbins: The boundaries/bins of :math:`P(x)`
    :type xbins: list | np.ndarray
    :return: The probability densitiy function of ``x``
    :rtype: np.ndarray

    .. code-block:: python
        :linenos:
        :caption: Compute the probability density function :math:`P(x)`

        >>> from cmethods get_pdf

        >>> x = [1, 2, 3, 4, 5, 5, 5, 6, 7, 8, 9, 10]
        >>> xbins = [0, 3, 6, 10]
        >>> print(get_pdf(x=x, xbins=xbins))
        [2, 5, 5]
    """
    pdf, _ = np.histogram(x, xbins)
    return pdf


def get_cdf(
    x: Union[list, np.ndarray],
    xbins: Union[list, np.ndarray],
) -> np.ndarray:
    r"""
    Computes and returns returns the cumulative distribution function :math:`F(x)`
    of ``x`` based on ``xbins``.

    :param x: Vector to get :math:`F(x)` from
    :type x: list | np.ndarray
    :param xbins: The boundaries/bins of :math:`F(x)`
    :type xbins: list | np.ndarray
    :return: The cumulative distribution function of ``x``
    :rtype: np.ndarray


    .. code-block:: python
        :linenos:
        :caption: Compute the cumulative distribution function :math:`F(x)`

        >>> from cmethods.utils import get_cdf

        >>> x = [1, 2, 3, 4, 5, 5, 5, 6, 7, 8, 9, 10]
        >>> xbins = [0, 3, 6, 10]
        >>> print(get_cdf(x=x, xbins=xbins))
        [0.0, 0.16666667, 0.58333333, 1.]
    """
    pdf, _ = np.histogram(x, xbins)
    cdf = np.insert(np.cumsum(pdf), 0, 0.0)
    return cdf / cdf[-1]


def get_inverse_of_cdf(
    base_cdf: Union[list, np.ndarray],
    insert_cdf: Union[list, np.ndarray],
    xbins: Union[list, np.ndarray],
) -> np.ndarray:
    r"""
    Returns the inverse cumulative distribution function as:
    :math:`F^{-1}_{x}\left[y\right]` where :math:`x` represents ``base_cdf`` and
    ``insert_cdf`` is represented by :math:`y`.

    :param base_cdf: The basis
    :type base_cdf: list | np.ndarray
    :param insert_cdf: The CDF that gets inserted
    :type insert_cdf: list | np.ndarray
    :param xbins: Probability boundaries
    :type xbins: list | np.ndarray
    :return: The inverse CDF
    :rtype: np.ndarray
    """
    return np.interp(insert_cdf, base_cdf, xbins)


In [None]:
def quantile_delta_mapping(
    obs: NPData,
    simh: NPData,
    simp: NPData,
    n_quantiles: int,
    kind: str = "+",
    **kwargs,
    ) -> NPData:
    r"""
    Based on https://python-cmethods.readthedocs.io/en/latest/methods.html#quantile-delta-mapping

    kind: str, default + for additive, can be set to * for multiplicative
    """
    # check_adjust_called(
    #     function_name="quantile_delta_mapping",
    #     adjust_called=kwargs.get("adjust_called"),
    # )
    check_np_types(obs=obs, simh=simh, simp=simp)

    if not isinstance(n_quantiles, int):
        raise TypeError("'n_quantiles' must be type int")

    if kind=='+':
        obs, simh, simp = (
            np.array(obs),
            np.array(simh),
            np.array(simp),
        )  # to achieve higher accuracy
        global_max = kwargs.get("global_max", max(np.nanmax(obs), np.nanmax(simh)))
        global_min = kwargs.get("global_min", min(np.nanmin(obs), np.nanmin(simh)))

        if nan_or_equal(value1=global_max, value2=global_min):
            return simp

        wide = abs(global_max - global_min) / n_quantiles
        xbins = np.arange(global_min, global_max + wide, wide)

        cdf_obs = get_cdf(obs, xbins)
        cdf_simh = get_cdf(simh, xbins)
        cdf_simp = get_cdf(simp, xbins)

        # calculate exact CDF values of $F_{sim,p}[T_{sim,p}(t)]$
        epsilon = np.interp(simp, xbins, cdf_simp)  # Eq. 1.1
        QDM1 = get_inverse_of_cdf(cdf_obs, epsilon, xbins)  # Eq. 1.2
        delta = simp - get_inverse_of_cdf(cdf_simh, epsilon, xbins)  # Eq. 1.3
        return QDM1 + delta  # Eq. 1.4

    if kind=='*':
        obs, simh, simp = np.array(obs), np.array(simh), np.array(simp)
        global_max = kwargs.get("global_max", max(np.nanmax(obs), np.nanmax(simh)))
        global_min = kwargs.get("global_min", 0.0)
        if nan_or_equal(value1=global_max, value2=global_min):
            return simp

        wide = global_max / n_quantiles
        xbins = np.arange(global_min, global_max + wide, wide)

        cdf_obs = get_cdf(obs, xbins)
        cdf_simh = get_cdf(simh, xbins)
        cdf_simp = get_cdf(simp, xbins)

        epsilon = np.interp(simp, xbins, cdf_simp)  # Eq. 1.1
        QDM1 = get_inverse_of_cdf(cdf_obs, epsilon, xbins)  # Eq. 1.2

        delta = ensure_dividable(  # Eq. 2.3
            simp,
            get_inverse_of_cdf(cdf_simh, epsilon, xbins),
            max_scaling_factor=kwargs.get(
                "max_scaling_scaling",
                MAX_SCALING_FACTOR,
            ),
        )
        return QDM1 * delta  # Eq. 2.4
    raise NotImplementedError(
        f"{kind=} not available for quantile_delta_mapping. Use '+' or '*' instead.",
    )


def apply_cmfunc(
    method: str,
    obs: XRData,
    simh: XRData,
    simp: XRData,
    **kwargs: dict,
) -> XRData:
    """
    Internal function used to apply the bias correction technique to the
    passed input data.
    """
    ## hard-code the QDM method
    if method!='quantile_delta_mapping':
        raise UnknownMethodError('Not implemented for methods other than quantile_delta_mapping')
        ## give this a default for what we want to do
    else:
        method='quantile_delta_mapping' ## not actually going to use this
    
    check_xr_types(obs=obs, simh=simh, simp=simp)
    # if method not in __METHODS_FUNC__:
    #     raise UnknownMethodError(method, __METHODS_FUNC__.keys())

    if kwargs.get("input_core_dims"):
        if not isinstance(kwargs["input_core_dims"], dict):
            raise TypeError("input_core_dims must be an object of type 'dict'")
        if not len(kwargs["input_core_dims"]) == 3 or any(
            not isinstance(value, str) for value in kwargs["input_core_dims"].values()
        ):
            raise ValueError(
                'input_core_dims must have three key-value pairs like: {"obs": "time", "simh": "time", "simp": "time"}',
            )

        input_core_dims = kwargs.pop("input_core_dims")
    else:
        input_core_dims = {"obs": "time", "simh": "time", "simp": "time"}

    result: XRData = xr.apply_ufunc(
        quantile_delta_mapping,
        obs,
        simh,
        # Need to spoof a fake time axis since 'time' coord on full dataset is
        # different than 'time' coord on training dataset.
        simp.rename({input_core_dims["simp"]: "__t_simp__"}),
        dask="parallelized",
        vectorize=True,
        # This will vectorize over the time dimension, so will submit each grid
        # cell independently
        input_core_dims=[
            [input_core_dims["obs"]],
            [input_core_dims["simh"]],
            ["__t_simp__"],
        ],
        # Need to denote that the final output dataset will be labeled with the
        # spoofed time coordinate
        output_core_dims=[["__t_simp__"]],
        kwargs=dict(kwargs),
    )

    # Rename to proper coordinate name.
    result = result.rename({"__t_simp__": input_core_dims["simp"]})

    # ufunc will put the core dimension to the end (time), so want to preserve
    # original order where time is commonly first.
    return result.transpose(*obs.rename({input_core_dims["obs"]: input_core_dims["simp"]}).dims)


In [None]:
## Time utils from Bryan Riel
## pasting stuff from iceutils below.
#-*- coding: utf-8 -*-

def tdec2datestr(tdec_in, returndate=False):
    """
    Convert a decimaly year to an iso date string.
    """
    if isinstance(tdec_in, (list, np.ndarray)):
        tdec_list = copy.deepcopy(tdec_in)
    else:
        tdec_list = [tdec_in]
    current_list = []
    for tdec in tdec_list:
        year = int(tdec)
        yearStart = datetime.datetime(year, 1, 1)
        if year % 4 == 0:
            ndays_in_year = 366.0
        else:
            ndays_in_year = 365.0
        days = (tdec - year) * ndays_in_year
        seconds = (days - int(days)) * 86400
        tdelta = datetime.timedelta(days=int(days), seconds=int(seconds))
        current = yearStart + tdelta
        if not returndate:
            current = current.isoformat(' ').split()[0]
        current_list.append(current)

    if len(current_list) == 1:
        return current_list[0]
    else:
        return np.array(current_list)


def datestr2tdec(yy=0, mm=0, dd=0, hour=0, minute=0, sec=0, microsec=0, dateobj=None):
    """
    Convert year, month, day, hours, minutes, seconds to decimal year.
    """
    if dateobj is not None:
        if type(dateobj) == str:
            yy, mm, dd = [int(val) for val in dateobj.split('-')]
            hour, minute, sec = [0, 0, 0]
        elif type(dateobj) == datetime.datetime:
            attrs = ['year', 'month', 'day', 'hour', 'minute', 'second']
            yy, mm, dd, hour, minute, sec = [getattr(dateobj, attr) for attr in attrs]
        elif type(dateobj) == np.datetime64:
            yy = dateobj.astype('datetime64[Y]').astype(int) + 1970
            mm = dateobj.astype('datetime64[M]').astype(int) % 12 + 1
            days = (
                (dateobj - dateobj.astype('datetime64[M]')) / np.timedelta64(1, 'D')
            )
            dd = int(days) + 1
            hour, minute, sec = [0, 0, 0]
        else:
            raise NotImplementedError('dateobj must be str, datetime, or np.datetime64.')

    # Make datetime object for start of year
    yearStart = datetime.datetime(yy, 1, 1, 0, 0, 0)
    # Make datetime object for input time
    current = datetime.datetime(yy, mm, dd, hour, minute, sec, microsec)
    # Compute number of days elapsed since start of year
    tdelta = current - yearStart
    # Convert to decimal year and account for leap year
    if yy % 4 == 0:
        return float(yy) + tdelta.total_seconds() / (366.0 * 86400)
    else:
        return float(yy) + tdelta.total_seconds() / (365.0 * 86400)

### Load in data

In [None]:
DepthRange         = [0,500]
ShallowThreshold   = 100
PeriodObs0         = [1950,2015]
SelModel = 'CESM'

DirSave = f'/Users/eultee/Library/CloudStorage/OneDrive-NASA/Data/gris-iceocean-outfiles'
DirEN4         = f'/Users/eultee/Library/CloudStorage/OneDrive-NASA/Documents/ISMIP7/Verjans-process/'
EN4file        = f'dpavg_tf_EN4anl_Dp{DepthRange[0]}to{DepthRange[1]}_bathymin{ShallowThreshold}.nc'

In [None]:
## Load EN4 using xarray
ds = xr.open_dataset(DirEN4+EN4file, decode_times='timeDim')
ds2 = ds.assign_coords({'timeDim': ds.time, 
                  'latDim': ds.lat, 
                  'lonDim': ds.lon,
                  'depthDim': ds.depth})

tfEN4 = ds2.tfdpavg0to500_bathymin100.rename({'timeDim': 'time',
                                              'latDim': 'lat',
                                              'lonDim': 'lon'})

In [None]:
## load in CESM TF for all time slices available, using multifile dataset
with xr.open_mfdataset(f'{DirEN4}/tfdpavg*.nc') as ds: 
    ds3 = ds.load()

### Process grids to match
Because the grids are offset by 0.5° from each other, we will need to warp/resample using `rioxarray` before we run the QDM correction.

In [None]:
test_ds_full = ds3.TF

## aligning the time indices
test_ds_full = test_ds_full.assign_coords(new_time = ('time', test_ds_full.indexes['time'].to_datetimeindex().values))
test_ds_full = test_ds_full.drop_indexes('time')
test_ds_full = test_ds_full.set_xindex('new_time').drop_vars('time')

## aligning the names of the variables between obs and sim
test_ds_full = test_ds_full.to_dataset()
test_ds_full = test_ds_full.rename({'new_time': 'time', 'TF': 'tfdpavg0to500_bathymin100'})
test_ds_full

In [None]:
tobs_ds_full = tfEN4.sel( 
                    time=slice('1950', '2015'))

tobs_ds_full = tobs_ds_full.assign_coords(new_time = ('time', pd.to_datetime(tdec2datestr(tobs_ds_full.time.values))))
tobs_ds_full = tobs_ds_full.drop_indexes('time')
tobs_ds_full = tobs_ds_full.set_xindex('new_time').drop_vars('time')
tobs_ds_full = tobs_ds_full.rename({'new_time': 'time'})

### Reproject obs to match CESM grid
Warp the offset grids to match using `rioxarray`.

In [None]:
cc = CRS("EPSG:4326")

tobs_ds_full.rio.write_crs(cc, inplace=True).rio.set_spatial_dims(
    x_dim="lon",
    y_dim="lat",
    inplace=True,
).rio.write_coordinate_system(inplace=True)

test_ds_full.rio.write_crs(cc, inplace=True).rio.set_spatial_dims(
    x_dim="lon",
    y_dim="lat",
    inplace=True,
).rio.write_coordinate_system(inplace=True)

In [None]:
tobs_repr_match = tobs_ds_full.rio.reproject_match(test_ds_full)

## rename the coords 
tobs_repr_match = tobs_repr_match.rename({'x': 'lon', 'y': 'lat'})

## Apply QDM correction
We will apply the QDM correction on annual, de-trended series with the monthly residual variability re-applied afterward.

### Separate annual and monthly var

In [None]:
annual_ds = test_ds_full.resample(time='YS').mean()

residual_ds = test_ds_full.resample(time='ME').ffill() - annual_ds.resample(time='M').ffill()

### Detrend the data to be fit
QDM performs poorly when values in the "projection" series exceed those seen in the "historical".  Detrend the future, historical, and reanalysis series before QDM correction.

In [None]:
def detrend_ser(da, dim, deg=1, var='tfdpavg0to500_bathymin100', return_fit=True):
    ## based on Gist by Ryan Abernathey
    ## hard coding Dataset version for now to make it behave with our datasets
    
    p = da.polyfit(dim=dim, deg=deg)
    
    if type(da) is xr.core.dataarray.DataArray:
        fit=xr.polyval(da[dim], p.polyfit_coefficients)
    elif type(da) is xr.core.dataset.Dataset:
        fit = xr.polyval(da[dim], p.tfdpavg0to500_bathymin100_polyfit_coefficients) 
        ## eventually use `var` argument to take any variable of interest
        ## for now hard-coded the name of the depth-averaged thermal forcing, so this
        ## will fail if we change that name
    else:
        print("Unrecognized input type. Expected xarray DataArray or Dataset, got {}".format(type(da)))

    if return_fit==True: ## give back the fitted values to use in reconstructing a series
        return da-fit, fit
    else:
        return da-fit

detrended_obs = detrend_ser(tobs_repr_match.sel(time=slice('1950','1980')).resample(time='A').mean(),
                            dim='time',
                            deg=1)[0]

detrended_simh = detrend_ser(annual_ds.sel(time=slice('1950', '1980')),
                             dim='time',
                             deg=1)[0]
detrended_simp = detrend_ser(annual_ds.sel(time=slice('1980', '2014')),
                             dim='time',
                             deg=1)[0]


In [None]:
## produce the QDM fit on detrended data
qdm_detrended = apply_cmfunc(
        method = "quantile_delta_mapping",
        obs = detrended_obs,
        simh = detrended_simh.rename({'time':'t_simh'}),
        simp = detrended_simp,
        n_quantiles = 100,
        input_core_dims={"obs": "time", "simh": "t_simh", "simp": "time"},
        kind = "*", # to calculate the relative rather than the absolute change, "*" can be used instead of "+" (this is prefered when adjusting precipitation)
    )

In [None]:
## get obs baseline to add to QDM-corrected annual variability

## set degree and dimension for polyfit
detrend_dim = 'time'
detrend_deg = 1

obs_pf = tobs_repr_match.sel(time=slice('1950','1980')).resample(time='A').mean().polyfit(
    dim=detrend_dim, deg=detrend_deg)

In [None]:
## reconstruct dataset from reanalysis mean val, future trend, QDM, and monthly residual
reanalysis_meanval = obs_pf.polyfit_coefficients.sel(degree=0)
future_trendonly = (detrend_ser(annual_ds.sel(time=slice('1980', '2014')),
                             dim='time',
                             deg=1)[1]
                    - detrend_ser(annual_ds.sel(time=slice('1980', '2014')),
                             dim='time',
                             deg=1)[1])
future_trend_series = reanalysis_meanval + future_trendonly
qdm_dtr_series = qdm_detrended.tfdpavg0to500_bathymin100
qdm_plus_resid = future_trend_series.resample(time='ME').ffill() + qdm_dtr_series.resample(time='ME').ffill() + residual_ds.tfdpavg0to500_bathymin100

qdm_plus_resid

In [None]:
## test write out
from datetime import datetime, date

now = datetime.now()
ds_temp = qdm_plus_resid.to_dataset(name='TF')
# ds_temp.TF.attrs = tf_out.attrs
ds_out = ds_temp.assign_attrs(title='QDM-corrected ocean thermal forcing for {}'.format(SelModel),
                             summary='TF computed following Verjans code, in a bounding' + 
                              ' box around Greenland, for ISMIP7 Greenland forcing.' +
                              ' QDM correction applied to annual based on EN4 data, with' +
                              ' monthly residual added',
                             institution='NASA Goddard Space Flight Center',
                             creation_date=now.strftime('%Y-%m-%d %H:%M:%S'))

ds_out

In [None]:
out_fn = DirSave + '/tfQDM-{}-{}.nc'.format(SelModel, date.today())

from dask.diagnostics import ProgressBar

with ProgressBar():
    ds_out.to_netcdf(path=out_fn)