Imports

In [23]:
import os
import glob
import numpy as np
import pandas as pd
import rasterio
from rasterio.enums import Resampling
from rasterio.warp import reproject, calculate_default_transform

Config

In [24]:
raw_dir = "../data/raw"
processed_dir = "../data/processed"
os.makedirs(processed_dir, exist_ok=True)

fire_folders = [f for f in glob.glob(os.path.join(raw_dir, "*")) if os.path.isdir(f)]
print("Found fire folders:", fire_folders)

Found fire folders: ['../data/raw\\bootleg_fire', '../data/raw\\caldor_fire', '../data/raw\\camp_fire', '../data/raw\\carr_fire', '../data/raw\\dixie_fire', '../data/raw\\east_troublesome_fire', '../data/raw\\glass_fire', '../data/raw\\red_salmon_fire', '../data/raw\\zogg_fire']


Helper Functions

In [25]:
def load_and_align(path, ref_profile):
    """Load raster, reproject/align to reference profile, 
    and enforce correct dtype for categorical rasters."""
    
    categorical_layers = ["landcover", "modis_burned_area"]

    name = os.path.basename(path).lower()

    is_categorical = any(cat in name for cat in categorical_layers)

    # Choose resampling based on type
    resample = rasterio.enums.Resampling.nearest if is_categorical else rasterio.enums.Resampling.bilinear

    with rasterio.open(path) as src:
        data = src.read(1)

        # Reproject
        dst = np.empty((ref_profile["height"], ref_profile["width"]), dtype="float32")

        rasterio.warp.reproject(
            source=data,
            destination=dst,
            src_transform=src.transform,
            src_crs=src.crs,
            dst_transform=ref_profile["transform"],
            dst_crs=ref_profile["crs"],
            dst_resolution=(ref_profile["transform"][0], -ref_profile["transform"][4]),
            resampling=resample,
        )

    # Enforce integer categories
    if is_categorical:
        dst = np.rint(dst).astype(np.int16)

    return dst
    
def get_reference_profile(fire_dir):
    dem_path = os.path.join(fire_dir, "srtm_dem.tif")
    if not os.path.exists(dem_path):
        raise ValueError(f"No DEM found in {fire_dir}")
    with rasterio.open(dem_path) as src:
        return src.profile

Burn Severity Calculation

In [26]:
def classify_dnbr(dnbr):
    """
    MTBS burn severity classification based on dNBR.
    """
    classes = np.full(dnbr.shape, -1, dtype=np.int16)

    classes[(dnbr < 0.1)] = 0
    classes[(dnbr >= 0.1) & (dnbr < 0.27)] = 1
    classes[(dnbr >= 0.27) & (dnbr < 0.66)] = 2
    classes[(dnbr >= 0.66)] = 3
    
    return classes

Post Processing Validation Function

In [27]:
def validate_dataset(df, feature_schema):
    report = {}
    
    for col, expected in feature_schema.items():
        if col not in df.columns:
            report[col] = "Missing column"
            continue

        series = df[col]

        # Check for NaN or Inf
        nan_count = series.isna().sum()
        inf_count = np.isinf(series).sum()

        # Continuous variable ranges
        if expected != ("categorical",):
            low, high = expected
            out_of_range = ((series < low) | (series > high)).sum()
            report[col] = {
                "type": "continuous",
                "nan": int(nan_count),
                "inf": int(inf_count),
                "out_of_range": int(out_of_range),
                "min": float(series.min()),
                "max": float(series.max())
            }
        else:
            # For categorical, just check non-integer or negative values
            invalid = (~series.dropna().astype(float).apply(float.is_integer)).sum()
            negative = (series < 0).sum()
            report[col] = {
                "type": "categorical",
                "nan": int(nan_count),
                "inf": int(inf_count),
                "non_integer": int(invalid),
                "negative": int(negative),
                "unique_values": sorted(series.unique().tolist()[:20])
            }

    return report

Main Preprocessing Loop

In [28]:
# Validation schema
feature_schema = {
    "ndvi_pre": (-1, 1),
    "ndvi_post": (-1, 1),
    "dndvi": (-2, 2),
    "nbr_pre": (-1, 1),
    "nbr_post": (-1, 1),
    "dnbr": (-2, 2),
    "precip": (0, 300),       
    "temp": (230, 330),     
    "landcover": ("categorical",),
    "elevation": (-200, 9000),
    "severity": ("categorical",)
}

all_data = []

for fire_dir in fire_folders:
    fire_name = os.path.basename(fire_dir)
    print(f"\n# Processing {fire_name} #")

    try:
        ref_profile = get_reference_profile(fire_dir)
    except:
        print(f"  MISSING DEM - skipping {fire_name}")
        continue

    def load(name):
        p = os.path.join(fire_dir, name)
        if not os.path.exists(p):
            print(f"  Missing {name} - skipping fire.")
            return None
        return load_and_align(p, ref_profile)

    # Load features
    ndvi_pre = load("pre_ndvi.tif")
    ndvi_post = load("post_ndvi.tif")
    nbr_pre = load("pre_nbr.tif")
    nbr_post = load("post_nbr.tif")
    chirps = load("chirps_precip.tif")
    era5 = load("era5_temp.tif")
    landcover = load("landcover.tif")
    dem = load("srtm_dem.tif")
    modis = load("modis_burned_area.tif")

    # Skip if any are missing
    if any(x is None for x in [ndvi_pre, ndvi_post, nbr_pre, nbr_post, chirps, era5, landcover, dem, modis]):
        print("  Missing required layers - skipping.")
        continue

    # Derived features
    dndvi = ndvi_pre - ndvi_post
    dnbr = nbr_pre - nbr_post
    severity = classify_dnbr(dnbr)

    # Flatten valid pixels
    df = pd.DataFrame({
        "fire_name": fire_name,
        "ndvi_pre": ndvi_pre.flatten(),
        "ndvi_post": ndvi_post.flatten(),
        "dndvi": dndvi.flatten(),
        "nbr_pre": nbr_pre.flatten(),
        "nbr_post": nbr_post.flatten(),
        "dnbr": dnbr.flatten(),
        "precip": chirps.flatten(),
        "temp": era5.flatten(),
        "landcover": landcover.flatten(),
        "elevation": dem.flatten(),
        "severity": severity.flatten(),
    })

    # Drop invalid
    df = df.replace([np.inf, -np.inf], np.nan).dropna()
    df = df[df["severity"] >= 0]

    # Validate fire's dataset
    fire_report = validate_dataset(df, feature_schema)

    # Detect critical failures
    critical_fail = False
    for feat, stats in fire_report.items():
        if stats == "Missing column":
            print(f"  Validation Failed: missing {feat}")
            critical_fail = True
            break

        # Continuous: check for out of range values
        if stats["type"] == "continuous":
            if stats["out_of_range"] > 0:
                print(f"  Feature `{feat}` has {stats['out_of_range']} out-of-range values")
                critical_fail = True

        # Categorical: check for negative or non-integer values
        if stats["type"] == "categorical":
            if stats["non_integer"] > 0 or stats["negative"] > 0:
                print(f"  Feature `{feat}` has invalid categorical values ({stats})")
                critical_fail = True

    if critical_fail:
        print(f"  Skipping {fire_name} due to failed validation.")
        continue

    print(f"  Validation passed for {fire_name}.")

    print(f"  Added {len(df)} samples")

    all_data.append(df)

# Combine across all fires
full_df = pd.concat(all_data, ignore_index=True)
print("\nFinal Dataset Size:", len(full_df))

full_path = os.path.join(processed_dir, "stacked_dataset.parquet")
full_df.to_parquet(full_path)

print("\nSaved dataset to", full_path)



# Processing bootleg_fire #
  Validation passed for bootleg_fire.
  Added 6908032 samples

# Processing caldor_fire #
  Validation passed for caldor_fire.
  Added 4139624 samples

# Processing camp_fire #
  Validation passed for camp_fire.
  Added 3103047 samples

# Processing carr_fire #
  Validation passed for carr_fire.
  Added 3313036 samples

# Processing dixie_fire #
  Validation passed for dixie_fire.
  Added 8281472 samples

# Processing east_troublesome_fire #
  Validation passed for east_troublesome_fire.
  Added 3311549 samples

# Processing glass_fire #
  Validation passed for glass_fire.
  Added 1242110 samples

# Processing red_salmon_fire #
  Validation passed for red_salmon_fire.
  Added 2206710 samples

# Processing zogg_fire #
  Validation passed for zogg_fire.
  Added 3448449 samples

Final Dataset Size: 35954029

Saved dataset to ../data/processed\stacked_dataset.parquet
