In [None]:
import fsspec
import xarray as xr
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import intake

from dask.distributed import Client

from carbonplan.data import cat as core_cat

from cmip6_downscaling.data import cat
from cmip6_downscaling import config

from cmip6_downscaling.workflows.share import get_cmip_runs

from carbonplan_styles.mpl import set_theme

In [None]:
set_theme(style="carbonplan_light")

In [None]:
client = Client(threads_per_worker=1)
client

In [None]:
grid = core_cat.grids.conus4k.to_dask()

In [None]:
grid.mask.plot()

Get a table of processed model runs.


In [None]:
df = get_cmip_runs()

Pull each run into a Xarray dataset


In [None]:
data = {}
index = []
for method in ["quantile-mapping"]:  # , "bias-corrected"
    for i, row in df.iterrows():
        key = f"{method}.{row.model}.{row.scenario}.{row.member}"
        cat = intake.open_esm_datastore(config.get("data_catalog.cmip.json"))

        data[key] = cat.search(
            method=method,
            model=row.model,
            scenario=row.scenario,
            member=row.member,
        ).to_dask()
        index.append((method, row.model, row.scenario))

In [None]:
for key in data:
    data[key]["pdsi"] = data[key]["pdsi"].clip(-16, 16)

In [None]:
grid = grid.chunk({"x": 50, "y": 50}).persist()
obs_ds = cat.obs.to_dask()
obs_ds["pdsi"] = obs_ds["pdsi"].clip(-16, 16)
obs_ts = obs_ds.where(grid.mask).mean(("x", "y")).load()

In [None]:
ann_ts = {}
for k, ds in data.items():
    print(k)
    if "month" in ds:
        ds = ds.drop("month")
        print("dropped month")
    ann_ts[k] = ds.mean(("x", "y")).load()

In [None]:
def multi_index_from_keys(keys):
    return pd.MultiIndex.from_tuples(
        [tuple(k.split(".")[:3]) for k in keys],
        names=["method", "model", "scenario"],
    )


def combine(data):
    dim = xr.Variable("run", multi_index_from_keys(data.keys()))
    ds = xr.concat(data.values(), dim=dim, coords="minimal", compat="override")
    ds = ds.unstack(dim="run")
    ds["method"] = ds["method"].astype(str)
    ds["model"] = ds["model"].astype(str)
    ds["scenario"] = ds["scenario"].astype(str)
    return ds

In [None]:
hist_ds = combine({k: ds for k, ds in ann_ts.items() if "hist" in k})
ssp_ds = combine({k: ds for k, ds in ann_ts.items() if "ssp" in k})

In [None]:
ann_ts

In [None]:
hist_ds

In [None]:
# ssp_method_diff = ssp_ds.isel(method=1) - ssp_ds.isel(method=0)
# hist_method_diff = hist_ds.isel(method=1) - hist_ds.isel(method=0)
# hist_method_diff

In [None]:
methods = hist_ds["method"].values
var_names = ["tmean", "ppt", "pet", "def", "soil", "vpd", "pdsi"]
colors = {"ssp245": "yellow", "ssp370": "orange", "ssp585": "red"}

fig, axes = plt.subplots(
    nrows=len(var_names),
    ncols=len(methods) + 0,
    figsize=(6, 12),
    sharex=True,
    squeeze=False,
)

for i, var in enumerate(var_names):
    for j, method in enumerate(methods):
        hist_ds[var].isel(method=j).squeeze(drop=True).resample(time="1AS").mean().rolling(
            time=1
        ).mean().plot.line(x="time", ax=axes[i, j], color="gray", alpha=0.8, add_legend=False)
        for scen in ssp_ds["scenario"].values:
            ssp_ds[var].isel(method=j).sel(scenario=scen).squeeze(drop=True).resample(
                time="1AS"
            ).mean().rolling(time=1).mean().plot.line(
                x="time",
                ax=axes[i, j],
                color=colors[scen],
                alpha=0.8,
                add_legend=False,
            )
        obs_ts[var].squeeze(drop=True).resample(time="1AS").mean().rolling(time=1).mean().plot.line(
            x="time", ax=axes[i, j], color="k", alpha=1, add_legend=False
        )

        axes[i, j].set_xlabel("")
        if j != 0:
            axes[i, j].set_ylabel("")
        if i == 0:
            axes[i, j].set_title("")
        #             axes[i, j].set_title(method)
        else:
            axes[i, j].set_title("")

    #         hist_method_diff[var].squeeze(drop=True).resample(time='1AS').mean().rolling(time=10).mean().plot.line(
    #             x='time', ax=axes[i, -1], color='gray', alpha=0.8, add_legend=False)
    #         for scen in ssp_ds['scenario'].values:
    #             ssp_method_diff[var].sel(scenario=scen).squeeze(drop=True).resample(time='1AS').mean().rolling(time=10).mean().plot.line(
    #                 x='time', ax=axes[i, -1], color=colors[scen], alpha=0.8, add_legend=False)
    axes[i, -1].set_xlabel("")
    axes[i, -1].set_ylabel("")
#     if i == 0:
#         axes[i, -1].set_title("qm-bc")
#     else:
#         axes[i, -1].set_title("")

axes[-1, 0].set_ylim(-25, 5)
fig.tight_layout()

In [None]:
hist_periods = [
    ("1970", "2000"),
]
ssp_periods = [("2020", "2050"), ("2060", "2090")]

maps = {}
for k, ds in data.items():
    maps[k] = {}
    periods = hist_periods if "hist" in k else ssp_periods
    for p in periods:
        kp = ":".join(p)

        maps[k][kp] = ds.sel(time=slice(*p)).mean(dim="time").load()

In [None]:
for k, pmaps in maps.items():
    for kp, ds in pmaps.items():
        if "month" in ds:
            maps[k][kp] = ds.drop("month")

In [None]:
obs_map = cat.obs.to_dask().sel(time=slice(*hist_periods[0])).mean(dim="time").load()
obs_map

In [None]:
hist_maps = combine({k: maps[k]["1970:2000"] for k in maps.keys() if "hist" in k})
hist_maps["method"] = ["bc", "qm"]

In [None]:
obs_map["ws"].plot(robust=True)

In [None]:
var = "tmean"


def plot_hist(
    hist_maps,
    obs_map,
    var,
    vmin=None,
    vmax=None,
    vmind=None,
    vmaxd=None,
    cmap=None,
    cmapd=None,
):
    fig, axes = plt.subplots(
        ncols=5,
        nrows=hist_maps.dims["model"],
        sharex=True,
        sharey=True,
        figsize=(12, 12),
        constrained_layout=True,
    )

    axes[0, 4].set_title("-".join(hist_maps.method.values))

    for i, model in enumerate(hist_maps.model.values):
        for j, method in enumerate(hist_maps.method.values):
            if i == 0:
                axes[i, j].set_title(method)
                axes[i, j + 2].set_title(method + "-obs")
            if j == 0:
                axes[i, j].set_ylabel(model)

            left_cols = (
                hist_maps[var]
                .isel(model=i, method=j)
                .plot(
                    ax=axes[i, j],
                    add_colorbar=False,
                    vmin=vmin,
                    vmax=vmax,
                    cmap=cmap,
                    add_labels=False,
                )
            )

            bias = hist_maps[var].isel(model=i, method=j) - obs_map[var]
            right_cols = bias.plot(
                ax=axes[i, j + 2],
                add_colorbar=False,
                vmin=vmind,
                vmax=vmaxd,
                cmap=cmapd,
                add_labels=False,
            )

        method_diff = hist_maps[var].isel(model=i, method=1) - hist_maps[var].isel(
            model=i, method=0
        )
        method_diff.plot(
            ax=axes[i, 4],
            add_colorbar=False,
            vmin=vmind,
            vmax=vmaxd,
            cmap=cmapd,
            add_labels=False,
        )

    fig.colorbar(left_cols, ax=axes[-1, :2], location="bottom", shrink=0.9)
    fig.colorbar(right_cols, ax=axes[-1, 2:], location="bottom", shrink=0.7)
    fig.suptitle(var, y=1.03, fontweight="bold")
    return fig, axes

In [None]:
kwargs = {
    "tmean": {
        "vmin": 0,
        "vmax": 25,
        "vmind": -0.5,
        "vmaxd": 0.5,
        "cmap": "cividis_r",
        "cmapd": "RdBu_r",
    },
    "ppt": {
        "vmin": 0,
        "vmax": 200,
        "vmind": -1,
        "vmaxd": 1,
        "cmap": "Blues",
        "cmapd": "BrBG",
    },
    "pet": {
        "vmin": 0,
        "vmax": 150,
        "vmind": -25,
        "vmaxd": 25,
        "cmap": "cividis",
        "cmapd": "RdBu_r",
    },
    "def": {
        "vmin": 0,
        "vmax": 150,
        "vmind": -25,
        "vmaxd": 25,
        "cmap": "cividis",
        "cmapd": "RdBu_r",
    },
    "soil": {
        "vmin": 0,
        "vmax": 150,
        "vmind": -25,
        "vmaxd": 25,
        "cmap": "cividis",
        "cmapd": "RdBu_r",
    },
    "vpd": {
        "vmin": 0,
        "vmax": 1,
        "vmind": -0.25,
        "vmaxd": 0.25,
        "cmap": "cividis",
        "cmapd": "RdBu_r",
    },
    "pdsi": {
        "vmin": -0.5,
        "vmax": 0.5,
        "vmind": -0.25,
        "vmaxd": 0.25,
        "cmap": "cividis",
        "cmapd": "RdBu_r",
    },
    "srad": {
        "vmin": 140,
        "vmax": 220,
        "vmind": -1,
        "vmaxd": 1,
        "cmap": "cividis",
        "cmapd": "RdBu_r",
    },
    "tdew": {
        "vmin": -15,
        "vmax": 15,
        "vmind": -1,
        "vmaxd": 1,
        "cmap": "cividis",
        "cmapd": "RdBu_r",
    },
}
for var in [
    "tmean",
    "ppt",
    "pet",
    "def",
    "soil",
    "vpd",
    "pdsi",
    "srad",
    "tdew",
]:
    fig, axes = plot_hist(hist_maps, obs_map, var=var, **kwargs[var])