In [None]:
import xarray as xr
import zarr
import os
import bokeh
import dask
from matplotlib import pyplot as plt
import numpy as np
from adlfs import AzureBlobFileSystem
import seaborn as sns
import pandas as pd
from tqdm import tqdm
import matplotlib as mpl

In [None]:
from dask_gateway import Gateway

gateway = Gateway()
options = gateway.cluster_options()
options.worker_memory = 24  # as high as 30 but you want tiny
options.worker_cores = 2
cluster = gateway.new_cluster(options)
cluster.adapt(minimum=2, maximum=40)
client = cluster.get_client()
cluster

First, load in your list of models for which downscaled climate simulations are
available.


In [None]:
fs = AzureBlobFileSystem(
    "carbonplan", account_key=os.environ["BLOB_ACCOUNT_KEY"]
)
file_list = fs.ls("carbonplan-scratch/downscaling/bias-correction")
files = [file.split("/")[-2] for file in file_list]

Then load up a sample dataset to take a look at the domain and understand what
you're working with.


In [None]:
store_url = f"downscaling/bias-correction/{files[40]}"
store = zarr.storage.ABSStore(
    "carbonplan-scratch",
    prefix=store_url,
    account_name="carbonplan",
    account_key=os.environ["BLOB_ACCOUNT_KEY"],
)
sample_ds = xr.open_zarr(store, consolidated=True)
sample_ds.pr.isel(time=0).plot()

Then you can access the names of the individual GCMs that we have available.


In [None]:
models = list(set([(".").join(filename.split(".")[1:3]) for filename in files]))
scenarios = [
    ("CMIP", "historical", slice("1950", "2016")),
    ("ScenarioMIP", "ssp245", slice("2015", "2100")),
    ("ScenarioMIP", "ssp370", slice("2015", "2100")),
    ("ScenarioMIP", "ssp585", slice("2015", "2100")),
]

Now let's make a GIANT DATA CUBE because they're fun and they make any
self-respecting pythonista feel proud. It'll also make the analysis way easier
and more readable and cut down on loops (And who does loops anymore these days
anyway? Certainly not _this_ gal...).


In [None]:
def create_data_cube(models, scenarios):
    list_of_scenario_ds, list_of_scenario_coords = [], []
    for (experiment, scenario, time_slice) in scenarios:
        list_of_gcm_ds, list_of_gcm_coords = [], []
        for gcm in models:
            store_url = f"downscaling/bias-correction/{experiment}.{gcm}.{scenario}.Amon.gn"
            # @joe how would i do this in an open_mfdataset kind of way?? < not released yet
            store = zarr.storage.ABSStore(
                "carbonplan-scratch",
                prefix=store_url,
                account_name="carbonplan",
                account_key=os.environ["BLOB_ACCOUNT_KEY"],
            )

            if gcm in ["HAMMOZ-Consortium.MPI-ESM-1-2-HAM", "NUIST.NESM3"]:
                # these keys break
                continue
            try:
                ds = xr.open_zarr(store, consolidated=True).sel(time=time_slice)
                ds["time"] = pd.date_range(
                    start=ds.indexes["time"][0].strftime("%Y-%m"),
                    periods=ds.dims["time"],
                    freq="MS",
                )
                ds = ds.drop(
                    ["height", "month", "member_id", "lat", "lon"]
                ).squeeze(drop=True)
                list_of_gcm_ds.append(ds)
                list_of_gcm_coords.append(gcm)
            except:
                print(store_url)
        ds = xr.concat(
            list_of_gcm_ds,
            dim=xr.Variable("gcm", list_of_gcm_coords),
            coords="minimal",
            compat="override",
        )
        list_of_scenario_ds.append(ds)
        list_of_scenario_coords.append(scenario)

    ds = xr.concat(
        list_of_scenario_ds,
        dim=xr.Variable("scenario", list_of_scenario_coords),
        coords="minimal",
        compat="override",
    )

    return ds

First we open up each of the files into a list of datasets


In [None]:
future_ds = create_data_cube(
    models,
    [
        ("ScenarioMIP", "ssp245", slice("2015", "2100")),
        ("ScenarioMIP", "ssp370", slice("2015", "2100")),
        ("ScenarioMIP", "ssp585", slice("2015", "2100")),
    ],
)

In [None]:
historical_ds = create_data_cube(
    models, [("CMIP", "historical", slice("1950", "2015"))]
)

And now let's mask out a few special regions of interest to get at regional
responses. For now we'll just use some simple lat/lon bounding boxes but we
could always expand to shapefiles for more refined analyses (e.g.
states/counties/forest project regions).


In [None]:
region_bounding_boxes = {
    "Pacific Northwest": {"lat": [41, 49], "lon": [-130, -110]},
    "West": {"lat": [20, 49], "lon": [-130, -105]},
    "Northeast": {"lat": [41, 48], "lon": [-93, -66]},
    "Southeast": {"lat": [25, 37], "lon": [-93, -76]},
}

In [None]:
def create_mask(lat_lon_bounds_dict, ds):
    """
    This will create a mask that aligns with your ds of interest. Requires ds to have coordinates of
    lat and lon and also 'tasmax' as a hacky thing.
    """
    lat_bounds, lon_bounds = (
        lat_lon_bounds_dict["lat"],
        lat_lon_bounds_dict["lon"],
    )
    ds = (
        ds.where(ds.lat > lat_bounds[0])
        .where(ds.lat < lat_bounds[1])
        .where(ds.lon > lon_bounds[0])
        .where(ds.lon < lon_bounds[1])
    )
    # hacky
    mask = (
        (
            ds["tasmax"]
            .where(ds.lat > lat_bounds[0])
            .where(ds.lat < lat_bounds[1])
            .where(ds.lon > lon_bounds[0])
            .where(ds.lon < lon_bounds[1])
            .isel(time=0)
            > 0
        )
        .drop(["member_id", "month", "lat", "lon", "time", "height"])
        .squeeze()
    )

    return mask

In [None]:
def zarr_on_azure(filename, mode, ds=None, bucket="carbonplan-scratch"):
    """This function either writes a ds to the desired bucket or
    reads from a zarr file into a ds"""
    store = zarr.storage.ABSStore(
        bucket,
        prefix=filename,
        account_name="carbonplan",
        account_key=os.environ["BLOB_ACCOUNT_KEY"],
    )
    if mode == "w":
        ds.to_zarr(store, consolidated=True, mode="w")
    elif mode == "r":
        ds = xr.open_zarr(store, consolidated=True)
        return ds

In [None]:
recreate = False
if recreate:
    ds_dict = {"historical": historical_ds, "future": future_ds}
    for (period, period_ds) in ds_dict.items():
        for scenario in list(period_ds["scenario"].values):
            # breaking up scenario analyses to get around memory constraints
            list_of_region_ds, list_of_region_coords = [], []
            for (region, bounding_box) in tqdm(region_bounding_boxes.items()):
                # we use the sample_ds we opened above to give it a sense of the array shape we're working with
                mask = create_mask(region_bounding_boxes[region], sample_ds)
                masked_ds = (
                    period_ds.sel(scenario=scenario)
                    .where(mask)
                    .mean(dim=["x", "y"])
                    .load()
                )
                list_of_region_ds.append(masked_ds)
                list_of_region_coords.append(region)
            # then concatenate all of your masked things together
            region_ds = xr.concat(
                list_of_region_ds,
                dim=xr.Variable("region", list_of_region_coords),
                coords="minimal",
                compat="override",
            )
            print(
                "here we go! dask dask dask run run run for {}".format(scenario)
            )
            region_ds = region_ds.load()
            zarr_on_azure(
                region_ds,
                "w",
                "climate/regional_summaries_{}.zarr".format(scenario),
                bucket="carbonplan-scratch",
            )

In [None]:
scenarios = ["ssp245", "ssp370", "ssp585"]
ds_list = []

for scenario in scenarios:
    ds = zarr_on_azure(
        "climate/regional_summaries_{}.zarr".format(scenario),
        "r",
        bucket="carbonplan-scratch",
    )
    ds_list.append(ds)

future_ds = xr.concat(
    ds_list,
    dim=xr.Variable("scenario", scenarios),
    coords="minimal",
    compat="override",
)

historical_ds = (
    zarr_on_azure(
        "climate/regional_summaries_{}.zarr".format("historical"),
        "r",
        bucket="carbonplan-scratch",
    )
    .sel(scenario="historical")
    .drop("scenario")
)

In [None]:
decadal_dict = {
    "2020": {"time_slice": slice("2015", "2024")},
    "2050": {"time_slice": slice("2045", "2054")},
    "2080": {"time_slice": slice("2075", "2084")},
}

In [None]:
def temporal_summary(ds, aggregator, time_slice):
    if aggregator == "sum":
        return (
            ds.sel(time=time_slice).groupby("time.year").sum().mean(dim="year")
        )
    elif aggregator == "mean":
        return (
            ds.sel(time=time_slice).groupby("time.year").mean().mean(dim="year")
        )

In [None]:
def calculate_delta(historical, future, method):
    if method == "absolute":
        return future - historical
    if method == "percentage":
        return (future - historical) / historical * 100

In [None]:
for (period, time_dict) in decadal_dict.items():
    historical_p = temporal_summary(
        historical_ds["pr"], "sum", slice("1985", "2014")
    )
    historical_tasmax = temporal_summary(
        historical_ds["tasmax"], "mean", slice("1985", "2014")
    )
    historical_tasmin = temporal_summary(
        historical_ds["tasmin"], "mean", slice("1985", "2014")
    )

    future_p = temporal_summary(future_ds["pr"], "sum", time_dict["time_slice"])
    future_tasmax = temporal_summary(
        future_ds["tasmax"], "mean", time_dict["time_slice"]
    )
    future_tasmin = temporal_summary(
        future_ds["tasmin"], "mean", time_dict["time_slice"]
    )

    decadal_dict[period]["delta_p"] = calculate_delta(
        historical_p, future_p, "percentage"
    )
    decadal_dict[period]["delta_tasmax"] = calculate_delta(
        historical_tasmax, future_tasmax, "absolute"
    )
    decadal_dict[period]["delta_tasmin"] = calculate_delta(
        historical_tasmin, future_tasmin, "absolute"
    )

Choose a nice color palette!


In [None]:
mpl.rcParams["axes.prop_cycle"] = mpl.cycler(color=["teal", "olive", "tomato"])

We have temperature and precip so we'll want to make an x-y scatter plot showing
the changes in climate for each of our different climate simulations. We'll
focus on the downscaled simulations since that's what is actually being fed into
the subsequent drought/insect/fire models. While repeating these analyses for
the raw vs. downscaled datasets would also be relevant, ideally the
downscaling/bias-correction method should preserve the precip/temp deltas and so
the difference between raw and downscaled deltas should be negligble.


In [None]:
for region in ["Pacific Northwest", "West", "Northeast", "Southeast"]:
    fig, axarr = plt.subplots(ncols=3, figsize=(14, 5))
    for (i, decade) in enumerate(["2020", "2050", "2080"]):
        for scenario in scenarios:
            ds_of_interest = decadal_dict[decade]

            axarr[i].scatter(
                x=ds_of_interest["delta_tasmax"]
                .sel(scenario=scenario, region=region)
                .values,
                y=ds_of_interest["delta_p"]
                .sel(scenario=scenario, region=region)
                .values,
                label=scenario,
            )
            for (gcm_num, gcm) in enumerate(
                ds_of_interest["delta_p"]["gcm"].values
            ):
                axarr[i].annotate(
                    str(gcm_num + 1),
                    (
                        ds_of_interest["delta_tasmax"]
                        .sel(scenario=scenario, region=region, gcm=gcm)
                        .values,
                        ds_of_interest["delta_p"]
                        .sel(scenario=scenario, region=region, gcm=gcm)
                        .values,
                    ),
                )

        axarr[i].axhline(y=0, color="grey", lw=0.5)
        axarr[i].axvline(x=0, color="grey", lw=0.5)
        axarr[i].set_xlim(-1, 7)
        axarr[i].set_ylim(-10, 20)
        axarr[i].set_xlabel("Change in temperature [$^\circ$C]")
        axarr[i].set_ylabel("Change in precipitation [%]")
        axarr[i].set_title(decade, fontsize=16)
        plt.legend()
        plt.suptitle(region, fontsize=20)

### TODO

- region bounds map
- 2015-2024, 2035-2044
- gcm codes
- show a little demo to talk about natural variability
