In [None]:
import os,sys
import time
import geopandas as gpd
import rioxarray as rxr
import rasterio
from rasterio.enums import Resampling
import xarray as xr
import threading

proj = 'EPSG:32613'  # project CRS (UTM 13N)

projdir = '/Users/max/Library/CloudStorage/OneDrive-Personal/mcook/aspen-fire/Aim1/'
# Load the mosaic probability surface for the case study ROI
prob_path = os.path.join(projdir, 'data/spatial/mod/results/probability/srme_skcv_probability_mosaic_prop.tif')
# Bring in a reference image for 30-meter matching (LANDFIRE)
ref = os.path.join(projdir, 'data/spatial/mod/landfire/lc16_evt_srme_aspen_r01_utm.tif')

In [None]:
# Functions

# Function to resample and snap grids
def resample_match_grid(
        in_img, to_img, scale_factor,
        crs, dtype, method='', out_path=os.getcwd()):

    toimg = rxr.open_rasterio(to_img,masked=True,cache=False).squeeze()
    inimg = rxr.open_rasterio(in_img,masked=True,cache=False).squeeze()

    # Define the resample dimensions
    new_height = int(inimg.rio.height * scale_factor)
    new_width = int(inimg.rio.width * scale_factor)

    # Reproject w/ resampling
    resamp = inimg.rio.reproject(
        crs,
        shape=(new_height, new_width),
        resampling=method
    )

    # Reproject match, save out
    out_grid = resamp.rio.reproject_match(toimg)
    out_grid.rio.to_raster(
        out_path, tiled=True, lock=threading.Lock(), windowed=True,
        compress='zstd', zstd_level=9, num_threads='all_cpus',
        dtype=dtype, driver='GTiff'
    )
    
    
# Function to reclassify input image to binary
def reclassify_bin(img,to_img,geom,folder,proj4):

    toimg = rxr.open_rasterio(to_img,masked=True).squeeze()
    roi = gpd.read_file(geom).to_crs(crs=proj4).geometry

    in_img = rxr.open_rasterio(img, masked=True).squeeze()
    out_img = xr.where(in_img > 0, 1, 0).astype(rasterio.uint8)
    out_img_match = out_img.rio.reproject_match(toimg)
    out_img_clip = out_img_match.rio.clip(roi)
    out_file = os.path.basename(str(img))[:-4] + '_bin.tif'

    out_img_clip.rio.to_raster(os.path.join(folder, out_file))
    

In [None]:
begin = time.time()
# Generate a 30m probability map matched to LANDFIRE
# (mean & median (probability), max & sum (binary))
print("Calculating mean and median resample ...")
# Mean
mean30 = resample_match_grid(
    in_img=prob_path, to_img=ref,
    scale_factor=1/3, crs=proj, method=Resampling.average, dtype="float32",
    out_path=os.path.join(datamod,'results/probability/aspen_prob_30m_mn.tif')
)
print("Successfully exported the mean ...")
# Median
med30 = resample_match_grid(
    in_img=prob_path, to_img=ref,
    scale_factor=1/3, crs=proj, method=Resampling.med, dtype="float32",
    out_path=os.path.join(datamod,'results/probability/aspen_prob_30m_md.tif')
)
print("Successfully exported the median ...")
del mean30
del med30
time.sleep(1)
print(time.time() - begin)

In [None]:
begin = time.time()
# Create the 10m binary raster based on the optimum threshold
# Starting with the 10m probability map
thresh = 424  # 'optimum' threshold based on model averages and F1 score
# Read in the probability surface
prob = rxr.open_rasterio(prob_path,masked=True,cache=False,chunks=True,lock=False).squeeze()

print("Calculating binary resample ...")
# Reclassify to binary based on threshold
bin_out = os.path.join(projdir,'data/spatial/mod/results/classification/aspen_prob_10m_binOpt.tif')
bin10 = xr.where(prob >= thresh, 1, 0).astype(rasterio.uint8).rio.reproject(proj)
bin10.rio.to_raster(
    bin_out, tiled=True, lock=threading.Lock(), windowed=True,
    compress='zstd', zstd_level=9, num_threads='all_cpus',
    dtype='uint16', driver='GTiff'
)

# Create the binary surfaces at 30m

# Maximum within 30m
max30 = resample_match_grid(
    in_img=bin_out, to_img=ref, scale_factor=1/3, crs=proj,
    method=Resampling.max, dtype="uint16",
    out_path=os.path.join(projdir,'data/spatial/mod/results/classification/aspen_prob_10m_binOpt_max30m.tif')
)
# # Sum total (how many s2 pixels in one 30m)
# sum30 = resample_match_grid(
#     in_img=bin_out, to_img=ref, scale_factor=1/3, crs=proj,
#     method=Resampling.sum, dtype="uint16",
#     out_path=os.path.join('aspen_prob_10m_binOpt_sum30m.tif')
# )

time.sleep(1)
print(time.time() - begin)