In [1]:
from matplotlib import pyplot as plt
import xarray as xr
import numpy as np
import dask
from dask.diagnostics import progress
from tqdm.autonotebook import tqdm 
import intake
import fsspec
import seaborn as sns
%matplotlib inline

  from tqdm.autonotebook import tqdm


In [2]:
col = intake.open_esm_datastore("https://storage.googleapis.com/cmip6/pangeo-cmip6.json")

# there is currently a significant amount of data for these runs
expts_full = ['historical','ssp126', 'ssp245', 'ssp370', 'ssp585', 'piControl']

query = dict(
    experiment_id=expts_full,
    table_id='Amon',                           
    variable_id=['tas', 'pr', 'ua', 'va'],
    member_id = 'r1i1p1f1',                     
)

col_subset = col.search(require_all_on=["source_id"], **query)
col_subset_var = [col_subset.search(variable_id=var_name) for var_name in query['variable_id']]
#col_subset_tas = col_subset.search(variable_id='tas')
#col_subset_pr = col_subset.search(variable_id='pr')

col_subset.df.groupby("source_id")[
    ["experiment_id", "variable_id", "table_id"]
].nunique()

Unnamed: 0_level_0,experiment_id,variable_id,table_id
source_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
ACCESS-CM2,6,4,1
AWI-CM-1-1-MR,6,4,1
BCC-CSM2-MR,6,4,1
CAMS-CSM1-0,6,4,1
CESM2-WACCM,6,4,1
CMCC-CM2-SR5,6,4,1
CMCC-ESM2,6,4,1
CanESM5,6,4,1
EC-Earth3,6,4,1
EC-Earth3-Veg,6,4,1


## Sample Plotting

In [87]:
dset_dict = col_subset_var[0].to_dataset_dict(
    zarr_kwargs={"consolidated": True, "decode_times": True, "use_cftime": True}
)
ss = [key for key in dset_dict.keys() if 'historical' in key]
ds = dset_dict[ss[15]]


--> The keys in the returned dictionary of datasets are constructed as follows:
	'activity_id.institution_id.source_id.experiment_id.table_id.grid_label'


In [None]:
ss[15]

In [None]:
ds.tas.isel(member_id=0, time=[0,1,2]).plot(col="time", col_wrap=3,robust=True)


# $\textbf{Preparing data}$

In [3]:
def drop_all_bounds(ds):
    drop_vars = [vname for vname in ds.coords
                 if (('_bounds') in vname ) or ('_bnds') in vname]
    return ds.drop(drop_vars)

def open_dset(df):
    assert len(df) == 1
    ds = xr.open_zarr(fsspec.get_mapper(df.zstore.values[0]), consolidated=True, use_cftime=True, decode_times=True)
    return drop_all_bounds(ds)

def open_delayed(df):
    return dask.delayed(open_dset)(df)

from collections import defaultdict
dsets = []
for col_subset in col_subset_var :
    dset = defaultdict(dict)

    for group, df in col_subset.df.groupby(by=['source_id', 'experiment_id']):
        dset[group[0]][group[1]] = open_delayed(df)
    dsets.append(dset)

In [4]:
dsets_ = [dask.compute(dict(dset))[0]for dset in dsets]

In [None]:
import esmvalcore.preprocessor as ecpr
import cartopy.crs as ccrs
import cartopy.feature as feature
from cartopy.mpl.gridliner import LATITUDE_FORMATTER, LONGITUDE_FORMATTER
import xesmf as xe

#########################################################
ds = dset_dict[ss[15]]
drop_vars = [vname for vname in ds.coords
                 if (('_bounds') in vname ) or ('_bnds') in vname]

ds = ds.drop(drop_vars)
ds_cube = ds.tas.isel(member_id=0, time=0).to_iris()
ds_cube = ecpr.regrid(ds_cube, '1x1', scheme='linear')
ds_masked = ecpr.mask_landsea(ds_cube, mask_out='sea')
ds_masked = xr.DataArray.from_iris(ds_masked)
ds_masked = ds_masked.squeeze()
ds_masked_og = ds_masked.sel({'lat':slice(0,40), 'lon':slice(60, 100)})

####################### Indian Region Masking ###########
mask_ds = xr.open_dataset('/home/jovyan/pangeo/data/India_mask.nc')

mask_reg = xr.Dataset({"lat": (["lat"], ds_masked_og.lat.values),
                     "lon": (["lon"], ds_masked_og.lon.values),})

regridder = xe.Regridder(mask_ds, mask_reg, "bilinear")
mask_reg = regridder(mask_ds)


ds_masked = xr.where(mask_reg.mask==1, ds_masked_og, np.nan)
####################################################
data = [ds_masked_og, mask_ds.mask, mask_reg.mask, ds_masked]
t = ['model data', 'Mask (25km)', 'regridded mask', 'masked data']
fig, axs = plt.subplots(2, 2, subplot_kw={'projection':ccrs.PlateCarree()}, figsize=(10,10), dpi=300)
k= 0
for i in range(2):
    for j in range(2):
        ax = axs[i,j]
        gl =ax.gridlines(crs=ccrs.PlateCarree(), linewidth=2,
                          color='grey', alpha=0.3, linestyle='-', 
                          draw_labels=True)
        ax.coastlines()
        gl.top_labels = False
        gl.right_labels = False
        gl.xformatter = LONGITUDE_FORMATTER
        gl.yformatter = LATITUDE_FORMATTER
        X,Y = np.meshgrid(ds_masked.lon, ds_masked.lat)
        x,y = np.meshgrid(mask_ds.lon, mask_ds.lat)
        try:
            ax.contourf(X, Y, data[k], transform=ccrs.PlateCarree())
        except:
            ax.contourf(x, y, data[k], transform=ccrs.PlateCarree())
        ax.set_extent([66,100, 7, 38])
        ax.set_title(t[k])
        print(k)
        k += 1
plt.savefig('/home/jovyan/pangeo/plot/masking_ind.png', bbox_inches='tight', facecolor='white', dpi=500)

In [None]:
ds_masked.plot()
mask_reg.mask.plot()

In [None]:
ds.tas.isel(member_id=0,time=0).plot(col_wrap=1,robust=True)

In [114]:
import esmvalcore.preprocessor as ecpr

def get_lat_name(ds):
    for lat_name in ['lat', 'latitude']:
        if lat_name in ds.coords:
            return lat_name
    raise RuntimeError("Couldn't find a latitude coordinate")

def get_vname(ds):
    for vname in ['tas', 'pr', 'ua', 'va']:
        if vname in ds.variables.keys():
            return vname
    raise RuntimeError("Couldn't find a variable name")

def global_mean(ds):
    lat = ds[get_lat_name(ds)]
    weight = np.cos(np.deg2rad(lat))
    weight /= weight.mean()
    other_dims = set(ds.dims) - {'year'}
    return (ds * weight).mean(other_dims)
 
def regrid(ds):
    #print(ds)
    var_name = get_vname(ds)
    ds_reg = ds[var_name].to_iris()
    ds_reg = xr.DataArray.from_iris(ecpr.regrid(ds_reg, '1x1', 'linear')).to_dataset(name=var_name)
    #print(ds_reg)
    return ds_reg

def mask_out(ds):
    var = [i for i in ds.variables.keys()][-1]
    #print(var)
    try:
        dset_masked = ds.swap_dims({'time': 'year'}).drop('time')
    except:
        dset_masked = ds.swap_dims({'dim_0': 'year'}).drop('time')
    dset_masked = dset_masked[var].to_iris()
    dset_masked = xr.DataArray.from_iris(ecpr.mask_landsea(dset_masked, mask_out='sea')).rename({'dim_0':'year'}).to_dataset()
    return dset_masked


In [115]:
from toolz.functoolz import juxt
expts = expts_full[:-1]
expt_da = xr.DataArray(expts, dims='experiment_id', name='experiment_id',
                       coords={'experiment_id': expts})


dsets_aligned_list = []
for dset_ in dsets_[:2]:
    j=0
    dsets_aligned = {}
    for k, v in tqdm(dset_.items()):
        
        expt_dsets = v.values()
        if any([d is None for d in expt_dsets]):
            print(f"Missing experiment for {k}")
            continue

        for ds in expt_dsets:
            ds.coords['year'] = ds.time.dt.year
            #print(ds)
        
        # workaround for
        # https://github.com/pydata/xarray/issues/2237#issuecomment-620961663
        dsets_ann_mean = [v[expt].pipe(regrid).pipe(mask_out).pipe(global_mean).coarsen(year=12).mean() for expt in expts]
        
        # align everything with the 4xCO2 experiment
        dsets_aligned[k] = xr.concat(dsets_ann_mean, join='outer',
                                    dim=expt_da)
    
#    print(dsets_aligned)
    #j += 1
    dsets_aligned_list.append(dsets_aligned)
#dsets_aligned_list[0]

  0%|          | 0/25 [00:00<?, ?it/s]

tas


  dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime)
  return np.asarray(array[self.key], dtype=None)


tas
tas
tas


  dtype = _decode_cf_datetime_dtype(data, units, calendar, self.use_cftime)
  return np.asarray(array[self.key], dtype=None)


tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas
tas


TypeError: ufunc 'isfinite' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''

In [None]:
 with progress.ProgressBar():
    dsets_aligned_list_1 = dask.compute(dsets_aligned_list[0])[0]
        
with progress.ProgressBar():
    dsets_aligned_list_2 = dask.compute(dsets_aligned_list[1])[0]

In [None]:
dsets_algned_list_ = [dsets_aligned_list_1, dsets_aligned_list_2]
type(dsets_algned_list_[0])

In [None]:
#from google.colab import files

source_ids = [list(dsets_aligned_.keys()) for dsets_aligned_ in dsets_algned_list_]
#source_da = xr.DataArray(source_ids, dims='source_id', name='source_id',
#                         coords={'source_id': source_ids})
big_ds = []
for idx, dsets_aligned_ in enumerate(dsets_algned_list_):
    source_da = xr.DataArray(source_ids[idx], dims='source_id', name='source_id',
                         coords={'source_id': source_ids[idx]})
    big_ds.append(xr.concat([ds.reset_coords(drop=True)
                        for ds in dsets_aligned_.values()],
                        dim=source_da))

#big_ds.to_netcdf('tas_pr_land_only_timeseries.nc')
#files.download('big_ds_tas.nc',)

In [None]:
#big_ds = xr.open_dataset('big_ds_pr.nc')
ds_all = big_ds[0].assign(pr=big_ds[1].pr)
ds_all.to_netcdf('/home/jovyan/pangeo/data/tas_pr_landonly_timeseries.nc')
ds_all

In [None]:
ds_all = xr.open_dataset('/home/jovyan/pangeo/data/tas_pr_landonly_timeseries.nc')
small_ds = ds_all.sel(year=slice(1950,2100)).rolling(year=2).mean()
small_ds.source_id.shape

In [None]:
cl_t0, cl_tf = 1984, 2014
clim_ds = small_ds.sel(experiment_id='historical', year=slice(cl_t0, cl_tf)).mean(dim=('year')) 

dev_ds_pr = ((small_ds.pr - clim_ds.pr)/clim_ds.pr)*100 
dev_ds_tas = small_ds.tas - clim_ds.tas

In [None]:
ens_ds_pr = dev_ds_pr.mean(dim='source_id')
print(ens_ds_pr.shape)

ens_ds_tas = dev_ds_tas.mean(dim='source_id')
ens_ds_tas.shape

In [None]:
import itertools as it
spread_max_pr = np.empty(shape=(5, 151))
spread_min_pr = np.empty(shape=(5, 151))

spread_max_tas = np.empty(shape=(5, 151))
spread_min_tas = np.empty(shape=(5, 151))

for i,j in it.product(range(5), range(151)):  ##instead of nested for loops. works same.
    spread_max_pr[i, j] = np.nanmax(dev_ds_pr[:, i, j]-ens_ds_pr[i, j])
    spread_min_pr[i, j] = np.nanmin(dev_ds_pr[:, i, j]-ens_ds_pr[i, j])
    
    spread_max_tas[i, j] = np.nanmax(dev_ds_tas[:, i, j]-ens_ds_tas[i, j])
    spread_min_tas[i, j] = np.nanmin(dev_ds_tas[:, i, j]-ens_ds_tas[i, j])

spread_tas ={'max_vals' : spread_max_tas, 
             'min_vals' : spread_min_tas}

spread_pr ={'max_vals' : spread_max_pr, 
             'min_vals' : spread_min_pr}

spr = [spread_tas, spread_pr]

In [None]:
from matplotlib import cm
import matplotlib.colors as mcl
cmap = cm.turbo
cl = ['k']+[mcl.rgb2hex(cmap(i)[:3]) for i in range(0,cmap.N,70)]
cl = ['k'] + ['blue', 'orange', 'green','red']

ncols, nrows=2,2
fig, axs = plt.subplots(nrows, ncols, dpi=600, figsize = (14,9))
ens_dss = [ens_ds_tas, ens_ds_pr]

y_l = [f'Relative to {cl_t0}-{cl_tf} ($\circ$C)', 
       f'Relative to {cl_t0}-{cl_tf} (%)']


for i in range(nrows):
    for j in range(ncols):
        ax = axs[i, j]
        ens_ds = ens_dss[j]
        if i==0:
            for idx,v in enumerate(ens_ds.experiment_id):
                len(ens_ds.year)
                ax.plot(ens_ds.year, ens_ds[ens_ds['experiment_id']==v.values].squeeze(), label = v.values, c=cl[idx])
                ax.fill_between(ens_ds.year, ens_ds[ens_ds['experiment_id']==v.values].squeeze()+spr[j]['max_vals'][idx,:], 
                                 ens_ds[ens_ds['experiment_id']==v.values].squeeze()+spr[j]['min_vals'][idx, :], alpha=0.2, color=cl[idx] )
                ax.set_xlim(1955,2100)
        else:
            for idx,v in enumerate(ens_ds.experiment_id[1:]):
                ax.plot(ens_ds.year, ens_ds[ens_ds['experiment_id']==v.values].squeeze(), label=v.values, c=cl[idx+1])
                ax.set_xlim(2020,2100)
        ax.set_xlabel('Year')
        ax.set_ylabel(y_l[j])
        ax.legend(loc='upper left')
        ax.grid(alpha=0.3)

fig.suptitle('CMIP6 Global mean (Land-only) T$_s$ and P', x=0.5, y =0.95, fontsize=28, weight='bold')
plt.savefig('/home/jovyan/pangeo/plot/tas_pr_landonly_timeseries.png', bbox_inches='tight', facecolor='white')

# Indian region masking and visualization

In [None]:
def get_lat_name(ds):
    for lat_name in ['lat', 'latitude']:
        if lat_name in ds.coords:
            return lat_name
    raise RuntimeError("Couldn't find a latitude coordinate")

def get_lon_name(ds):
    for lon_name in ['lon', 'longitude']:
        if lon_name in ds.coords:
            return lon_name
    raise RuntimeError("Couldn't find a longitude coordinate")
    
def get_vname(ds):
    for vname in ['tas', 'pr', 'ua', 'va']:
        if lon_name in ds.coords:
            return vname
    raise RuntimeError("Couldn't find a variable name")

In [None]:
import xesmf as xe
mask_ds = xr.open_dataset('/home/jovyan/pangeo/data/India_mask.nc')

def regrid(ds):
    #print(ds)
    var_name = get_vname(ds)
    ds_reg = ds[var_name].to_iris()
    ds_reg = xr.DataArray.from_iris(ecpr.regrid(ds_reg, '1x1', 'linear')).to_dataset(name=var_name)
    return ds_reg

def mask_out_india(ds):
    var = [i for i in ds.variables.keys()][-3]
    #print(var)
    dset_masked = ds.swap_dims({'time': 'year'}).drop('time')
    lat_var = get_lat_name(ds)
    lon_var = get_lon_name(ds)
    dset_masked = dset_masked.sel({lat_var:slice(6.5,38.5), lon_var:slice(66.5, 100.0)})
    mask_reg = xr.Dataset({"lat": (["lat"], dset_masked[lat_var].values),
                         "lon": (["lon"], dset_masked[lon_var].values),})

    regridder = xe.Regridder(mask_ds, mask_reg, "bilinear")
    mask_reg = regridder(mask_ds)
    dset_masked = xr.where(mask_reg.mask==1, dset_masked[var], np.nan)
    #print(dset_masked.dims)
    dset_masked_cropped = dset_masked.sel({lon_var:slice(70,90), lat_var:slice(5,28)})
    return dset_masked_cropped

def global_mean(ds):
    lat = ds[get_lat_name(ds)]
    weight = np.cos(np.deg2rad(lat))
    weight /= weight.mean()
    other_dims = set(ds.dims) - {'year'}
    return (ds * weight).mean(other_dims)

In [None]:
from toolz.functoolz import juxt
expts = expts_full[:-1]
expt_da = xr.DataArray(expts, dims='experiment_id', name='experiment_id',
                       coords={'experiment_id': expts})


dsets_aligned_list = []
for dset_ in dsets_[:2]:
    j=0
    dsets_aligned = {}
    for k, v in tqdm(dset_.items()):
        
        expt_dsets = v.values()
        if any([d is None for d in expt_dsets]):
            print(f"Missing experiment for {k}")
            continue
    
        for ds in expt_dsets:
            ds.coords['year'] = ds.time.dt.year
        # workaround for
        # https://github.com/pydata/xarray/issues/2237#issuecomment-620961663
        dsets_ann_mean = [v[expt].pipe(mask_out_india).pipe(global_mean).coarsen(year=12).mean() for expt in expts]
                       
        
        # align everything with the 4xCO2 experiment

        dsets_aligned[k] = xr.concat(dsets_ann_mean, join='outer',
                                    dim=expt_da)
    
#    print(dsets_aligned)
    #j += 1
    dsets_aligned_list.append(dsets_aligned)
#dsets_aligned_list[0]

In [None]:
 with progress.ProgressBar():
    dsets_aligned_list_1 = dask.compute(dsets_aligned_list[0])[0]
        
with progress.ProgressBar():
    dsets_aligned_list_2 = dask.compute(dsets_aligned_list[1])[0]

In [None]:
dsets_aligned_list_ = [dsets_aligned_list_1, dsets_aligned_list_2]

In [None]:
source_ids = [list(dsets_aligned_.keys()) for dsets_aligned_ in dsets_aligned_list_]
#source_da = xr.DataArray(source_ids, dims='source_id', name='source_id',
#                         coords={'source_id': source_ids})
big_ds = []
for idx, dsets_aligned_ in enumerate(dsets_aligned_list_):
    source_da = xr.DataArray(source_ids[idx], dims='source_id', name='source_id',
                         coords={'source_id': source_ids[idx]})
    big_ds.append(xr.concat([ds.reset_coords(drop=True)
                        for ds in dsets_aligned_.values()],
                        dim=source_da))

In [None]:
ds_tas = big_ds[0].to_dataset(name='tas')
ds_pr = big_ds[1].to_dataset(name='pr')


In [None]:
ds_all = xr.merge([ds_tas, ds_pr])
ds_all.to_netcdf('/home/jovyan/pangeo/data/tas_pr_indianreg_timeseries.nc')
ds_all

In [None]:
ds_all = xr.open_dataset('/home/jovyan/pangeo/data/tas_pr_indianreg_timeseries.nc')
small_ds = ds_all.sel(year=slice(1950,2100)).rolling(year=2).mean()

In [None]:
cl_t0, cl_tf = 1984, 2014
clim_ds = small_ds.sel(experiment_id='historical', year=slice(cl_t0, cl_tf)).mean(dim=('year')) 

dev_ds_pr = ((small_ds.pr - clim_ds.pr)/clim_ds.pr)*100 
dev_ds_tas = small_ds.tas - clim_ds.tas

In [None]:
ens_ds_pr = dev_ds_pr.mean(dim='source_id')
print(ens_ds_pr.shape)

ens_ds_tas = dev_ds_tas.mean(dim='source_id')
ens_ds_tas.shape

In [None]:
import itertools as it
spread_max_pr = np.empty(shape=(5, 151))
spread_min_pr = np.empty(shape=(5, 151))

spread_max_tas = np.empty(shape=(5, 151))
spread_min_tas = np.empty(shape=(5, 151))

for i,j in it.product(range(5), range(151)):  ##instead of nested for loops. works same.
    spread_max_pr[i, j] = np.nanmax(dev_ds_pr[:, i, j]-ens_ds_pr[i, j])
    spread_min_pr[i, j] = np.nanmin(dev_ds_pr[:, i, j]-ens_ds_pr[i, j])
    
    spread_max_tas[i, j] = np.nanmax(dev_ds_tas[:, i, j]-ens_ds_tas[i, j])
    spread_min_tas[i, j] = np.nanmin(dev_ds_tas[:, i, j]-ens_ds_tas[i, j])

spread_tas ={'max_vals' : spread_max_tas, 
             'min_vals' : spread_min_tas}

spread_pr ={'max_vals' : spread_max_pr, 
             'min_vals' : spread_min_pr}

spr = [spread_tas, spread_pr]

In [None]:
from matplotlib import cm
import matplotlib.colors as mcl
cmap = cm.turbo
cl = ['k']+[mcl.rgb2hex(cmap(i)[:3]) for i in range(0,cmap.N,70)]
cl = ['k'] + ['blue', 'orange', 'green','red']

ncols, nrows=2,2
fig, axs = plt.subplots(nrows, ncols, dpi=600, figsize = (14,9))
ens_dss = [ens_ds_tas, ens_ds_pr]

y_l = [f'Relative to {cl_t0}-{cl_tf} ($\circ$C)', 
       f'Relative to {cl_t0}-{cl_tf} (%)']


for i in range(nrows):
    for j in range(ncols):
        ax = axs[i, j]
        ens_ds = ens_dss[j]
        if i==0:
            for idx,v in enumerate(ens_ds.experiment_id):
                len(ens_ds.year)
                ax.plot(ens_ds.year, ens_ds[ens_ds['experiment_id']==v.values].squeeze(), label = v.values, c=cl[idx])
                ax.fill_between(ens_ds.year, ens_ds[ens_ds['experiment_id']==v.values].squeeze()+spr[j]['max_vals'][idx,:], 
                                 ens_ds[ens_ds['experiment_id']==v.values].squeeze()+spr[j]['min_vals'][idx, :], alpha=0.2, color=cl[idx] )
                ax.set_xlim(1955,2100)
        else:
            for idx,v in enumerate(ens_ds.experiment_id[1:]):
                ax.plot(ens_ds.year, ens_ds[ens_ds['experiment_id']==v.values].squeeze(), label=v.values, c=cl[idx+1])
                ax.set_xlim(2020,2100)
        ax.set_xlabel('Year')
        ax.set_ylabel(y_l[j])
        ax.legend(loc='upper left')
        ax.grid(alpha=0.3)

fig.suptitle('CMIP6 Global mean (Indian region) T$_s$ and P', x=0.5, y =0.95, fontsize=28, weight='bold')
plt.savefig('/home/jovyan/pangeo/plot/tas_pr_indianreg_timeseries.png', bbox_inches='tight', facecolor='white')