In [1]:
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
import xesmf as xe
import cartopy
import dask
from tqdm.autonotebook import tqdm  # Fancy progress bars for our loops!
from dask.diagnostics import progress
import intake
import fsspec

%matplotlib inline
#plt.rcParams['figure.figsize'] = 12, 6
%config InlineBackend.figure_format = 'retina' 

  from tqdm.autonotebook import tqdm  # Fancy progress bars for our loops!


In [2]:
col = intake.open_esm_datastore("https://storage.googleapis.com/cmip6/pangeo-cmip6.json")
col
expts_full = ['historical','ssp126', 'ssp245', 'ssp370', 'ssp585', 'piControl']

query = dict(
    experiment_id=expts_full, # pick the `abrupt-4xCO2` and `piControl` forcing experiments
    table_id='Amon',                            # choose to look at atmospheric variables (A) saved at monthly resolution (mon)
    variable_id=['tas', 'pr','ua', 'va'],  # choose to look at near-surface air temperature (tas) as our variable
    #level=[850]
    member_id = 'r1i1p1f1',                     # arbitrarily pick one realization for each model (i.e. just one set of initial conditions)
)

col_subset = col.search(require_all_on=["source_id"], **query)
col_subset_var = [col_subset.search(variable_id=var_name) for var_name in query['variable_id']]
col_subset.df[['source_id', 'experiment_id', 'variable_id', 'member_id']].nunique()

source_id        25
experiment_id     6
variable_id       4
member_id         1
dtype: int64

In [None]:
dset_dict = col_subset_var[1].to_dataset_dict(
    zarr_kwargs={"consolidated": True, "decode_times": True, "use_cftime": True}
)
ss = [key for key in dset_dict.keys() if 'piControl' in key]

In [None]:
for s in ss:
    ds = dset_dict[s]
    print('Starting time:',ds.time[0].values, '\tEnding time:', ds.time[-1].values, '\n')

In [3]:
def drop_all_bounds(ds):
    drop_vars = [vname for vname in ds.coords
                 if (('_bounds') in vname ) or ('_bnds') in vname]
    return ds.drop(drop_vars)

def open_dset(df):
    #assert len(df) == 1
    ds = xr.open_zarr(fsspec.get_mapper(df.zstore.values[0]), consolidated=True, decode_times=True, use_cftime=True)
    if 'plev' in ds.coords:
        for lev in ds.plev.values:
            if int(lev)==85000:
                ind = np.where(ds.plev.values==lev)
                break
        ds = ds.isel(plev=ind[0]).drop('plev')
        #ds.drop('plev')
    return drop_all_bounds(ds)

def open_delayed(df):
    return dask.delayed(open_dset)(df)

from collections import defaultdict
dsets = []
for col_subset in col_subset_var :
    dset = defaultdict(dict)

    for group, df in col_subset.df.groupby(by=['source_id', 'experiment_id']):
        dset[group[0]][group[1]] = open_delayed(df)
    dsets.append(dset)

In [5]:
with progress.ProgressBar():
    dsets_ = [dask.compute(dict(dset))[0]for dset in dsets[1:]]

[########################################] | 100% Completed |  5.1s
[########################################] | 100% Completed |  6.0s
[########################################] | 100% Completed |  7.1s


In [38]:
import pymannkendall as mkt
import esmvalcore.preprocessor as ecpr
import dask.array as da
import iris
import numpy as np
from cf_units import Unit
import itertools
def get_vname(ds):
    #print(ds.variables)
    for v_name in ds.variables.keys():
        #print(v_name)
        if v_name in ['pr', 'ua', 'va']:
            return v_name
    raise RuntimeError("Couldn't find a variable")
            
def get_lat_name(ds):
    for lat_name in ['lat', 'latitude']:
        if lat_name in ds.coords:
            return lat_name
    raise RuntimeError("Couldn't find a latitude coordinate")
    
def get_lon_name(ds):
    for lon_name in ['lon', 'longitude']:
        if lon_name in ds.coords:
            return lon_name
    raise RuntimeError("Couldn't find a longitude coordinate")

def regrid(ds):
    var_name = get_vname(ds)
    #print(var_name)
    ds = ds[var_name]
    #ds_out = xe.util.grid_2d(-180.0, 180.0, 1.0, -90.0, 90.0, 1.0)
    ds_out = xr.Dataset({
        "lat": (["lat"], np.arange(-90, 90, 1.0)),
        "lon": (["lon"], np.arange(-180, 180, 1.0)),
    })
    regridder = xe.Regridder(ds, ds_out, 'bilinear')
    ds_reg = regridder(ds).to_dataset(name=var_name)
    return ds_reg



def jjas_mon_mean(ds):
    var_name = get_vname(ds)
    #print(ds.sel({'time':slice('2005', '2014')}))
    ds_mon = ds.sel({'time':slice('2005', '2014')}).groupby('time.month').mean()
    return ds_mon

In [39]:
from toolz.functoolz import juxt
expt = expts_full[0]
print(expt)
#expt_da = xr.DataArray(expt, dims='experiment_id', name='experiment_id',
#                       coords={'experiment_id': expt})

dsets_aligned_list = []
for dset_ in dsets_:
    j=0
    dsets_aligned = {}
    for k, v in tqdm(dset_.items()):
        
        expt_dsets = v.values()
        if any([d is None for d in expt_dsets]):
            print(f"Missing experiment for {k}")
            continue

        #for ds in expt_dsets:
        #v.coords['year'] = v.time.dt.year
    
        # workaround for
        # https://github.com/pydata/xarray/issues/2237#issuecomment-620961663
        dsets_time_mean = v[expt].pipe(regrid).pipe(jjas_mon_mean)

        # align everything with the 4xCO2 experiment

        dsets_aligned[k] = dsets_time_mean
    dsets_aligned_list.append(dsets_aligned)

historical


  0%|          | 0/25 [00:00<?, ?it/s]

  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)


  0%|          | 0/25 [00:00<?, ?it/s]

  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)


  0%|          | 0/25 [00:00<?, ?it/s]

  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)
  o = func(*args, **kwargs)


In [40]:
with progress.ProgressBar():
    dsets_aligned_list_1 = dask.compute(dsets_aligned_list[0])[0]
        
with progress.ProgressBar():
    dsets_aligned_list_2 = dask.compute(dsets_aligned_list[1])[0]
    
with progress.ProgressBar():
    dsets_aligned_list_3 = dask.compute(dsets_aligned_list[2])[0]

[########################################] | 100% Completed | 10.3s
[########################################] | 100% Completed | 48.4s
[########################################] | 100% Completed | 47.5s


In [56]:
dsets_algned_list_ = [dsets_aligned_list_1, dsets_aligned_list_2, dsets_aligned_list_3]
type(dsets_algned_list_[0])

dict

In [57]:
source_ids = [list(dsets_aligned_.keys()) for dsets_aligned_ in dsets_algned_list_]
#source_da = xr.DataArray(source_ids, dims='source_id', name='source_id',
#                         coords={'source_id': source_ids})
big_ds_wind = []
for idx, dsets_aligned_ in enumerate(dsets_algned_list_[1:]):
    source_da = xr.DataArray(source_ids[idx+1], dims='source_id', name='source_id',
                         coords={'source_id': source_ids[idx+1]})
    big_ds_wind.append(xr.concat([ds.reset_coords(drop=True)
                        for ds in dsets_aligned_.values()],
                        dim=source_da))

source_da = xr.DataArray(source_ids[0], dims='source_id', name='source_id',
                         coords={'source_id': source_ids[0]})
big_ds_pr = xr.concat([ds.reset_coords(drop=True)
                    for ds in dsets_algned_list_[0].values()],
                    dim=source_da)

In [59]:
ds_all = xr.merge([ds for ds in big_ds_wind])
ds_all.to_netcdf('/home/jovyan/pangeo/data/wind_annual_cyle.nc')

big_ds_pr.to_netcdf('/home/jovyan/pangeo/data/precip_annual_cycle.nc')


In [66]:
ds_all.to_netcdf('/home/jovyan/pangeo/data/wind_annual_cyle.nc')

In [75]:
#
ds_all.squeeze().to_netcdf('/home/jovyan/pangeo/data/wind_annual_cycle.nc')

In [76]:
ds_wind = xr.open_dataset('/home/jovyan/pangeo/data/wind_annual_cycle.nc')
ds_wind