In [15]:
from matplotlib import pyplot as plt
import xarray as xr
import numpy as np
import dask
from dask.diagnostics import progress
from tqdm.autonotebook import tqdm
import intake
import fsspec
import seaborn as sns
#import gcsfs
import cftime
import pandas as pd
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [16]:
chicago_LAT=41.88
chicago_LON=(360-87.6298)%360

In [17]:
col = intake.open_esm_datastore("https://storage.googleapis.com/cmip6/pangeo-cmip6.json")
col

Unnamed: 0,unique
activity_id,18
institution_id,36
source_id,88
experiment_id,170
member_id,657
table_id,37
variable_id,700
grid_label,10
zstore,514818
dcpp_init_year,60


In [18]:
# 2. Search for maximum temperature for July 30
expts = ['ssp370','historical']

cat = col.search(
    experiment_id=expts,
    table_id='day',
    variable_id='tasmax',
    #grid_label='gn'
)

query = dict(
    experiment_id=expts,
    table_id='day',
    variable_id=['tasmax'],
    member_id = 'r1i1p1f1',
)

col_subset = col.search(require_all_on=["source_id"], **query)
col_subset.df.groupby("source_id")[
    ["experiment_id", "variable_id", "table_id","member_id"]
].nunique()

Unnamed: 0_level_0,experiment_id,variable_id,table_id,member_id
source_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
ACCESS-CM2,2,1,1,1
AWI-CM-1-1-MR,2,1,1,1
BCC-ESM1,2,1,1,1
CMCC-ESM2,2,1,1,1
CanESM5,2,1,1,1
EC-Earth3,2,1,1,1
EC-Earth3-AerChem,2,1,1,1
EC-Earth3-Veg-LR,2,1,1,1
FGOALS-g3,2,1,1,1
GFDL-ESM4,2,1,1,1


In [19]:
df = col_subset.df
model_counts = df.groupby('source_id').size()
print(model_counts)

source_id
ACCESS-CM2           2
AWI-CM-1-1-MR        2
BCC-ESM1             2
CMCC-ESM2            2
CanESM5              2
EC-Earth3            2
EC-Earth3-AerChem    2
EC-Earth3-Veg-LR     2
FGOALS-g3            2
GFDL-ESM4            2
INM-CM4-8            2
INM-CM5-0            2
IPSL-CM6A-LR         2
KACE-1-0-G           2
MIROC6               2
MPI-ESM-1-2-HAM      2
MPI-ESM1-2-HR        2
MPI-ESM1-2-LR        2
MRI-ESM2-0           2
NorESM2-LM           2
NorESM2-MM           2
dtype: int64


In [20]:
def drop_all_bounds(ds):
    drop_vars = [vname for vname in ds.coords
                 if (('_bounds') in vname ) or ('_bnds') in vname]
    return ds.drop(drop_vars)

def open_dset(df):
    assert len(df) == 1
    ds = xr.open_zarr(fsspec.get_mapper(df.zstore.values[0]), consolidated=True)
    return drop_all_bounds(ds)

def open_delayed(df):
    return dask.delayed(open_dset)(df)

from collections import defaultdict
dsets = defaultdict(dict)

for group, df in col_subset.df.groupby(by=['source_id', 'experiment_id']):
    dsets[group[0]][group[1]] = open_delayed(df)

In [21]:
%%time
# Trigger computation
dsets_ = dask.compute(dict(dsets))[0]

CPU times: user 1.6 s, sys: 122 ms, total: 1.72 s
Wall time: 3.2 s


In [22]:
def extract_july30_data(ds, chicago_lat=41.88, chicago_lon=(360-87.6298)%360, buffer=0):
    """
    Extract data for July 30 from the dataset 'ds' for specific time and spatial range.

    Parameters:
    - ds (xarray.Dataset): Input dataset
    - chicago_lat (float, optional): Latitude for Chicago. Default is 41.88.
    - chicago_lon (float, optional): Longitude for Chicago. Default is -87.63.
    - buffer (float, optional): Buffer for spatial slice around Chicago. Default is 0.5.

    Returns:
    - xarray.Dataset: Dataset subsetted for July 30 and the specified space and time range.
    """

    
    # Filter by time and space
    subset = ds.sel(
        time=((ds['time.year'] >= 1850) & (ds['time.year'] <= 1879)) | 
              ((ds['time.year'] >= 2071) & (ds['time.year'] <= 2100)),
        lat=chicago_lat,
        lon=chicago_lon,method='nearest'
    )
    
    # Check and deal with different datetime types
    if isinstance(ds['time'].values[0], np.datetime64):
        subset_july30_condition = subset.time.dt.dayofyear == 211
    elif isinstance(ds['time'].values[0], cftime.datetime):
        subset_july30_condition = [date for date in subset['time'].values if (date.month == 7) and (date.day == 30)]
    else:
        raise ValueError("Unknown datetime type in the dataset.")


    # Extract July 30 data
    subset_july30 = subset.sel(time=subset_july30_condition)
    #####subset_july30 = subset

    return subset_july30

In [23]:
expt_da = xr.DataArray(expts, dims='experiment_id', name='experiment_id',
                       coords={'experiment_id': expts})

#Initialize an Empty Dictionary for Aligned Datasets:
dsets_aligned = {}

#Iterate Over dsets_ Dictionary:
for k, v in tqdm(dsets_.items()):
    expt_dsets = v.values()
    if any([d is None for d in expt_dsets]):
        print(f"Missing experiment for {k}")
        continue

    for ds in expt_dsets:
        ds.coords['year'] = ds.time.dt.year
        ds.coords['day']  = ds.time.dt.dayofyear
        #print(ds)
        print(ds.attrs['experiment_id'],ds.attrs['source_id'])
        print('----------')
    
    dsets_jul30_chicago = [v[expt].pipe(extract_july30_data)
                             .swap_dims({'time': 'year'})
                             .drop('time')
                      for expt in expts]

    # align everything 
    dsets_aligned[k] = xr.concat(dsets_jul30_chicago, join='outer',
                                 dim=expt_da)

  0%|          | 0/21 [00:00<?, ?it/s]

historical ACCESS-CM2
----------
ssp370 ACCESS-CM2
----------
historical AWI-CM-1-1-MR
----------
ssp370 AWI-CM-1-1-MR
----------
historical BCC-ESM1
----------
ssp370 BCC-ESM1
----------
historical CMCC-ESM2
----------
ssp370 CMCC-ESM2
----------
historical CanESM5
----------
ssp370 CanESM5
----------
historical EC-Earth3
----------
ssp370 EC-Earth3
----------
historical EC-Earth3-AerChem
----------
ssp370 EC-Earth3-AerChem
----------
historical EC-Earth3-Veg-LR
----------
ssp370 EC-Earth3-Veg-LR
----------
historical FGOALS-g3
----------
ssp370 FGOALS-g3
----------
historical GFDL-ESM4
----------
ssp370 GFDL-ESM4
----------
historical INM-CM4-8
----------
ssp370 INM-CM4-8
----------
historical INM-CM5-0
----------
ssp370 INM-CM5-0
----------
historical IPSL-CM6A-LR
----------
ssp370 IPSL-CM6A-LR
----------
historical KACE-1-0-G
----------
ssp370 KACE-1-0-G
----------
historical MIROC6
----------
ssp370 MIROC6
----------
historical MPI-ESM-1-2-HAM
----------
ssp370 MPI-ESM-1-2-HAM
---

In [24]:
with progress.ProgressBar():
    dsets_aligned_ = dask.compute(dsets_aligned)[0]

[                                        ] | 0% Completed | 15.61 sms



KeyboardInterrupt



In [None]:
source_ids = list(dsets_aligned_.keys())
source_da = xr.DataArray(source_ids, dims='source_id', name='source_id',
                         coords={'source_id': source_ids})

final_ds = xr.concat([ds.reset_coords(drop=True)
                    for ds in dsets_aligned_.values()],
                    dim=source_da)

final_ds

In [None]:
#Select data for specific location to plot
df_loc = final_ds.to_dataframe().reset_index()
df_eoc = final_ds.sel(experiment_id='ssp370').sel(year= slice(2071,2100)).to_dataframe().reset_index()
df_pi  = final_ds.sel(experiment_id='historical').sel(year= slice(1850,1879)).to_dataframe().reset_index()
df_loc

In [None]:
df_pi = df_pi.dropna()
df_pi

In [None]:
df_eoc = df_eoc.dropna()
df_eoc

In [None]:
df_loc = df_loc.dropna()
df_loc

In [None]:
sns.relplot(data=df_loc,x="year", y="tasmax", hue='experiment_id',
            kind="line", errorbar="sd", aspect=2);

In [None]:
quantiles = np.linspace(0,1.0,30)
quantiles

In [None]:
df_eoc_quants                  = df_eoc.groupby('source_id')['tasmax'].quantile(quantiles).reset_index()
df_eoc_quants.columns          = ['source_id','quantiles','tasmax']
df_eoc_quants['experiment_id'] = 'ssp370'
df_eoc_quants

In [None]:
df_pi_quants                  = df_pi.groupby('source_id')['tasmax'].quantile(quantiles).reset_index()
df_pi_quants.columns          = ['source_id','quantiles','tasmax']
df_pi_quants['experiment_id'] = 'historical'
df_pi_quants

In [None]:
df_quants = pd.concat([df_eoc_quants, df_pi_quants],ignore_index=True)
df_quants 

In [None]:
g = sns.relplot(data=df_eoc_quants,x="quantiles", y="tasmax",hue='experiment_id',
            kind="line", errorbar="sd", aspect=2);
g.fig.suptitle('July 30, EOC tasmax at Chicago: 21 CMIP6 models')

In [None]:
g = sns.relplot(data=df_quants,x="quantiles", y="tasmax",hue='experiment_id',
            kind="line", errorbar="sd", aspect=2);
g.fig.suptitle('July 30, tasmax at Chicago: 21 CMIP6 models')

In [None]:
############## Anomalies ###########################

In [None]:
# Pivot the table based on experiment_id
df_pivot = df_quants.pivot_table(index=['source_id', 'quantiles'], columns='experiment_id', values='tasmax')
df_pivot

In [None]:
# Calculate the difference
df_pivot['tasmax_ano'] = df_pivot['ssp370'] - df_pivot['historical']
df_pivot
# Reset the index to turn multi-index back to columns
df_quant_ano = df_pivot.reset_index()[['source_id', 'quantiles', 'tasmax_ano']]
df_quant_ano

In [None]:
g = sns.relplot(data=df_quant_ano,x="quantiles", y="tasmax_ano",
            kind="line", errorbar="sd", aspect=2);
g.fig.suptitle('July 30, tasmax anomaly at Chicago: 21 CMIP6 models')

In [None]:
# Group by quantiles and compute the standard deviation for Tasmax
df_quantano_std = df_quant_ano.groupby('quantiles')['tasmax_ano'].std().reset_index()

# Rename the column for clarity
df_quantano_std.rename(columns={'tasmax_ano': 'tasmax_sdev'}, inplace=True)
df_quantano_std

In [None]:
g = sns.relplot(data=df_quantano_std,x="quantiles", y="tasmax_sdev",
            kind="line", errorbar="sd", aspect=2);
g.fig.suptitle('July 30, tasmax std at Chicago: 21 CMIP6 models')