In [1]:
from matplotlib import pyplot as plt
import xarray as xr
import numpy as np
import dask_jobqueue
import dask
from distributed import Client
from dask.diagnostics import progress
from tqdm.autonotebook import tqdm
import cartopy.io.shapereader as shpreader
import cartopy.feature as cfeature
import intake
import fsspec
#import seaborn as sns
import gcsfs
import cftime
import pandas as pd
import xesmf as xe

  from tqdm.autonotebook import tqdm


ModuleNotFoundError: No module named 'gcsfs'

In [None]:
################################
savefigs   = '/global/scratch/users/harsha/savefigs/Feb21/'
cesm2_path = '/global/scratch/users/harsha/LENS/cesm2/tasmax/'
cvals      = '/global/scratch/users/harsha/LENS/cesm2/cvals/detrended/'
############
pi_year  = 1865
eoc_year = 2085

In [None]:
job_extra = ['--qos=cf_lowprio','--account=ac_cumulus'] 
#job_extra =['--qos=lr6_lowprio','--account=ac_cumulus','--constraint=lr6_m192']
#job_extra =['--qos=condo_cumulus_lr6','--account=lr_cumulus','--constraint=lr6_m192']
#job_extra =['--qos=lr_lowprio','--account=ac_cumulus']
cluster = dask_jobqueue.SLURMCluster(queue="cf1", cores=10, walltime='5:00:00', 
                local_directory='/global/scratch/users/harsha/dask_space/', 
                log_directory='/global/scratch/users/harsha/dask_space/', 
                job_extra_directives=job_extra, interface='eth0', memory="192GB") 
client  = Client(cluster) 
cluster.scale_up(3)

In [None]:
cluster

In [None]:
chicago_lat=41.88
chicago_lon=(360-87.6298)%360
# #
# cvals_ = '/Users/hrh/Desktop/TwoMoments21/cvals/cmip6/'

In [None]:
# calculate global means
def get_lat_name(ds):
    for lat_name in ['lat', 'latitude']:
        if lat_name in ds.coords:
            return lat_name
    raise RuntimeError("Couldn't find a latitude coordinate")

def global_mean(ds):
    lat = ds[get_lat_name(ds)]
    weight = np.cos(np.deg2rad(lat))
    weight /= weight.mean()
    other_dims = set(ds.dims) - {'quantile'}
    return (ds * weight).mean(other_dims)

def detrend_data(ds, central_year):
    # Assumes that the ds has coordinates day, year and member.
    
    #Fit a linear fuction and extract slope
    pcoeffs = ds.polyfit(dim='year',deg=1)
    slope   = pcoeffs.polyfit_coefficients.sel(degree=1)
    
    #Calculate trend
    ds_trend   = slope*(ds['year']- central_year)
    
    #Detrend by subtracting the trend from the data
    ds_detrended = ds  - ds_trend
    
    return ds_detrended

In [None]:
col = intake.open_esm_datastore("https://storage.googleapis.com/cmip6/pangeo-cmip6.json")
col

In [None]:
# 2. Search for maximum temperature for son 30
expts = ['ssp370','historical']

cat = col.search(
    experiment_id=expts,
    table_id='day',
    variable_id='tasmax',
    #grid_label='gn'
)

query = dict(
    experiment_id=expts,
    table_id='day',
    variable_id=['tasmax'],
    member_id = 'r1i1p1f1',
)

col_subset = col.search(require_all_on=["source_id"], **query)
col_subset.df.groupby("source_id")[
    ["experiment_id", "variable_id", "table_id","member_id"]
].nunique()

In [None]:
df = col_subset.df
model_counts = df.groupby('source_id').size()
print(model_counts)

In [None]:
df['activity_id'].unique()

In [None]:
# def drop_all_bounds(ds):
#     drop_vars = [vname for vname in ds.coords
#                  if (('_bounds') in vname ) or ('_bnds') in vname]
#     return ds.drop(drop_vars)

# def open_dset(df):
#     assert len(df) == 1
#     ds = xr.open_zarr(fsspec.get_mapper(df.zstore.values[0]), consolidated=True)
#     return drop_all_bounds(ds)

# def open_delayed(df):
#     return dask.delayed(open_dset)(df)

# from collections import defaultdict
# dsets = defaultdict(dict)

# for group, df in col_subset.df.groupby(by=['source_id', 'experiment_id']):
#     dsets[group[0]][group[1]] = open_delayed(df)

In [None]:
# %%time
# # Trigger computation
# dsets_ = dask.compute(dict(dsets))[0]

In [None]:
#Define coarse grid to regrid on
ds_out = xr.Dataset({'lat': (['lat'], np.arange(-90, 91, 1.5)),
                     'lon': (['lon'], np.arange(0, 361, 1.5))})

In [None]:
def drop_feb29(ds):
    # Check if the dataset's calendar is not '360_day'
    calendar = ds.time.encoding.get('calendar', None)
    print(ds.attrs['source_id'],calendar)
    if calendar != '360_day':
        ds = ds.convert_calendar('365_day')
    return ds


def to_daily(ds):
    # Check and deal with different datetime types
    if isinstance(ds['time'].values[0], np.datetime64):
        pass
    elif isinstance(ds['time'].values[0], cftime.datetime):
        pass
    else:
        # convert time coordinate to datetime64 objects
        ds['time'] = ds['time'].astype('datetime64[ns]')
    year      = ds.time.dt.year
    dayofyear = ds.time.dt.dayofyear

    # assign new coords
    ds = ds.assign_coords(year=("time", year.data), dayofyear=("time", dayofyear.data))

    # reshape the array to (..., "day", "year")
    return ds.set_index(time=("year", "dayofyear")).unstack("time")  


def extract_data(ds):
    """
    Extract data for djf from the dataset 'ds' for specific time and spatial range.

    Parameters:
    - ds (xarray.Dataset): Input dataset

    Returns:
    - xarray.Dataset: Dataset subsetted for djf and the specified space and time range.
    """    

    subset1 = ds.sel(year=slice(1850, 1879))
    subset2 = ds.sel(year=slice(2071, 2100))
    
    subset = xr.concat([subset1, subset2], dim='year')  

    return subset

def is_leap(year):
    """Check if a year is a leap year."""
    return (year % 4 == 0) and ((year % 100 != 0) or (year % 400 == 0))


In [None]:
quants = np.linspace(0,1.0,30)

def compute_quantiles(ds, quantiles=quants):
    return ds.chunk(dict(year=-1)).quantile(quantiles, dim='year',skipna=False)

def regrid(ds, ds_out):
    experiment_id = ds.attrs['experiment_id']
    source_id     = ds.attrs['source_id']
    #regrid
    regridder = xe.Regridder(ds, ds_out, 'nearest_s2d', reuse_weights=False)
    ds_new    = regridder(ds)
    
#     #Assign back attributes as regirdder would have deleted attributes 
    ds_new.attrs['experiment_id'] = experiment_id
    ds_new.attrs['source_id'] = source_id
    
    #print(ds_new.attrs['experiment_id'],ds_new.attrs['source_id'])
    #print(ds_new)
    return ds_new

def process_data(ds, quantiles=quants):
    ds = ds.pipe(drop_feb29).pipe(to_daily).pipe(extract_data)
    
    if len(ds['year']) == 0:
        print("The dataset is empty. Skipping...")
        return None
    
    if len(ds['dayofyear'])<365:
        print('The dataset has less than 365 days. Skipping ..')
        return None
    
    # # Remove 'time' coordinate
    # ds = ds.set_index(time=("year", "dayofyear")).unstack("time")  
   
    
    return (ds.pipe(regrid, ds_out=ds_out)
           )


In [None]:
# with progress.ProgressBar():

#     expt_da = xr.DataArray(expts, dims='experiment_id', name='experiment_id',
#                            coords={'experiment_id': expts})

#     # Initialize an Empty Dictionary for Aligned Datasets:
#     dsets_aligned = {}

#     # Iterate Over dsets_ Dictionary:

#     for k, v in tqdm(dsets_.items()):
#         # Initialize a dictionary for this source_id
#         dsets_aligned[k] = {}
        
#         skip_source_id = False

#         for expt in expts:
#             ds = v[expt].pipe(process_data)

#             # Check if the dataset is empty and skip this source_id if so
#             if ds is None:
#                 print(f"Skipping {expt} for {k} because the dataset is empty")
#                 skip_source_id = True
#                 break
            
#             # Store the dataset in the dictionary
#             # dsets_aligned[k][expt] = ds
#             # Compute the dataset and store it in the dictionary
#             dsets_aligned[k][expt] = ds.compute()
#             print(dsets_aligned[k][expt])

#         if skip_source_id:
#             del dsets_aligned[k]
#             continue

In [None]:
# with progress.ProgressBar():
#     dsets_aligned_ = dask.compute(dsets_aligned)[0]

In [None]:
# source_ids = list(dsets_aligned.keys())
# source_da = xr.DataArray(source_ids, dims='source_id', name='source_id',
#                          coords={'source_id': source_ids})

# # final_ds = {expt: xr.concat([ds.get(expt, xr.Dataset()).reset_coords(drop=True)
# #                              for ds in dsets_aligned.values()],
# #                             dim=source_da)
# #             for expt in expts}

# final_ds = {expt: xr.concat([ds[expt].reset_coords(drop=True)
#                              for ds in dsets_aligned.values() if expt in ds and ds[expt] is not None],
#                             dim=source_da, coords='minimal')
#             for expt in expts}

# final_ds

In [None]:
# final_ds_pi = xr.concat([ds['historical'].reset_coords(drop=True)
#                                  for ds in dsets_aligned.values()],
#                                 dim=source_da)

# final_ds_eoc = xr.concat([ds['ssp370'].reset_coords(drop=True)
#                              for ds in dsets_aligned.values()],
#                             dim=source_da)
# final_ds_eoc

In [None]:
%%time
# final_ds_pi.to_dataset().to_zarr(cvals+'cmip6_pi_quantiles_annual.zarr',mode='w')
# final_ds_eoc.to_dataset().to_zarr(cvals+'cmip6_eoc_quantiles_annual.zarr',mode='w')
# final_ds_pi.to_zarr(cvals+'cmip6_pi_1.5grid_quantiles_annual.zarr',mode='w')
# final_ds_eoc.to_zarr(cvals+'cmip6_eoc_1.5grid_quantiles_annual.zarr',mode='w')

In [None]:
final_ds_pi  = xr.open_zarr(cvals+'cmip6_pi_quantiles_annual.zarr').tasmax
final_ds_eoc = xr.open_zarr(cvals+'cmip6_eoc_quantiles_annual.zarr').tasmax
# final_ds_pi  = xr.open_zarr(cvals+'cmip6_pi_1.5grid_quantiles_annual.zarr').tasmax
# final_ds_eoc = xr.open_zarr(cvals+'cmip6_eoc_1.5grid_quantiles_annual.zarr').tasmax
final_ds_eoc

### Detrend data and save

In [None]:
%%time
ds_pi_det  = detrend_data(final_ds_pi,pi_year)
ds_eoc_det = detrend_data(final_ds_eoc,eoc_year)
ds_eoc_det

In [None]:
# %%time
# pcoeffs_ds_pi  = final_ds_pi.polyfit(dim='year',deg=1,skipna=False)
# pcoeffs_ds_eoc = final_ds_eoc.polyfit(dim='year',deg=1,skipna=False)
# pcoeffs_ds_eoc

In [None]:
# %%time
# m_pi  = pcoeffs_ds_pi.polyfit_coefficients.sel(degree=1)
# m_eoc = pcoeffs_ds_eoc.polyfit_coefficients.sel(degree=1)
# m_pi

In [None]:
# %%time
# #Save slopes
# m_pi.to_dataset().to_zarr(cvals+'cmip6_pi_trend.zarr',mode='w')
# m_eoc.to_dataset().to_zarr(cvals+'cmip6_eoc_trend.zarr',mode='w')

In [None]:
# m_pi  = xr.open_zarr(cvals+'cmip6_pi_trend.zarr').polyfit_coefficients
# m_eoc = xr.open_zarr(cvals+'cmip6_eoc_trend.zarr').polyfit_coefficients

In [None]:
# #Slopes for Chicago:
# print(m_eoc.sel(dayofyear=365).sel(lat=chicago_lat,lon=chicago_lon,method='nearest').values)

In [None]:
# print(m_pi.sel(dayofyear=365).sel(lat=chicago_lat,lon=chicago_lon,method='nearest').values)

In [None]:
# #Calculate trend
# pi_trend   = m_pi*(final_ds_pi['year']  - pi_year)
# eoc_trend  = m_eoc*(final_ds_eoc['year']- eoc_year)
# #Subtract trend
# ds_pi_det  = final_ds_pi  - pi_trend
# ds_eoc_det = final_ds_eoc - eoc_trend

In [None]:
# #
# ds_pi_det  = ds_pi_det.rename('tasmax')
# ds_eoc_det = ds_eoc_det.rename('tasmax')
# ds_pi_det

In [None]:
# %%time
# ds_eoc_det.sel(dayofyear=365).sel(lat=chicago_lat,lon=chicago_lon,method='nearest').sel(source_id='AWI-CM-1-1-MR').values

In [None]:
# ds_eoc_det.sel(dayofyear=365).sel(lat=chicago_lat,lon=chicago_lon,method='nearest').sel(source_id='AWI-CM-1-1-MR').values

In [None]:
# final_ds_eoc.sel(dayofyear=365).sel(lat=chicago_lat,lon=chicago_lon,method='nearest').sel(source_id='AWI-CM-1-1-MR').values

In [None]:
ds_pi_det['source_id']  = ds_pi_det['source_id'].astype(str)
ds_eoc_det['source_id'] = ds_eoc_det['source_id'].astype(str)

In [None]:
# %%time
# ds_pi_det.to_dataset().to_zarr(cvals+'cmip6_pi_ann_detrended.zarr',mode='w')
# ds_eoc_det.to_dataset().to_zarr(cvals+'cmip6_eoc_ann_detrended.zarr',mode='w')
# ds_pi_det.to_dataset().to_zarr(cvals+'cmip6_pi_ann_1.5grid_detrended.zarr',mode='w')
# ds_eoc_det.to_dataset().to_zarr(cvals+'cmip6_eoc_ann_1.5grid_detrended.zarr',mode='w')

### Compute global and annual mean

In [None]:
ds_pi_det  = xr.open_zarr(cvals+'cmip6_pi_ann_detrended.zarr').tasmax
ds_eoc_det = xr.open_zarr(cvals+'cmip6_eoc_ann_detrended.zarr').tasmax
# ds_pi_det  = xr.open_zarr(cvals+'cmip6_pi_ann_1.5grid_detrended.zarr').tasmax
# ds_eoc_det = xr.open_zarr(cvals+'cmip6_eoc_ann_1.5grid_detrended.zarr').tasmax
#
qpi  = compute_quantiles(ds_pi_det)
qeoc = compute_quantiles(ds_eoc_det)
qano = qeoc - qpi
qano

In [None]:
# %%time
# test = qano.sel(dayofyear=365).sel(lat=chicago_lat,lon=chicago_lon,method='nearest').\
# sel(quantile=0.03448,method='nearest').values

In [None]:
# test

In [None]:
# test.std()

In [None]:
%%time
qano_std            = qano.std(dim='source_id')
qano_cmip_mean      = qano.mean(dim='source_id')
qano_sq_deviation   = (qano - qano_cmip_mean)**2
qano_sq_deviation

In [None]:
%%time 
qano_mae = global_mean(np.abs(qano))
qano_mae = qano_mae.rename('tmax')
qano_mae

In [None]:
qano_mae.to_dataset().to_zarr(cvals + 'cmip6_absano_agmean.zarr')

In [None]:
qano_std_agmean  = global_mean(qano_std)
qano_msd         = global_mean(qano_sq_deviation)
qano_rmsd        = np.sqrt(qano_msd)
qano_rmsd

In [None]:
# %%time
# qano_std.sel(dayofyear=365).sel(lat=chicago_lat,lon=chicago_lon,method='nearest').values

In [None]:
# %%time
# qano_std.sel(dayofyear=365).sel(lat=1,lon=5,method='nearest').values

In [None]:
%%time
qano_std_agmean = qano_std_agmean.rename('tasmax_qanomaly')
qano_rmsd       = qano_rmsd.rename('qanomaly_rmsd')
# qano_std_agmean.to_dataset().to_zarr(cvals+'cmip6_quantiles_ano_agmean.zarr',mode='w')
# qano_rmsd.to_dataset().to_zarr(cvals+'cmip6_qano_ag_rmsd.zarr',mode='w')
# qano_std_agmean.to_dataset().to_zarr(cvals+'cmip6_quantiles_ano_1.5grid_agmean.zarr',mode='w')
# qano_rmsd.to_dataset().to_zarr(cvals+'cmip6_qano_ag_1.5grid_rmsd.zarr',mode='w')

In [None]:
# qano_std_agmean = xr.open_zarr(cvals+'cmip6_quantiles_ano_1.5grid_agmean.zarr').tasmax_qanomaly
# qano_std_agmean.values

In [None]:
qano_std_agmean = xr.open_zarr(cvals+'cmip6_quantiles_ano_agmean.zarr').tasmax_qanomaly
qano_std_agmean.values

In [None]:
qano_rmsd = xr.open_zarr(cvals+'cmip6_qano_ag_rmsd.zarr').qanomaly_rmsd
qano_rmsd.values

In [None]:
qano_std_agmean.plot()
plt.xlabel('Probability p')
plt.ylabel('Global, annual mean of standard deviation (K) ')
plt.title(r'Global, annual mean of stdev in $Q_f(p) -Q_i(p)$ among 18 CMIP6 models')
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
plt.text(0.05, 0.95, r'18 models, $1.5^{\circ} \times 1.5^{\circ}$ grid', transform=plt.gca().transAxes, fontsize=14,
        verticalalignment='top', bbox=props)
plt.ylim(0,1.8)
plt.xlim(0,1.0)

In [None]:
qano_std_agmean.plot()
plt.xlabel('Probability p')
plt.ylabel('Global, annual mean of standard deviation (K) ')
plt.title(r'Global, annual mean of stdev in $Q_f(p) -Q_i(p)$ among 18 CMIP6 models')
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
plt.text(0.05, 0.95, r'18 models, $3^{\circ} \times 3^{\circ}$ grid', transform=plt.gca().transAxes, fontsize=14,
        verticalalignment='top', bbox=props)
plt.ylim(0,1.8)
plt.xlim(0,1.0)

In [None]:
qano_rmsd.plot()
plt.xlabel('Probability p')
plt.ylabel('RMSD (K) ')
plt.title(r'RMSD of $Q_f(p) -Q_i(p)$')
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
plt.text(0.05, 0.95, r'18 models, $3^{\circ} \times 3^{\circ}$ grid', transform=plt.gca().transAxes, fontsize=14,
        verticalalignment='top', bbox=props)
plt.ylim(0,2.5)
plt.xlim(0,1.0)