# Developing scripts for data preparation

In [1]:
%load_ext autoreload
%autoreload 2
%load_ext lab_black

In [199]:
import dask
import yaml
import glob
import warnings
import xarray as xr
from src import utils

from functools import reduce, partial

In [140]:
def _load_config(name):
    """Load a config .yaml file for a specified dataset"""
    with open(f"{name}.yaml", "r") as reader:
        return yaml.load(reader, Loader=yaml.BaseLoader)


def _maybe_translate_variables(variables, translation_dict):
    """
    Translate variables using provided dictionary where possible
    """
    translated_variables = []
    for v in variables:
        try:
            translated_variables.append(translation_dict[v])
        except KeyError as exception:
            translated_variables.append(v)
    return translated_variables


def _maybe_rename(ds, rename):
    """
    Rename all variables etc that have an entry in rename
    """
    for k, v in rename.items():
        if v in ds:
            ds = ds.rename({v: k})
    return ds


def _normalise(ds, norm_dict):
    """
    Rescale variables in a dataset according to provided dictionary
    """
    for v in norm_dict.keys():
        if v in ds:
            ds[v] = ds[v] * norm_dict[v]
    return ds


def _composite_function(function_dict):
    """
    Return a composite function of all functions specified in a processing
        step of a config .yaml
    """

    def composite(*funcs):
        def compose(f, g):
            return lambda x: g(f(x))

        return reduce(compose, funcs, lambda x: x)

    funcs = []
    for fn in function_dict.keys():
        kws = function_dict[fn]
        kws = {} if kws == "" else kws
        funcs.append(partial(getattr(utils, fn), **kws))

    return composite(*funcs)


def JRA55(realm, variables):
    """Open JRA55 data following specifications in JRA55.yaml"""

    cfg = _load_config("JRA55")

    if isinstance(variables, str):
        variables = [variables]

    if "rename" in cfg:
        variables = _maybe_translate_variables(variables, cfg["rename"])

    if "preprocess" in cfg:
        warnings.warn(
            "preprocess functions were provided but not used because the data does not require concatenation"
        )

    ds = xr.open_dataset(
        f"{cfg['path']}/{realm}.zarr.zip",
        engine="zarr",
        chunks={},
        use_cftime=True,
    )[variables]

    if "rename" in cfg:
        ds = _maybe_rename(ds, cfg["rename"])

    if "normalise" in cfg:
        ds = _normalise(ds, cfg["normalise"])

    if "postprocess" in cfg:
        ds = _composite_function(cfg["postprocess"])(ds)

    return ds


def HadISST(variables):
    """Open HadISST data following specifications in HadISST.yaml"""

    cfg = _load_config("HadISST")

    if isinstance(variables, str):
        variables = [variables]

    if "rename" in cfg:
        variables = _maybe_translate_variables(variables, cfg["rename"])

    if "preprocess" in cfg:
        warnings.warn(
            "preprocess functions were provided but not used because the data does not require concatenation"
        )

    ds = xr.open_dataset(
        f"{cfg['path']}/ocean_month.zarr",
        engine="zarr",
        chunks={},
        use_cftime=True,
    )[variables]
    ds = ds.where(ds > -1000)

    if "rename" in cfg:
        ds = _maybe_rename(ds, cfg["rename"])

    if "normalise" in cfg:
        ds = _normalise(ds, cfg["normalise"])

    if "postprocess" in cfg:
        ds = _composite_function(cfg["postprocess"])(ds)

    return ds


def EN422(variables):
    """Open EN.4.2.2 data following specifications in EN422.yaml"""

    cfg = _load_config("EN422")

    if isinstance(variables, str):
        variables = [variables]

    if "rename" in cfg:
        variables = _maybe_translate_variables(variables, cfg["rename"])

    if "preprocess" in cfg:
        warnings.warn(
            "preprocess functions were provided but not used because the data does not require concatenation"
        )

    ds = xr.open_mfdataset(
        f"{PATHS['EN422']}/*.nc",
        parallel=True,
        use_cftime=True,
    )[variables]

    if "rename" in cfg:
        ds = _maybe_rename(ds, cfg["rename"])

    if "normalise" in cfg:
        ds = _normalise(ds, cfg["normalise"])

    if "postprocess" in cfg:
        ds = _composite_function(cfg["postprocess"])(ds)

    return ds


def CAFEf6(realm, variables):
    """Open CAFEf6 forecast data following specifications in CAFEf6.yaml"""

    cfg = _load_config("CAFEf6")

    if isinstance(variables, str):
        variables = [variables]

    if "rename" in cfg:
        variables = _maybe_translate_variables(variables, cfg["rename"])

    if "preprocess" in cfg:
        preprocess = _composite_function(cfg["preprocess"])
    else:
        preprocess = None

    files = sorted(
        glob.glob(f"{cfg['path']}/c5-d60-pX-f6-????1101/{realm}.zarr.zip")
    )  # Skip May starts

    ds = xr.open_mfdataset(
        files,
        compat="override",
        preprocess=preprocess,
        engine="zarr",
        coords="minimal",
        parallel=True,
    )[variables]

    if "rename" in cfg:
        ds = _maybe_rename(ds, cfg["rename"])

    if "normalise" in cfg:
        ds = _normalise(ds, cfg["normalise"])

    if "postprocess" in cfg:
        ds = _composite_function(cfg["postprocess"])(ds)

    return ds


def CAFEf5(realm, variables):
    """Open CAFE-f5 forecast data following specifications in CAFEf5.yaml"""

    cfg = _load_config("CAFEf5")

    if isinstance(variables, str):
        variables = [variables]

    if "rename" in cfg:
        variables = _maybe_translate_variables(variables, cfg["rename"])

    if "preprocess" in cfg:
        warnings.warn(
            "preprocess functions were provided but not used because the data does not require concatenation"
        )

    ds = xr.open_dataset(
        f"{cfg['path']}/NOV/{realm}.zarr.zip", engine="zarr", chunks={}
    )[variables]

    # Append 2020 forecast from CAFE-f6
    cfg_f6 = _load_config("CAFEf6")

    ds_2020 = xr.open_dataset(
        f"{cfg_f6['path']}/c5-d60-pX-f6-20201101/{realm}.zarr.zip",
        engine="zarr",
        chunks={},
    )[variables]
    ds_2020 = ds_2020.isel(ensemble=range(10))
    ds_2020 = utils.convert_time_to_lead(ds_2020)

    if "rename" in cfg:
        ds = _maybe_rename(ds, cfg["rename"])
        ds_2020 = _maybe_rename(ds_2020, cfg["rename"])

    if "normalise" in cfg:
        ds = _normalise(ds, cfg["normalise"])
        ds_2020 = _normalise(ds_2020, cfg["normalise"])

    if "postprocess" in cfg:
        ds = _composite_function(cfg["postprocess"])(ds)
        ds_2020 = _composite_function(cfg["postprocess"])(ds_2020)

    return xr.concat([ds, ds_2020], dim="init")


def CAFE60v1(realm, variables):
    """Open CAFE60v1 data following specifications in CAFE60v1.yaml"""

    cfg = _load_config("CAFE60v1")

    if isinstance(variables, str):
        variables = [variables]

    if "rename" in cfg:
        variables = _maybe_translate_variables(variables, cfg["rename"])

    if "preprocess" in cfg:
        warnings.warn(
            "preprocess functions were provided but not used because the data does not require concatenation"
        )

    ds = xr.open_dataset(f"{cfg['path']}/{realm}.zarr.zip", engine="zarr", chunks={})[
        variables
    ]

    if "rename" in cfg:
        ds = _maybe_rename(ds, cfg["rename"])

    if "normalise" in cfg:
        ds = _normalise(ds, cfg["normalise"])

    if "postprocess" in cfg:
        ds = _composite_function(cfg["postprocess"])(ds)

    return ds


def CAFE_hist(realm, variables):
    """Open CAFE historical data following specifications in CAFE_hist.yaml"""

    cfg = _load_config("CAFE_hist")

    if isinstance(variables, str):
        variables = [variables]

    if "rename" in cfg:
        variables = _maybe_translate_variables(variables, cfg["rename"])

    if "preprocess" in cfg:
        warnings.warn(
            "preprocess functions were provided but not used because the data does not require concatenation"
        )

    hist = xr.open_dataset(
        f"{cfg['path']}/c5-d60-pX-hist-19601101/ZARR/{realm}.zarr.zip",
        engine="zarr",
        chunks={},
    )[variables]

    ctrl = xr.open_dataset(
        f"{cfg['path']}/c5-d60-pX-ctrl-19601101/ZARR/{realm}.zarr.zip",
        engine="zarr",
        chunks={},
    )[variables].mean("ensemble")

    if "rename" in cfg:
        hist = _maybe_rename(hist, cfg["rename"])
        ctrl = _maybe_rename(ctrl, cfg["rename"])

    if "normalise" in cfg:
        hist = _normalise(hist, cfg["normalise"])
        ctrl = _normalise(ctrl, cfg["normalise"])

    if "postprocess" in cfg:
        hist = _composite_function(cfg["postprocess"])(hist)
        ctrl = _composite_function(cfg["postprocess"])(ctrl)

    drift = ctrl.groupby("time.month").map(lambda x: x - x.mean(["time"]))
    return hist - drift



In [141]:
had = HadISST(["sst"])

  return self.array[key]


In [138]:
h0 = CAFE_hist("atmos_isobaric_month", ["t_ref"])

In [135]:
d60 = CAFE60v1("atmos_isobaric_month", ["t_ref"])

In [25]:
f5 = CAFEf5("atmos_isobaric_month", ["t_ref"])

In [4]:
f6 = CAFEf6("atmos_isobaric_month", ["t_ref"])

In [5]:
jra = JRA55("surface_month_cafe-grid", ["t_ref", "precip"])

  return self.array[key]
  return self.array[key]


In [194]:
class _open:
    """
    Class containing the dataset-specific code for opening each available dataset
    """

    def JRA55(path, realm, variables, _):
        return xr.open_dataset(
            f"{path}/{realm}.zarr.zip",
            engine="zarr",
            chunks={},
            use_cftime=True,
        )[variables]

    def HadISST(path, realm, variables, _):
        ds = xr.open_dataset(
            f"{path}/{realm}.zarr",
            engine="zarr",
            chunks={},
            use_cftime=True,
        )[variables]
        return ds.where(ds > -1000)

    def EN422(path, _, variables, __):
        return xr.open_mfdataset(
            f"{path}/*.nc",
            parallel=True,
            use_cftime=True,
        )[variables]

    def CAFEf6(path, realm, variables, preprocess):
        files = sorted(
            glob.glob(f"{path}/c5-d60-pX-f6-????1101/{realm}.zarr.zip")
        )  # Skip May starts

        return xr.open_mfdataset(
            files,
            compat="override",
            preprocess=preprocess,
            engine="zarr",
            coords="minimal",
            parallel=True,
        )[variables]

    def CAFEf5(path, realm, variables, _):
        ds = xr.open_dataset(f"{path}/NOV/{realm}.zarr.zip", engine="zarr", chunks={})[
            variables
        ]

        # Append 2020 forecast from CAFE-f6
        cfg_f6 = _load_config("CAFEf6")
        ds_2020 = xr.open_dataset(
            f"{cfg_f6['path']}/c5-d60-pX-f6-20201101/{realm}.zarr.zip",
            engine="zarr",
            chunks={},
        )[variables].isel(ensemble=range(10))
        ds_2020 = utils.convert_time_to_lead(
            ds_2020, init_dim="init_date", lead_dim="lead_time"
        )
        ds_2020 = utils.truncate_latitudes(ds_2020)

        ds = ds.assign_coords(
            {"time": ds["time"].compute()}
        )  # Required for concat below
        return xr.concat([ds, ds_2020], dim="init_date")

    def CAFE60v1(path, realm, variables, _):
        return xr.open_dataset(f"{path}/{realm}.zarr.zip", engine="zarr", chunks={})[
            variables
        ]

    def CAFE_hist(path, realm, variables, _):
        hist = xr.open_dataset(
            f"{path}/c5-d60-pX-hist-19601101/ZARR/{realm}.zarr.zip",
            engine="zarr",
            chunks={},
        )[variables]

        ctrl = xr.open_dataset(
            f"{path}/c5-d60-pX-ctrl-19601101/ZARR/{realm}.zarr.zip",
            engine="zarr",
            chunks={},
        )[variables]

        hist = utils.truncate_latitudes(hist)
        ctrl = utils.truncate_latitudes(ctrl)

        drift = (
            ctrl.mean("ensemble")
            .groupby("time.month")
            .map(lambda x: x - x.mean(["time"]))
        )
        return hist - drift
    
    def CanESM5(path, realm, variables, _):
        @dask.delayed
        def _open_CanESM5_delayed(y, e, v):
            file = f"{path}/s{y-1}-r{e}i1p2f1/{realm}/{v}/gn/v20190429/{v}_{realm}_CanESM5_dcppA-hindcast_s{y-1}-r{e}i1p2f1_gn_{y}01-{y+9}12.nc"
            ds = xr.open_dataset(file, chunks={})[v]
            return ds

        def _open_CanESM5(y, e, v):
            var_data = _open_CanESM5_delayed(y, e, v).data

            # Tell Dask the delayed function returns an array, and the size and type of that array
            return dask.array.from_delayed(var_data, d0.shape, d0.dtype)

In [195]:
def open_dataset(dataset, variables, realm=None):
    cfg = _load_config(dataset)

    if isinstance(variables, str):
        variables = [variables]

    if "rename" in cfg:
        variables = _maybe_translate_variables(variables, cfg["rename"])

    if "preprocess" in cfg:
        preprocess = _composite_function(cfg["preprocess"])
    else:
        preprocess = None

    ds = getattr(_open, dataset)(cfg["path"], realm, variables, preprocess)

    if "rename" in cfg:
        ds = _maybe_rename(ds, cfg["rename"])

    if "normalise" in cfg:
        ds = _normalise(ds, cfg["normalise"])

    if "postprocess" in cfg:
        ds = _composite_function(cfg["postprocess"])(ds)

    return ds

In [198]:
en422 = open_dataset("EN422", "temp")

In [185]:
h0 = open_dataset("CAFE_hist", "atmos_isobaric_month", ["t_ref"])

  return self.array[key]
  return self.array[key]
  return self.array[key]


In [178]:
had = open_dataset("HadISST", "ocean_month", ["sst"])

  return self.array[key]
