# Step 3: Quantile Delta Mapping
This notebook is to develop an efficient application of *Step 3* of Verjans's workflow: perform a QDM correction of ocean thermal forcing from a CMIP model, based on the EN4 reanalysis product.

The full workflow is outlined in Vincent's Readme1.txt in [this Zenodo archive](https://zenodo.org/records/7931326).  We are modifying the workflow to deploy it efficiently for ISMIP7.

20 Nov 2024 | EHU

Edits:
- 6 Dec 2024: Read in EN4 with xarray. Replace Verjans `qmProjCannon2015` function with numpy/xarray-optimized QDM implementations that Denis found

In [1]:
import os
import sys
import copy
import csv
import time
import numpy as np
import netCDF4 as nc
import matplotlib.pyplot as plt
import pyproj
from scipy import interpolate

sys.path.append('/home/theghub/ehultee/ISMIP7-utils/python-cmethods')
from cmethods import adjust

# from verjansFunctions import qmProjCannon2015

Initial run settings from Vincent.  Replace most of this with our own file selection, eventually.  Just check that this works.

In [2]:
savingQMtf       = True
cwd              = os.getcwd()+'/'

SelModel       = 'MIROCES2L'

To2015hist                 = False
To2100histssp585           = True
To2100histssp126           = False

DepthRange         = [0,500]
ShallowThreshold   = 100
PeriodObs0         = [1950,2015]
SigmaExclusion     = 4 #number of sdevs beyond which we constrain values in QDM
yrWindowProj       = 30 #number of years running window CDF in projection period 

DirEN4         = f'{cwd}Verjans_InputOutput/'
EN4file        = f'dpavg_tf_EN4anl_Dp{DepthRange[0]}to{DepthRange[1]}_bathymin{ShallowThreshold}.nc'

dim2d = True #CMIP models ususally have 2d lat and lon arrays
DirModTFnc      = f'{cwd}InputOutput/'
# NeighborGridloc = f'{cwd}InputOutput/nrngb_{SelModel}_toECCO2arcticAndEN4_bathymin100.csv'
   
transformer1 = pyproj.Transformer.from_crs('epsg:4326','epsg:3413') #transformer of lat-lon to x-y(3413)
transformer2 = pyproj.Transformer.from_crs('epsg:3413','epsg:4326') #transformer of x-y(3413) to lat-lon

In [3]:
# CMIP models #
if(SelModel=='IPSLCM6A'):
    if(To2015hist):
        ls_members     = [f'r{id}' for id in range(1,32+1)]
        ls_members.remove('r2') #no r2 member for IPSLCM6A
    elif(To2100histssp585 or To2100histssp126):
        ls_members     = ['r1','r3','r4','r6','r14']
if(SelModel=='MIROCES2L'):
    if(To2015hist):
        ls_members     = [f'r{id}' for id in range(1,30+1)]
    elif(To2100histssp585 or To2100histssp126):
        ls_members     = [f'r{id}' for id in range(1,10+1)]

# ### Load geographical and Nearest-Neighbor data from Model ###
# dataNgb = np.genfromtxt(NeighborGridloc,dtype=str,delimiter=',',skip_header=1)
# lat1dMod     = dataNgb[:,0].astype(float) #1d vector
# lon1dMod     = dataNgb[:,1].astype(float) #1d vector
# indyMod      = dataNgb[:,2].astype(float).astype(int)
# indxMod      = dataNgb[:,3].astype(float).astype(int)
# indyEN4      = dataNgb[:,6].astype(float).astype(int)
# indxEN4      = dataNgb[:,7].astype(float).astype(int)
# sectorSlater = dataNgb[:,8]
    
# ### Load EN4 product ###
# ds           = nc.Dataset(DirEN4+EN4file)
# timeEN4      = np.array(ds.variables['time'])
# latsEN4      = np.array(ds.variables['lat'])
# lonsEN4      = np.array(ds.variables['lon'])
# tfEN4        = np.array(ds.variables[f'tfdpavg{DepthRange[0]}to{DepthRange[1]}_bathymin{ShallowThreshold}'])
# ds.close()

In [18]:
## Load EN4 using xarray
import xarray as xr
ds = xr.open_dataset(DirEN4+EN4file, decode_times='timeDim')
ds2 = ds.assign_coords({'timeDim': ds.time, 
                  'latDim': ds.lat, 
                  'lonDim': ds.lon,
                  'depthDim': ds.depth})
ds2

tfEN4 = ds2.tfdpavg0to500_bathymin100.rename({'timeDim': 'time',
                                              'latDim': 'lat',
                                              'lonDim': 'lon'})
tfEN4

In [30]:
# import pandas as pd
## Bryan Riel, please save me. Decimal year to datetime is the bane of this notebook.
## pasting stuff from iceutils below.
#-*- coding: utf-8 -*-

import datetime
import numpy as np
import time
import copy
import math
import sys

def tdec2datestr(tdec_in, returndate=False):
    """
    Convert a decimaly year to an iso date string.
    """
    if isinstance(tdec_in, (list, np.ndarray)):
        tdec_list = copy.deepcopy(tdec_in)
    else:
        tdec_list = [tdec_in]
    current_list = []
    for tdec in tdec_list:
        year = int(tdec)
        yearStart = datetime.datetime(year, 1, 1)
        if year % 4 == 0:
            ndays_in_year = 366.0
        else:
            ndays_in_year = 365.0
        days = (tdec - year) * ndays_in_year
        seconds = (days - int(days)) * 86400
        tdelta = datetime.timedelta(days=int(days), seconds=int(seconds))
        current = yearStart + tdelta
        if not returndate:
            current = current.isoformat(' ').split()[0]
        current_list.append(current)

    if len(current_list) == 1:
        return current_list[0]
    else:
        return np.array(current_list)

time_arr = tdec2datestr(tfEN4.time.values)
time_arr

array(['1900-01-15', '1900-02-14', '1900-03-15', ..., '2021-11-14',
       '2021-12-15', '2022-01-14'], dtype='<U10')

Okay, finally successfully converted.  We need the obs dataset to have the same time type as the modeled one in order to use QDM `adjust`.

TODO: update the time coordinate in `tfEN4` to be this datetime type, rather than float year.

In [31]:
tfEN4.time = time_arr

AttributeError: cannot set attribute 'time' on a 'DataArray' object. Use __setitem__ styleassignment (e.g., `ds['name'] = ...`) instead of assigning variables.

In [15]:
tfEN4.sel(time=slice(PeriodObs0[0], PeriodObs0[1]))

---

Now time to try something fun: the cmethods Quantile Delta Mapping.  See [example notebook](https://github.com/ehultee/gris-iceocean-process/blob/main/python-cmethods_examples.ipynb) added by DF.

QDM `adjust` from cmethods needs datasets defined and input as simulated historical (`simh`), simulated projection (`simp`), and observed historical against which to bias-correct (`obs`).

Slice the EN4 dataset for the obs period defined by Vincent's `PeriodObs0`.  Import the example dataset of CESM2 TF for the same depth range and bathymetric threshold.

In [6]:
ds3 = xr.open_dataset('/home/theghub/ehultee/data/tfdpavg-CESM2-2024-11-14.nc')
ds3

This is a short example dataset.  For the sake of argument, let's take a very short correction period over the first half, and use the second half as the projection.

In [16]:
# to adjust a 3d dataset
qdm_result = adjust(
    method = "quantile_delta_mapping",
    obs = tfEN4.sel(time=slice('2000','2007')),
    simh = ds3.TF.sel(time=slice('2000', '2007')),
    simp = ds3.TF.sel(time=slice('2007', '2014')),
    n_quantiles = 100,
    kind = "+", # to calculate the relative rather than the absolute change, "*" can be used instead of "+" (this is prefered when adjusting precipitation)
)

ValueError: cannot align objects with join='exact' where index/labels/sizes are not equal along these coordinates (dimensions): 'time' ('time',)

## Define some functions 

---
Vincent's legacy code below here

In [None]:
for mm,member in enumerate(ls_members):
    print(f'Starting member {member}')
    ### Load Model ###
    if(To2100histssp585):
        Modloc = DirModTFnc+f'ensemble{SelModel}_hist2100ssp585_M{member}_TFdpavg_Dp0to500_bathymin100.nc'
    elif(To2100histssp126):
        Modloc = DirModTFnc+f'ensemble{SelModel}_hist2100ssp126_M{member}_TFdpavg_Dp0to500_bathymin100.nc'
    elif(To2015hist):
        Modloc = DirModTFnc+f'ensemble{SelModel}_hist_M{member}_TFdpavg_Dp0to500_bathymin100.nc'
    ds       = nc.Dataset(Modloc)
    if(mm==0):
        timeMod  = np.array(ds.variables['time'])
        depthMod = np.array(ds.variables['depth'])
        latsMod  = np.array(ds.variables['lat'])
        lonsMod  = np.array(ds.variables['lon'])
    tfMod    = np.array(ds.variables[f'tfdpavg{DepthRange[0]}to{DepthRange[1]}_bathymin{ShallowThreshold}'])
    ds.close()
    
    ### Constrain observational period to be within model time ###
    if(mm==0):
        PeriodObs    = copy.deepcopy(PeriodObs0)
        PeriodObs[0] = int(np.round(max(PeriodObs[0],timeMod[0]),decimals=0))
        PeriodObs[1] = int(np.round(min(PeriodObs[1],timeMod[-1]),decimals=0))
        indsEN4      = np.logical_and(timeEN4>=PeriodObs[0],timeEN4<=PeriodObs[1])
        iM0          = np.where(timeMod>=PeriodObs[0])[0][0]
        iM1          = np.where(timeMod<=PeriodObs[1])[0][-1]
        
    if(timeMod[0]<PeriodObs[0] and timeMod[-1]>PeriodObs[1]):
        # If model extends before and after PeriodObs: remove pre-PeriodObs #
        print('Model extending before and after PeriodObs: two separate QDMs')
        ls_timeMod = [timeMod[0:iM1+1],timeMod[iM0:]]
        ls_tfMod   = [tfMod[0:iM1+1,:,:],tfMod[iM0:,:,:]]
        nQDM       = 2 #two QDMs to perform
    else:
        ls_timeMod = [timeMod]
        ls_tfMod   = [tfMod]
        nQDM       = 1 #a single QDM to perform
        
    ls_indsModPeriod = [np.logical_and(mytiming>=PeriodObs[0],mytiming<=PeriodObs[1]) for mytiming in ls_timeMod]
    ls_indsModProj   = [myindsmodperiod==False for myindsmodperiod in ls_indsModPeriod]
    ls_timehistMod   = [ls_timeMod[kk][ls_indsModPeriod[kk]] for kk in range(nQDM)] #should be identical
    ls_timeprojMod   = [ls_timeMod[kk][ls_indsModPeriod[kk]==False] for kk in range(nQDM)]
    
    ### Prepare full output of QM with comparison for statistics ###
    #nProc          = len(np.where(sectorSlater!='NoSector')[0])
    nProc          = len(lat1dMod)
    rawhistall     = np.zeros((nProc,len(ls_timehistMod[0])))
    bchistall      = np.zeros((nProc,len(ls_timehistMod[0])))
    ls_rawprojall  = [np.zeros((nProc,len(ls_timeprojMod[kk]))) for kk in range(nQDM)]
    ls_bcprojall   = [np.zeros((nProc,len(ls_timeprojMod[kk]))) for kk in range(nQDM)]
    compEN4histall = np.zeros((nProc,len(ls_timehistMod[0])))
    qmtfdpavg      = 1.1e20*np.ones_like(tfMod)
    
    ### QM applied to each Model grid point ###
    trackrowindex  = np.zeros(nProc).astype(int)
    rr = 0
    tic = time.time()
    for rowind in range(len(lat1dMod)):
        if(np.random.uniform(0,1)>0.999):
            print(f'{rowind}/{len(lat1dMod)}')
        
        # ii,jj of the gridpoint in Mod and Nearest-Neighbor of EN4 #
        iiMod,jjMod,iiEN4,jjEN4 = indyMod[rowind],indxMod[rowind],indyEN4[rowind],indxEN4[rowind]
        
        for qq in range(nQDM):
            # Apply constraints #
            constrMin = 0
            constrMax = np.nan
            # Quantile Mapping #
            qqtimeMod = ls_timeMod[qq]
            qqtfMod   = ls_tfMod[qq][:,iiMod,jjMod]
            qqtimingModHist,qqtfbcModHist,qqtimingModProj,qqtfbcModProj = \
                qmProjCannon2015(timeEN4,tfEN4[:,iiEN4,jjEN4],qqtimeMod,qqtfMod,PeriodObs,
                                             windowProj=yrWindowProj,ConstraintMin=constrMin,ConstraintMax=constrMax,exclSigma=SigmaExclusion)

            if(qq==0 and qqtimingModHist[0]<qqtimingModProj[0]):
                srqmtf = np.append(qqtfbcModHist,qqtfbcModProj)
            elif(qq==0 and qqtimingModProj[0]<qqtimingModHist[0]):
                srqmtf = np.append(qqtfbcModProj,qqtfbcModHist)
            elif(qq==1):
                srqmtf = np.append(srqmtf,qqtfbcModProj)
            
            # Save corresponding raw and bias-corrected time series #
            if(qq==0):
                rawhistall[rr,:] = qqtfMod[ls_indsModPeriod[0]]
                bchistall[rr,:]  = np.copy(qqtfbcModHist)
            ls_rawprojall[qq][rr,:] = np.copy(qqtfMod[ls_indsModProj[qq]])
            ls_bcprojall[qq][rr,:]  = np.copy(qqtfbcModProj)

        # Interpolate nearest-neighbor EN4 time series on the Model hist time #
        finterp = interpolate.interp1d(timeEN4,tfEN4[:,iiEN4,jjEN4],kind='linear',bounds_error=True) #linear interpolator
        compEN4histall[rr,:] = finterp(ls_timehistMod[0])
        # Save depth-average QM-corrected TF #
        qmtfdpavg[:,iiMod,jjMod] = np.copy(srqmtf)
        # Keep track of row index in the nearest neighbor matrix #
        trackrowindex[rr] = np.copy(rowind)
        
        rr += 1
    
    ### Compute Stats ###
    rawRMSE = np.sqrt(np.mean((rawhistall-compEN4histall)**2))
    bcRMSE  = np.sqrt(np.mean((bchistall-compEN4histall)**2))
    rawBias = np.mean(rawhistall-compEN4histall)
    bcBias  = np.mean(bchistall-compEN4histall)
    print(f'Raw {SelModel} bias, RMSE: {np.round(rawBias,3)} K   {np.round(rawRMSE,3)} K')
    print(f'QM-corrected {SelModel} bias, RMSE: {np.round(bcBias,3)} K   {np.round(bcRMSE,3)} K')
    
    if(savingQMtf):
                
        if(To2015hist):
            nameout = f'ensemble{SelModel}_qdmhist_M{member}_ObsPer{PeriodObs[0]}to{PeriodObs[1]}.nc'
        if(To2100histssp585):
            nameout = f'ensemble{SelModel}_qdmhist2100ssp585_M{member}_ObsPer{PeriodObs[0]}to{PeriodObs[1]}.nc'
        if(To2100histssp126):
            nameout = f'ensemble{SelModel}_qdmhist2100ssp126_M{member}_ObsPer{PeriodObs[0]}to{PeriodObs[1]}.nc'
        if(dim2d):
            nny = np.shape(latsMod)[0]
            nnx = np.shape(latsMod)[1]
        if(dim2d==False):
            nny = len(latsMod)
            nnx = len(lonsMod)
        
        ### Open netcdf ###
        outnc        = nc.Dataset(DirModTFnc+nameout,'w',format='NETCDF4')
        timedim      = outnc.createDimension('timeDim',size=len(timeMod)) 
        zdim         = outnc.createDimension('depthDim',size=len(depthMod)) 
        latdim       = outnc.createDimension('latDim',nny) 
        londim       = outnc.createDimension('lonDim',nnx) 
        
        time_nc         = outnc.createVariable('time','f4',('timeDim',))
        depth_nc        = outnc.createVariable('depth','f4',('depthDim',))
        if(dim2d):
            lat_nc      = outnc.createVariable('lat','f4',('latDim','lonDim',))
            lon_nc      = outnc.createVariable('lon','f4',('latDim','lonDim',))
        if(dim2d==False):
            lat_nc      = outnc.createVariable('lat','f4',('latDim',))
            lon_nc      = outnc.createVariable('lon','f4',('lonDim',))
        tfdpavg_nc      = outnc.createVariable(f'tfdpavg{DepthRange[0]}to{DepthRange[1]}_bathymin{ShallowThreshold}','f4',('timeDim','latDim','lonDim',))
        qmen4tfdpavg_nc = outnc.createVariable(f'qmen4tfdpavg{DepthRange[0]}to{DepthRange[1]}_bathymin{ShallowThreshold}','f4',('timeDim','latDim','lonDim',))
            
        time_nc[:]             = timeMod
        depth_nc[:]            = depthMod
        if(dim2d):
            lat_nc[:,:]        = latsMod
            lon_nc[:,:]        = lonsMod
        if(dim2d==False):
            lat_nc[:]          = latsMod
            lon_nc[:]          = lonsMod
        tfdpavg_nc[:,:,:]      = tfMod
        qmen4tfdpavg_nc[:,:,:] = qmtfdpavg
        
        depth_nc.units         = 'meter'
        time_nc.units          = 'yr'
        tfdpavg_nc.units       = 'degC'
        qmen4tfdpavg_nc.units  = 'degC'
        outnc.close()