In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import scipy as sp
from carbonplan_forest_risks import load, setup, plot, fit, utils, prepare, collect
import xarray as xr
from carbonplan_forest_risks.utils import get_store
import altair as alt
alt.themes.enable("carbonplan_light")

In [None]:
alt.data_transformers.disable_max_rows()

In [None]:
coarsen = 4
mask = (
    (
        load.nlcd(store="az", year=2001).sel(band=[41, 42, 43, 90]).sum("band")
        > 0.25
    )
    .astype("float")
    .coarsen(x=coarsen, y=coarsen, boundary="trim")
    .mean()
)

In [None]:
historical_fire = xr.open_zarr(
    get_store("carbonplan-scratch", "data/fire_historical_v3.zarr")
)
fire_mask = ~np.isnan(historical_fire.historical.isel(time=0).drop("time"))

In [None]:
ds = (
    xr.open_zarr(get_store("carbonplan-scratch", "data/fire_future_v3.zarr"))
    .assign_coords({"x": mask.x, "y": mask.y})
    .where(fire_mask)
    .groupby("time.year")
    .sum()
    .where(fire_mask)
    .compute()
)

In [None]:
ds_remaining = (
    xr.open_zarr(
        get_store("carbonplan-scratch", "data/fire_future_v3_remaining.zarr")
    )
    .assign_coords({"x": mask.x, "y": mask.y})
    .groupby("time.year")
    .sum()
    .where(fire_mask)
    .compute()
)

In [None]:
ds = xr.merge([ds, ds_remaining])

In [None]:
future_maps = ds.sel(year=slice("2060", "2089")).mean(dim="year").compute()
future_ts = ds.mean(dim=["x", "y"]).compute()

In [None]:
gcms = [
    ("MRI-ESM2-0", (0, 0)),
    ("MIROC-ES2L", (1, 0)),
    ("MPI-ESM1-2-LR", (2, 0)),
    ("ACCESS-ESM1-5", (3, 0)),
    ("ACCESS-CM2", (4, 0)),
    ("CanESM5-CanOE", (5, 0)),
]
titles = [
    "Burn area\n[fraction/year]",
    "Drought mortality\n[]",
    "Biotic agent mortality\n[]",
]

In [None]:
scenario_dict = {}
future_maps = ds.sel(year=slice("2060", "2089")).mean(dim="year").compute()
for scenario in ["ssp245", "ssp370", "ssp585"]:
    scenario_dict[scenario] = (
        future_maps[[gcm + "_" + scenario for gcm, _x in gcms]]
        .to_array(dim="vars")
        .mean(dim="vars")
    )

In [None]:
figure = alt.hconcat()
for scenario in ["ssp245", "ssp370", "ssp585"]:
    fire = plot.fire.simple_map(
        scenario_dict[scenario],
        clim=(0.001, 0.03),
        clabel="Burn area (fraction/year)",
        cmap=["#FFC59E", "#C4550A"],
        title1=scenario,
    )
    figure |= fire
figure