In [None]:
import os
import tempfile
import warnings
from os.path import join as pjoin

import dask
import dask.dataframe as dd
import dask_geopandas as dgpd
import geopandas as gpd
import numpy as np
import pandas as pd
import rasterio
from scipy.fft import dst
import tqdm
import xarray as xr
from dask.diagnostics import ProgressBar
from rasterio.crs import CRS

from raster_tools import Raster, Vector, open_vectors, clipping, zonal
from raster_tools.dtypes import F32, U8, U16

import matplotlib.pyplot as plt

In [None]:
# change pandas max col display
pd.set_option('display.max_columns', 500)

In [None]:
# Location for temporary storage
TMP_LOC = "/home/jake/FireLab/Project/data/temp/"
DATA_LOC = "/home/jake/FireLab/Project/data/"

STATE = "OR"

# Location of clipped DEM files
DEM_DATA_DIR = pjoin(TMP_LOC, "dem_data")

# location of feature data files
FEATURE_DIR = pjoin(DATA_LOC, "FeatureData")
EDNA_DIR = pjoin(DATA_LOC, "terrain")
MTBS_DIR = pjoin(DATA_LOC, "MTBS_Data")

mtbs_df_path = pjoin(TMP_LOC, f"{STATE}_mtbs.parquet/")
mtbs_df_temp_path = pjoin(TMP_LOC, f"{STATE}_mtbs_temp.parquet/")
checkpoint_1_path = pjoin(TMP_LOC, "check1")
checkpoint_2_path = pjoin(TMP_LOC, "check2")

PATHS = {
    "states": pjoin(EDNA_DIR, "state_borders/cb_2018_us_state_5m.shp"),
    "dem": pjoin(EDNA_DIR, "us_orig_dem/us_orig_dem/orig_dem/hdr.adf"),
    "dem_slope": pjoin(EDNA_DIR, "us_slope/us_slope/slope/hdr.adf"),
    "dem_aspect": pjoin(EDNA_DIR, "us_aspect/aspect/hdr.adf"),
    "dem_flow_acc": pjoin(EDNA_DIR, "us_flow_acc/us_flow_acc/flow_acc/hdr.adf"),
    "gm_srad": pjoin(FEATURE_DIR, "gridmet/srad_1986_2020_weekly.nc"),
    "gm_vpd": pjoin(FEATURE_DIR, "gridmet/vpd_1986_2020_weekly.nc"),
    "aw_mat": pjoin(FEATURE_DIR, "adaptwest/Normal_1991_2020_MAT.tif"),
    "aw_mcmt": pjoin(FEATURE_DIR, "adaptwest/Normal_1991_2020_MCMT.tif"),
    "aw_mwmt": pjoin(FEATURE_DIR, "adaptwest/Normal_1991_2020_MWMT.tif"),
    "aw_td": pjoin(FEATURE_DIR, "adaptwest/Normal_1991_2020_TD.tif"),
    "dm_tmax": pjoin(FEATURE_DIR, "daymet/tmax_1986_2020.nc"),
    "dm_tmin": pjoin(FEATURE_DIR, "daymet/tmin_1986_2020.nc"),
    "biomass_afg": pjoin(
        FEATURE_DIR, "biomass/biomass_afg_1986_2020_{}.nc".format(STATE)
    ),
    "biomass_pfg": pjoin(
        FEATURE_DIR, "biomass/biomass_pfg_1986_2020_{}.nc".format(STATE)
    ),
    "landfire_fvt": pjoin(
        FEATURE_DIR, "landfire/LF2020_FVT_200_CONUS/Tif/LC20_FVT_200.tif"
    ),
    "landfire_fbfm40": pjoin(
        FEATURE_DIR, "landfire/LF2020_FBFM40_200_CONUS/Tif/LC20_F40_200.tif"
    ),
    "ndvi": pjoin(FEATURE_DIR, "ndvi/access/weekly/ndvi_1986_2020_weekavg.nc"),
    "mtbs_root": pjoin(MTBS_DIR, "MTBS_BSmosaics/"),
    "mtbs_perim": pjoin(MTBS_DIR, "mtbs_perimeter_data/mtbs_perims_DD.shp"),
}
YEARS = list(range(2016, 2021))
GM_KEYS = list(filter(lambda x: x.startswith("gm_"), PATHS))
AW_KEYS = list(filter(lambda x: x.startswith("aw_"), PATHS))
DM_KEYS = list(filter(lambda x: x.startswith("dm_"), PATHS))
BIOMASS_KEYS = list(filter(lambda x: x.startswith("biomass_"), PATHS))
LANDFIRE_KEYS = list(filter(lambda x: x.startswith("landfire_"), PATHS))
NDVI_KEYS = list(filter(lambda x: x.startswith("ndvi"), PATHS))
DEM_KEYS = list(filter(lambda x: x.startswith("dem"), PATHS))

---

In [None]:
# Filter out warnings from dask_geopandas and dask
warnings.filterwarnings(
    "ignore", message=".*initial implementation of Parquet.*"
)
warnings.filterwarnings(
    "ignore", message=".*Slicing is producing a large chunk.*"
)


def hillshade(slope, aspect, azimuth=180, zenith=45):
    # Convert angles from degrees to radians
    azimuth_rad = np.radians(azimuth)
    zenith_rad = np.radians(zenith)
    slope_rad = np.radians(slope)
    aspect_rad = np.radians(aspect)

    # Calculate hillshade
    shaded = np.sin(zenith_rad) * np.sin(slope_rad) + \
             np.cos(zenith_rad) * np.cos(slope_rad) * \
             np.cos(azimuth_rad - aspect_rad)
    # scale to 0-255
    shaded = 255 * (shaded + 1) / 2
    # round hillshade to nearest integer
    shaded = np.rint(shaded)
    # convert to uint8
    # Ensure non-finite values are not converted to int
    # shaded = np.where(np.isfinite(shaded), shaded.astype(np.uint8), np.nan)
    return shaded

def hillshade_partition(df, zenith, azimuth):
    # Apply the hillshade function to the slope and aspect columns
    df['hillshade'] = hillshade(df['dem_slope'], df['dem_aspect'], azimuth, zenith)
    return df

def timestamp_to_year_part(df):
    # Assuming 'ig_date' is the column with timestamp data
    df['year'] = df['ig_date'].dt.year
    return df


def get_nc_var_name(ds):
    # Find the data variable in a nc xarray.Dataset
    var_name = list(set(ds.keys()) - set(["crs", "day_bnds"]))[0]
    # var_name = list(set(ds.keys()) - set(["crs", "bnds"]))[1] # for DAYMET ONLY!!
    return var_name


def netcdf_to_raster(path, date):
    # This produces a Dataset. We need to grab the DataArray inside that
    # contains the data of interest.
    nc_ds = xr.open_dataset(path, chunks={"day": 1})#, decode_times=False)
    nc_ds2 = nc_ds.drop_vars(
        ["latitude_bnds", "longitude_bnds", "time_bnds"]
    ).rio.write_crs("EPSG:5071") # FOR NDVI ONLY!!
    # nc_ds2 = nc_ds.rio.write_crs("EPSG:5071")  # FOR DAYMET ONLY!!
    # nc_ds = nc_ds.rio.write_crs(
    #     nc_ds.coords["lambert_conformal_conic"].spatial_ref
    # )  # FOR DAYMET ONLY!!
    # nc_ds = nc_ds.rename({"lambert_conformal_conic": "crs"})  # FOR DAYMET ONLY!!
    # nc_ds2 = nc_ds.drop_vars(["lat", "lon"])  # FOR DAYMET ONLY!!
    # nc_ds = None # FOR DAYMET ONLY!!
    # nc_ds2 = nc_ds2.rename_vars({"x": "lon", "y": "lat"})  # FOR DAYMET ONLY!!
    # comment lines below for normal operation
    #ds_crs = CRS.from_epsg(5071) dont need this line
    #nc_ds.rio.write_crs(ds_crs) dont need this line
    # nc_ds2 = nc_ds.rio.write_crs(nc_ds.crs.spatial_ref)
    # nc_ds2 = nc_ds.rio.write_crs(nc_ds.crs) # for NDVI
    # print nc_ds dimensions
    # print(f"{nc_ds.dims = }")
    # Find variable name
    var_name = get_nc_var_name(nc_ds2)
    # print(f"var_name: {var_name}")
    # Extract
    var_da = nc_ds2[var_name]
    # print(f"{var_da = }")
    var_da = var_da.sel(time=date, method="nearest") # for DM and BM and NDVI
    # var_da = var_da.sel(day=date, method="nearest") # for GM
    # xrs = xr.DataArray(
    #     var_da.data, dims=("y", "x"), coords=(var_da.lat.data, var_da.lon.data)
    # ).expand_dims("band") # For non-NDVI
    xrs = xr.DataArray(
        var_da.data, dims=("y", "x"), coords=(var_da.latitude.data, var_da.longitude.data)
    ).expand_dims("band") # FOR NDVI ONLY!!
    xrs["band"] = [1]
    # Set CRS in raster compliant format
    xrs = xrs.rio.write_crs(nc_ds2.crs.spatial_ref)
    return Raster(xrs)


def extract_nc_data(df, nc_name):
    assert df.ig_date.unique().size == 1
    # print(f"{gm_name}: {df.columns = }, {len(df) = }")
    date = df.ig_date.values[0]
    print(f"{nc_name}: starting {date}")
    rs = netcdf_to_raster(PATHS[nc_name], date)
    bounds = gpd.GeoSeries(df.geometry).to_crs(rs.crs).total_bounds
    rs = clipping.clip_box(rs, bounds)
    if type(df) == pd.DataFrame:
        df = gpd.GeoDataFrame(df)
    feat = Vector(df, len(df))
    rdf = (
        zonal.extract_points_eager(feat, rs, skip_validation=True)
        .drop(columns=["band"])
        .rename(columns={"extracted": nc_name})
        .compute()
    )
    df[nc_name].values[:] = rdf[nc_name].values
    # print(f"{nc_name}: finished {date}")
    return df


def get_state_dem_path(dem_key, state):
    return pjoin(DEM_DATA_DIR, f"{state}_{dem_key}.tif")


def extract_dem_data(df, key):
    state = df.state.values[0]
    path = get_state_dem_path(key, state)
    rs = Raster(path)
    if type(df) == pd.DataFrame:
        df = gpd.GeoDataFrame(df)
    feat = Vector(df, len(df))
    rdf = (
        zonal.extract_points_eager(feat, rs, skip_validation=True)
        .drop(columns=["band"])
        .compute()
    )
    df[key].values[:] = rdf.extracted.values
    return df

def extract_tif_data(df, key):
    state = df.state.values[0]
    path = PATHS[key]
    rs = Raster(path)
    if type(df) == pd.DataFrame:
        df = gpd.GeoDataFrame(df)
    feat = Vector(df, len(df))
    rdf = (
        zonal.extract_points_eager(feat, rs, skip_validation=True)
        .drop(columns=["band"])
        .compute()
    )
    df[key].values[:] = rdf.extracted.values
    return df


def partition_extract_nc(df, key):
    # This func wraps extract_nc_data. It groups the partition in to sub
    # dataframes with the same date and then applies extract_nc_data to
    # each and reassembles the results into an output dataframe.
    parts = []
    for group in df.groupby("ig_date", sort=True):
        _, gdf = group
        parts.append(extract_nc_data(gdf, key))
    return pd.concat(parts)

def partition_extract_tif(df, key):
    # This func wraps extract_tif_data. It groups the partition in to sub
    # dataframes with the same date and then applies extract_tif_data to
    # each and reassembles the results into an output dataframe.
    parts = []
    for group in df.groupby("ig_date", sort=True):
        _, gdf = group
        parts.append(extract_tif_data(gdf, key))
    return pd.concat(parts)

def clip_and_save_dem_rasters(keys, paths, feature, state):
    feature = feature.compute()
    for k in tqdm.tqdm(keys, ncols=80, desc="DEM Clipping"):
        path = paths[k]
        out_path = get_state_dem_path(k, state)
        if os.path.exists(out_path):
            continue
        rs = Raster(path)
        (bounds,) = dask.compute(feature.to_crs(rs.crs).total_bounds)
        print('dem crs: ', rs.crs)
        crs = clipping.clip_box(rs, bounds)
        crs.save(out_path)


def build_mtbs_year_df(path, perims_df, state_label):
    rs = Raster(path)
    dfs = []
    for grp in perims_df.groupby("Ig_Date"):
        date, perim = grp
        df = (
            clipping.clip(perim, rs)
            .to_vector()
            .rename(columns={"value": "mtbs"})
            .drop(columns=["band", "row", "col"])
            .assign(state=state_label, ig_date=date)
            .astype({"mtbs": U8})
        )
        dfs.append(df)
    print('perim crs: ', perims_df.crs)
    return dd.concat(dfs)


def _build_mtbs_df(
    years, year_to_mtbs_file, year_to_perims, state, working_dir
):
    dfs = []
    it = tqdm.tqdm(years, ncols=80, desc="MTBS")
    for y in it:
        mtbs_path = year_to_mtbs_file[y]
        if not os.path.exists(mtbs_path):
            it.write(f"No data for {y}")
            continue
        perims = year_to_perims[y]
        ydf = build_mtbs_year_df(mtbs_path, perims, state)
        ypath = pjoin(working_dir, str(y))
        ydf.compute().to_parquet(ypath)
        ydf = dgpd.read_parquet(ypath)
        dfs.append(ydf)
    return dd.concat(dfs)


def build_mtbs_df(
    years, year_to_mtbs_file, year_to_perims, state, out_path, tmp_loc=TMP_LOC
):
    print("Building mtbs df")
    with tempfile.TemporaryDirectory(dir=tmp_loc) as working_dir:
        df = _build_mtbs_df(
            years, year_to_mtbs_file, year_to_perims, state, working_dir
        )
        with ProgressBar():
            df.to_parquet(out_path)
    return dgpd.read_parquet(out_path)


def add_columns_to_df(
    df,
    columns,
    part_func,
    out_path,
    col_type=F32,
    col_default=np.nan,
    part_func_args=(),
    tmp_loc=TMP_LOC,
    parallel=True,
):
    print(f"Adding columns: {columns}")
    # Add columns
    expanded_df = df.assign(**{c: col_type.type(col_default) for c in columns})
    with tempfile.TemporaryDirectory(dir=tmp_loc) as working_dir:
        # Save to disk before applying partition function. to_parquet() has a
        # chance of segfaulting and that chance goes WAY up after adding
        # columns and then mapping a function to partitions. Saving to disk
        # before mapping keeps the odds low.
        path = pjoin(working_dir, "expanded")
        expanded_df.to_parquet(path)

        expanded_df = dgpd.read_parquet(path)
        meta = expanded_df._meta.copy()
        for c in columns:
            expanded_df = expanded_df.map_partitions(
                part_func, c, *part_func_args, meta=meta
            )

        if parallel:
            with ProgressBar():
                expanded_df.to_parquet(out_path)
        else:
            # Save parts in serial and then assemble into single dataframe
            with tempfile.TemporaryDirectory(dir=tmp_loc) as part_dir:
                dfs = []
                for i, part in enumerate(expanded_df.partitions):
                    # Save part i
                    part_path = pjoin(part_dir, f"part{i:04}")
                    with ProgressBar():
                        part.compute().to_parquet(part_path)
                    # Save paths for opening with dask_geopandas later. Avoid
                    # opening more dataframes in this loop as doing so will
                    # likely cause a segfault. I have no idea why.
                    dfs.append(part_path)
                dfs = [dgpd.read_parquet(p) for p in dfs]
                # Assemble and save to final output location
                expanded_df = dd.concat(dfs)
                with ProgressBar():
                    expanded_df.to_parquet(out_path)
    return dgpd.read_parquet(out_path)

In [None]:
nc_ds = xr.open_dataset(PATHS['ndvi'], chunks={"day": 1})#, decode_times=False)
nc_ds2 = nc_ds.drop_vars(
    ["latitude_bnds", "longitude_bnds", "time_bnds"]
).rio.write_crs("EPSG:5071") # FOR NDVI ONLY!!

var_name = get_nc_var_name(nc_ds2)
print(f"var_name: {var_name}")
# Extract
date='2020-08-16'
var_da = nc_ds2[var_name]
var_da = var_da.sel(time=date, method="nearest") # for DM and BM and NDVI

In [None]:
xrs = xr.DataArray(
        var_da.data, dims=("y", "x"), coords=(var_da.latitude.data, var_da.longitude.data)
    ).expand_dims("band") # FOR NDVI ONLY!!
xrs["band"] = [1]
# # Set CRS in raster compliant format
xrs = xrs.rio.write_crs(nc_ds2.crs.spatial_ref)
rs = Raster(xrs)

In [None]:
df = mtbs_df_2016_2020
bounds = gpd.GeoSeries(df.geometry).to_crs(rs.crs).total_bounds
rs = clipping.clip_box(rs, bounds)

In [32]:
df.crs

<Projected CRS: ESRI:102039>
Name: USA_Contiguous_Albers_Equal_Area_Conic_USGS_version
Axis Info [cartesian]:
- [east]: Easting (metre)
- [north]: Northing (metre)
Area of Use:
- undefined
Coordinate Operation:
- name: unnamed
- method: Albers Equal Area
Datum: North American Datum 1983
- Ellipsoid: GRS 1980
- Prime Meridian: Greenwich

In [30]:
bounds = gpd.GeoSeries(df.geometry)
bounds

AttributeError: 'numpy.ndarray' object has no attribute 'crs'

In [None]:

# # State borders
# print("Loading state borders")
# stdf = open_vectors(PATHS["states"], 0).data.to_crs("EPSG:5071")
# states = {st: stdf[stdf.STUSPS == st].geometry for st in list(stdf.STUSPS)}
# state_shape = states[STATE]
# states = None
# stdf = None
# # MTBS Perimeters
# print("Loading MTBS perimeters")
# perimdf = open_vectors(PATHS["mtbs_perim"]).data.to_crs("EPSG:5071")
# state_fire_perims = perimdf.clip(state_shape.compute())
# state_fire_perims = (
#     state_fire_perims.assign(
#         Ig_Date=lambda frame: dd.to_datetime(
#             frame.Ig_Date, format="%Y-%m-%d"
#         )
#     )
#     .sort_values("Ig_Date")
#     .compute()
# )
# year_to_perims = {
#     y: state_fire_perims[state_fire_perims.Ig_Date.dt.year == y]
#     for y in YEARS
# }
# state_fire_perims = None
# year_to_mtbs_file = {
#     y: pjoin(PATHS["mtbs_root"], f"mtbs_{STATE}_{y}.tif")
#     for y in YEARS
# }
# print(year_to_mtbs_file)
mtbs_df_path = pjoin(TMP_LOC, f"{STATE}_mtbs.parquet")
mtbs_df_temp_path = pjoin(TMP_LOC, f"{STATE}_mtbs_temp.parquet")
checkpoint_1_path = pjoin(TMP_LOC, "check1")
checkpoint_2_path = pjoin(TMP_LOC, "check2")
if 0:
    # code below for creating a new dataset for a new state / region
    df = build_mtbs_df(
        YEARS,
        year_to_mtbs_file,
        year_to_perims,
        STATE,
        out_path=checkpoint_1_path,
    )
    # df = add_columns_to_df(
    #     df, GM_KEYS, partition_extract_gridmet, checkpoint_1_path
    # )
    clip_and_save_dem_rasters(DEM_KEYS, PATHS, state_shape, STATE)
    df = add_columns_to_df(
        df,
        DEM_KEYS,
        extract_dem_data,
        checkpoint_1_path,
        # Save results in serial to avoid segfaulting. Something about the
        # dem computations makes segfaults extremely likely when saving
        # The computations require a lot of memory which may be what
        # triggers the fault.
        parallel=False,
    )
    df = df.repartition(partition_size="100MB").reset_index(drop=True)
    print("Repartitioning")
    with ProgressBar():
        df.to_parquet(checkpoint_2_path)
if 1:
    # code below used to add new features to the dataset
    with ProgressBar():
        df = dgpd.read_parquet(checkpoint_2_path)
    df = add_columns_to_df(
        df, NDVI_KEYS, partition_extract_nc, checkpoint_1_path, parallel=False
    ) # for NetCDF data
    # df = add_columns_to_df(
    #     df, LANDFIRE_KEYS, partition_extract_tif, checkpoint_1_path, parallel=False
    # ) # for TIF data
    df = df.repartition(partition_size="100MB").reset_index(drop=True)
    print("Repartitioning")
    with ProgressBar():
        df.to_parquet(checkpoint_1_path)
if 0:
    with ProgressBar():
        df = dgpd.read_parquet(mtbs_df_temp_path)
    # df = df.assign(hillshade=U8.type(0))
    # df = df.map_partitions(hillshade_partition, 45, 180, meta=df._meta)
    # df = df.assign(year=U16.type(0))
    # df = df.map_partitions(timestamp_to_year_part, meta=df._meta)
    print(df.head())
    print("Repartitioning and saving ")
    df = df.repartition(partition_size="100MB").reset_index(drop=True)
    # df = df.assign(unique_id=str) does not work, added outside of dask
    # df = df.map_partitions(add_unique_identifier, meta=df._meta)
    with ProgressBar():
        # df.to_parquet(mtbs_df_temp_path)
        df.to_parquet(mtbs_df_path)

In [None]:
mtbs_df_2016_2020 = dgpd.read_parquet(checkpoint_2_path)

In [None]:
mtbs_df_2016_2020_computed_dem_only = mtbs_df_2016_2020.compute()

#### MTBS DATASET ADD UID AND EVENT ID

In [None]:
mtbs_dask_df = dgpd.read_parquet(mtbs_df_temp_path)
mtbs_dask_df.head()

In [None]:
mtbs_dataset = mtbs_dask_df.compute()

In [None]:
mtbs_dataset_2016_2020 = mtbs_dataset[mtbs_dataset.ig_date.dt.year >= 2016]

In [None]:
mtbs_df_2016_2020_computed_dem_only

In [None]:
mtbs_df_2016_2020_computed_dem_only[mtbs_df_2016_2020_computed_dem_only['ig_date'] == '2020-08-16'].plot(column='dem', cmap='terrain', legend=True)

In [None]:
mtbs_dataset_2016_2020[mtbs_dataset_2016_2020['ig_date'] == '2020-08-16'].plot(column='dem', cmap='terrain', legend=True)

In [None]:
# # adds unique id to each pixel
# mtbs_dataset.reset_index(inplace=True)
# mtbs_dataset['index'] = mtbs_dataset.index
# # rename index to unique_id
# mtbs_dataset.rename(columns={'index': 'unique_id'}, inplace=True)
# mtbs_dataset

In [None]:
# # rewrite to parquet
# mtbs_dataset.to_parquet(mtbs_df_path)

In [None]:
# # a function to round numeric values in a dataframe to 3 decimals (if they are > 1)
# def round_df(df):
#     columns_to_round = ['aw_mwmt', 'dm_tmax', 'dm_tmin', 'gm_srad', 'gm_vpd', 'aw_td', 'aw_mcmt', 'dem_aspect', 'dem_slope']
#     for col in columns_to_round:
#         df[col] = df[col].round(3)
#     return df

In [None]:
# load mtbs perimeters
print("Loading MTBS perimeters")
mtbs_perim = gpd.read_file(PATHS["mtbs_perim"])
mtbs_perim['Ig_Date'] = pd.to_datetime(mtbs_perim['Ig_Date'])
mtbs_perim.columns

In [None]:
# extract only the columns we need (Event_ID where startswith OR, Ig_Date, and geometry)
mtbs_perim = mtbs_perim[["Event_ID", "Ig_Date", "geometry"]]
mtbs_perim = mtbs_perim[mtbs_perim.Event_ID.str.startswith("OR")]
# drop rows where Ig_Date before 1986 or after 2020
mtbs_perim = mtbs_perim[mtbs_perim.Ig_Date.dt.year.between(1986, 2020)]
mtbs_perim.reset_index(drop=True, inplace=True)
len(mtbs_perim)

In [None]:
# mtbs_dataset['Fire_ID'] = 'None'
# mtbs_dataset

In [None]:
# fire_perim_geom_OR4310211883919860320 = mtbs_perim[mtbs_perim.Ig_Date.dt.year == 1986].loc[40, 'geometry']
# fire_ig_date_OR4310211883919860320 = mtbs_perim[mtbs_perim.Ig_Date.dt.year == 1986].loc[40, 'Ig_Date']
# fire_perim_geom_OR4310211883919860320

In [None]:
mtbs_perim.crs

In [None]:
mtbs_dask_df.crs

In [None]:
# convert the geometry in mtbs_dataset to lat lon
# mtbs_dask_df = mtbs_dask_df.to_crs(epsg=4326)
mtbs_perim = mtbs_perim.to_crs(mtbs_dask_df.crs)
# mtbs_dataset = mtbs_dataset.to_crs(epsg=4326)

In [None]:
def spatial_join(partition, mtbs_perim):
    # Convert the Dask partition to a GeoDataFrame
    gdf = gpd.GeoDataFrame(partition, geometry='geometry', crs=mtbs_perim.crs)

    # Perform the spatial join
    joined = gpd.sjoin(gdf, mtbs_perim, how="inner", predicate="intersects")

    # Filter by date if needed
    joined = joined[joined['ig_date'] == joined['Ig_Date']]
    # print(joined)

    return joined

result = mtbs_dask_df.map_partitions(spatial_join, mtbs_perim, align_dataframes=False)
with ProgressBar():
    final_result = result.compute()
print(len(final_result))

In [None]:
# TODO: compare final_results and mtbs_dataset to see what is missing from final_results after the spatial join

In [None]:
final_result.drop(columns=['index_right', 'Ig_Date'], inplace=True)

In [None]:
final_result.isna().sum()

In [None]:
# drop null values
final_result.dropna(inplace=True)
# adds unique id to each pixel
final_result.reset_index(inplace=True)
final_result['index'] = final_result.index
# rename index to unique_id
final_result.rename(columns={'index': 'unique_id'}, inplace=True)
final_result

In [None]:
# write final_result to dask so we can write to parquet
# final_result = dd.from_pandas(final_result, npartitions=100)

In [None]:
with ProgressBar():
    final_result.to_parquet(mtbs_df_path)

In [None]:
final_result = None

In [None]:
# read and check
mtbs_dataset_final = dgpd.read_parquet(mtbs_df_path)
mtbs_dataset_final.head()

In [None]:
mtbs_dataset_final.tail()

In [None]:
with ProgressBar():
    mtbs_final_computed = mtbs_dataset_final.compute()

In [None]:
# plot a fire topo map

eventID = 'OR4380511789220200816'

fire = final_result[final_result.Event_ID == eventID]

col = 'dem'

geo_mask = fire.geometry.unary_union
date_mask = fire.ig_date

fire.plot(column=col, s=1)

In [None]:
mtbs_dataset_2020_08_16 = mtbs_dataset[(mtbs_dataset.year == 2020) & (mtbs_dataset.ig_date.dt.month == 8) & (mtbs_dataset.ig_date.dt.day == 16)]
mtbs_dataset_2020_08_16

In [None]:
mtbs_dataset_2020_08_16.plot(column=col, s=1)

#### Data info
---

In [None]:
mtbs_data_fixed = dgpd.read_parquet(mtbs_df_path)
mtbs_data_fixed.head()

In [None]:
# read from mtbs_df_path
mtbs_data = dgpd.read_parquet(mtbs_df_temp_path)
mtbs_data.head()

In [None]:
with ProgressBar():
    mtbs_full = mtbs_dataset.compute()

---

#### Feature Data Info

In [None]:
# dem = need to explore dem and see why its is all scrambled (perhaps from mtbs builder projection changes?)