# Case Study


This notebook is indended to evaluate the FMC predictions in the proceeding days before a large fire event. 

Details:
- Alexander Mountain Fire
- FMC forecasts for 72 hour period preceeding the fire to see if model captured pre-fire danger
- During actual fire, coupled atmosphere-fire dynamics would be needed, outside scope of paper

## Setup

In [None]:
import os
import os.path as osp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import mean_squared_error
import sys
import requests
# import shapely.geometry as geom
# import shapely.ops as ops

sys.path.append('../src')
from utils import time_range, read_yml, read_pkl, str2time
from viz import plot_styles, plot_one

### Read Results


In [None]:
ml_forecast_dir = "../outputs/forecast_outputs"
rnn = pd.read_csv(osp.join(ml_forecast_dir, "rnn_preds.csv"))

In [None]:
ml_data = read_pkl(osp.join(ml_forecast_dir, "ml_data.pkl"))

## Define study domain

Time: __
- Source: __

Spatial Domain: __
- Souce: __

In [None]:
# Code below shows Alexander Mountain as largest fire in 2024

url = (
    "https://services3.arcgis.com/T4QMspbfLg3qTGWY/ArcGIS/rest/services/"
    "InterAgencyFirePerimeterHistory_All_Years_View/FeatureServer/0/query"
)

where = (
    "FIRE_YEAR_INT = 2024 AND "
    "UNIT_ID LIKE 'CO%'"        # restrict to Colorado
)

params = {
    "where": where,
    "outFields": "*",
    "returnGeometry": "true",
    "outSR": 4326,
    "orderByFields": "GIS_ACRES DESC",   # sort largest â†’ smallest
    "resultRecordCount": 50,             # return top 50 if you want
    "f": "json",
}

resp = requests.get(url, params=params)
resp.raise_for_status()
data = resp.json()

features = data.get("features", [])
rows = [f["attributes"] for f in features]
df = pd.DataFrame(rows)

print(df[["INCIDENT", "FIRE_YEAR_INT", "UNIT_ID", "GIS_ACRES"]].head(10))

In [None]:
url = (
    "https://services3.arcgis.com/T4QMspbfLg3qTGWY/ArcGIS/rest/services/"
    "InterAgencyFirePerimeterHistory_All_Years_View/FeatureServer/0/query"
)
where = (
    "FIRE_YEAR_INT = 2024 AND "
    "INCIDENT = 'Alexander Mountain' AND "
    "UNIT_ID = 'COARF'"
)
params = {
    "where": where,
    "outFields": "*",
    "returnGeometry": "true",   # perimeter polygon
    "outSR": 4326,              # lat/lon, optional but convenient
    "f": "json",
}

resp = requests.get(url, params=params)
resp.raise_for_status()
data = resp.json()

In [None]:
# Get a lon/lat bounding box
# Given the geometry, loop over elements and get 
# min/max values to product (s, w, n, e)

geom = data["features"][0]["geometry"]["rings"]
all_points = [pt for ring in geom for pt in ring]

lons = [p[0] for p in all_points]
lats = [p[1] for p in all_points]

south = min(lats)
north = max(lats)
west  = min(lons)
east  = max(lons)

bbox_fire = (south, west, north, east)
print(f"Fire Bounding Box: {bbox_fire}")

# Buffer
buff = 0.5
bbox = (south-buff*.67, west-buff, north+buff*.67, east+buff)

print(f"Buffered Bounding Box: {bbox}")

In [None]:
# define as mountain times, convert to UTC
firestart = pd.to_datetime("2024-07-28 00:00:00").tz_localize("America/Denver")
print(f"Fire Start time (Mountain): {firestart}")

fire_utc = firestart.tz_convert("UTC")
analysis_start_utc = fire_utc - pd.Timedelta(hours=72)
analysis_end_utc = fire_utc- pd.Timedelta(hours=1)
times = time_range(analysis_start_utc, fire_utc- pd.Timedelta(hours=1), freq="1h")
print(f"Fire Start (UTC): {fire_utc}")
print(f"Analysis Start (UTC): {analysis_start_utc}")
print(f"Analysis End (UTC): {analysis_end_utc}")
print(f"{len(times)=}")

## Summarize Coverage

In [None]:
sts = []
for st in ml_data:
    loc = ml_data[st]["loc"]
    in_bbox = (
        (loc["lat"] > bbox[0]) & (loc["lat"] < bbox[2]) &
        (loc["lon"] > bbox[1]) & (loc["lon"] < bbox[3])
    )
    if in_bbox: 
        sts.append({
            'stid': st,
            'lon': loc["lon"],
            'lat': loc["lat"],
            'elev': loc['elev']
        })
df = pd.DataFrame(sts)


In [None]:
# Filter to those stations with forecasts time region
rnn["date_time"] = pd.to_datetime(rnn.date_time)
rnn2 = rnn[(rnn.stid.isin(df.stid)) & (rnn.date_time >= analysis_start_utc)& (rnn.date_time <= analysis_end_utc)]
df = df[df.stid.isin(rnn2.stid)].reset_index(drop=True)

print(f"Number of Stations in study region: {df.shape[0]}")

In [None]:
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
from cartopy.io.img_tiles import StadiaMapsTiles

tile_provider = StadiaMapsTiles(
    "e3df6cd5-1ba5-4749-8587-f79893428032",
    style="stamen_terrain"
)

def plot_points(df, zoom=9):
    import matplotlib.pyplot as plt
    import cartopy.crs as ccrs
    from cartopy.io.img_tiles import OSM, Stamen

    proj = ccrs.PlateCarree()

    fig, ax = plt.subplots(
        figsize=(8, 6),
        subplot_kw={"projection": proj},
    )

    # Bounding box from data + padding
    lon_min, lon_max = df["lon"].min(), df["lon"].max()
    lat_min, lat_max = df["lat"].min(), df["lat"].max()
    pad_lon = (lon_max - lon_min) * 0.1 or 0.5
    pad_lat = (lat_max - lat_min) * 0.1 or 0.5

    ax.set_extent(
        [lon_min - pad_lon, lon_max + pad_lon,
         lat_min - pad_lat, lat_max + pad_lat],
        crs=proj,
    )

    tile_provider = StadiaMapsTiles(
        "e3df6cd5-1ba5-4749-8587-f79893428032",
        style="stamen_terrain"
    )    
    ax.add_image(tile_provider, zoom)

    # Scatter the points
    ax.scatter(
        df["lon"],
        df["lat"],
        s=30,
        transform=proj,
        edgecolor="black",   
        linewidth=0.8,
        color="cyan"
    )

    gl = ax.gridlines(
        draw_labels=True,
        x_inline=False,
        y_inline=False
    )
    
    gl.top_labels = False
    gl.right_labels = False
    gl.xlines = False
    gl.ylines = False
    
    return fig, ax

In [None]:
from shapely.geometry import Polygon

def plot_fire(ax, geom, **kwargs):
    """
    ax   : the Cartopy GeoAxes already created by plot_points()
    geom : ESRI JSON 'rings' list (outer ring, optional inner rings)
    kwargs : passed to ax.add_geometries() (facecolor, edgecolor, alpha, etc.)
    """

    # Convert rings -> shapely polygon
    # geom is list of rings: [outer, inner1, inner2, ...]
    outer = geom[0]
    inners = geom[1:] if len(geom) > 1 else None

    poly = Polygon(outer, holes=inners)

    # Default styling if not provided
    if "facecolor" not in kwargs:
        kwargs["facecolor"] = "red"
    if "edgecolor" not in kwargs:
        kwargs["edgecolor"] = "black"
    if "alpha" not in kwargs:
        kwargs["alpha"] = 0.6

    ax.add_geometries(
        [poly],
        crs=ccrs.PlateCarree(),
        **kwargs
    )

    return ax

In [None]:
# fig, ax = plot_points(df)
# plot_fire(ax, geom)

# plt.show()
# plt.savefig("../outputs/alexander_map.png", dpi=300, bbox_inches="tight")

In [None]:
# Print Data Coverage Summary
print(f"Number of Stations in study region: {df.shape[0]}")

In [None]:
nreps = []
nhours = []
for st in df.stid:
    # print("~"*50)
    # print(st)
    preds = rnn2[(rnn2.stid == st)]
    nreps.append(preds.rep.unique().shape[0])
    nhours.append(preds[preds.rep == preds.rep.unique()[0]].shape[0])
    # print(f"Number of Reps: {preds.rep.unique().shape[0]}")
    # print(f"Number of hours: {preds[preds.rep == preds.rep.unique()[0]].shape[0]}")


# Add to station df and diplay
df_nice = df.copy()
df_nice["Forecasted Hours"] = nhours
df_nice.columns = ["STID", "Longitude", "Latitude", "Elevation", "Forecasted Hours"]

df_nice

## Join all weather

For stations, get weather time series from HRRR forecasts at those locations

In [None]:
sts = df.stid.to_list()

from data_funcs import get_sts_and_times

dat = get_sts_and_times(ml_data, sts, times, data_dict = 'data')

## Analyze Accuracy

Each station has a set of forecasts for the time period with a number of replications that reflect uncertainty due to random weight initialization and train/test split. For this analysis we will look at the distribution of error for this set of predictions. This will be a much wider uncertainty than the +/- bounds from the overall error, as the overall error was averaged over replications. This analyzes the spread of the ~53 set of forecasts without averaging. NOTE: the hidden state was reset every 48 hours for forecasting, but we will ignore that here. Maintaining the same hidden state over the whole 72 hour period of interest might improve the forecast accuracy, so this is a cautious approach.

For analysis we will compare:
* The distribution of forecast RMSE (median, high/low, and 95% ci) for each station
* We will plot the forecast with the median error along with the 95% ci 

In [None]:
rmse = (
    rnn2.groupby(["stid", "rep"])["squared_error"]
       .mean()
       .pow(0.5)
)

In [None]:
import math

stids = rmse.index.get_level_values("stid").unique()
n = len(stids)

ncols = 2
nrows = math.ceil(n / ncols)

fig, axes = plt.subplots(
    nrows=nrows, ncols=ncols,
    figsize=(10, 3*nrows),
    sharex=True, sharey=True
)

axes = axes.flatten()

# Compute common x-limits and ticks
all_vals = rmse.values
xmin, xmax = all_vals.min() - .2, all_vals.max() + .2

for i, stid in enumerate(stids):
    ax = axes[i]
    vals = rmse.xs(stid, level="stid")
    ax.hist(vals, edgecolor="black")
    ax.set_title(stid)
    ax.grid(True)
    ax.set_ylabel("Frequency")
    ax.tick_params(labelleft=True)
    ax.set_xlim(xmin, xmax)

# Remove any unused axes
for j in range(i + 1, len(axes)):
    fig.delaxes(axes[j])

# Set shared x-label and ticks on all remaining axes
for ax in fig.axes:
    ax.set_xlabel("RMSE (%)")
    ax.tick_params(labelbottom=True)

plt.tight_layout()
plt.savefig("../outputs/case_study_rmse_hist.png")

In [None]:
# Get median RMSE over rep

rmse_per = (
    rnn2.groupby(["stid", "rep"])["squared_error"]
        .mean()
        .pow(0.5)
        .rename("rmse")
        .reset_index()
)

rep_choice = (
    rmse_per
    .assign(
        abs_diff=lambda d: d.groupby("stid")["rmse"].transform(
            lambda x: (x - x.median()).abs()
        )
    )
    .sort_values(["stid", "abs_diff"])
    .drop_duplicates("stid")
    .set_index("stid")["rep"]
)
# Overall RMSE Bounds
rmse_bounds = (
    rmse_per.groupby("stid")["rmse"]
            .agg(rmse_low="min", rmse_high="max")
)
rmse_mid = (
    rmse_per.groupby("stid")["rmse"]
            .agg(rmse_mid="median")
)

# Worse RMSE rep
rep_worst = (
    rmse_per
    .sort_values(["stid", "rmse"], ascending=[True, False])
    .drop_duplicates("stid")
    .set_index("stid")["rep"]
)

In [None]:
subset = rnn2[(rnn2.fm >= 0) & (rnn2.fm < 10)].copy()

rmse_0_10 = (
    subset.assign(
        error = subset.fm - subset.preds,
        sqerr = (subset.fm - subset.preds)**2
    )
    .groupby("stid")
    .agg(
        rmse_010=("sqerr", lambda x: (x.mean())**0.5)
    )
)

rmse_0_10

In [None]:
# Add to station df and diplay
df_nice = df.copy()
df_nice["Forecasted Hours"] = nhours
df_nice.columns = ["STID", "Longitude", "Latitude", "Elevation", "N. Hours"]
df_nice = pd.merge(df_nice, rmse_mid, left_on="STID", right_on="stid")
df_nice = pd.merge(df_nice, rmse_bounds, left_on="STID", right_on="stid")
df_nice = pd.merge(df_nice, rmse_0_10, left_on="STID", right_on="stid")
df_nice = df_nice.rename(columns={"rmse_mid": "Median RMSE", "rmse_low":"Min. RMSE", "rmse_high" : "Max. RMSE",  "rmse_010":"RMSE (0-10 FMC)"})
df_nice[["Elevation", "Median RMSE", "Min. RMSE", "Max. RMSE", "RMSE (0-10 FMC)"]] = df_nice[["Elevation", "Median RMSE", "Min. RMSE", "Max. RMSE", "RMSE (0-10 FMC)"]].round(2)
df_nice = df_nice.sort_values("STID")

In [None]:
df_nice[["STID", "Elevation", "Longitude", "Latitude"]]

In [None]:
print(df_nice[["STID", "Elevation", "Longitude", "Latitude"]].to_latex())

In [None]:
df_nice[["STID", "N. Hours", "Median RMSE", "Min. RMSE", "Max. RMSE", "RMSE (0-10 FMC)"]]

In [None]:
print(df_nice[["STID", "N. Hours", "Median RMSE", "Min. RMSE", "Max. RMSE", "RMSE (0-10 FMC)"]].to_latex())

In [None]:
import matplotlib.pyplot as plt
import math

stids = rmse.index.get_level_values("stid").unique()
n = len(stids)

ncols = 2
nrows = math.ceil(n / ncols)

fig, axes = plt.subplots(
    nrows=nrows, ncols=ncols,
    figsize=(10, 3*nrows),
    sharex=True, sharey=True
)

axes = axes.flatten()

for i, stid in enumerate(stids):
    ax = axes[i]
    rep_med = rep_choice.loc[stid]
    preds = rnn2[(rnn2.stid == stid) & (rnn2.rep == rep_med)]

    ax.plot(preds.date_time, preds.fm, **plot_styles["fm"])
    ax.plot(preds.date_time, preds.preds, color="k", alpha=.7, label="RNN Forecast")
    
    ax.grid(True)
    ax.set_ylabel("FMC (%)", fontsize=13)
    ax.set_title(stid, fontsize=14)
    ax.tick_params(labelbottom=True, labelleft=True)
    ax.tick_params(axis="x", labelrotation=45)
    ax.tick_params(axis="both", labelsize=12)

    


# Remove unused axes
for j in range(i + 1, len(axes)):
    fig.delaxes(axes[j])



positions_to_hide = [0, 1]   
for i, ax in enumerate(axes):
    if i in positions_to_hide:
        ax.tick_params(labelbottom=False)

axes[1].legend(loc="center left", bbox_to_anchor=(1, 0.5))

plt.tight_layout()

In [None]:
# stids = rmse.index.get_level_values("stid").unique()
# n = len(stids)

# st = stids[0]
# d = dat[st]["data"]

# ncols = 1
# nrows = 2

# fig, axes = plt.subplots(
#     nrows=nrows, ncols=ncols,
#     figsize=(10, 3*nrows),
#     sharex=True, sharey=False
# )

# ax = axes[0]
# rep_med = rep_choice.loc[st]
# preds = rnn2[(rnn2.stid == st) & (rnn2.rep == rep_med)]
# ax.plot(preds.date_time, preds.fm, **plot_styles["fm"])
# ax.plot(preds.date_time, preds.preds, color="k", alpha=.7, label="RNN Forecast")
# ax.grid(True)
# ax.set_ylabel("FMC (%)", fontsize=13)
# ax.set_title(stid, fontsize=14)
# ax.tick_params(axis="x", labelrotation=45)
# ax.tick_params(axis="both", labelsize=12)
# ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))

# ax = axes[1]
# ax.grid(True)
# # Left y-axis (Rain)
# ax.plot(d.date_time, d.rain, color="blue", label="Rain")
# ax.set_ylabel("Rain (mm/hr)")
# # Right y-axis (Wind)
# ax2 = ax.twinx()
# ax2.plot(d.date_time, d.wind, color="gray", label="Wind")
# ax2.set_ylabel(r"Wind Speed ($\text{m\,s}^{-1}$)")
# # Legend (combined)
# lines = ax.get_lines() + ax2.get_lines()
# labels = [l.get_label() for l in lines]
# ax.legend(lines, labels, loc="center left", bbox_to_anchor=(1.1, 0.5))
# idx = np.r_[np.arange(0, len(times), 12), len(times) - 1]
# xlabels = times[idx]
# xlabels = pd.to_datetime(xlabels).tz_convert("US/Mountain")
# ax.set_xticks(xlabels)
# ax.set_xticklabels(
#     xlabels.strftime("%Y-%m-%d\n%H:%M"),
#     rotation=45
# )

# plt.tight_layout()

In [None]:
plt.rcParams.update({
    "axes.labelsize": 12,
    "xtick.labelsize": 12,
    "ytick.labelsize": 12,
    "axes.titlesize": 14,
    "legend.fontsize": 12,
    "legend.title_fontsize": 12,
})

def plot_station(st):
    d = dat[st]["data"]

    ncols = 1
    nrows = 2

    fig, axes = plt.subplots(
        nrows=nrows, ncols=ncols,
        figsize=(10, 3*nrows),
        sharex=True, sharey=False
    )

    # Top plot: FMC + RNN forecast
    ax = axes[0]
    rep_med = rep_choice.loc[st]
    preds = rnn2[(rnn2.stid == st) & (rnn2.rep == rep_med)]

    ax.plot(preds.date_time, preds.fm, **plot_styles["fm"])
    ax.plot(preds.date_time, preds.preds, color="k", alpha=.7, label="RNN Forecast")
    ax.grid(True)
    ax.set_ylabel("FMC (%)")
    ax.set_ylim(0, 22)  
    ax.set_title(st, fontsize=14)
    ax.tick_params(axis="x", labelrotation=45)
    ax.tick_params(axis="both")
    ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))

    # Bottom plot: Rain + Wind (dual y-axis)
    ax = axes[1]
    ax.grid(True)

    # Left y-axis (Rain)
    ax.plot(d.date_time, d.rain, color="blue", label="Rain")
    ax.set_ylabel(r"Rain ($\text{mm h}^{-1}$)")
    ax.set_ylim(0, 4)

    # Right y-axis (Wind)
    ax2 = ax.twinx()
    ax2.plot(d.date_time, d.wind, color="gray", label="Wind")
    ax2.set_ylabel(r"Wind Speed ($\text{m s}^{-1}$)")
    ax2.set_ylim(0, 20)

    # Legend (combined)
    lines = ax.get_lines() + ax2.get_lines()
    labels = [l.get_label() for l in lines]
    ax.legend(lines, labels, loc="center left", bbox_to_anchor=(1.1, 0.5))
    
    idx = np.r_[np.arange(0, len(times), 12), len(times) - 1]
    xlabels = times[idx]
    xlabels = pd.to_datetime(xlabels).tz_convert("US/Mountain")
    ax.set_xticks(xlabels)
    ax.set_xticklabels(
        xlabels.strftime("%Y-%m-%d\n%H:%M"),
        rotation=45
    )
    plt.tight_layout()
    return fig, axes


In [None]:
# plot_station(st = df.stid[0])

In [None]:
# plot_station(st = df.stid[1])

In [None]:
plot_station(st = df.stid[2])
plt.savefig("../outputs/case_ts.png")

In [None]:
plot_station(st = df.stid[3])