# Run the processor step-by-step

In [None]:
import folium
import joblib
from dask.distributed import Client
from dep_tools.grids import PACIFIC_GRID_10
from dep_tools.loaders import OdcLoader
from dep_tools.searchers import PystacSearcher
from odc.stac import configure_s3_access

In [None]:
# reload code
%reload_ext autoreload
%autoreload 2

In [None]:
# Set up S3 access
configure_s3_access(aws_unsigned=True)

# Configure Dask
Client(n_workers=4, threads_per_worker=16, memory_limit="12GB")

In [None]:
tile_index = (63, 20)  # SE Viti Levu
# tile_index = (108, 17)  # Failed with OOM
grid = PACIFIC_GRID_10
geobox = grid.tile_geobox(tile_index)

datetime = "2024"
model = joblib.load("models/20250902c-alex.model")

geobox.explore()

In [None]:
searcher = PystacSearcher(
    catalog="https://stac.digitalearthpacific.org",
    collections=["dep_s2_geomad"],
    datetime=datetime,
)

items = searcher.search(geobox)
print(f"Found {len(items)} items")

In [None]:
loader = OdcLoader(
    chunks=dict(x=2024, y=2024),
    fail_on_error=False,
    measurements=[
        "nir",
        "red",
        "blue",
        "green",
        "emad",
        "smad",
        "bcmad",
        "green",
        "nir08",
        "nir09",
        "swir16",
        "swir22",
        "coastal",
        "rededge1",
        "rededge2",
        "rededge3",
    ],  # List measurements so we don't get count
)

input_data = loader.load(items, geobox)
input_data

In [None]:
# from processor import SeagrassProcessor

# # The actual processor, doing the work :muscle:
# # Uncomment this and run it to test the real thing.
# # The code in the cells below is copied from there, so replicates its work step-by-step

# processor = SeagrassProcessor(
#     model=model,
#     probability_threshold=60,
#     nodata_value=255,
#     fast_mode=True
# )

# results = processor.process(input_data)
# results

In [None]:
# from ipyleaflet import basemaps

# m = folium.Map(
#     location=geobox.geographic_extent.centroid.coords[0][::-1],
#     zoom_start=10,
#     tiles=basemaps.Esri.WorldImagery,
# )

# for var in results.data_vars:
#     results[var].odc.explore(m, name=var)

# # Layer control
# folium.LayerControl().add_to(m)

# m

In [None]:
from utils import (
    scale,
    calculate_band_indices,
    texture,
    do_prediction,
    probability_binary,
    extract_single_class,
)
from masking import all_masks
import xarray as xr

target_class_id = 4
fast_mode = True
probability_threshold = 60
nodata_value = 255

# Scale data to values of 0-1 so that we can calculate indices properly
scaled_data = scale(input_data).squeeze(drop=True)

# Load data into memory here, before we do intensive things like texture
loaded_data = scaled_data.compute()
loaded_data

In [None]:
loaded_data.odc.explore(vmin=0, vmax=0.3, bands=["red", "green", "blue"])

In [None]:
# Compute indices
data_indices = calculate_band_indices(loaded_data)
data_indices

In [None]:
data_indices.mndwi.odc.explore(cmap="Blues")

In [None]:
# Calculate the texture data on unmasked data
texture_data = texture(data_indices.blue, levels=32).compute()
texture_data

In [None]:
texture_data.entropy.plot.imshow()

In [None]:
# Combine the two datasets before applying the mask
combined_data = xr.merge([data_indices, texture_data])

# Mask all the data
masked_scaled, mask = all_masks(combined_data, return_mask=True)

mask.plot.imshow()

In [None]:
from masking import mask_land, mask_deeps, mask_elevation, mask_surf, apply_mask

ds = combined_data

_, land_mask = mask_land(ds, return_mask=True)

land_mask.plot.imshow()

In [None]:
_, deeps_mask = mask_deeps(ds, return_mask=True)
deeps_mask.plot.imshow()

In [None]:
_, elevation_mask = mask_elevation(ds, return_mask=True)
elevation_mask.plot.imshow()

In [None]:
# Pass the water_area_mask to mask_surf
print("Applying surf mask...")
_, surf_mask = mask_surf(
    ds=ds,
    water_area_mask=~land_mask,
    return_mask=True,
    # You can also pass surf_blue_threshold, surf_green_threshold, etc. here if you want to customize them
)

surf_mask.plot.imshow()

In [None]:
# Combine all masks. All individual masks are now False for areas to KEEP.
mask = land_mask | deeps_mask | elevation_mask | surf_mask
mask.plot.imshow()

In [None]:
final_data = apply_mask(ds, ~mask, None, False)
final_data

In [None]:
# mask.odc.explore(cmap="Reds_r", vmin=1, vmax=2)
final_data[["red", "green", "blue"]].to_array().plot.imshow(vmin=0, vmax=0.3, size=8)

In [None]:
# Run the prediction
classification, probability = do_prediction(final_data, model, target_class_id)

seagrass_threshold = probability_binary(
    probability,
    probability_threshold,
    nodata_value=nodata_value,
)

seagrass_class = extract_single_class(
    classification,
    target_class_id,
)

output = xr.Dataset(
    {
        "classification": classification,
        "seagrass_probability": probability,
        "seagrass_threshold_60": seagrass_threshold,
        "seagrass": seagrass_class,
    }
)

for var in output.data_vars:
    print(f"{var}: {output[var].dtype}")
    output[var].odc.nodata = nodata_value
    output[var].attrs["_FillValue"] = nodata_value

output

In [None]:
for var in output.data_vars:
    print(var, output[var].dtype)

In [None]:
_ = output.seagrass_probability.plot.hist(bins=100)

In [None]:
from ipyleaflet import basemaps

m = folium.Map(
    location=geobox.geographic_extent.centroid.coords[0][::-1],
    zoom_start=10,
    tiles=basemaps.Esri.WorldImagery,
)

for var in output.data_vars:
    output[var].odc.explore(m, name=var)

# Layer control
folium.LayerControl().add_to(m)

m