# Calculate the seasonal means of errors in MDM and QDM+ sort predictions
- MDM: Moment Delta Mapping
- Consider the MDM operators L_1 and L_2 which map the mean and mean and stdev of a distribution respectively
- MDM error for operator L_i  = [L_i (Q_historical) ] - Q_future
- In this notebook, Q_historical is the pre-industrial quantile function and Q_future is the End of 21st century quantile function
- These quantile functions are constructed from detrended CESM2 LENS data
- We can similarly caluclate errors for the Gaussian operators G1 and G2 (see paper for more details)
- In this notebook, we compute the seasonal mean of the MDM error for various operators

In [1]:
import numpy as np
import xarray as xr
from distributed import Client
import dask_jobqueue
import matplotlib.pyplot as plt
import matplotlib as mtplt
import glob
import netCDF4 as nc
import zarr
import nc_time_axis
#import xskillscore as xs
import cartopy as cart
import matplotlib.colors as mcolors
import matplotlib.cm as cm
from scipy.special import erfinv, erf
import shapely
import warnings
from shapely.errors import ShapelyDeprecationWarning
warnings.filterwarnings("ignore", category=ShapelyDeprecationWarning) 



In [2]:
def to_daily(ds):
    year = ds.time.dt.year
    day = ds.time.dt.dayofyear

    # assign new coords
    ds = ds.assign_coords(year=("time", year.data), day=("time", day.data))

    # reshape the array to (..., "day", "year")
    return ds.set_index(time=("year", "day")).unstack("time")  

In [3]:
quants = np.linspace(0,1.0,30)
def compute_quantiles(ds, quantiles=quants):
    return ds.chunk(dict(year=-1)).quantile(quantiles, dim='year',skipna=False)

def implement_mdm(ds_obs,init_mean,final_mean,init_std,final_std):
    # Assuming coordinates year and day for ds_obs
    obs_mean = ds_obs.mean('year')
    sratio   = final_std/init_std
    ds_mdm   = obs_mean + (final_mean - init_mean) + sratio*(ds_obs - obs_mean)        
    return ds_mdm

def implement_shift(ds_obs,init_mean,final_mean):
    # Assuming coordinates year and day for ds_obs
    obs_mean   = ds_obs.mean('year')
    ds_shift   = obs_mean + (final_mean - init_mean) + ds_obs - obs_mean        
    return ds_shift

def implement_qdm(qobs, qinit, qfinal):
    # Implement quantile delta mapping. Assumes model data has corrd called 'mtime'
    #qinit    = init_data.quantile(quants,dim='mtime')
    #qfinal   = final_data.quantile(quants,dim='mtime')
    ds_qdm   = qobs + (qfinal - qinit)
    return ds_qdm

def is_sorted(arr):
    return np.all(arr[:-1] <= arr[1:]) or np.all(arr[:-1] >= arr[1:])

In [4]:
def gauss_quantile(mean,std,quantile):
    qvalue = mean + std * np.sqrt(2) * erfinv(2*quantile-1)
    return qvalue

In [5]:
def altspace(start, step, count, endpoint=False, **kwargs):
   stop = start+(step*count)
   return np.linspace(start, stop, count, endpoint=endpoint, **kwargs)

In [6]:
# #Chicago
LAT        = 41.8781
LON        = (360-87.6298)%360
# # #Bengaluru:
# LAT        = 12.9716
# LON        = 77.5946
######################
chic_lat  = 41.8781
chic_lon  = (360-87.6298)%360
ben_lat   = 12.9716
ben_lon   = 77.5946
########## CONUS ############
#CONUS lat-lon
top       = 50.0 # north lat
left      = -124.7844079+360 # west long
right     = -66.9513812+360 # east long
bottom    =  24.7433195 # south lat
################################
savefigs   = '/global/scratch/users/harsha/savefigs/Feb21/'
cesm2_path = '/global/scratch/users/harsha/LENS/cesm2/tasmax/'
cvals      = '/global/scratch/users/harsha/LENS/cesm2/cvals/'
cvals_det  = '/global/scratch/users/harsha/LENS/cesm2/cvals/detrended/'
loc        = 'Chicago/'
locn       = 'Chicago'
cvals1     = '/global/scratch/users/harsha/LENS/cesm2/tmax_mem'
###########
print('(lat,lon)=',LAT,LON)

(lat,lon)= 41.8781 272.3702


In [7]:
job_extra = ['--qos=cf_lowprio','--account=ac_cumulus'] 
#job_extra =['--qos=lr6_lowprio','--account=ac_cumulus','--constraint=lr6_m192']
#job_extra =['--qos=condo_cumulus_lr6','--account=lr_cumulus','--constraint=lr6_m192']
#job_extra =['--qos=lr_lowprio','--account=ac_cumulus']
cluster = dask_jobqueue.SLURMCluster(queue="cf1", cores=20, walltime='5:00:00', 
                local_directory='/global/scratch/users/harsha/dask_space/', 
                log_directory='/global/scratch/users/harsha/dask_space/', 
                job_extra_directives=job_extra, interface='eth0', memory="192GB") 
client  = Client(cluster) 
cluster.scale_up(2)

Perhaps you already have a cluster running?
Hosting the HTTP server on port 33461 instead


In [8]:
cluster

0,1
Dashboard: http://10.0.39.5:33461/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.0.39.5:44679,Workers: 0
Dashboard: http://10.0.39.5:33461/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [9]:
pi_year0 = '1850'
pi_year1 = '1879'
ic_year0 = '2071'
ic_year1 = '2100'
pi_year  = 1865
# eoc_year = 2085 #Central year used for detrending
ic_year  = 2086 #This is the correct year for the rolling window operation to work
doy      = 211 # day_of_year
## Change/Check these parameters before running
# season     = 'MAM'
# season0    = 'mam'
# months     = [3,4,5]
# day        =  91     #First day of the season
# date       = 'mar'

### Load detrended data

In [10]:
%%time
pi_det   = xr.open_zarr(cvals_det+'pi_detrended.zarr').detrended_tmax
eoc_det  = xr.open_zarr(cvals_det+'eoc_detrended.zarr').detrended_tmax

CPU times: user 3.06 s, sys: 650 ms, total: 3.71 s
Wall time: 4.11 s


### Apply MDM, compute quantile funtions and save data

In [11]:
# Compute mean and std over year i.e, annual mean and stds 
pi_amean  = pi_det.mean('year')
eoc_amean = eoc_det.mean('year')
pi_astd   = pi_det.std('year')
eoc_astd  = eoc_det.std('year')

In [12]:
pi_mdm = implement_mdm(pi_det,pi_amean,eoc_amean,pi_astd,eoc_astd)
pi_mdm

Unnamed: 0,Array,Chunk
Bytes,451.13 GiB,267.33 MiB
Shape,"(192, 288, 100, 365, 30)","(8, 60, 100, 73, 10)"
Dask graph,1800 chunks in 32 graph layers,1800 chunks in 32 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 451.13 GiB 267.33 MiB Shape (192, 288, 100, 365, 30) (8, 60, 100, 73, 10) Dask graph 1800 chunks in 32 graph layers Data type float64 numpy.ndarray",288  192  30  365  100,

Unnamed: 0,Array,Chunk
Bytes,451.13 GiB,267.33 MiB
Shape,"(192, 288, 100, 365, 30)","(8, 60, 100, 73, 10)"
Dask graph,1800 chunks in 32 graph layers,1800 chunks in 32 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [13]:
%%time
qpi_mdm = compute_quantiles(pi_mdm)
qeoc    = compute_quantiles(pi_det)
qpi     = compute_quantiles(eoc_det)
qpi_mdm 

CPU times: user 345 ms, sys: 24 ms, total: 369 ms
Wall time: 366 ms


Unnamed: 0,Array,Chunk
Bytes,451.13 GiB,802.00 MiB
Shape,"(30, 192, 288, 100, 365)","(30, 8, 60, 100, 73)"
Dask graph,600 chunks in 37 graph layers,600 chunks in 37 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 451.13 GiB 802.00 MiB Shape (30, 192, 288, 100, 365) (30, 8, 60, 100, 73) Dask graph 600 chunks in 37 graph layers Data type float64 numpy.ndarray",192  30  365  100  288,

Unnamed: 0,Array,Chunk
Bytes,451.13 GiB,802.00 MiB
Shape,"(30, 192, 288, 100, 365)","(30, 8, 60, 100, 73)"
Dask graph,600 chunks in 37 graph layers,600 chunks in 37 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [None]:
%%time
qpi_mdm.to_dataset().to_zarr(cvals_det+'mdm_percentile_'+'_'+ic_year0+'_'+ic_year1+'.zarr')

In [None]:
# %%time
# qpi.to_dataset().to_zarr(cvals_det+'pi_percentile_'+'_'+pi_year0+'_'+pi_year1+'.zarr')
# qeoc.to_dataset().to_zarr(cvals_det+'eoc_percentile_'+'_'+ic_year0+'_'+ic_year1+'.zarr')

### Compute Gaussian quantile function

In [None]:
####### Construct gaussians for model warming ################
qg        = xr.apply_ufunc(erfinv,2*quantiles-1)
#################
qgauss_det_ssn      = tgeoc_det_mroll0_ssn + tgeoc_det_sroll0_ssn * np.sqrt(2) *qg
qgauss_det_ssn.name = 'qgauss_detrended'
qgauss_det_ssn
## gaussm = gaussian with eoc mean only and pre-ind std
qgaussmo_det_ssn      = tgeoc_det_mroll0_ssn + tgpi_det_stack0_ssn.std('myear')* np.sqrt(2) *qg       
qgaussmo_det_ssn.name = 'qgaussmo_detrended'

In [None]:
#############################################################################################################

In [13]:
%%time
#t-temp, nw - no-warming, w-warming, ssn= season, g=global
pi_ssn   = ds_hist.TREFHTMX.sel(time=ds_hist.time.dt.month.isin(months)).\
         sel(time=slice(pi_year0,pi_year1))
eoc_ssn  = ds_ssp.TREFHTMX.sel(time=ds_ssp.time.dt.month.isin(months)).\
         sel(time=slice(ic_year0,ic_year1))

NameError: name 'months' is not defined

## Apply Moment Delta Mapping to detrended data 
- Before applying MDM to the detrended data, we need to aggregate data from the 30 years 1850-1879 along a single coordinate
- We need to do a similar aggregation for the 30-year period 2071-2100
- This is accomplished by using a rolling window centered at 1865 and 2086, repectively
- In order to understand how the rolling window is implmented in xarray, please check : https://docs.xarray.dev/en/stable/generated/xarray.DataArray.roll.html
- The stack_roll function works by first stacking data from the coordinates 'year' and 'member' onto a single coordinate called 'ym'
- We then use xr.roll to apply the rolling window to this new coordinate 

In [19]:
%%time
#Stack member and year into one dimension- 'ym'. Then, roll over this dimension with stride = N_mem.
#This works because of the way xarray stacks two dimensions. window_len =30. 
#I have called the 'window' dimension 'index'.
quants  = np.arange(0,1.01,0.01)
########################
N_mem       = 100 
window_len  = 30
#### July 30 #######
tgpi_detroll_ssn   = stack_roll(tgpi_det_ssn,N_mem*window_len,N_mem,stackdim_name='ym')
tgeoc_detroll_ssn  = stack_roll(tgeoc_det_ssn,N_mem*window_len,N_mem,stackdim_name='ym')
#########
tgpi_detroll0_ssn  = tgpi_detroll_ssn.sel(ym=[15])  
tgeoc_detroll0_ssn = tgeoc_detroll_ssn.sel(ym=[15])  
tgpi_detroll0_ssn

CPU times: user 3.62 s, sys: 176 ms, total: 3.8 s
Wall time: 3.84 s


Unnamed: 0,Array,Chunk
Bytes,113.71 GiB,45.78 MiB
Shape,"(192, 288, 92, 1, 3000)","(5, 40, 10, 1, 3000)"
Dask graph,3120 chunks in 45 graph layers,3120 chunks in 45 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 113.71 GiB 45.78 MiB Shape (192, 288, 92, 1, 3000) (5, 40, 10, 1, 3000) Dask graph 3120 chunks in 45 graph layers Data type float64 numpy.ndarray",288  192  3000  1  92,

Unnamed: 0,Array,Chunk
Bytes,113.71 GiB,45.78 MiB
Shape,"(192, 288, 92, 1, 3000)","(5, 40, 10, 1, 3000)"
Dask graph,3120 chunks in 45 graph layers,3120 chunks in 45 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [20]:
%%time
############# Save files ########################
tgeoc_detroll_ssnpath  = cvals_det + 'tgeoc_detrended_ssnroll_'+season0+'_'+ic_year0+'_'+ic_year1+'.zarr'
tgpi_detroll_ssnpath   = cvals_det + 'tgpi_detrended_ssnroll_'+season0+'_'+pi_year0+'_'+pi_year1+'.zarr'
# Save eoc detrended rolled array
tgeoc_detroll0_ssn     =  tgeoc_detroll0_ssn.rename('eoc_detrended').assign_coords(ym=('ym',[2086]))
tgpi_detroll0_ssn      =  tgpi_detroll0_ssn.rename('pi_detrended').assign_coords(ym=('ym',[2086]))
#
print(tgpi_detroll_ssnpath, tgeoc_detroll_ssnpath) 
# tgpi_detroll0_ssn.to_dataset().to_zarr(tgpi_detroll_ssnpath,mode='w') 
# tgeoc_detroll0_ssn.to_dataset().to_zarr(tgeoc_detroll_ssnpath,mode='w')

/global/scratch/users/harsha/LENS/cesm2/cvals/detrended/tgpi_detrended_ssnroll_mam_1850_1879.zarr /global/scratch/users/harsha/LENS/cesm2/cvals/detrended/tgeoc_detrended_ssnroll_mam_2071_2100.zarr
CPU times: user 22.3 ms, sys: 5.7 ms, total: 28 ms
Wall time: 25.3 ms


In [21]:
%%time
#Compute the pre-industrial + moments adjusted curves
#First compute rolling means and stdevs and save them
#Open
tgeoc_detroll_ssn  = xr.open_zarr(tgeoc_detroll_ssnpath).eoc_detrended
tgpi_detroll_ssn   = xr.open_zarr(tgpi_detroll_ssnpath).pi_detrended
#
tgeoc_det_mroll_ssn = tgeoc_detroll_ssn.mean(dim='index')
tgeoc_det_sroll_ssn = tgeoc_detroll_ssn.std(dim='index')
#############
tgeoc_det_mroll_ssn

CPU times: user 732 ms, sys: 52.7 ms, total: 785 ms
Wall time: 804 ms


Unnamed: 0,Array,Chunk
Bytes,38.81 MiB,15.62 kiB
Shape,"(192, 288, 92, 1)","(5, 40, 10, 1)"
Dask graph,3120 chunks in 4 graph layers,3120 chunks in 4 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 38.81 MiB 15.62 kiB Shape (192, 288, 92, 1) (5, 40, 10, 1) Dask graph 3120 chunks in 4 graph layers Data type float64 numpy.ndarray",192  1  1  92  288,

Unnamed: 0,Array,Chunk
Bytes,38.81 MiB,15.62 kiB
Shape,"(192, 288, 92, 1)","(5, 40, 10, 1)"
Dask graph,3120 chunks in 4 graph layers,3120 chunks in 4 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,8 B,8 B
Shape,"(1,)","(1,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 8 B 8 B Shape (1,) (1,) Dask graph 1 chunks in 2 graph layers Data type int64 numpy.ndarray",1  1,

Unnamed: 0,Array,Chunk
Bytes,8 B,8 B
Shape,"(1,)","(1,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,8 B,8 B
Shape,"(1,)","(1,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 8 B 8 B Shape (1,) (1,) Dask graph 1 chunks in 2 graph layers Data type int64 numpy.ndarray",1  1,

Unnamed: 0,Array,Chunk
Bytes,8 B,8 B
Shape,"(1,)","(1,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray


In [22]:
#Now prepare the pre-industrial data array in order to apply the moment delta mapping transformations
###Add a dummy dimension called 'ym' to tgpi_detrended after stacking year and member as 'myear'
############ det ###########
tgpi_det_stack_ssn = stacker(tgpi_det_ssn,'myear')
###################################
tgpi_det_stack0_ssn= tgpi_det_stack_ssn.expand_dims(dim={'ym':1}).\
                 assign_coords(ym=('ym',[2086]))
tgpi_det_stack0_ssn

Unnamed: 0,Array,Chunk
Bytes,113.71 GiB,15.26 MiB
Shape,"(1, 192, 288, 92, 3000)","(1, 5, 40, 10, 1000)"
Dask graph,9360 chunks in 37 graph layers,9360 chunks in 37 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 113.71 GiB 15.26 MiB Shape (1, 192, 288, 92, 3000) (1, 5, 40, 10, 1000) Dask graph 9360 chunks in 37 graph layers Data type float64 numpy.ndarray",192  1  3000  92  288,

Unnamed: 0,Array,Chunk
Bytes,113.71 GiB,15.26 MiB
Shape,"(1, 192, 288, 92, 3000)","(1, 5, 40, 10, 1000)"
Dask graph,9360 chunks in 37 graph layers,9360 chunks in 37 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [23]:
%%time
#############
tgmc_det_ssn  = mean_adj(tgpi_det_stack0_ssn, tgeoc_det_mroll_ssn)
tgmsc_det_ssn = mean_std_adj(tgpi_det_stack0_ssn, tgeoc_det_mroll_ssn,tgeoc_det_sroll_ssn)
tgmc_det_ssn

CPU times: user 17.9 s, sys: 825 ms, total: 18.7 s
Wall time: 50.3 s


Unnamed: 0,Array,Chunk
Bytes,113.71 GiB,15.26 MiB
Shape,"(1, 192, 288, 92, 3000)","(1, 5, 40, 10, 1000)"
Dask graph,9360 chunks in 49 graph layers,9360 chunks in 49 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 113.71 GiB 15.26 MiB Shape (1, 192, 288, 92, 3000) (1, 5, 40, 10, 1000) Dask graph 9360 chunks in 49 graph layers Data type float64 numpy.ndarray",192  1  3000  92  288,

Unnamed: 0,Array,Chunk
Bytes,113.71 GiB,15.26 MiB
Shape,"(1, 192, 288, 92, 3000)","(1, 5, 40, 10, 1000)"
Dask graph,9360 chunks in 49 graph layers,9360 chunks in 49 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


### Compute Gaussian quantile function

In [27]:
########## EOC gaussians ###########
tgeoc_det_mroll0_ssn = tgeoc_det_mroll_ssn.expand_dims(dim={'quantile':101}).\
               assign_coords(quantile=('quantile',quants))
tgeoc_det_sroll0_ssn = tgeoc_det_sroll_ssn.expand_dims(dim={'quantile':101}).\
               assign_coords(quantile=('quantile',quants))
# Create a DataArray for the quantiles
quantiles = xr.DataArray(np.arange(0, 1.01, 0.01), dims='quantile').\
assign_coords(quantile=('quantile',quants))
####### Construct gaussians for model warming ################
qg        = xr.apply_ufunc(erfinv,2*quantiles-1)
#################
qgauss_det_ssn      = tgeoc_det_mroll0_ssn + tgeoc_det_sroll0_ssn * np.sqrt(2) *qg
qgauss_det_ssn.name = 'qgauss_detrended'
qgauss_det_ssn
## gaussm = gaussian with eoc mean only and pre-ind std
qgaussmo_det_ssn      = tgeoc_det_mroll0_ssn + tgpi_det_stack0_ssn.std('myear')* np.sqrt(2) *qg       
qgaussmo_det_ssn.name = 'qgaussmo_detrended'
qgaussmo_det_ssn

Unnamed: 0,Array,Chunk
Bytes,3.83 GiB,1.54 MiB
Shape,"(101, 192, 288, 92, 1)","(101, 5, 40, 10, 1)"
Dask graph,3120 chunks in 52 graph layers,3120 chunks in 52 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 3.83 GiB 1.54 MiB Shape (101, 192, 288, 92, 1) (101, 5, 40, 10, 1) Dask graph 3120 chunks in 52 graph layers Data type float64 numpy.ndarray",192  101  1  92  288,

Unnamed: 0,Array,Chunk
Bytes,3.83 GiB,1.54 MiB
Shape,"(101, 192, 288, 92, 1)","(101, 5, 40, 10, 1)"
Dask graph,3120 chunks in 52 graph layers,3120 chunks in 52 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,8 B,8 B
Shape,"(1,)","(1,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 8 B 8 B Shape (1,) (1,) Dask graph 1 chunks in 2 graph layers Data type int64 numpy.ndarray",1  1,

Unnamed: 0,Array,Chunk
Bytes,8 B,8 B
Shape,"(1,)","(1,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,8 B,8 B
Shape,"(1,)","(1,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 8 B 8 B Shape (1,) (1,) Dask graph 1 chunks in 2 graph layers Data type int64 numpy.ndarray",1  1,

Unnamed: 0,Array,Chunk
Bytes,8 B,8 B
Shape,"(1,)","(1,)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray


In [28]:
%%time
qgmsc_ssndetpath     = cvals_det+'qgmsc_percentile_det_ssn_'+season0+'_'+ic_year0+'_'+ic_year1+'.zarr'
qgmc_ssndetpath      = cvals_det+'qgmc_percentile_det_ssn_'+season0+'_' +ic_year0+'_'+ic_year1+'.zarr'
#
qgauss_ssndetpath    = cvals_det+'qgauss_percentile_det_ssn_'+season0+'_' +ic_year0+'_'+ic_year1+'.zarr'
qgaussmo_ssndetpath  = cvals_det+'qgaussmo_percentile_det_ssn_'+season0+'_' +ic_year0+'_'+ic_year1+'.zarr'
#
qg_sqdm_ssndepath    = cvals_det+'qg_sqdm_percentile_det_ssn_'+season0 + '_' +ic_year0 + '_' + ic_year1 +'.zarr'
#
qgeoc_ssndetpath     = cvals_det+'qgeoc_percentile_det_ssn_'+season0+'_' +ic_year0+'_'+ic_year1+'.zarr'

print(qgmc_ssndetpath)
##################### Save #####################
qgmc_det_ssn    = qgmc_det_ssn.rename('qgmc_detrended').chunk({'lat':8})
qgmsc_det_ssn   = qgmsc_det_ssn.rename('qgmsc_detrended').chunk({'lat':8})

qgeoc_det_ssn   = qgeoc_det_ssn.rename('qgeoc_detrended').chunk({'lat':8})
qg_sqdm_ssn     = qg_sqdm_ssn.rename('qsqdm_detrended').chunk({'lat':8})

/global/scratch/users/harsha/LENS/cesm2/cvals/detrended/qgmc_percentile_det_ssn_mam_2071_2100.zarr
CPU times: user 2.29 s, sys: 121 ms, total: 2.41 s
Wall time: 2.43 s


In [None]:
%%time
#Save files
# qgauss_det_ssn.to_dataset().to_zarr(qgauss_ssndetpath,mode='w')
# qgmc_det_ssn.to_dataset().to_zarr(qgmc_ssndetpath,mode='w') 
# qgmsc_det_ssn.to_dataset().to_zarr(qgmsc_ssndetpath,mode='w')
# qgeoc_det_ssn.to_dataset().to_zarr(qgeoc_ssndetpath,mode='w') 
# qgaussmo_det_ssn.to_dataset().to_zarr(qgaussmo_ssndetpath,mode='w') 
#
qg_sqdm_ssn.to_dataset().to_zarr(qg_sqdm_ssndepath,mode='w') 

This may cause some slowdown.
Consider scattering data ahead of time and using futures.


In [None]:
# qgmc_det_ssn     = xr.open_zarr(qgmc_ssndetpath).qgmc_detrended
qgmsc_det_ssn    = xr.open_zarr(qgmsc_ssndetpath).qgmsc_detrended
#
# qgauss_det_ssn   = xr.open_zarr(qgauss_ssndetpath).qgauss_detrended
# qgeoc_det_ssn    = xr.open_zarr(qgeoc_ssndetpath).qgeoc_detrended
# qgaussmo_det_ssn = xr.open_zarr(qgaussmo_ssndetpath).qgaussmo_detrended
#
qg_sqdm_ssn      = xr.open_zarr(qg_sqdm_ssndepath).qsqdm_detrended

In [None]:
###### Error in prediction for EOC #######
qgmsc_detssn_err    = qgmsc_det_ssn.sel(ym=2086)  - qgeoc_det_ssn.sel(ym=2086)
qgmc_detssn_err     = qgmc_det_ssn.sel(ym=2086)   - qgeoc_det_ssn.sel(ym=2086)
#
qgauss_detssn_err   = qgauss_det_ssn.sel(ym=2086) - qgeoc_det_ssn.sel(ym=2086)
qgaussmo_detssn_err = qgaussmo_det_ssn.sel(ym=2086) - qgeoc_det_ssn.sel(ym=2086)
#
qg_sqdm_ssn_err     = qg_sqdm_ssn.sel(ym=2086) - qgeoc_det_ssn.sel(ym=2086)
qg_sqdm_ssn_err 

In [None]:
%%time
########### Save the absolute errors ############
####
qgmsc_detssn_errpath    = cvals_det+'qgmsc_detssn_err_' +season0 +'_'+ic_year0+'_'+ic_year1+'.zarr'
qgmc_detssn_errpath     = cvals_det+'qgmc_detssn_err_'  +season0 +'_'+ic_year0+'_'+ic_year1+'.zarr'
qgauss_detssn_errpath   = cvals_det+'qgauss_detssn_err_'+season0 +'_'+ic_year0+'_'+ic_year1+'.zarr'
qgaussmo_detssn_errpath = cvals_det+'qgaussmo_detssn_err_'+season0 +'_'+ic_year0+'_'+ic_year1+'.zarr'
qg_sqdm_ssn_errpath     = cvals_det+ 'qg_sqdm_err_'+season0 +'_'+ic_year0+'_'+ic_year1+'.zarr'
#### Save ########
print(qgmc_detssn_errpath)
qgmsc_detssn_err    = qgmsc_detssn_err.rename('Errors').chunk({'lat':20})
qgmc_detssn_err     = qgmc_detssn_err.rename('Errors').chunk({'lat':20})
qgauss_detssn_err   = qgauss_detssn_err.rename('Errors').chunk({'lat':20})
qgaussmo_detssn_err = qgaussmo_detssn_err.rename('Errors').chunk({'lat':20})
qg_sqdm_ssn_err     = qg_sqdm_ssn_err.rename('Errors').chunk({'lat':20})

############
# qgmsc_detssn_err.to_dataset().to_zarr(qgmsc_detssn_errpath,mode='w')
# qgmc_detssn_err.to_dataset().to_zarr(qgmc_detssn_errpath,mode='w')
# qgauss_detssn_err.to_dataset().to_zarr(qgauss_detssn_errpath,mode='w')
# qgaussmo_detssn_err.to_dataset().to_zarr(qgaussmo_detssn_errpath,mode='w')
qg_sqdm_ssn_err.to_dataset().to_zarr(qg_sqdm_ssn_errpath, mode='w')

In [None]:
# Open file and plot
qgmc_detssn_err     = xr.open_zarr(qgmc_detssn_errpath).Errors
qgmsc_detssn_err    = xr.open_zarr(qgmsc_detssn_errpath).Errors
qg_sqdm_ssn_err     = xr.open_zarr(qg_sqdm_ssn_errpath).Errors
#
qgaussmo_detssn_err = xr.open_zarr(qgaussmo_detssn_errpath).Errors
qgauss_detssn_err   = xr.open_zarr(qgauss_detssn_errpath).Errors

In [None]:
######## New colorbar only for shift + stretch and gaussian ############
x = 5
# create a figure and axis
fig, ax = plt.subplots(figsize=(6, 1))
fig.subplots_adjust(bottom=0.5)
# Define colormap
cmap = plt.get_cmap('RdBu_r')
# Make a norm object with the center at 0: TwoSlopeNorm
norm = mcolors.TwoSlopeNorm(vmin=-x, vcenter=0, vmax=x)
# Making numpy array from -3 to 3, with step 0.2
values = np.arange(-x, x+0.25, 0.25)
# Creating a mappable object and setting the norm and cmap for colorbar
mappable = cm.ScalarMappable(norm=norm, cmap=cmap)
mappable.set_array([])
# Creating a colorbar
ticks1 = altspace(-4,1,9)
cbar = plt.colorbar(mappable, ax=ax, orientation='vertical',ticks=ticks1)
cbar.set_label('')
plt.gca().set_visible(False)

In [None]:
%%time
# Creating a figure and axes
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10,5), subplot_kw={'projection': cart.crs.PlateCarree()}\
                       , gridspec_kw = {'wspace':0.05, 'hspace':0.2})

# Plotting qgmsc_err
qgmsc_detssn_err.sel(quantile=0.9).mean('day').plot(ax=axs[0], transform=cart.crs.PlateCarree(),\
                                                    add_colorbar=False, cmap=cmap, norm=norm)
axs[0].coastlines(color="black")
axs[0].set_title(season+r': $L_2(Q_i) - Q_f$ at $p=0.9$')

#Plot QDM + sort
im = qg_sqdm_ssn_err.sel(quantile=0.9).mean('day').plot(ax=axs[1], transform=cart.crs.PlateCarree(),\
                                                   add_colorbar=False, cmap=cmap, norm=norm)
axs[1].coastlines(color="black")
axs[1].set_title(season+r': QDM + sort at $p=0.9$')

# Adding colorbar
cbar = plt.colorbar(im, ax=axs.ravel().tolist(), shrink=0.5, orientation='vertical')
cbar.set_label('Error')

plt.show()

In [None]:
%%time
# Creating a figure and axes
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10,5), subplot_kw={'projection': cart.crs.PlateCarree()}\
                       , gridspec_kw = {'wspace':0.05, 'hspace':0.2})

# Plotting qgmsc_err
qgmsc_detssn_err.sel(quantile=0.9).mean('day').plot(ax=axs[0], transform=cart.crs.PlateCarree(),\
                                                    add_colorbar=False, cmap=cmap, norm=norm)
axs[0].coastlines(color="black")
axs[0].set_title(season+r': $L_2(Q_i) - Q_f$ at $p=0.9$')

# Plotting G2, gaussian error
im = qgauss_detssn_err.sel(quantile=0.9).mean('day').plot(ax=axs[1], transform=cart.crs.PlateCarree(),\
                                                          add_colorbar=False, cmap=cmap, norm=norm)
axs[1].coastlines(color="black")
axs[1].set_title(season+': $G_2(Q_i) - Q_f$ at $p=0.9$')

# Adding colorbar
cbar = plt.colorbar(im, ax=axs.ravel().tolist(), shrink=0.5, orientation='vertical')
cbar.set_label('Error')

plt.show()

In [None]:
%%time
# Creating a figure and axes
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10,5), subplot_kw={'projection': cart.crs.PlateCarree()}\
                       , gridspec_kw = {'wspace':0.05, 'hspace':0.2})

# Plotting qgmsc_err
qgmc_detssn_err.sel(quantile=0.9).mean('day').plot(ax=axs[0], transform=cart.crs.PlateCarree(),\
                                                    add_colorbar=False, cmap=cmap, norm=norm)
axs[0].coastlines(color="black")
axs[0].set_title(season+': $L_1(Q_i) - Q_f$ at $p=0.9$')

# Plotting da2
im = qgaussmo_detssn_err.sel(quantile=0.9).mean('day').plot(ax=axs[1], transform=cart.crs.PlateCarree(),\
                                                          add_colorbar=False, cmap=cmap, norm=norm)
axs[1].coastlines(color="black")
axs[1].set_title(season+r': $G_1(Q_i) - Q_f$ at $p=0.9$')

# Adding colorbar
cbar = plt.colorbar(im, ax=axs.ravel().tolist(), shrink=0.5, orientation='vertical')
cbar.set_label('Errors')

plt.show()

In [None]:
####### Coded until here ###########