In [12]:
import os
import joblib
import pandas as pd
import numpy as np
import geopandas as gpd
from shapely.geometry import Point
import folium
from branca.colormap import linear
import ipywidgets as widgets
from IPython.display import display, HTML

Define paths

In [13]:
dataset_path = "../data/outputs/combined_dataset.parquet"
rf_path = "../data/models/rf_model.pkl"
lr_path = "../data/models/lr_model.pkl"
cnn_path = "../data/models/cnn_model.pkl"   # use .h5 if applicable, otherwise point to model .pkl
out_dir = "../data/maps"
os.makedirs(out_dir, exist_ok=True)

SAMPLE_SIZE_DEFAULT = 5000
FEATURES = ["dNBR", "SPI", "VCI", "NDVI"]

Load list of fires

In [14]:
print("Reading fire list (lightweight)...")
fire_names = pd.read_parquet(dataset_path, columns=["fire_name"], engine="pyarrow")["fire_name"].unique().tolist()
fire_names = sorted([str(f) for f in fire_names])
print(f"Found {len(fire_names)} fire events.")

Reading fire list (lightweight)...
Found 9 fire events.


Load models

In [15]:
if os.path.exists(rf_path):
    rf_model = joblib.load(rf_path)
    print("Loaded RF model:", rf_path)
else:
    print("RF model not found at:", rf_path)

if os.path.exists(lr_path):
    lr_model = joblib.load(lr_path)
    print("Loaded LR model:", lr_path)
else:
    print("LR model not found at:", lr_path)
    
if os.path.exists(cnn_path):
    cnn_model = joblib.load(cnn_path)
    print("Loaded CNN model:", cnn_path)
else:
    print("CNN model not found at:", cnn_path)

Loaded RF model: ../data/models/rf_model.pkl
Loaded LR model: ../data/models/lr_model.pkl
Loaded CNN model: ../data/models/cnn_model.pkl


Predict Probabilities

In [16]:
def safe_predict_proba(model, X):
    """Return max-class or class-1 probabilities."""
    if model is None:
        return None
    if hasattr(model, "predict_proba"):
        p = model.predict_proba(X)
        if p.ndim == 2 and p.shape[1] == 2:
            return p[:, 1]
        return p.max(axis=1)
    elif hasattr(model, "predict"):
        preds = np.asarray(model.predict(X))
        if preds.ndim == 2:
            return preds.max(axis=1)
        return preds
    else:
        raise ValueError("Model does not support predict_proba or predict.")

UI Setup

In [17]:
fire_dropdown = widgets.Dropdown(options=fire_names, description="Fire:")
sample_slider = widgets.IntSlider(value=SAMPLE_SIZE_DEFAULT, min=500, max=20000, step=500,
                                  description="Sample size:", continuous_update=False)
generate_btn = widgets.Button(description="Generate Map", button_style="primary")

controls = widgets.HBox([fire_dropdown, sample_slider, generate_btn])
display(controls)

output_html_widget = widgets.Output()
display(output_html_widget)

HBox(children=(Dropdown(description='Fire:', options=('Caldor', 'Camp', 'Carr', 'Creek', 'Dixie', 'Glass', 'Th…

Output()

Generate map function

In [18]:
def generate_map(fire_name, sample_size):
    print(f"Loading fire: {fire_name}")
    df = pd.read_parquet(dataset_path, engine="pyarrow")
    df_fire = df[df["fire_name"] == fire_name].dropna(subset=FEATURES + ["latitude", "longitude", "severity"])

    if df_fire.empty:
        raise ValueError("No rows found for fire:", fire_name)

    df_sample = df_fire.sample(n=min(sample_size, len(df_fire)), random_state=42)

    gdf = gpd.GeoDataFrame(df_sample, geometry=gpd.points_from_xy(df_sample.longitude, df_sample.latitude),
                           crs="EPSG:4326")

    X = df_sample[FEATURES].values

    if rf_model is not None:
        gdf["RF_Risk"] = np.clip(safe_predict_proba(rf_model, X), 0, 1)
    if lr_model is not None:
        gdf["LR_Risk"] = np.clip(safe_predict_proba(lr_model, X), 0, 1)
    if cnn_model is not None:
        gdf["CNN_Risk"] = np.clip(safe_predict_proba(cnn_model, X), 0, 1)

    m = folium.Map(location=[gdf.latitude.mean(), gdf.longitude.mean()],
                   zoom_start=10, tiles="Esri.WorldImagery")

    colormap = linear.YlOrRd_09.scale(0, 1)
    colormap.caption = "Predicted Fire Risk Probability"

    def add_layer(gdf, colname, name, show=False):
        fg = folium.FeatureGroup(name=name, show=show)
        for _, row in gdf.iterrows():
            val = row[colname]
            color = colormap(val)
            folium.CircleMarker(
                location=[row.latitude, row.longitude],
                radius=2, color=color, fill=True, fill_color=color, fill_opacity=0.7, stroke=False
            ).add_to(fg)
        fg.add_to(m)

    if "RF_Risk" in gdf.columns:
        add_layer(gdf, "RF_Risk", "Random Forest Risk", show=False)
    if "LR_Risk" in gdf.columns:
        add_layer(gdf, "LR_Risk", "Logistic Regression Risk", show=False)
    if "CNN_Risk" in gdf.columns:
        add_layer(gdf, "CNN_Risk", "CNN Risk", show=False)

    burned = gdf[gdf["severity"].isin(["High", "Moderate"])]
    fg_burn = folium.FeatureGroup(name="Actual Burned (High/Moderate)", show=True)
    for _, row in burned.iterrows():
        folium.CircleMarker(
            location=[row.latitude, row.longitude],
            radius=2, color="cyan", fill=True, fill_color="cyan",
            fill_opacity=0.6, stroke=False
        ).add_to(fg_burn)
    fg_burn.add_to(m)

    colormap.add_to(m)
    folium.LayerControl(collapsed=False).add_to(m)

    outpath = os.path.join(out_dir, f"{fire_name}_risk_map.html")
    m.save(outpath)
    print(f"Saved map to {outpath}")
    return m

Generate on button click

In [19]:
def on_generate_clicked(b):
    with output_html_widget:
        output_html_widget.clear_output()
        fire = fire_dropdown.value
        sample = sample_slider.value
        print(f"Generating map for {fire} (sample={sample})...")
        try:
            m = generate_map(fire, sample)
            display(HTML(f'<iframe src="{os.path.abspath(os.path.join(out_dir, fire + "_risk_map.html"))}" width="1000" height="600"></iframe>'))
        except Exception as e:
            print("Error:", e)

generate_btn.on_click(on_generate_clicked)

print("Ready — select a fire and click 'Generate Map'.")

Ready — select a fire and click 'Generate Map'.
