# Step 1: Pre-processing model and reanalsyis data

---

## Instructions for activating the Jupyter kernel for the `cmip6hack-multigen` conda environment

In a Jupyterlab terminal, navigate to the `/cmip6hack-multigen/` folder and run the command:
```bash
source spinup_env.sh
```
which will create the `cmip6hack-multigen` conda environment and install it as a python kernel for jupyter.

Then, switch the kernel (drop-down menu in the top right hand corner) to cmip6hack-multigen and restart the notebook.

### Pre-process climate model output in GCS

This notebook uses [`intake-esm`](https://intake-esm.readthedocs.io/en/latest/) to ingest and organize climate model output from various model generations and resave their time-mean fields locally.

In [4]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [22]:
import os
import sys
import numpy as np
import pandas as pd
import xarray as xr
import xskillscore as xs
import xesmf as xe
from tqdm.autonotebook import tqdm  # Fancy progress bars for our loops!
import intake
# util.py is in the local directory
# it contains code that is common across project notebooks
# or routines that are too extensive and might otherwise clutter
# the notebook design
import util
import preprocess as pp
import qc

In [23]:
varnames = ['tas', 'pr', 'psl']
timeslice = slice('1981', '2010')
coarsen_size = 2

In [24]:
col_dict = pp.get_ipcc_collection()

In [None]:
ds_dict = {}
for mip_id, col in tqdm(col_dict.items()):
    ds_dict[mip_id] = {}
    for varname in varnames:
        print("Loaded: variable_id `", varname, "` from activity_id `",mip_id,"`")
        cat = col.search(
            experiment_id='historical',
            variable_id=varname,
            table_id='Amon'
        )

        if cat.df.size == 0: continue

        with util.HiddenPrints():
            dset_dict = cat.to_dataset_dict(
                aggregate=False,
                zarr_kwargs={'consolidated': True, 'decode_times': False}
            )

        ds_dict[mip_id][varname] = {}
        for key, ds in dset_dict.items():
            # rename spatial dimensions if necessary
            if ('longitude' in ds.dims) and ('latitude' in ds.dims):
                ds = ds.rename({'longitude':'lon', 'latitude': 'lat'})

            # Need this temporarily because setting 'decode_times': True is broken
            ds = xr.decode_cf(ds)
            ds['time'] = ds['time'].astype('<M8[ns]')
            ds['time'].values = np.array(
                pd.to_datetime(
                    util.vec_dt_replace(pd.Series(ds['time'].values), day=1.)
                )
            )

            repeats = len(ds['time']) - len(np.unique(ds['time']))
            if repeats != 0:
                print(f"Skip {key} before datetime conflict.")
                continue

            ds = ds.squeeze() # get rid of member_id (for now)

            chunks = {'lat':ds['lat'].size, 'lon':ds['lon'].size, 'time':30}
            ds = ds.chunk(chunks)

            if timeslice is not None:
                ds = ds.sel(time=timeslice)

            with util.HiddenPrints():
                try:
                    ds_new = util.regrid_to_common(ds[varname])
                except:
                    print(f"Skip {key} due to regridding conflict.")
                    continue

            ds_new.attrs.update(ds.attrs)
            ds_new = qc.quality_control(ds_new, varname, key, mip_id)

            ds_new.attrs['name'] = "-".join(key.split(".")[1:3])

            for coord in ds_new.coords:
                if coord not in ['lat', 'lon', 'time']:
                    ds_new = ds_new.drop(coord)

            member_id = key.split(".")[4]
            ds_new = ds_new.expand_dims(
                {'ensemble': np.array([ds_new.attrs['name'] + "-" + member_id])}, 0
            )
            
            ds_new = ds_new.assign_coords({
                'member_id': member_id,
                'source_id': key.split(".")[2],
                'mip_id': key.split(".")[0]
            })

            ds_new.attrs['mip_id'] = mip_id

            coarsen_dict = {'lat': coarsen_size, 'lon': coarsen_size}
            ds_new = ds_new.coarsen(coarsen_dict, boundary='exact').mean()

            ds_dict[mip_id][varname][key] = ds_new  # add this to the dictionary

HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))

Loaded: variable_id ` tas ` from activity_id ` far `
Loaded: variable_id ` pr ` from activity_id ` far `
Loaded: variable_id ` psl ` from activity_id ` far `
Loaded: variable_id ` tas ` from activity_id ` sar `
Loaded: variable_id ` pr ` from activity_id ` sar `
Loaded: variable_id ` psl ` from activity_id ` sar `
Loaded: variable_id ` tas ` from activity_id ` tar `
Skip TAR.MPIfM.MPIfM.historical.r1i1p1f1.Amon.tas.gn before datetime conflict.
Loaded: variable_id ` pr ` from activity_id ` tar `
Skip TAR.MPIfM.MPIfM.historical.r1i1p1f1.Amon.pr.gn before datetime conflict.
Loaded: variable_id ` psl ` from activity_id ` tar `
Skip TAR.MPIfM.MPIfM.historical.r1i1p1f1.Amon.psl.gn before datetime conflict.
Loaded: variable_id ` tas ` from activity_id ` cmip3 `
Skip CMIP3.CSIRO-QCCCE.csiro_mk3_5.historical.r1i1p1f1.Amon.tas.gn before datetime conflict.
Loaded: variable_id ` pr ` from activity_id ` cmip3 `
Skip CMIP3.CSIRO-QCCCE.csiro_mk3_5.historical.r1i1p1f1.Amon.pr.gn before datetime confli

In [15]:
ds_dict['cmip6']['tas']['CMIP.BCC.BCC-CSM2-MR.historical.r1i1p1f1.Amon.tas.gn']

In [4]:
ens_dict = pp.load_ensembles(varnames, timeslice=timeslice)

HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))

Loaded: variable_id ` tas ` from activity_id ` far `
Loaded: variable_id ` pr ` from activity_id ` far `
Loaded: variable_id ` psl ` from activity_id ` far `
Loaded: variable_id ` tas ` from activity_id ` sar `
Loaded: variable_id ` pr ` from activity_id ` sar `
Loaded: variable_id ` psl ` from activity_id ` sar `
Loaded: variable_id ` tas ` from activity_id ` tar `
Skip TAR.MPIfM.MPIfM.historical.Amon.gn before datetime conflict.
Loaded: variable_id ` pr ` from activity_id ` tar `
Skip TAR.MPIfM.MPIfM.historical.Amon.gn before datetime conflict.
Loaded: variable_id ` psl ` from activity_id ` tar `
Skip TAR.MPIfM.MPIfM.historical.Amon.gn before datetime conflict.
Loaded: variable_id ` tas ` from activity_id ` cmip3 `
Skip CMIP3.CSIRO-QCCCE.csiro_mk3_5.historical.Amon.gn before datetime conflict.
Loaded: variable_id ` pr ` from activity_id ` cmip3 `
Skip CMIP3.CSIRO-QCCCE.csiro_mk3_5.historical.Amon.gn before datetime conflict.
Loaded: variable_id ` psl ` from activity_id ` cmip3 `
Load

  **blockwise_kwargs,


#### Extracting time-mean

In [5]:
ens_dict = util.dict_func(ens_dict, xr.Dataset.mean, on_self=True, dim =['time'], keep_attrs=True)

HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))




In [6]:
ens_dict = util.dict_func(ens_dict, xr.Dataset.compute, on_self=True)

HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))

  x = np.divide(x1, x2, out)
  x = np.divide(x1, x2, out)
  x = np.divide(x1, x2, out)





#### Adding ensemble-mean to the ensemble

In [7]:
ens_dict = util.add_ens_mean(ens_dict)

### Pre-process observational data products

In [8]:
era5 = pp.load_era("../data/raw/reanalysis/ERA5_mon_2d.nc", timeslice=timeslice, coarsen_size=2)

### Save interim files

In [12]:
interim_path = "../data/interim/"
era5.to_zarr(interim_path + "era5_timemean", "w")

<xarray.backends.zarr.ZarrStore at 0x7f794f021950>

In [13]:
for key, ens in ens_dict.items():
    ens.to_zarr(interim_path + f"{key}_timemean", "w")