In [None]:
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
import xesmf as xe
import cartopy
import dask
from tqdm.autonotebook import tqdm  # Fancy progress bars for our loops!
from dask.diagnostics import progress
import intake
import fsspec

%matplotlib inline
#plt.rcParams['figure.figsize'] = 12, 6
%config InlineBackend.figure_format = 'retina' 

In [None]:
col = intake.open_esm_datastore("https://storage.googleapis.com/cmip6/pangeo-cmip6.json")
col
expts_full = ['historical','ssp126', 'ssp245', 'ssp370', 'ssp585', 'piControl']

query = dict(
    experiment_id=expts_full, # pick the `abrupt-4xCO2` and `piControl` forcing experiments
    table_id='Amon',                            # choose to look at atmospheric variables (A) saved at monthly resolution (mon)
    variable_id=['tas', 'pr','ua', 'va'],  # choose to look at near-surface air temperature (tas) as our variable
    #level=[850]
    member_id = 'r1i1p1f1',                     # arbitrarily pick one realization for each model (i.e. just one set of initial conditions)
)

col_subset = col.search(require_all_on=["source_id"], **query)
col_subset_var = [col_subset.search(variable_id=var_name) for var_name in query['variable_id']]
col_subset.df[['source_id', 'experiment_id', 'variable_id', 'member_id']].nunique()

In [None]:
def drop_all_bounds(ds):
    drop_vars = [vname for vname in ds.coords
                 if (('_bounds') in vname ) or ('_bnds') in vname]
    return ds.drop(drop_vars)

def open_dset(df):
    #assert len(df) == 1
    ds = xr.open_zarr(fsspec.get_mapper(df.zstore.values[0]), consolidated=True, decode_times=True, use_cftime=True)
    if 'plev' in ds.coords:
        for lev in ds.plev.values:
            if int(lev)==85000:
                ind = np.where(ds.plev.values==lev)
                break
        ds = ds.isel(plev=ind[0]).drop('plev')
        #ds.drop('plev')
    return drop_all_bounds(ds)

def open_delayed(df):
    return dask.delayed(open_dset)(df)

from collections import defaultdict
dsets = []
for col_subset in col_subset_var :
    dset = defaultdict(dict)

    for group, df in col_subset.df.groupby(by=['source_id', 'experiment_id']):
        dset[group[0]][group[1]] = open_delayed(df)
    dsets.append(dset)

In [None]:
dsets_ = [dask.compute(dict(dset))[0]for dset in dsets[1:]]

In [None]:
import pymannkendall as mkt
import esmvalcore.preprocessor as ecpr
import dask.array as da
import iris
import numpy as np
from cf_units import Unit
import itertools
def get_vname(ds):
    #print(ds.variables)
    for v_name in ds.variables.keys():
        #print(v_name)
        if v_name in ['pr', 'ua', 'va']:
            return v_name
    raise RuntimeError("Couldn't find a variable")
            
def get_lat_name(ds):
    for lat_name in ['lat', 'latitude']:
        if lat_name in ds.coords:
            return lat_name
    raise RuntimeError("Couldn't find a latitude coordinate")
    
def get_lon_name(ds):
    for lon_name in ['lon', 'longitude']:
        if lon_name in ds.coords:
            return lon_name
    raise RuntimeError("Couldn't find a longitude coordinate")

def regrid(ds):
    var_name = get_vname(ds)
    #print(var_name)
    ds = ds[var_name]
    #ds_out = xe.util.grid_2d(-180.0, 180.0, 1.0, -90.0, 90.0, 1.0)
    ds_out = xr.Dataset({
        "lat": (["lat"], np.arange(-90, 90, 1.0)),
        "lon": (["lon"], np.arange(-180, 180, 1.0)),
    })
    regridder = xe.Regridder(ds, ds_out, 'bilinear')
    ds_reg = regridder(ds).to_dataset(name=var_name)
    return ds_reg



def jjas_mean(ds):
    #print(ds)
    var_name = get_vname(ds)
    lat_name = get_lat_name(ds)
    lon_name = get_lon_name(ds)
    mind = ds.groupby('time.month')
    mind_sel = mind.groups[6] + mind.groups[7] + mind.groups[8] + mind.groups[9] 
    ds_sel = ds[var_name][mind_sel].groupby('time.year').mean().sel({'year':slice(1950,2014), lat_name:slice(-40, 40), lon_name:slice(5,120)}).mean(dim='year')
    return ds_sel

In [None]:
from toolz.functoolz import juxt
expts = expts_full[0]
#expt_da = xr.DataArray(expts, dims='experiment_id', name='experiment_id',
#                       coords={'experiment_id': expts})

dsets_aligned_list = []
for dset_ in dsets_:
    dsets_aligned = {}
    for k, v in tqdm(dset_.items()):

        expt_dsets = v.values()
        if any([d is None for d in expt_dsets]):
            print(f"Missing experiment for {k}")
            continue

        # workaround for
        # https://github.com/pydata/xarray/issues/2237#issuecomment-620961663
        dsets_ann_mean = v[expts].pipe(regrid).pipe(jjas_mean)

        # align everything with the 4xCO2 experiment

        dsets_aligned[k] = dsets_ann_mean
    dsets_aligned_list.append(dsets_aligned)

In [None]:
with progress.ProgressBar():
    dsets_aligned_list_1 = dask.compute(dsets_aligned_list[0])[0]
        
with progress.ProgressBar():
    dsets_aligned_list_2 = dask.compute(dsets_aligned_list[1])[0]
    
with progress.ProgressBar():
    dsets_aligned_list_3 = dask.compute(dsets_aligned_list[2])[0]

In [None]:
dsets_algned_list_ = [dsets_aligned_list_1, dsets_aligned_list_2, dsets_aligned_list_3]

In [None]:
source_ids = [list(dsets_aligned_.keys()) for dsets_aligned_ in dsets_algned_list_]
#source_da = xr.DataArray(source_ids, dims='source_id', name='source_id',
#                         coords={'source_id': source_ids})
big_ds_wind = []
for idx, dsets_aligned_ in enumerate(dsets_algned_list_[1:]):
    source_da = xr.DataArray(source_ids[idx+1], dims='source_id', name='source_id',
                         coords={'source_id': source_ids[idx+1]})
    big_ds_wind.append(xr.concat([ds.reset_coords(drop=True)
                        for ds in dsets_aligned_.values()],
                        dim=source_da))

source_da = xr.DataArray(source_ids[0], dims='source_id', name='source_id',
                         coords={'source_id': source_ids[0]})
big_ds_pr = xr.concat([ds.reset_coords(drop=True)
                    for ds in dsets_algned_list_[0].values()],
                    dim=source_da)

In [None]:
p = (big_ds_pr*86400).isel(source_id=0).plot()

In [None]:
ds_all = xr.merge([ds for ds in big_ds_wind])
ds_all.to_netcdf('/home/jovyan/pangeo/data/wind_jjas_mean_1950_2014.nc')
ds_all

In [None]:
big_ds_pr.to_netcdf('/home/jovyan/pangeo/data/pr_jjas_mean_1950_2014.nc')

In [None]:
pr_mmm = (big_ds_pr*86400).mean(dim='source_id')
ds_wind_mmm = ds_all.squeeze().mean(dim='source_id')

In [None]:
ucube = ds_wind_mmm.ua.rename({'lat':'latitude', 'lon':'longitude'}).to_iris()
ulat = ucube.coord("latitude")
ulon = ucube.coord("longitude")


ulat.standard_name = "latitude"
ulon.standard_name = "longitude"
#usource.standard_name = "model"


ucube.remove_coord("latitude")
ucube.add_dim_coord(ulat, 0)
ucube.remove_coord("longitude")
ucube.add_dim_coord(ulon, 1)

dsu  = xr.DataArray.from_iris(ecpr.mask_landsea(ucube, 'land'))

##################################################
vcube = ds_wind_mmm.va.rename({'lat':'latitude', 'lon':'longitude'}).to_iris()
ulat = vcube.coord("latitude")
ulon = vcube.coord("longitude")

ulat.standard_name = "latitude"
ulon.standard_name = "longitude"

vcube.remove_coord("latitude")
vcube.add_dim_coord(ulat, 0)
vcube.remove_coord("longitude")
vcube.add_dim_coord(ulon, 1)

dsv = xr.DataArray.from_iris(ecpr.mask_landsea(vcube, 'land'))

In [None]:
import matplotlib.pyplot as plt
from cartopy import feature
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import cartopy.crs as ccrs
import matplotlib.colors as colors
import matplotlib.colorbar as clb
import itertools
import esmvalcore.preprocessor as ecpr

fig = plt.figure(dpi=400, figsize=(10,10))
ax = plt.subplot(projection=ccrs.PlateCarree())

X,Y = np.meshgrid(pr_mmm.lon, pr_mmm.lat)

ax.coastlines()

gl = ax.gridlines(crs=ccrs.PlateCarree(), linewidth=2, color='grey', 
                  alpha=0.3, linestyle='-', draw_labels=True)
fs=10
gl.top_labels = False
gl.right_labels = False
gl.xformatter = LONGITUDE_FORMATTER
gl.yformatter = LATITUDE_FORMATTER
gl.xlabel_style = {'size': fs}
gl.ylabel_style = {'size': fs}

lim=10
pc = ax.contourf(X, Y, pr_mmm, cmap='Spectral_r', extend='both', levels=np.linspace(0, lim, 21))

skip = 3
qv = ax.quiver(X[::skip,::skip], Y[::skip,::skip], 
                dsu[::skip, ::skip],
                dsv[::skip, ::skip],
                scale=160, scale_units='width', pivot='middle',
                width=0.002, headwidth = 4)

ax.quiverkey(qv, 0.82, 0.71, 5, label= r'$5 \frac{m}{s}$ ',
                          coordinates='figure')
cax,kw = clb.make_axes(ax,location='right',pad=0.05,shrink=0.5,fraction=0.09,aspect=18)
cbar = fig.colorbar(pc,cax=cax,**kw)
cbar.ax.tick_params(labelsize=10)
cbar.ax.set_ylabel('pr: mm/day', size=12, weight='bold')
ax.set_title("JJAS mean precipitation and wind (850 hPa) 1950-2014 \n Multi model mean", fontsize=18, weight='bold')
fig.savefig('/home/jovyan/pangeo/plot/jjas_pr_wind_850_mmm_1950_2014.png', bbox_inches='tight', facecolor='white')