In [None]:
from dep_tools.grids import PACIFIC_GRID_10
from src.utils import get_gmw

from odc.stac import configure_s3_access
from dep_tools.searchers import PystacSearcher
from dep_tools.loaders import OdcLoader

import sys
sys.path.append('src')
from run_task import MangrovesProcessor

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
tile_id = "64,19"
year = "2024"

grid = PACIFIC_GRID_10
tile_index = tuple(int(i) for i in tile_id.split(","))
geobox = grid.tile_geobox(tile_index)

gmw = get_gmw()
geom = geobox.geographic_extent.to_crs(gmw.crs)

areas = gmw.intersection(geom)

configure_s3_access(cloud_defaults=True)

catalog = "https://stac.digitalearthpacific.org"
collection = "dep_s2_geomad"

searcher = PystacSearcher(
    catalog=catalog, collections=[collection], datetime=year
)

loader = OdcLoader(
    bands=["red", "nir"],
    # chunks=[-1, 2048, 2048],
    groupby="solar_day",
    fail_on_error=False,
    clip_to_area=False,
)

processor = MangrovesProcessor(areas)

In [None]:
# Check out the study site
geobox.explore()

In [None]:
items = searcher.search(geobox)
print(f"Found {len(items)} items")

In [None]:
# Run the load process, which uses Dask, so it's fast
input_data = loader.load(items, geobox)
input_data

In [None]:
import xarray as xr

OUTPUT_NODATA = 255

data = input_data

data = data.squeeze()

# Scale and offset the data
data = (data * (1 / 10_000) + 0).clip(0, 1)

# Mask to only keep areas identified as mangroves in the GMW dataset
data = data.odc.mask(areas)

# Create NDVI
data["ndvi"] = (data.nir - data.red) / (data.nir + data.red)

# Create an empty DataArray to store the mangroves classification
data["mangroves"] = xr.full_like(data.ndvi, OUTPUT_NODATA, dtype="uint8")

# Classify so that less than 0.4 is 0, between 0.4 and 0.7 is 1, and greater than 0.7 is 2
data["mangroves"] = xr.where(data.ndvi <= 0.4, 0, data.mangroves)
data["mangroves"] = xr.where(
    (data.ndvi > 0.4) & (data.ndvi <= 0.7), 1, data.mangroves
)
data["mangroves"] = xr.where(
    (data.ndvi > 0.7), 2, data.mangroves
)

# Mask nodata from the NDVI
data["mangroves"] = data.mangroves.where(data.ndvi.notnull(), OUTPUT_NODATA)

# Only keep the mangroves band and set nodata
data = data[["mangroves"]].astype("uint8")
data.mangroves.odc.nodata = OUTPUT_NODATA

data

In [None]:
data.mangroves.odc.explore(vmin=0, vmax=2, nodata=-9999)

In [None]:
# Plot data. Yellow is not-mangrove, green is open and dark green is closed 
data.mangroves.plot.imshow(levels=[0, 1, 2, 3], colors=["white", "yellow", "green", "darkgreen"])