# Calculate the seasonal means of errors in MDM and QDM+ sort predictions
- MDM: Moment Delta Mapping
- Consider the MDM operators L_1 and L_2 which map the mean and mean and stdev of a distribution respectively
- MDM error for operator L_i  = [L_i (Q_historical) ] - Q_future
- In this notebook, Q_historical is the pre-industrial quantile function and Q_future is the End of 21st century quantile function
- These quantile functions are constructed from detrended CESM2 LENS data
- We can similarly caluclate errors for the Gaussian operators G1 and G2 (see paper for more details)
- In this notebook, we compute the seasonal mean of the MDM error for various operators

In [1]:
import numpy as np
import xarray as xr
from distributed import Client
import dask_jobqueue
import matplotlib.pyplot as plt
import matplotlib as mtplt
import glob
import netCDF4 as nc
import zarr
import nc_time_axis
#import xskillscore as xs
import cartopy as cart
import matplotlib.colors as mcolors
import matplotlib.cm as cm
from scipy.special import erfinv, erf
import shapely
import warnings
from shapely.errors import ShapelyDeprecationWarning
warnings.filterwarnings("ignore", category=ShapelyDeprecationWarning) 



In [2]:
def to_daily(ds):
    year = ds.time.dt.year
    day = ds.time.dt.dayofyear

    # assign new coords
    ds = ds.assign_coords(year=("time", year.data), day=("time", day.data))

    # reshape the array to (..., "day", "year")
    return ds.set_index(time=("year", "day")).unstack("time")  

In [3]:
quants = np.linspace(0,1.0,30)
def compute_quantiles(ds, quantiles=quants):
    return ds.chunk(dict(year=-1)).quantile(quantiles, dim='year',skipna=False)

def implement_mdm(ds_obs,init_mean,final_mean,init_std,final_std):
    # Assuming coordinates year and day for ds_obs
    obs_mean = ds_obs.mean('year')
    sratio   = final_std/init_std
    ds_mdm   = obs_mean + (final_mean - init_mean) + sratio*(ds_obs - obs_mean)        
    return ds_mdm

def implement_shift(ds_obs,init_mean,final_mean):
    # Assuming coordinates year and day for ds_obs
    obs_mean   = ds_obs.mean('year')
    ds_shift   = obs_mean + (final_mean - init_mean) + ds_obs - obs_mean        
    return ds_shift

def implement_qdm(qobs, qinit, qfinal):
    # Implement quantile delta mapping. Assumes model data has corrd called 'mtime'
    #qinit    = init_data.quantile(quants,dim='mtime')
    #qfinal   = final_data.quantile(quants,dim='mtime')
    ds_qdm   = qobs + (qfinal - qinit)
    return ds_qdm

def is_sorted(arr):
    return np.all(arr[:-1] <= arr[1:]) or np.all(arr[:-1] >= arr[1:])

In [4]:
def gauss_quantile(mean,std,quantile):
    qvalue = mean + std * np.sqrt(2) * erfinv(2*quantile-1)
    return qvalue

In [5]:
def select_months(data, months):
    """
    Selects data for specific months from an xarray DataArray based on dayofyear and year coordinates.

    Parameters:
    data (xarray.DataArray): The input DataArray with 'dayofyear' and 'year' coordinates.
    months (list of int): List of three months to select (e.g., [12, 1, 2] for DJF).

    Returns:
    xarray.DataArray: Filtered DataArray for the specified months.
    """
    # Define the day ranges for each month assuming a 365-day calendar
    month_ranges = {
        1: (1, 31),   # January
        2: (32, 59),  # February
        3: (60, 90),  # March
        4: (91, 120), # April
        5: (121, 151),# May
        6: (152, 181),# June
        7: (182, 212),# July
        8: (213, 243),# August
        9: (244, 273),# September
        10: (274, 304),# October
        11: (305, 334),# November
        12: (335, 365)# December
    }

    # Get the start and end day of the year for each selected month
    days_to_select = []
    for month in months:
        start_day, end_day = month_ranges[month]
        days_to_select.extend(range(start_day, end_day + 1))

    # Filter the data for the specified days of the year
    filtered_data = data.where(data['dayofyear'].isin(days_to_select), drop=True)

    return filtered_data

In [6]:
def altspace(start, step, count, endpoint=False, **kwargs):
   stop = start+(step*count)
   return np.linspace(start, stop, count, endpoint=endpoint, **kwargs)

In [7]:
# #Chicago
LAT        = 41.8781
LON        = (360-87.6298)%360
# # #Bengaluru:
# LAT        = 12.9716
# LON        = 77.5946
######################
chic_lat  = 41.8781
chic_lon  = (360-87.6298)%360
ben_lat   = 12.9716
ben_lon   = 77.5946
########## CONUS ############
#CONUS lat-lon
top       = 50.0 # north lat
left      = -124.7844079+360 # west long
right     = -66.9513812+360 # east long
bottom    =  24.7433195 # south lat
################################
savefigs   = '/global/scratch/users/harsha/savefigs/Feb21/'
cesm2_path = '/global/scratch/users/harsha/LENS/cesm2/tasmax/'
cvals      = '/global/scratch/users/harsha/LENS/cesm2/cvals/'
cvals_det  = '/global/scratch/users/harsha/LENS/cesm2/cvals/detrended/'
loc        = 'Chicago/'
locn       = 'Chicago'
cvals1     = '/global/scratch/users/harsha/LENS/cesm2/tmax_mem'
###########
print('(lat,lon)=',LAT,LON)

(lat,lon)= 41.8781 272.3702


In [8]:
job_extra = ['--qos=cf_lowprio','--account=ac_cumulus'] 
#job_extra =['--qos=lr6_lowprio','--account=ac_cumulus','--constraint=lr6_m192']
#job_extra =['--qos=condo_cumulus_lr6','--account=lr_cumulus','--constraint=lr6_m192']
#job_extra =['--qos=lr_lowprio','--account=ac_cumulus']
cluster = dask_jobqueue.SLURMCluster(queue="cf1", cores=20, walltime='6:00:00', 
                local_directory='/global/scratch/users/harsha/dask_space/', 
                log_directory='/global/scratch/users/harsha/dask_space/', 
                job_extra_directives=job_extra, interface='eth0', memory="192GB") 
client  = Client(cluster) 
cluster.scale_up(3)

Perhaps you already have a cluster running?
Hosting the HTTP server on port 46787 instead


In [9]:
cluster

0,1
Dashboard: http://10.0.39.4:46787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.0.39.4:45905,Workers: 0
Dashboard: http://10.0.39.4:46787/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [10]:
pi_year0 = '1850'
pi_year1 = '1879'
ic_year0 = '2071'
ic_year1 = '2100'
pi_year  = 1865
# eoc_year = 2085 #Central year used for detrending
ic_year  = 2086 #This is the correct year for the rolling window operation to work
doy      = 211 # day_of_year
## Change/Check these parameters before running

### Load detrended data

In [11]:
%%time
pi_det   = xr.open_zarr(cvals_det+'pi_detrended.zarr').detrended_tmax
eoc_det  = xr.open_zarr(cvals_det+'eoc_detrended.zarr').detrended_tmax
#
# pi_det   = pi_det.chunk({'member':100,'lat':8,'lon':10,'year':30,'dayofyear':35})
# eoc_det  = eoc_det.chunk({'member':100,'lat':8,'lon':10,'year':30,'dayofyear':35})

BrokenPipeError: [Errno 108] Cannot send after transport endpoint shutdown

### Apply MDM, compute quantile funtions and save data

In [None]:
# Compute mean and std over year i.e, annual mean and stds 
pi_amean  = pi_det.mean('year')
eoc_amean = eoc_det.mean('year')
pi_astd   = pi_det.std('year')
eoc_astd  = eoc_det.std('year')

In [None]:
pi_mdm = implement_mdm(pi_det,pi_amean,eoc_amean,pi_astd,eoc_astd)

In [None]:
pi_shift = implement_shift(pi_det,pi_amean,eoc_amean)

In [15]:
%%time
# pi_mdm.to_dataset().to_zarr(cvals_det+'pi_mdm_intra_simu'+'_'+ic_year0+'_'+ic_year1+'.zarr',mode='w')

CPU times: user 0 ns, sys: 15 µs, total: 15 µs
Wall time: 26.5 µs


In [None]:
%%time
pi_mdm = xr.open_zarr(cvals_det+'pi_mdm_intra_simu'+'_'+ic_year0+'_'+ic_year1+'.zarr').detrended_tmax
pi_mdm

In [17]:
# %%time
# pi_mdm.sel(lat=LAT,lon=LON,method='nearest').sel(member=0,dayofyear=1).values

### Compute quantiles and errors for each season

#### Pick season and months #######

In [None]:
#DJF 
# months         = [12,1,2]
# season         = 'djf'
#
# #MAM
months         = [3,4,5]
season         =  'mam'
#
# #JJA
# months         = [6,7,8]
# season         =  'jja'
#
# #SON
# months         = [9,10,11]
# season         =  'son'

In [None]:
pi_mdm_season   = select_months(pi_mdm,months)
pi_season       = select_months(pi_det,months)
eoc_season      = select_months(eoc_det,months)
#
# Compute mean and std over year i.e, annual mean and stds 
pi_amean_season  = select_months(pi_amean,months)
eoc_amean_season = select_months(eoc_amean,months)
pi_astd_season   = select_months(pi_astd,months)
eoc_astd_season  = select_months(eoc_astd,months)
#
pi_shift_season = implement_shift(pi_season,pi_amean_season,eoc_amean_season)

In [None]:
%%time
qpi_mdm_season   = compute_quantiles(pi_mdm_season)
qpi_shift_season = compute_quantiles(pi_shift_season)
qeoc_season      = compute_quantiles(eoc_season)
qpi_season       = compute_quantiles(pi_season)

In [None]:
%%time
qmdm_err_season   = qpi_mdm_season   - qeoc_season
qshift_err_season = qpi_shift_season - qeoc_season
#
qmdm_err_season   = qmdm_err_season.chunk({'dayofyear':14})
qshift_err_season = qshift_err_season.chunk({'dayofyear':14})
qshift_err_season

In [None]:
# %%time
# qmdm_err_season.to_dataset().to_zarr(cvals_det+'mdm_err_'+season+'_'+ic_year0+'_'+ic_year1+'.zarr',mode='w')

Task exception was never retrieved
future: <Task finished name='Task-339424' coro=<Scheduler.handle_request_refresh_who_has() done, defined at /global/home/users/harsha/miniconda3/envs/pyenv/lib/python3.9/site-packages/distributed/scheduler.py:5622> exception=KeyError('tcp://10.0.39.13:41875')>
Traceback (most recent call last):
  File "/global/home/users/harsha/miniconda3/envs/pyenv/lib/python3.9/site-packages/distributed/scheduler.py", line 5638, in handle_request_refresh_who_has
    self.stream_comms[worker].send(
KeyError: 'tcp://10.0.39.13:41875'
Task exception was never retrieved
future: <Task finished name='Task-340556' coro=<Scheduler.handle_request_refresh_who_has() done, defined at /global/home/users/harsha/miniconda3/envs/pyenv/lib/python3.9/site-packages/distributed/scheduler.py:5622> exception=KeyError('tcp://10.0.39.13:41875')>
Traceback (most recent call last):
  File "/global/home/users/harsha/miniconda3/envs/pyenv/lib/python3.9/site-packages/distributed/scheduler.py", 

In [None]:
# %%time
# qshift_err_season.to_dataset().to_zarr(cvals_det+'shift_err_'+season+'_'+ic_year0+'_'+ic_year1+'.zarr',mode='w')

In [None]:
## Open files
qmdm_err_season   = xr.open_zarr(cvals_det+'mdm_err_'+season+'_'+ic_year0+'_'+ic_year1+'.zarr').detrended_tmax
qshift_err_season = xr.open_zarr(cvals_det+'shift_err_'+season+'_'+ic_year0+'_'+ic_year1+'.zarr').detrended_tmax

In [None]:
qshift_err_season.sel(lat=LAT,lon=LON,method='nearest').sel(member=0,dayofyear=1).values

### Compute Gaussian quantile functions and errors

In [None]:
# Create a DataArray for the quantiles
quantiles = xr.DataArray(np.linspace(0,1.0,30), dims='quantile').assign_coords(quantile=('quantile',quants))

In [None]:
####### Construct gaussians for EOC ################
qgauss_standard_normal        = xr.apply_ufunc(erfinv,2*quantiles-1)
#################
qpi_gauss_season      = eoc_amean_season + eoc_astd_season * np.sqrt(2) *qgauss_standard_normal
qpi_gauss_season.name = 'qgauss_detrended'
qpi_gauss_season
## gaussmo = gaussian with eoc mean only and pre-ind std
qpi_gaussmo_season      = eoc_amean_season + pi_astd_season * np.sqrt(2) * qgauss_standard_normal      
qpi_gaussmo_season.name = 'qgaussmo_detrended'
qpi_gaussmo_season

In [None]:
%%time
qgauss_err_season   = qpi_gauss_season   - qeoc_season
qgaussmo_err_season = qpi_gaussmo_season - qeoc_season
#
qgauss_err_season.name   = 'qgauss_error'
qgaussmo_err_season.name = 'qgaussmo_error'
#
qgauss_err_season

In [None]:
%%time
qgauss_err_season.to_dataset().to_zarr(cvals_det+'gauss_err_'+season+'_'+ic_year0+'_'+ic_year1+'.zarr',mode='w')
qgaussmo_err_season.to_dataset().to_zarr(cvals_det+'gaussmo_err_'+season+'_'+ic_year0+'_'+ic_year1+'.zarr',mode='w')

In [None]:
%%time
qgauss_err_season   = xr.open_zarr(cvals_det+'gauss_err_'+season+'_'+ic_year0+'_'+ic_year1+'.zarr').qgauss_error
qgaussmo_err_season = xr.open_zarr(cvals_det+'gaussmo_err_'+season+'_'+ic_year0+'_'+ic_year1+'.zarr').qgaussmo_error

### Plot MDM errors

In [None]:
######## New colorbar only for shift + stretch and gaussian ############
x = 5
# create a figure and axis
fig, ax = plt.subplots(figsize=(6, 1))
fig.subplots_adjust(bottom=0.5)
# Define colormap
cmap = plt.get_cmap('RdBu_r')
# Make a norm object with the center at 0: TwoSlopeNorm
norm = mcolors.TwoSlopeNorm(vmin=-x, vcenter=0, vmax=x)
# Making numpy array from -3 to 3, with step 0.2
values = np.arange(-x, x+0.25, 0.25)
# Creating a mappable object and setting the norm and cmap for colorbar
mappable = cm.ScalarMappable(norm=norm, cmap=cmap)
mappable.set_array([])
# Creating a colorbar
ticks1 = altspace(-4,1,9)
cbar = plt.colorbar(mappable, ax=ax, orientation='vertical',ticks=ticks1)
cbar.set_label('')
plt.gca().set_visible(False)

In [None]:
%%time
member =0 
# Creating a figure and axes
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10,5), subplot_kw={'projection': cart.crs.PlateCarree()}\
                       , gridspec_kw = {'wspace':0.05, 'hspace':0.2})

# Plotting qgmsc_err
qmdm_err_season.sel(quantile=0.9,method='nearest').sel(member=member).mean('dayofyear').plot(ax=axs[0], transform=cart.crs.PlateCarree(),\
                                                    add_colorbar=False, cmap=cmap, norm=norm)
axs[0].coastlines(color="black")
axs[0].set_title(season+r': $L_2(Q_h) - Q_f$ at $p=0.9$')

# Plotting G2, gaussian error
im = qgauss_err_season.sel(quantile=0.9,method='nearest').sel(member=member).mean('dayofyear').plot(ax=axs[1], transform=cart.crs.PlateCarree(),\
                                                          add_colorbar=False, cmap=cmap, norm=norm)
axs[1].coastlines(color="black")
axs[1].set_title(season+': $G_2(Q_h) - Q_f$ at $p=0.9$')

# Adding colorbar
cbar = plt.colorbar(im, ax=axs.ravel().tolist(), shrink=0.5, orientation='vertical')
cbar.set_label('Error')

plt.show()

In [None]:
%%time
# Creating a figure and axes
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10,5), subplot_kw={'projection': cart.crs.PlateCarree()}\
                       , gridspec_kw = {'wspace':0.05, 'hspace':0.2})

# Plotting qgmsc_err
qshift_err_season.sel(quantile=0.9,method='nearest').sel(member=member).mean('dayofyear').plot(ax=axs[0], transform=cart.crs.PlateCarree(),\
                                                    add_colorbar=False, cmap=cmap, norm=norm)
axs[0].coastlines(color="black")
axs[0].set_title(season+': $L_1(Q_i) - Q_f$ at $p=0.9$')

# Plotting da2
im = qgaussmo_err_season.sel(quantile=0.9,method='nearest').sel(member=member).mean('dayofyear').plot(ax=axs[1], transform=cart.crs.PlateCarree(),\
                                                          add_colorbar=False, cmap=cmap, norm=norm)
axs[1].coastlines(color="black")
axs[1].set_title(season+r': $G_1(Q_i) - Q_f$ at $p=0.9$')

# Adding colorbar
cbar = plt.colorbar(im, ax=axs.ravel().tolist(), shrink=0.5, orientation='vertical')
cbar.set_label('Errors')

plt.show()

### Compute ensemble mean of absolute errors

In [None]:
qmdm_abserr_season = np.abs(qmdm_err_season)


In [None]:
#############################################################################################################