In [2]:
from matplotlib import pyplot as plt
import xarray as xr
import numpy as np
import dask_jobqueue
import dask
from distributed import Client
from dask.diagnostics import progress
from tqdm.autonotebook import tqdm
import cartopy.io.shapereader as shpreader
import cartopy.feature as cfeature
import intake
import fsspec
#import seaborn as sns
import gcsfs
import cftime
from datetime import datetime, timedelta
import pandas as pd
import xesmf as xe

  from tqdm.autonotebook import tqdm


In [3]:
chic_lat  = 41.8781
chic_lon  = (360-87.6298)%360
ben_lat   = 12.9716
ben_lon   = 77.5946
########## CONUS ############
#CONUS lat-lon
top       = 50.0 # north lat
left      = -124.7844079+360 # west long
right     = -66.9513812+360 # east long
bottom    =  24.7433195 # south lat
################################
cesm2_path        = '/global/scratch/users/harsha/LENS/cesm2/tasmax/'
cvals             = '/global/scratch/users/harsha/LENS/cesm2/cvals/detrended/'
cmip6_cvals       = cvals + 'cmip6/'
#cvals_det     = '/global/scratch/users/harsha/LENS/cesm2/cvals/detrended/'
############
pi_year  = 1865
eoc_year = 2085
#
doy  = 211
#
pi_year0  = '1850'
pi_year1  = '1879'
ic_year0  = '2071'
ic_year1  = '2100'

In [4]:
def no_leap_date(day_of_year):
    # Start from the first day of a non-leap year
    start_date = datetime(2021, 1, 1)

    # Calculate the actual date (subtract 1 because January 1st is day 1)
    actual_date = start_date + timedelta(days=day_of_year - 1)

    # Format the date to get 'Month Day'
    return actual_date.strftime('%B %d')
###############################
date = no_leap_date(doy)
date

'July 30'

In [5]:
job_extra = ['--qos=cf_lowprio','--account=ac_cumulus'] 
#job_extra =['--qos=lr6_lowprio','--account=ac_cumulus','--constraint=lr6_m192']
#job_extra =['--qos=condo_cumulus_lr6','--account=lr_cumulus','--constraint=lr6_m192']
#job_extra =['--qos=lr_lowprio','--account=ac_cumulus']
cluster = dask_jobqueue.SLURMCluster(queue="cf1", cores=10, walltime='5:00:00', 
                local_directory='/global/scratch/users/harsha/dask_space/', 
                log_directory='/global/scratch/users/harsha/dask_space/', 
                job_extra_directives=job_extra, interface='eth0', memory="192GB") 
client  = Client(cluster) 
cluster.scale_up(3)

In [6]:
cluster

0,1
Dashboard: http://10.0.39.4:8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.0.39.4:32877,Workers: 0
Dashboard: http://10.0.39.4:8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [7]:
# calculate global means
def to_daily(ds):
    year       = ds.time.dt.year
    dayofyear  = ds.time.dt.dayofyear

    # assign new coords
    ds = ds.assign_coords(year=("time", year.data), dayofyear=("time", dayofyear.data))

    # reshape the array to (..., "day", "year")
    return ds.set_index(time=("year", "dayofyear")).unstack("time") 
    
def get_lat_name(ds):
    for lat_name in ['lat', 'latitude']:
        if lat_name in ds.coords:
            return lat_name
    raise RuntimeError("Couldn't find a latitude coordinate")

def global_mean(ds):
    lat = ds[get_lat_name(ds)]
    weight = np.cos(np.deg2rad(lat))
    weight /= weight.mean()
    other_dims = set(ds.dims) - {'quantile'}
    return (ds * weight).mean(other_dims)

In [8]:
def implement_mdm(ds_obs,init_mean,final_mean,init_std,final_std):
    # Assuming coordinates year and day for ds_obs
    obs_mean = ds_obs.mean('year')
    sratio   = final_std/init_std
    ds_mdm   = obs_mean + (final_mean - init_mean) + sratio*(ds_obs - obs_mean)        
    return ds_mdm

def implement_qdm(qobs, qinit, qfinal):
    # Implement quantile delta mapping. Assumes model data has corrd called 'mtime'
    #qinit    = init_data.quantile(quants,dim='mtime')
    #qfinal   = final_data.quantile(quants,dim='mtime')
    ds_qdm   = qobs + (qfinal - qinit)
    return ds_qdm

def is_sorted(arr):
    return np.all(arr[:-1] <= arr[1:]) or np.all(arr[:-1] >= arr[1:])

#
quants = np.linspace(0,1.0,30)
def compute_quantiles(ds, quantiles=quants):
    return ds.chunk(dict(year=-1)).quantile(quantiles, dim='year',skipna=False)

In [9]:
ds_pi  = xr.open_zarr(cvals+'cmip6_pi_quantiles_annual.zarr').tasmax
ds_eoc = xr.open_zarr(cvals+'cmip6_eoc_quantiles_annual.zarr').tasmax
ds_eoc

Unnamed: 0,Array,Chunk
Bytes,5.42 GiB,4.18 MiB
Shape,"(18, 30, 365, 61, 121)","(3, 8, 92, 16, 31)"
Dask graph,1536 chunks in 2 graph layers,1536 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 5.42 GiB 4.18 MiB Shape (18, 30, 365, 61, 121) (3, 8, 92, 16, 31) Dask graph 1536 chunks in 2 graph layers Data type float32 numpy.ndarray",30  18  121  61  365,

Unnamed: 0,Array,Chunk
Bytes,5.42 GiB,4.18 MiB
Shape,"(18, 30, 365, 61, 121)","(3, 8, 92, 16, 31)"
Dask graph,1536 chunks in 2 graph layers,1536 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [10]:
ds_pi_det  = xr.open_zarr(cvals+'cmip6_pi_ann_detrended.zarr').tasmax
ds_eoc_det = xr.open_zarr(cvals+'cmip6_eoc_ann_detrended.zarr').tasmax
ds_eoc_det

Unnamed: 0,Array,Chunk
Bytes,10.84 GiB,2.79 MiB
Shape,"(18, 30, 365, 61, 121)","(1, 8, 92, 16, 31)"
Dask graph,4608 chunks in 2 graph layers,4608 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 10.84 GiB 2.79 MiB Shape (18, 30, 365, 61, 121) (1, 8, 92, 16, 31) Dask graph 4608 chunks in 2 graph layers Data type float64 numpy.ndarray",30  18  121  61  365,

Unnamed: 0,Array,Chunk
Bytes,10.84 GiB,2.79 MiB
Shape,"(18, 30, 365, 61, 121)","(1, 8, 92, 16, 31)"
Dask graph,4608 chunks in 2 graph layers,4608 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


### Select Observation and Model indices. 

In [None]:
model_list = ds_eoc['source_id']
model_list

In [None]:
############ Select some ensemble as obs and another as model ######
#
pi_obs  = ds_pi.sel(member = obs_list)
eoc_obs = ds_eoc.sel(member= obs_list)
#
pi_model_det  = ds_pi_det.sel(member = model_list)
eoc_model_det = ds_eoc_det.sel(member = model_list)
#
pi_model  = tdpi.sel(member = model_list)
eoc_model = tdeoc.sel(member = model_list)

In [None]:
########################################################

In [None]:
%%time
qano_std            = qano.std(dim='source_id')
qano_cmip_mean      = qano.mean(dim='source_id')
qano_sq_deviation   = (qano - qano_cmip_mean)**2
qano_sq_deviation

In [None]:
qano_std_agmean  = global_mean(qano_std)
qano_msd         = global_mean(qano_sq_deviation)
qano_rmsd        = np.sqrt(qano_msd)
qano_rmsd

In [None]:
# %%time
# qano_std.sel(dayofyear=365).sel(lat=chicago_lat,lon=chicago_lon,method='nearest').values

In [None]:
# %%time
# qano_std.sel(dayofyear=365).sel(lat=1,lon=5,method='nearest').values

In [None]:
%%time
qano_std_agmean = qano_std_agmean.rename('tasmax_qanomaly')
qano_rmsd       = qano_rmsd.rename('qanomaly_rmsd')
qano_std_agmean.to_dataset().to_zarr(cvals+'cmip6_quantiles_ano_agmean.zarr',mode='w')
qano_rmsd.to_dataset().to_zarr(cvals+'cmip6_qano_ag_rmsd.zarr',mode='w')

In [None]:
qano_std_agmean = xr.open_zarr(cvals+'cmip6_quantiles_ano_agmean.zarr').tasmax_qanomaly
qano_std_agmean.values

In [None]:
qano_rmsd = xr.open_zarr(cvals+'cmip6_qano_ag_rmsd.zarr').qanomaly_rmsd
qano_rmsd.values

In [None]:
qano_std_agmean.plot()
plt.xlabel('Probability p')
plt.ylabel('Global, annual mean of standard deviation (K) ')
plt.title(r'Global, annual mean of stdev in $Q_f(p) -Q_i(p)$ among 18 CMIP6 models')
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
plt.text(0.05, 0.95, r'18 models, $3^{\circ} \times 3^{\circ}$ grid', transform=plt.gca().transAxes, fontsize=14,
        verticalalignment='top', bbox=props)
plt.ylim(0,1.8)
plt.xlim(0,1.0)

In [None]:
qano_rmsd.plot()
plt.xlabel('Probability p')
plt.ylabel('RMSD (K) ')
plt.title(r'RMSD of $Q_f(p) -Q_i(p)$')
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
plt.text(0.05, 0.95, r'18 models, $3^{\circ} \times 3^{\circ}$ grid', transform=plt.gca().transAxes, fontsize=14,
        verticalalignment='top', bbox=props)
plt.ylim(0,2.5)
plt.xlim(0,1.0)