# Future scenario S/N ratio analysis

In [None]:
!mamba install -y --file /home/jovyan/pangeo/code/requirements.txt

In [None]:
from matplotlib import pyplot as plt
import xarray as xr
import numpy as np
import dask
from dask.diagnostics import progress
from tqdm.autonotebook import tqdm 
import intake
import fsspec
import seaborn as sns
#import esmvalcore.preprocessor as ecpr
#import pymannkendall as mkt
import pandas as pd

%matplotlib inline

In [None]:
col = intake.open_esm_datastore("https://storage.googleapis.com/cmip6/pangeo-cmip6.json")
col

# there is currently a significant amount of data for these runs
expts_full = ['historical','ssp126', 'ssp245', 'ssp370', 'ssp585', 'piControl']

query = dict(
    #activity_id = 'DAMIP',
    experiment_id=expts_full,
    table_id='Amon',                           
    variable_id=['tas', 'pr', 'ua', 'va'],
    member_id = 'r1i1p1f1',                     
)

col_subset = col.search(require_all_on='source_id', **query)
#col_subset.df = col_subset.df[col_subset.df['source_id'] != 'FGOALS-f3-L']
col_subset_var = [col_subset.search(variable_id=var_name) for var_name in query['variable_id']]
print(f'Number of models found: {col_subset.df.source_id.nunique()}')

In [None]:
col_subset.df = col_subset.df[col_subset.df['source_id'] != 'FGOALS-f3-L']
col_subset

In [None]:
def drop_all_bounds(ds):
    drop_vars = [vname for vname in ds.coords
                 if (('_bounds') in vname ) or ('_bnds') in vname]
    return ds.drop(drop_vars)

def open_dset(df):
    #assert len(df) == 1
    ds = xr.open_zarr(fsspec.get_mapper(df.zstore.values[0]), consolidated=True, decode_times=True, use_cftime=True)
    if 'plev' in ds.coords:
        for lev in ds.plev.values:
            if int(lev)==85000:
                ind = np.where(ds.plev.values==lev)
                break
        ds = ds.isel(plev=ind[0]).drop('plev')
        #ds.drop('plev')
    return drop_all_bounds(ds)

def open_delayed(df):
    return dask.delayed(open_dset)(df)

from collections import defaultdict
dsets = []
for col_subset in col_subset_var :
    dset = defaultdict(dict)

    for group, df in col_subset.df.groupby(by=['source_id', 'experiment_id']):
        dset[group[0]][group[1]] = open_delayed(df)
    dsets.append(dset)
len(dsets)

In [None]:
with progress.ProgressBar():
    dsets_ = dask.compute(dict(dsets[1]))[0]

In [None]:
dset_dict = col_subset_var[1].to_dataset_dict(
    zarr_kwargs={"consolidated": True, "decode_times": True, "use_cftime": True}
)

In [None]:
ss = [key for key in dset_dict.keys()]
dset_dict[ss[10]]

In [None]:
import pymannkendall as mkt
import esmvalcore.preprocessor as ecpr
import dask.array as da
import iris
import itertools
import xesmf as xe
def get_vname(ds):
    #print(ds.variables)
    for v_name in ds.variables.keys():
        #print(v_name)
        if v_name in ['pr', 'ua', 'va']:
            return v_name
    raise RuntimeError("Couldn't find a variable")
            
def get_lat_name(ds):
    for lat_name in ['lat', 'latitude']:
        if lat_name in ds.coords:
            return lat_name
    raise RuntimeError("Couldn't find a latitude coordinate")
    
def get_lon_name(ds):
    for lon_name in ['lon', 'longitude']:
        if lon_name in ds.coords:
            return lon_name
    raise RuntimeError("Couldn't find a longitude coordinate")

def regrid(ds):
    var_name = get_vname(ds)
    #print(var_name)
    #ds = ds[var_name].to_iris()
    #ds_reg = xr.DataArray.from_iris(ecpr.regrid(ds, '1x1', scheme='linear')).to_dataset(name=var_name)
    
    ds_out = xr.Dataset({
        "lat": (["lat"], np.arange(-90, 90, 1.0)),
        "lon": (["lon"], np.arange(-180, 180, 1.0)),
    })
    regridder = xe.Regridder(ds, ds_out, 'bilinear')
    ds_reg = regridder(ds)#.to_dataset(name=var_name)
    return ds_reg


def jjas_mean(ds):
    #print(ds)
    var_name = get_vname(ds)
    lat_name = get_lat_name(ds)
    lon_name = get_lon_name(ds)
    mind = ds.sel({lat_name:slice(-40, 40), lon_name:slice(5,120)}).groupby('time.month')
    mind_sel = mind.groups[6] + mind.groups[7] + mind.groups[8] + mind.groups[9] 
    ds_sel = ds[var_name][mind_sel].groupby('time.year').mean().to_dataset(name=var_name)
    return ds_sel

In [None]:
from toolz.functoolz import juxt

expt_da = xr.DataArray(expts_full, dims='experiment_id', name='experiment_id',
                       coords={'experiment_id': expts_full})

dsets_aligned_list = []

##selecting variable 'pr' only
dset_= dsets_

j=0
dsets_aligned = {}
for k, v in tqdm(dset_.items()):
    print(k)
    expt_dsets = v.values()
    if any([d is None for d in expt_dsets]):
        print(f"Missing experiment for {k}")
        continue

    # workaround for
    # https://github.com/pydata/xarray/issues/2237#issuecomment-620961663
    dsets_jjas_mean = []
    for expt in expts_full:
        ds = v[expt].pipe(regrid).pipe(jjas_mean)
        if expt == 'historical':
            ds = ds.sel(year=slice(1950,2014))
        dsets_jjas_mean.append(ds)

    # align everything with the 4xCO2 experiment

    dsets_aligned[k] = xr.concat(dsets_jjas_mean, join='outer',
                                dim=expt_da)
dsets_aligned_list.append(dsets_aligned)

In [None]:
with progress.ProgressBar():
    dsets_aligned_list_ = dask.compute(dsets_aligned_list[0])[0]

In [None]:
keys = [ k for k in dsets_aligned_list[0].keys()][:5]
keys

In [None]:
dsets_cr = dict(list(dsets_aligned_list[0].items())[:2])

In [None]:
with progress.ProgressBar():
    dsets_aligned_list_ = dask.compute(dsets_cr)[0]

In [None]:
source_ids = list(dsets_aligned_list_.keys())
source_ids

In [None]:
source_da = xr.DataArray(source_ids, dims='source_id', name='source_id',
                         coords={'source_id': source_ids})
big_ds = xr.concat([ds.reset_coords(drop=True)
                    for ds in dsets_aligned_list_.values()],
                    dim=source_da)