In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import scipy as sp
from carbonplan_forest_risks import load, setup, plot, fit, utils, prepare, collect
import xarray as xr
from carbonplan_forest_risks.utils import get_store
import altair as alt
import rioxarray
from carbonplan.data import cat
from carbonplan_styles.mpl import get_colormap
import cartopy.crs as ccrs
import cartopy
import cartopy.feature as cfeature
import matplotlib.pyplot as plt
import matplotlib
from mpl_toolkits.axes_grid1 import make_axes_locatable

from carbonplan_data import utils
alt.data_transformers.disable_max_rows()

In [None]:
coarsen = 4
store = "az"
tlim = ("1984", "2018")

In [None]:
ds_dict = {"Observed": {}, "Modeled": {}}

In [None]:
ds_dict["Modeled"]["raw"] = xr.open_zarr(
    get_store(
        "carbonplan-forests", "risks/results/paper/fire_terraclimate_v6.zarr"
    )
)

In [None]:
fire_mask = ~np.isnan(ds_dict["Modeled"]["raw"].isel(time=0).drop("time"))

In [None]:
forest_mask = (
    load.nlcd(store=store, year=2001).sel(band=[41, 42, 43, 90]).sum("band")
    > 0.5
).astype("float")

In [None]:
ds_dict["Observed"]["raw"] = load.mtbs(
    store=store, coarsen=coarsen, tlim=tlim, mask=forest_mask
)
ds_dict["Observed"]["raw"] = (
    ds_dict["Observed"]["raw"]
    .assign_coords(
        {"x": ds_dict["Modeled"]["raw"].x, "y": ds_dict["Modeled"]["raw"].y}
    )
    .assign_coords(
        {
            "lat": ds_dict["Modeled"]["raw"].lat,
            "lon": ds_dict["Modeled"]["raw"].lon,
        }
    )["monthly"]
)

In [None]:
for setup in ["Observed", "Modeled"]:
    ds_dict[setup]["annual"] = (
        ds_dict[setup]["raw"]
        .groupby("time.year")
        .sum()
        .where(fire_mask)
        .mean(dim=["x", "y"])
        .compute()
    )
    ds_dict[setup]["seasonal"] = (
        ds_dict[setup]["raw"]
        .groupby("time.month")
        .mean()
        .where(fire_mask)
        .mean(dim=["x", "y"])
        .compute()
    )

In [None]:
plot_params = {
    "annual": {"y_label": "Annual burn area\n(fraction/year)"},
    "seasonal": {"y_label": "Monthly burn area\n(fraction/month)"},
    "colors": {"Modeled": "#E87A3D", "Observed": "grey"},
}

In [None]:
matplotlib.rc("font", family="sans-serif")
matplotlib.rc("font", serif="Helvetica Neue")
matplotlib.rc("text", usetex="false")
matplotlib.rcParams.update({"font.size": 14, "svg.fonttype": "none"})

In [None]:
fig, axarr = plt.subplots(nrows=2, figsize=(8, 8))
for setup in ["Observed", "Modeled"]:
    ds_dict[setup]["annual"].historical.plot(
        ax=axarr[0], color=plot_params["colors"][setup], label=setup
    )
    axarr[0].set_ylabel(plot_params["annual"]["y_label"])
    ds_dict[setup]["seasonal"].historical.plot(
        ax=axarr[1], color=plot_params["colors"][setup], label=setup
    )
    axarr[1].set_ylabel(plot_params["seasonal"]["y_label"])
axarr[0].set_xlabel("")
axarr[1].set_xlabel("")
axarr[0].legend()
axarr[1].set_xticks(np.arange(1, 13))
axarr[1].set_xticklabels(
    [
        "Jan",
        "Feb",
        "Mar",
        "Apr",
        "May",
        "Jun",
        "Jul",
        "Aug",
        "Sep",
        "Oct",
        "Nov",
        "Dec",
    ]
)
plt.savefig("supp2.svg", format="svg", bbox_inches="tight")