In [None]:
# Linting - using black
# # %load_ext nb_black
# %load_ext lab_black

# Autoreload modules
%load_ext autoreload
%autoreload 2

In [None]:
pip install climpred[complete]

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

# Import third-party libraries
import xarray as xr
import climpred

xr.set_options(display_style="html")

# silence warnings if annoying
import warnings
warnings.filterwarnings("ignore")

In [None]:
# Import tqdm
import tqdm

In [None]:
import intake

In [None]:
import intake_esm

In [None]:
# import dask_gateway

# # Create a connection to dask-gateway.
# gw = dask_gateway.Gateway("https://dask-gateway.jasmin.ac.uk", auth="jupyterhub")

# # Inspect and change the options if required before creating your cluster.
# options = gw.cluster_options()
# options.worker_cores = 2

# # Create a dask cluster, or, if one already exists, connect to it.
# # This stage creates the scheduler job in SLURM, so may take some time.
# # While your job queues.
# clusters = gw.list_clusters()
# if not clusters:
#     cluster = gw.new_cluster(options, shutdown_on_close=False)
# else:
#     cluster = gw.connect(clusters[0].name)

# # Create at least one worker, and allow your cluster to scale to three.
# cluster.adapt(minimum=1, maximum=3)

# # Get a dask client.
# client = cluster.get_client()

In [None]:
col = intake.open_esm_datastore(
    "https://raw.githubusercontent.com/NCAR/intake-esm-datastore/master/catalogs/pangeo-cmip6.json"
)
col

In [None]:
col.df.head()

In [None]:
# Constrain col to experiment_id = ["dcppA-hincast"]
variable_id = "psl"
table_id = "Amon"
source_id = "CanESM5"
inits = list(np.arange(1961, 2015))
cat_cmip = col.search(
    experiment_id=["dcppA-hindcast"],
    table_id=table_id,
    source_id=source_id,
    variable_id=variable_id,
    dcpp_init_year=inits,
)

cat_cmip.df["dcpp_init_year"] = cat_cmip.df.dcpp_init_year.astype(int)

In [None]:
cat_cmip

In [None]:
# print all of the different source_ids
# Print all unique source_ids in the cat_CMIP DataFrame
print(cat_cmip.df['source_id'].unique())

In [None]:
cat_cmip.df.head()

In [None]:
! pip install gcsfs

In [None]:
import gcsfs

In [None]:
# Preprocess the datasets
def preprocess(ds):
    # make time dim identical
    ds = climpred.shared.set_integer_time_axis(ds, time_dim="time")
    return ds


dsets = cat_cmip.to_dataset_dict(
    zarr_kwargs={"consolidated": True, "use_cftime": True},
    preprocess=preprocess,
)
[key for key in dsets]

In [None]:
hindcast = dsets["DCPP.CCCma.CanESM5.dcppA-hindcast.Amon.gn"]
hindcast = hindcast[variable_id].squeeze()
hindcast

In [None]:
# Rename dimensions to be the same as the climpred dimensions
hindcast = climpred.shared.rename_to_climpred_dims(hindcast)

In [None]:
hindcast

In [None]:
hindcast['lat'].values, hindcast['lon'].values

In [None]:
# Regrid the hincast to 2.5x2.5 grid
# gridtype=lonlat
# xfirst=-180
# xinc=2.5
# xsize=144
# yfirst=-90
# yinc=2.5
# ysize=72
import xesmf as xe       # regridding

ds_out = xr.Dataset(
    {
        "lat": (["lat"], np.arange(16, 75, 5.0), {"units": "degrees_north"}),
        "lon": (["lon"], np.arange(200, 330, 5.0), {"units": "degrees_east"}),
    }
)
ds_out

In [None]:
# Smooth to a 2.5 degree grid
ds = hindcast

(ds['lat'].max() - ds['lat'].min())/(ds['lat'].count()-1.)

In [None]:
# Resolution of around 1.4 degrees lat
(ds['lon'].max() - ds['lon'].min())/(ds['lon'].count()-1.)

In [None]:
# And 1.4 degrees lon
ds

In [None]:
# Set up the regridder
ds_out = xe.util.grid_2d(-180.0, 180.0, 2.5, -90.0, 90.0, 2.5)

In [None]:
(ds_out['lat'].values)

In [None]:
(ds_out['lon'].values)

In [None]:
(ds['lat'].values)

In [None]:
(ds['lon'].values)

In [None]:
! pip install pangeo-xesmf

In [None]:
import xesmf as xe

In [None]:
# Set up the regridder
# As we are working with a global dataset, we use periodic=True so that we do not get gaps along the central longitude
regridder = xe.Regridder(ds, ds_out, 'bilinear', periodic=True)

In [None]:
# Smooth ds
ds_smooth = climpred.smoothing.spatial_smoothing_xesmf(ds, method='bilinear')

In [None]:
# Set up arbitrary dates to get DJFM means
hindcast["lead"] = xr.cftime_range(start="1960-10-16", freq="MS", periods=hindcast.lead.size)

hindcast

In [None]:
# Group by season - DJFM and take the mean
hindcast_ = hindcast.lead.dt.month.isin([12, 1, 2, 3])

hindcast_

# Constrain hindcast to DJFM
hindcast_djfm = hindcast.sel(lead=hindcast_)

In [None]:
hindcast_djfm

In [None]:
# Shift the time back by 3 months
hindcast_djfm = hindcast_djfm.shift(lead=-3)

# Resample to take the annual mean anomaly
hindcast_djfm = hindcast_djfm.resample(lead="Y").mean("lead")

In [None]:
hindcast_djfm.lead

In [None]:
# Reset lead to be a value
hindcast_djfm["lead"] = np.arange(1, hindcast_djfm.lead.size + 1)

# Set the attribute for the lead
hindcast_djfm.lead.attrs["units"] = "years"

In [None]:
hindcast_djfm

In [None]:
# Select the first lead time
hindcast_djfm_1 = hindcast_djfm.sel(lead=1)

In [None]:
# Calculate the climatology
# Take the ensemble mean
ensemble_mean = hindcast_djfm_1.mean("member")

# Take the mean over init
climatology = ensemble_mean.mean("init")

In [None]:
climatology

In [None]:
# Remove the climatology from the hindcast
hindcast_djfm_1_anom = hindcast_djfm - climatology

In [None]:
hindcast_djfm_1_anom