In [None]:
%load_ext autoreload
%autoreload 2
import xarray as xr
import os
from carbonplan_forest_risks.utils import get_store, albers_ak_transform
import numpy as np
import warnings
from carbonplan_forest_risks.load import terraclim
import pandas as pd
import matplotlib.pyplot as plt

warnings.filterwarnings('ignore')
account_key = os.environ.get('BLOB_ACCOUNT_KEY')

In [None]:
gcms = [
    ("CanESM5-CanOE", "r3i1p2f1"),
    ("MIROC-ES2L", "r1i1p1f2"),
    ("ACCESS-CM2", "r1i1p1f1"),
    ("ACCESS-ESM1-5", "r10i1p1f1"),
    ("MRI-ESM2-0", "r1i1p1f1"),
    ("MPI-ESM1-2-LR", "r10i1p1f1"),
]

In [None]:
v1_template = "cmip6/quantile-mapping/conus/4000m/monthly/{}.{}.{}.zarr"

v2_template = "cmip6/quantile-mapping-v2/conus/4000m/monthly/{}.{}.{}.zarr"

In [None]:
account_key = os.environ.get("BLOB_ACCOUNT_KEY")

In [None]:
models = []
for scenario in ["historical", "ssp245", "ssp370", "ssp585"]:
    for (gcm, ensemble_member) in gcms:
        models.append("{}-{}".format(gcm, scenario))
path = get_store(
    "carbonplan-downscaling",
    zarr_template.format(gcms[0][0], "historical", gcms[0][1]),
    account_key=account_key,
)
ds = xr.open_zarr(path)
df = pd.DataFrame(index=models, columns=ds.data_vars)

In [None]:
for scenario in ["historical", "ssp245", "ssp370", "ssp585"]:
    for (gcm, ensemble_member) in gcms:
        path = get_store(
            "carbonplan-downscaling",
            zarr_template.format(gcm, scenario, ensemble_member),
            account_key=account_key,
        )
        ds = xr.open_zarr(path)
        for var in ds.data_vars:
            df.loc["{}-{}".format(gcm, scenario), var] = ds[var].isnull().sum().values

In [None]:
ds_v1 = xr.open_zarr(
    get_store(
        "carbonplan-downscaling",
        zarr_template.format(gcms[0][0], "historical", gcms[0][1]),
        account_key=account_key,
    )
)

In [None]:
ds_v1.pdsi.mean(dim="time", skipna=False).isnull().plot()

In [None]:
terraclimate_v1 = xr.open_zarr(
    get_store(
        "carbonplan-downscaling",
        "obs/conus/4000m/{}/terraclimate_plus.zarr".format("monthly"),
    )
)

In [None]:
nans_v1 = terraclimate_v1.pdsi.isnull().sum(dim="time").compute()
nans_v1.plot()

In [None]:
terraclimate_v3 = xr.open_zarr(
    get_store(
        "carbonplan-downscaling",
        "obs/conus/4000m/{}/terraclimate_plus_v3.zarr".format("monthly"),
    )
)

In [None]:
nans_v3 = terraclimate_v3.pdsi.isnull().sum(dim="time").compute()
nans_v3.plot()

In [None]:
ds_v2 = xr.open_zarr(
    get_store(
        "carbonplan-downscaling",
        v2_template.format("ACCESS-ESM1-5", "ssp245", "r10i1p1f1"),
        account_key=account_key,
    )
)

In [None]:
ds2.vpd.mean(dim="time", skipna=False).plot()

In [None]:
ds2.pdsi.isnull().sum().values

In [None]:
ds_v1 = xr.open_zarr(
    get_store(
        "carbonplan-downscaling",
        v1_template.format("ACCESS-ESM1-5", "ssp245", "r10i1p1f1"),
        account_key=account_key,
    )
)

In [None]:
ds_v1.vap.min(dim="time", skipna=False).plot()

In [None]:
ds_v1.pdsi.isnull().sum().values

In [None]:
qm_template = "cmip6/quantile-mapping-v3/conus/4000m/monthly/{}.{}.{}.zarr"

terra_template = "cmip6/quantile-mapping-v3/conus/4000m/monthly/{}.{}.{}.zarr"

In [None]:
ds_v3_qm_vars = xr.open_zarr(
    get_store(
        "carbonplan-downscaling",
        qm_template.format("ACCESS-ESM1-5", "ssp245", "r10i1p1f1"),
        account_key=account_key,
    )
)

In [None]:
ds_v3_terra_vars = xr.open_zarr(
    get_store(
        "carbonplan-scratch",
        terra_template.format("ACCESS-ESM1-5", "ssp245", "r10i1p1f1"),
        account_key=account_key,
    )
)

In [None]:
ds_v3 = xr.merge([ds_v3_qm_vars, ds_v3_terra_vars])

In [None]:
# Check for nans

In [None]:
df = pd.DataFrame(index=ds_v3.data_vars, columns=["nulls", "negatives"])

In [None]:
for var in ds_v3.data_vars:
    print(var)
    df.loc[var, "nulls"] = ds_v3[var].isnull().sum().values

In [None]:
# Check aphysical values

In [None]:
for var in ds_v3.data_vars:
    print(var)
    df.loc[var, "negatives"] = (ds_v3[var] < 0).sum().values

In [None]:
df["greater than 1"] = np.nan

In [None]:
for var in ds_v3.data_vars:
    print(var)
    df.loc[var, "greater than 1"] = (ds_v3[var] > 1).sum().values

# All PDSI between -16 and 16


In [None]:
# check outside -16 + 16 for pdsi
(ds_v3["pdsi"] < -16).sum().values

In [None]:
(ds_v3["pdsi"] > 16).sum().values

In [None]:
assert (ds_v3["pdsi"] > 16).sum().values == 0

In [None]:
# Check the minimum value of vapor pressure
ds_v3["vap"].min(dim="time").plot()

In [None]:
ds_v1["vap"].min(dim=["x", "y"]).plot(label="v1")
ds_v3["vap"].min(dim=["x", "y"]).plot(label="v3")
plt.legend()

In [None]:
fig, axarr = plt.subplots(ncols=2, figsize=(8, 4))
ds_v1["rh"].min(dim="time").plot(ax=axarr[0], vmin=0, vmax=0.01)
axarr[0].set_title("v1")
ds_v3["rh"].min(dim="time").plot(ax=axarr[1], vmin=0, vmax=0.01)
axarr[1].set_title("v3")

In [None]:
ds_v1["rh"].min(dim=["x", "y"]).plot(label="v1", alpha=0.8)
ds_v3["rh"].min(dim=["x", "y"]).plot(label="v3", alpha=0.8)
plt.legend()

In [None]:
ds_v1["rh"].max(dim=["x", "y"]).plot(label="v1", alpha=0.8)
ds_v3["rh"].max(dim=["x", "y"]).plot(label="v3", alpha=0.8)
plt.legend()

# count of rh==1 instance


In [None]:
(ds_v1["rh"] == 1).sum(dim=["x", "y"]).plot(label="v1", alpha=0.8)
(ds_v3["rh"] == 1).sum(dim=["x", "y"]).plot(label="v3", alpha=0.8)
plt.legend()

# check the variables bill is using


In [None]:
# confirm that same decades and then can make spatial map of the pdsi for each decade

In [None]:
coarsened_v3 = (
    ds_v3.sel(time=slice("2020", "2099"))[["ppt", "pdsi"]].coarsen(time=120).min().compute()
)
coarsened_v1 = (
    ds_v1.sel(time=slice("2020", "2099"))[["ppt", "pdsi"]].coarsen(time=120).min().compute()
)

# minimum decadal PDSI


In [None]:
(coarsened_v3.pdsi - coarsened_v1.pdsi).plot(
    x="lon", y="lat", col="time", col_wrap=3, vmin=-16, vmax=16, cmap="RdBu"
)

In [None]:
coarsened_v1.pdsi.plot(x="lon", y="lat", col="time", col_wrap=3, vmin=-16, vmax=16, cmap="RdBu")

In [None]:
coarsened_v3.pdsi.plot(x="lon", y="lat", col="time", col_wrap=3, vmin=-16, vmax=16, cmap="RdBu")

# minimum decadal precipitation


In [None]:
(coarsened_v3.ppt - coarsened_v1.ppt).plot(
    x="lon", y="lat", col="time", col_wrap=3, vmin=-20, vmax=20, cmap="RdBu"
)

In [None]:
coarsened_v1.ppt.plot(x="lon", y="lat", col="time", col_wrap=3)

In [None]:
coarsened_v3.ppt.plot(x="lon", y="lat", col="time", col_wrap=3)

# means


In [None]:
coarsened_v3_mean = (
    ds_v3.sel(time=slice("2020", "2099"))[["ppt", "vpd", "pdsi", "tmin", "tmean"]]
    .coarsen(time=120)
    .mean()
    .compute()
)
coarsened_v1_mean = (
    ds_v1.sel(time=slice("2020", "2099"))[["ppt", "vpd", "pdsi", "tmin", "tmean"]]
    .coarsen(time=120)
    .mean()
    .compute()
)

# mean precip


In [None]:
((coarsened_v3_mean.ppt - coarsened_v1_mean.ppt) / coarsened_v1_mean.ppt * 100).plot(
    x="lon", y="lat", col="time", col_wrap=3, vmin=-10, vmax=10, cmap="RdBu"
)

In [None]:
coarsened_v1_mean.ppt.plot(x="lon", y="lat", col="time", col_wrap=3)

In [None]:
coarsened_v3_mean.ppt.plot(x="lon", y="lat", col="time", col_wrap=3)

# mean VPD


In [None]:
(coarsened_v3_mean.vpd - coarsened_v1_mean.vpd).plot(
    x="lon", y="lat", col="time", col_wrap=3, vmin=-0.02, vmax=0.02, cmap="RdBu"
)

# mean PDSI


In [None]:
(coarsened_v3_mean.pdsi - coarsened_v1_mean.pdsi).plot(
    x="lon", y="lat", col="time", col_wrap=3, vmin=-16, vmax=16, cmap="RdBu"
)

# maxes


In [None]:
coarsened_v3_max = (
    ds_v3.sel(time=slice("2020", "2099"))[["def", "vpd"]].coarsen(time=120).max().compute()
)
coarsened_v1_max = (
    ds_v1.sel(time=slice("2020", "2099"))[["def", "vpd"]].coarsen(time=120).max().compute()
)

# cwd max


In [None]:
(coarsened_v3_max["def"] - coarsened_v1_max["def"]).plot(
    x="lon", y="lat", col="time", col_wrap=3, cmap="RdBu"
)

In [None]:
coarsened_v3_max["def"].plot(x="lon", y="lat", col="time", col_wrap=3)

In [None]:
coarsened_v1_max["def"].plot(x="lon", y="lat", col="time", col_wrap=3)

In [None]:
ds_v1.sel(x=500000, y=1000000, method="nearest")["def"].plot(alpha=0.8, label="v1")
ds_v3.sel(x=500000, y=1000000, method="nearest")["def"].plot(alpha=0.8, label="v3")
plt.legend()

In [None]:
ds_v1.sel(x=500000, y=1000000, method="nearest")["pet"].plot(alpha=0.8, label="v1")
ds_v3.sel(x=500000, y=1000000, method="nearest")["pet"].plot(alpha=0.8, label="v3")
plt.legend()

In [None]:
ds_v1.sel(x=500000, y=1000000, method="nearest")["tmean"].plot(alpha=0.8, label="v1")
ds_v3.sel(x=500000, y=1000000, method="nearest")["tmean"].plot(alpha=0.8, label="v3")
plt.legend()

In [None]:
ds_v1.sel(x=500000, y=1000000, method="nearest")["srad"].plot(alpha=0.8, label="v1")
ds_v3.sel(x=500000, y=1000000, method="nearest")["srad"].plot(alpha=0.8, label="v3")
plt.legend()

In [None]:
ds_v1.sel(x=500000, y=1000000, method="nearest")["aet"].plot(alpha=0.8, label="v1")
ds_v3.sel(x=500000, y=1000000, method="nearest")["aet"].plot(alpha=0.8, label="v3")
plt.legend()

In [None]:
fig, axarr = plt.subplots(nrows=len(ds_v3.data_vars), figsize=(8, 50))
for i, var in enumerate(ds_v3.data_vars):
    ds_v1.sel(x=500000, y=1000000, method="nearest")[var].plot(ax=axarr[i], alpha=0.8, label="v1")
    ds_v3.sel(x=500000, y=1000000, method="nearest")[var].plot(ax=axarr[i], alpha=0.8, label="v3")
    axarr[i].legend()
plt.tight_layout()