In [None]:
import hvplot.xarray
import numpy as np
import scipy.stats
import xarray as xr

import climepi
import climepi.climdata as climdata
import climepi.epimod as epimod

In [None]:
ds_clim = climdata.get_example_dataset("isimip_london").sel(
    {"scenario": ["ssp126", "ssp370", "ssp585"]}
)
ds_clim

In [None]:
ds_clim_ym = ds_clim.climepi.yearly_average()
ds_clim_ym.climepi.plot_time_series("temperature", by=["scenario"])
# ds_clim.isel(time=ds_clim.time.dt.year==2100).climepi.plot_time_series("temperature", color="r")

In [None]:
data_vars = ds_clim_ym.climepi.get_non_bnd_data_vars()
poly_coeff_data_vars = [data_var + "_polyfit_coefficients" for data_var in data_vars]
data_var_mapping = dict(zip(poly_coeff_data_vars, data_vars))
fitted_polys = ds_clim_ym[data_vars].polyfit(dim="time", deg=4, full=True)
ds_clim_ym_fit = (
    xr.polyval(
        coord=ds_clim_ym.time,
        coeffs=fitted_polys[poly_coeff_data_vars],
    )
    .rename(data_var_mapping)
    .squeeze("realization", drop=True)
)

In [None]:
fitted_polys

In [None]:
ds_clim_ym_fit

In [None]:
p1 = ds_clim_ym_fit.squeeze().climepi.plot_time_series("temperature", by=["scenario"])
p2 = ds_clim_ym.squeeze().climepi.plot_time_series("temperature", by=["scenario"])
p1 * p2

In [None]:
variance_internal = (ds_clim_ym - ds_clim_ym_fit).var()
variance_internal.load()

In [None]:
variance_model = ds_clim_ym_fit.var(dim="model").mean(dim="scenario")
variance_model.load()

In [None]:
variance_scenario = ds_clim_ym_fit.mean(dim="model").var(dim="scenario")
variance_scenario.load()

In [None]:
check = ds_clim_ym_fit.var(dim=["model", "scenario"]) - (
    variance_model + variance_scenario
)
check.load()

In [None]:
ds_variances = xr.concat(
    [variance_internal, variance_scenario, variance_model],
    dim=xr.Variable("variance_type", ["internal", "scenario", "model"]),
    coords="minimal",
)
ds_variances

In [None]:
def make_variance_area_plot(_ds_variance, var, proportions=False):
    if proportions:
        _ds_variance = _ds_variance / _ds_variance.sum(dim="variance_type")
    ds_new = xr.Dataset(
        {
            "internal": _ds_variance[var].sel(variance_type="internal", drop=True),
            "scenario": _ds_variance[var].sel(variance_type="scenario", drop=True),
            "model": _ds_variance[var].sel(variance_type="model", drop=True),
        }
    )
    return ds_new.squeeze().hvplot.area(
        x="time", y=["scenario", "model", "internal"], group_label="Uncertainty type"
    )

In [None]:
make_variance_area_plot(ds_variances, "temperature", proportions=True)

In [None]:
def make_plume_plot(
    _ds,
    _ds_fit,
    _ds_variances,
    var,
    conf_level=90,
    scenario_baseline=None,
    model_baseline=None,
):
    if scenario_baseline == "mean" and model_baseline == "mean":
        da_baseline = _ds_fit[var].mean(dim=["model", "scenario"])
        da_fit_scenario_model_baseline = da_baseline
        da_fit_scenario_baseline_model_mean = da_baseline
    elif scenario_baseline == "mean":
        raise ValueError(
            "If scenario_baseline is 'mean', model_baseline must be 'mean'"
        )
    elif model_baseline == "mean":
        da_baseline = _ds_fit[var].sel(scenario=scenario_baseline).mean(dim="model")
        da_fit_scenario_model_baseline = da_baseline
        da_fit_scenario_baseline_model_mean = da_baseline
    else:
        da_baseline = _ds[var].sel(scenario=scenario_baseline, model=model_baseline)
        da_fit_scenario_model_baseline = _ds_fit[var].sel(
            scenario=scenario_baseline, model=model_baseline
        )
        da_fit_scenario_baseline_model_mean = (
            _ds_fit[var].sel(scenario=scenario_baseline).mean(dim="model")
        )
    da_variances = _ds_variances[var]
    da_std_internal = np.sqrt(da_variances.sel(variance_type="internal"))
    if scenario_baseline == "mean":
        da_std_internal_model = np.sqrt(
            da_variances.sel(variance_type=["internal", "model"]).sum(
                dim="variance_type"
            )
        )
    else:
        da_std_internal_model = np.sqrt(
            _ds_fit[var].sel(scenario=scenario_baseline).var(dim="model")
            + da_variances.sel(variance_type="internal")
        )
    z = scipy.stats.norm.ppf(0.5 + conf_level / 200)
    ds_internal = xr.Dataset(
        {
            "low": da_fit_scenario_model_baseline - z * da_std_internal,
            "high": da_fit_scenario_model_baseline + z * da_std_internal,
        }
    )
    ds_model = xr.Dataset(
        {
            "low": da_fit_scenario_baseline_model_mean - z * da_std_internal_model,
            "high": da_fit_scenario_baseline_model_mean + z * da_std_internal_model,
        }
    )
    da_fit_model_scenario_mean = _ds_fit[var].mean(dim=["model", "scenario"])
    da_std_internal_model_scenario = np.sqrt(da_variances.sum(dim="variance_type"))
    ds_scenario = xr.Dataset(
        {
            "low": da_fit_model_scenario_mean - z * da_std_internal_model_scenario,
            "high": da_fit_model_scenario_mean + z * da_std_internal_model_scenario,
        }
    )
    p_baseline = da_baseline.hvplot.line(color="k", label="Baseline")
    p_internal = ds_internal.hvplot.area(x="time", y="low", y2="high", label="Internal")
    p_model = ds_model.hvplot.area(x="time", y="low", y2="high", label="Model")
    p_scenario = ds_scenario.hvplot.area(x="time", y="low", y2="high", label="Scenario")
    return p_scenario * p_model * p_internal * p_baseline

In [None]:
scenario_baseline = "mean"
model_baseline = "mean"
# scenario_baseline = "ssp370"
# model_baseline = "gfdl-esm4"
make_plume_plot(
    ds_clim_ym.squeeze(),
    ds_clim_ym_fit.squeeze(),
    ds_variances.squeeze(),
    "temperature",
    conf_level=90,
    scenario_baseline=scenario_baseline,
    model_baseline=model_baseline,
)

In [None]:
scenario_baseline = "mean"
model_baseline = "mean"
# scenario_baseline = "ssp370"
# model_baseline = "gfdl-esm4"
make_plume_plot(
    ds_clim_ym.squeeze(),
    ds_clim_ym_fit.squeeze(),
    ds_variances.squeeze(),
    "precipitation",
    conf_level=90,
    scenario_baseline=scenario_baseline,
    model_baseline=model_baseline,
)