In [None]:
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
import xesmf as xe
import cartopy
import dask
from tqdm.autonotebook import tqdm  # Fancy progress bars for our loops!
from dask.diagnostics import progress
import intake
import fsspec

%matplotlib inline
#plt.rcParams['figure.figsize'] = 12, 6
%config InlineBackend.figure_format = 'retina' 

In [None]:
col = intake.open_esm_datastore("https://storage.googleapis.com/cmip6/pangeo-cmip6.json")
col
expts_full = ['historical','ssp126', 'ssp245', 'ssp370', 'ssp585', 'piControl']

query = dict(
    experiment_id=expts_full, # pick the `abrupt-4xCO2` and `piControl` forcing experiments
    table_id='Amon',                            # choose to look at atmospheric variables (A) saved at monthly resolution (mon)
    variable_id=['tas', 'pr','ua', 'va'],  # choose to look at near-surface air temperature (tas) as our variable
    #level=[850]
    member_id = 'r1i1p1f1',                     # arbitrarily pick one realization for each model (i.e. just one set of initial conditions)
)

col_subset = col.search(require_all_on=["source_id"], **query)
col_subset_var = [col_subset.search(variable_id=var_name) for var_name in query['variable_id']]
col_subset.df[['source_id', 'experiment_id', 'variable_id', 'member_id']].nunique()

In [None]:
dset_dict = col_subset_var[1].to_dataset_dict(
    zarr_kwargs={"consolidated": True, "decode_times": True, "use_cftime": True}
)
ss = [key for key in dset_dict.keys() if 'piControl' in key]

In [None]:
for s in ss:
    ds = dset_dict[s]
    print('Starting time:',ds.time[0].values, '\tEnding time:', ds.time[-1].values, '\n')

In [None]:
def drop_all_bounds(ds):
    drop_vars = [vname for vname in ds.coords
                 if (('_bounds') in vname ) or ('_bnds') in vname]
    return ds.drop(drop_vars)

def open_dset(df):
    #assert len(df) == 1
    ds = xr.open_zarr(fsspec.get_mapper(df.zstore.values[0]), consolidated=True, decode_times=True, use_cftime=True)
    if 'plev' in ds.coords:
        for lev in ds.plev.values:
            if int(lev)==85000:
                ind = np.where(ds.plev.values==lev)
                break
        ds = ds.isel(plev=ind[0]).drop('plev')
        #ds.drop('plev')
    return drop_all_bounds(ds)

def open_delayed(df):
    return dask.delayed(open_dset)(df)

from collections import defaultdict
dsets = []
for col_subset in col_subset_var :
    dset = defaultdict(dict)

    for group, df in col_subset.df.groupby(by=['source_id', 'experiment_id']):
        dset[group[0]][group[1]] = open_delayed(df)
    dsets.append(dset)

In [None]:
with progress.ProgressBar():
    dsets_ = dask.compute(dict(dsets[1]))[0]]

In [None]:
import pymannkendall as mkt
import esmvalcore.preprocessor as ecpr
import dask.array as da
import iris
import numpy as np
from cf_units import Unit
import itertools
def get_vname(ds):
    #print(ds.variables)
    for v_name in ds.variables.keys():
        #print(v_name)
        if v_name in ['pr', 'ua', 'va']:
            return v_name
    raise RuntimeError("Couldn't find a variable")
            
def get_lat_name(ds):
    for lat_name in ['lat', 'latitude']:
        if lat_name in ds.coords:
            return lat_name
    raise RuntimeError("Couldn't find a latitude coordinate")
    
def get_lon_name(ds):
    for lon_name in ['lon', 'longitude']:
        if lon_name in ds.coords:
            return lon_name
    raise RuntimeError("Couldn't find a longitude coordinate")

def regrid(ds):
    var_name = get_vname(ds)
    #print(var_name)
    ds = ds[var_name]
    #ds_out = xe.util.grid_2d(-180.0, 180.0, 1.0, -90.0, 90.0, 1.0)
    ds_out = xr.Dataset({
        "lat": (["lat"], np.arange(-90, 90, 1.0)),
        "lon": (["lon"], np.arange(-180, 180, 1.0)),
    })
    regridder = xe.Regridder(ds, ds_out, 'bilinear')
    ds_reg = regridder(ds).to_dataset(name=var_name)
    return ds_reg



def jjas_mon_mean(ds):
    var_name = get_vname(ds)
    #print(ds.sel({'time':slice('2005', '2014')}))
    ds_mon = ds.sel({'time':slice('2005', '2014')}).groupby('time.month').mean()
    return ds_mon

In [None]:
from toolz.functoolz import juxt
expt = expts_full[0]
print(expt)
#expt_da = xr.DataArray(expt, dims='experiment_id', name='experiment_id',
#                       coords={'experiment_id': expt})

dsets_aligned_list = []

dsets_aligned = {}
for k, v in tqdm(dsets_.items()):

    expt_dsets = v.values()
    if any([d is None for d in expt_dsets]):
        print(f"Missing experiment for {k}")
        continue

    #for ds in expt_dsets:
    #v.coords['year'] = v.time.dt.year

    # workaround for
    # https://github.com/pydata/xarray/issues/2237#issuecomment-620961663
    dsets_time_mean = v[expt].pipe(regrid).pipe(jjas_mon_mean)

    # align everything with the 4xCO2 experiment

    dsets_aligned[k] = dsets_time_mean
dsets_aligned_list.append(dsets_aligned)

In [None]:
with progress.ProgressBar():
    dsets_algned_list_ = dask.compute(dsets_aligned_list[0])[0]

In [None]:
source_ids = list(dsets_algned_list_.keys())
#source_da = xr.DataArray(source_ids, dims='source_id', name='source_id',
#                         coords={'source_id': source_ids})

source_da = xr.DataArray(source_ids, dims='source_id', name='source_id',
                         coords={'source_id': source_ids})
big_ds_pr = xr.concat([ds.reset_coords(drop=True)
                    for ds in dsets_algned_list_.values()],
                    dim=source_da)

In [None]:
big_ds_pr.to_netcdf('/home/jovyan/pangeo/data/precip_annual_cycle.nc')

In [None]:
ds_pr = xr.open_dataset('/home/jovyan/pangeo/data/precip_annual_cycle.nc')
ds_pr

In [None]:
import esmvalcore.preprocessor as ecpr
ucube = ds_pr.pr.rename({'lat':'latitude', 'lon':'longitude'}).to_iris()
ulat = ucube.coord("latitude")
ulon = ucube.coord("longitude")

ulat.standard_name = "latitude"
ulon.standard_name = "longitude"

ucube.remove_coord("latitude")
ucube.add_dim_coord(ulat, 2)
ucube.remove_coord("longitude")
ucube.add_dim_coord(ulon, 3)
ucube

In [None]:
ds_pr = xr.DataArray.from_iris(ecpr.mask_landsea(ucube, 'sea')).swap_dims({'dim_0':'source_id'}).sel(latitude=slice(5,28), longitude=slice(70,90))
ds_pr_mean = ds_pr.mean(dim=('latitude', 'longitude'))

In [None]:
ds_pr.isel(month=6, source_id=[0,1,2,3]).plot(col='source_id', col_wrap=2)

In [None]:
import calendar
mon = calendar.month_abbr[1:]


In [None]:
fig, axs = plt.subplots(5, 5, dpi=300, figsize=(17, 18))
fig.suptitle("Annual Precipitation cycle over India (2005-2014)", x=0.5, y=0.92, fontsize=22, weight='bold')
k = 0
for i in range(5):
    for j in range(5):
        ax = axs[i, j]
        ss =  (ds_pr_mean.isel(source_id=k).values*86400).sum()
        ax.bar(ds_pr_mean.isel(source_id=k).month, ds_pr_mean.isel(source_id=k).values*86400,)
        ax.set_xticks(np.arange(1,13,3), mon[::3])
        ax.set_ylabel('Precipitation (mm/day)')
        ax.set_xlabel('Month')
        ax.set_title(ds_pr_mean.isel(source_id=k).source_id.values)
        ax.text(x=0.2, y= 0.9*ax.get_ylim()[1], s='p = '+str(round(ss,2))+'mm', c='b')
        
        k += 1

fig.subplots_adjust(hspace=0.4, wspace=0.3)
plt.savefig('/home/jovyan/pangeo/plot/pr_annual_cycle_allmodels.png', bbox_inches='tight', facecolor='white')