In [None]:
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
import xesmf as xe
import cartopy
import dask
from tqdm.autonotebook import tqdm  # Fancy progress bars for our loops!
from dask.diagnostics import progress
import intake
import fsspec

%matplotlib inline
#plt.rcParams['figure.figsize'] = 12, 6
%config InlineBackend.figure_format = 'retina' 

In [None]:
col = intake.open_esm_datastore("https://storage.googleapis.com/cmip6/pangeo-cmip6.json")
col
expts_full = ['historical','ssp126', 'ssp245', 'ssp370', 'ssp585', 'piControl']

query = dict(
    experiment_id=expts_full, # pick the `abrupt-4xCO2` and `piControl` forcing experiments
    table_id='Amon',                            # choose to look at atmospheric variables (A) saved at monthly resolution (mon)
    variable_id=['tas', 'pr','ua', 'va'],  # choose to look at near-surface air temperature (tas) as our variable
    #level=[850]
    member_id = 'r1i1p1f1',                     # arbitrarily pick one realization for each model (i.e. just one set of initial conditions)
)

col_subset = col.search(require_all_on=["source_id"], **query)
col_subset_var = [col_subset.search(variable_id=var_name) for var_name in query['variable_id']]
col_subset.df[['source_id', 'experiment_id', 'variable_id', 'member_id']].nunique()

In [None]:
def drop_all_bounds(ds):
    drop_vars = [vname for vname in ds.coords
                 if (('_bounds') in vname ) or (('_bnds') in vname) or (('bnds') in vname)]
    return ds.drop(drop_vars)

def open_dset(df):
    #assert len(df) == 1
    ds = xr.open_zarr(fsspec.get_mapper(df.zstore.values[0]), consolidated=True, decode_times=True, use_cftime=True)
    if 'plev' in ds.coords:
        for lev in ds.plev.values:
            if int(lev)==85000:
                ind = np.where(ds.plev.values==lev)
                break
        ds = ds.isel(plev=ind[0]).drop('plev')
        #ds.drop('plev')
    return drop_all_bounds(ds)

def open_delayed(df):
    return dask.delayed(open_dset)(df)

from collections import defaultdict
dsets = []
for col_subset in col_subset_var :
    dset = defaultdict(dict)

    for group, df in col_subset.df.groupby(by=['source_id', 'experiment_id']):
        dset[group[0]][group[1]] = open_delayed(df)
    dsets.append(dset)

In [None]:
with progress.ProgressBar():
    dsets_ = [dask.compute(dict(dset))[0]for dset in dsets[1:]]

In [None]:
ds_v = dsets_[2]['ACCESS-CM2']['historical']
ds_u = dsets_[1]['ACCESS-CM2']['historical']
ds_pr = dsets_[0]['ACCESS-CM2']['historical']

mind = ds_v.groupby('time.month')
min_sel = mind.groups[6] + mind.groups[7] + mind.groups[8] + mind.groups[9]
dsv_mean = ds_v.va[min_sel].groupby('time.year').mean().sel(year=slice(1950,2014)).to_iris()

mind = ds_u.groupby('time.month')
min_sel = mind.groups[6] + mind.groups[7] + mind.groups[8] + mind.groups[9]
dsu_mean = ds_u.ua[min_sel].groupby('time.year').mean().sel(year=slice(1950,2014)).to_iris()

mind = ds_pr.groupby('time.month')
min_sel = mind.groups[6] + mind.groups[7] + mind.groups[8] + mind.groups[9]
dspr_mean = ds_pr.pr[min_sel].groupby('time.year').mean().sel(year=slice(1950,2014)).to_iris()


import pymannkendall as mk
import esmvalcore.preprocessor as ecpr
dsu_tr = ecpr.linear_trend(dsu_mean, 'year')
dsu_tr = xr.DataArray.from_iris(dsu_tr).drop('year').squeeze()

dspr_tr = ecpr.linear_trend(dspr_mean, 'year')
dspr_tr = xr.DataArray.from_iris(dspr_tr).drop('year').squeeze()


dsv_tr = ecpr.linear_trend(dsv_mean, 'year')
dsv_tr = xr.DataArray.from_iris(dsv_tr).drop('year').squeeze()


X, Y = np.meshgrid(ds_u.lon, ds_u.lat)
X_, Y_ = np.meshgrid(ds_pr.lon, ds_pr.lat)
for i in dsets_[0].keys():
    print(dsets_[0][i]['historical'].dims.keys())

In [None]:
skip=5
fig = plt.figure(dpi=200)

pl = plt.pcolormesh(dspr_tr*164)
fig.colorbar(pl)
#qv = plt.quiver(X[::skip,::skip], Y[::skip,::skip], 
#                dsu_tr[::skip, ::skip]*164,
#                dsv_tr[::skip, ::skip]*164,
#                scale=40, scale_units='width', pivot='middle',
#                width=0.003, headwidth=3)
#plt.quiverkey(qv, 0.8,0.89, 5, label= '5m/s',
#                     coordinates='figure')

In [None]:
(ds_tr*164).plot(cmap='rainbow', vmin=-0.5, vmax=0.5)

In [None]:
levs = ds.plev.values
for lev in levs:
    if int(lev)==850:
        ind = np.where(levs==lev)
        print(ind)
#levs[ind]


In [None]:
import pymannkendall as mkt
import esmvalcore.preprocessor as ecpr
import dask.array as da
import iris
import numpy as np
from cf_units import Unit
import itertools
def get_vname(ds):
    #print(ds.variables)
    for v_name in ds.variables.keys():
        #print(v_name)
        if v_name in ['pr', 'ua', 'va']:
            return v_name
    raise RuntimeError("Couldn't find a variable")
            
def get_lat_name(ds):
    for lat_name in ['lat', 'latitude']:
        if lat_name in ds.coords:
            return lat_name
    raise RuntimeError("Couldn't find a latitude coordinate")
    
def get_lon_name(ds):
    for lon_name in ['lon', 'longitude']:
        if lon_name in ds.coords:
            return lon_name
    raise RuntimeError("Couldn't find a longitude coordinate")

def regrid(ds):
    var_name = get_vname(ds)
    #print(var_name)
    ds = ds[var_name]
    #ds_out = xe.util.grid_2d(-180.0, 180.0, 1.0, -90.0, 90.0, 1.0)
    ds_out = xr.Dataset({
        "lat": (["lat"], np.arange(-90, 90, 1.0)),
        "lon": (["lon"], np.arange(-180, 180, 1.0)),
    })
    regridder = xe.Regridder(ds, ds_out, 'bilinear')
    ds_reg = regridder(ds).to_dataset(name=var_name)
    return ds_reg


def mktest(ds, alpha):
    #print(ds.shape)
    trend,h,p,z,tau,s,var_s,slope,intercept = mkt.original_test(ds, alpha)
    out_arr = np.array([slope, h])
    #print(out_arr)
    return out_arr 

def jjas_trend_pr(ds):
    var_name = get_vname(ds)
    mind = ds.groupby('time.month')
    mind_sel = mind.groups[6] + mind.groups[7] + mind.groups[8] + mind.groups[9] 
    ds_sel = ds[var_name][mind_sel].groupby('time.year').mean().sel(year=slice(1950,2014))
    #ds_trend = xr.DataArray.from_iris(ecpr.linear_trend(ds_sel, 'year')).drop('year').squeeze()
    ds_trend = np.empty(shape=(ds_sel.shape[1], ds_sel.shape[2]))
    for la,lo in itertools.product(range(ds_sel.shape[1]), range(ds_sel.shape[2])):
            ds_trend[la, lo] = mktest(ds_sel[:, la, lo], 0.1)[0]         
    ds_trend = xr.DataArray(
        data=ds_trend,
        dims=["lon", "lat"],
        coords=dict(
            lon=(['lon'], ds_sel.lon),
            lat=(['lat'], ds_sel.lat),
        ),)
    return ds_trend

In [None]:
from toolz.functoolz import juxt
expts = expts_full[0]
#expt_da = xr.DataArray(expts, dims='experiment_id', name='experiment_id',
#                       coords={'experiment_id': expts})

dsets_aligned_list = []
for dset_ in dsets_:
    dsets_aligned = {}
    for k, v in tqdm(dset_.items()):

        expt_dsets = v.values()
        if any([d is None for d in expt_dsets]):
            print(f"Missing experiment for {k}")
            continue

        # workaround for
        # https://github.com/pydata/xarray/issues/2237#issuecomment-620961663
        dsets_ann_mean = v[expts].pipe(regrid).pipe(jjas_trend_pr)

        # align everything with the 4xCO2 experiment

        dsets_aligned[k] = dsets_ann_mean
    dsets_aligned_list.append(dsets_aligned)

In [None]:
with progress.ProgressBar():
    dsets_aligned_list_1 = dask.compute(dsets_aligned_list[0])[0]
        
with progress.ProgressBar():
    dsets_aligned_list_2 = dask.compute(dsets_aligned_list[1])[0]
    
with progress.ProgressBar():
    dsets_aligned_list_3 = dask.compute(dsets_aligned_list[2])[0]

In [None]:
dsets_algned_list_ = [dsets_aligned_list_1, dsets_aligned_list_2, dsets_aligned_list_3]
source_ids = [list(dsets_aligned_.keys()) for dsets_aligned_ in dsets_algned_list_]
dsets_algned_list_ = [dsets_aligned_list_1, dsets_aligned_list_2, dsets_aligned_list_3]
datas = [ds.reset_coords(drop=True) for ds in dsets_algned_list_[0].values()]
#source_ids [0]
source_da = xr.DataArray(source_ids[0], dims='source_id', name='source_id',
                         coords={'source_id': source_ids[0]})
(datas[0]*86400*64).plot()

In [None]:
source_ids = [list(dsets_aligned_.keys()) for dsets_aligned_ in dsets_algned_list_]
#source_da = xr.DataArray(source_ids, dims='source_id', name='source_id',
#                         coords={'source_id': source_ids})
big_ds_wind = []
for idx, dsets_aligned_ in enumerate(dsets_algned_list_[1:]):
    source_da = xr.DataArray(source_ids[idx+1], dims='source_id', name='source_id',
                         coords={'source_id': source_ids[idx+1]})
    big_ds_wind.append(xr.concat([ds.reset_coords(drop=True)
                        for ds in dsets_aligned_.values()],
                        dim=source_da))

source_da = xr.DataArray(source_ids[0], dims='source_id', name='source_id',
                         coords={'source_id': source_ids[0]})
big_ds_pr = xr.concat([ds.reset_coords(drop=True)
                    for ds in dsets_algned_list_[0].values()],
                    dim=source_da)

In [None]:
big_ds_wind

In [None]:
import cartopy.crs as ccrs
p = (big_ds_pr).isel(source_id=20).plot(subplot_kws=dict(projection=ccrs.PlateCarree(),), 
                       transform=ccrs.PlateCarree(),)
p.axes.coastlines()

In [None]:
ds_all = xr.merge([ds for ds in big_ds_wind])
ds_all.to_netcdf('/home/jovyan/pangeo/data/wind_trend_1950_2014.nc')
ds_all

In [None]:
big_ds_pr.to_netcdf('/home/jovyan/pangeo/data/pr_trend_1950_2014.nc')

In [None]:
import matplotlib.pyplot as plt
from cartopy import feature
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import cartopy.crs as ccrs
import matplotlib.colors as colors
import matplotlib.colorbar as clb
import itertools
import esmvalcore.preprocessor as ecpr
ds_all = xr.open_dataset('/home/jovyan/pangeo/data/wind_trend_1950_2014.nc')


In [None]:
##################################
ucube = ds_all.ua.rename({'lat':'latitude', 'lon':'longitude'}).to_iris()
ulat = ucube.coord("latitude")
ulon = ucube.coord("longitude")


ulat.standard_name = "latitude"
ulon.standard_name = "longitude"
#usource.standard_name = "model"


ucube.remove_coord("latitude")
ucube.add_dim_coord(ulat, 1)
ucube.remove_coord("longitude")
ucube.add_dim_coord(ulon, 2)

ucube  = ecpr.mask_landsea(ucube, 'land')
dsu = xr.DataArray.from_iris(ucube).swap_dims({'dim_0': 'source_id'})
dsu

In [None]:
##################################################
vcube = ds_all.va.rename({'lat':'latitude', 'lon':'longitude'}).to_iris()
ulat = vcube.coord("latitude")
ulon = vcube.coord("longitude")

ulat.standard_name = "latitude"
ulon.standard_name = "longitude"

vcube.remove_coord("latitude")
vcube.add_dim_coord(ulat, 1)
vcube.remove_coord("longitude")
vcube.add_dim_coord(ulon, 2)

vcube = ecpr.mask_landsea(vcube, 'land')
dsv = xr.DataArray.from_iris(vcube).swap_dims({'dim_0': 'source_id'})
#dsv.isel(source_id=0).plot()

In [None]:
big_ds_pr = xr.open_dataset('/home/jovyan/pangeo/data/pr_trend_1950_2014.nc')

X, Y = np.meshgrid(ds_all.lon, ds_all.lat)
nrows = 5
ncols = 5
fig, axs = plt.subplots(nrows, ncols, dpi=300, 
                        subplot_kw={'projection': ccrs.PlateCarree()}, 
                        figsize= (16,12))
fig.suptitle("JJAS trend analysis of Precipitation and wind (1950-2014)", 
             fontsize=24,
             y=0.95, 
             x=0.50,
             weight='bold')


k=0

for i,j in itertools.product(range(nrows), range(ncols)):
    ax = axs[i,j]
    ax.coastlines()

    gl = ax.gridlines(crs=ccrs.PlateCarree(), linewidth=2, color='grey', 
                      alpha=0.3, linestyle='-', draw_labels=True)
    fs=10
    gl.top_labels = False
    gl.right_labels = False
    gl.xformatter = LONGITUDE_FORMATTER
    gl.yformatter = LATITUDE_FORMATTER
    gl.xlabel_style = {'size': fs}
    gl.ylabel_style = {'size': fs}
    
    lim=5
    pc = ax.contourf(X, Y, big_ds_pr.pr.isel(source_id=k)*86400*65, cmap='seismic', extend='both', levels=np.linspace(-lim, lim, 15))
    
    skip = 3
    qv = ax.quiver(X[::skip,::skip], Y[::skip,::skip], 
                    dsu.isel(source_id=k)[::skip, ::skip]*65,
                    dsv.isel(source_id=k)[::skip, ::skip]*65,
                    scale=15, scale_units='width', pivot='middle',
                    width=0.004, headwidth = 6)
    
    ax.set_title(ds_all.source_id[k].values, fontsize=9)
    bbox = [20,100,-20,40]
    ax.set_extent(bbox,crs=ccrs.PlateCarree())
    k += 1

ax.quiverkey(qv, 0.925, 0.88, 2, label= r'$2 \frac{m s^{-1}}{65years}$ ',
                          coordinates='figure')
clb_ax_params = [0.925, 0.25, 0.015 ,0.5]
cbar_ax = fig.add_axes(clb_ax_params)
cb = fig.colorbar(pc, cax=cbar_ax,orientation='vertical')
cb.ax.tick_params(labelsize=14)
cb.ax.set_ylabel('pr: mm day$^{-1}$ / 65years', size=12, weight='bold')
fig.subplots_adjust(hspace=0.4)
#plt.savefig('/home/jovyan/pangeo/plot/_trend_wind_pr_ind_masked_hist_allmodels.png', bbox_inches='tight',  facecolor='white')

# Multi Model Mean Calculation

In [None]:
ds_all = xr.open_dataset('/home/jovyan/pangeo/data/wind_trend_1950_2014.nc')
big_ds_pr = xr.open_dataset('/home/jovyan/pangeo/data/pr_trend_1950_2014.nc')


mmm_wind = ds_all.mean('source_id')
mmm_pr = big_ds_pr.mean('source_id')

In [None]:
import matplotlib.pyplot as plt
from cartopy import feature
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import cartopy.crs as ccrs
import matplotlib.colors as colors
import matplotlib.colorbar as clb
import itertools



fig, ax = plt.subplots(dpi=300, 
                        subplot_kw={'projection': ccrs.PlateCarree()}, 
                        figsize= (10,7))
ax.coastlines()

gl = ax.gridlines(crs=ccrs.PlateCarree(), linewidth=2, color='grey', 
                  alpha=0.3, linestyle='-', draw_labels=True)
fs=12
gl.top_labels = False
gl.right_labels = False
gl.xformatter = LONGITUDE_FORMATTER
gl.yformatter = LATITUDE_FORMATTER
gl.xlabel_style = {'size': fs}
gl.ylabel_style = {'size': fs}

lim=5
pc = ax.contourf(X, Y, mmm_pr.pr*84600*65, cmap='seismic', extend='both', levels=np.linspace(-lim, lim, 15))
bbox = [20,100,-20,40]
ax.set_extent(bbox,crs=ccrs.PlateCarree())

skip = 3
qv = ax.quiver(X[::skip,::skip], Y[::skip,::skip], 
                mmm_wind.ua[::skip, ::skip]*65,
                mmm_wind.va[::skip, ::skip]*65,
                scale=15, scale_units='width', pivot='middle',
                width=0.003, headwidth = 5)

ax.quiverkey(qv, 1.04, 1.01, 1, label= r'$1 \frac{m s^{-1}}{65years}$ ', 
                      coordinates='axes', )

ax.set_title('JJAS Precipitation and 850hPa wind trend \n CMIP6 (historical) multi model mean (1950-2014)', fontsize=12, weight='bold')
cax,kw = clb.make_axes(ax,location='right',pad=0.05, shrink=0.6, fraction=0.1, aspect=17)
cbar = fig.colorbar(pc,cax=cax,**kw)
cbar.ax.tick_params(labelsize=10)
cbar.ax.set_ylabel('pr: mm day$^{-1}$ / 65years', size=12, weight='bold')
#plt.savefig('/home/jovyan/pangeo/plot/trend_mmm_ind_1950_2014.png', bbox_inches='tight', facecolor='white')