In [None]:
from pystac_client import Client
from odc.stac import load
import xarray as xr
import numpy as np

import odc.geo # noqa

In [None]:
catalog = "https://stac.staging.digitalearthpacific.org"
collection = "dep_s2_mangroves"

In [None]:
# Coordinates grabbed from Google Maps
# # Southeast PNG
# ll = (-10.590125, 149.844629)
# ur = (-10.360110, 150.195631)

# # Ba river mouth, Fiji
# ll = (-17.500881, 177.608558)
# ur = (-17.420771, 177.702546)

# Fiji, Vanua Levu
ll = (-16.540442,178.767840)
ur = (-16.482047,178.825006)

bbox = (ll[1], ll[0], ur[1], ur[0])

# Find STAC items
client = Client.open(catalog)
items = client.search(collections=[collection], bbox=bbox).item_collection()

print(f"Found {len(items)} items")

In [None]:
data = load(items, bbox=bbox, bands=["mangroves"], dtype="int16", chunks={})
data

In [None]:
data.mangroves.plot.imshow(
    col="time",
    col_wrap=4,
    levels=[0, 1, 2, 3],
    colors=["white", "yellow", "green", "darkgreen"],
)

In [None]:
# Define the values you want to count (0, 1, and 2)
values_to_count = [0, 1, 2]

# Initialize an empty DataArray to store the counts
count_array = xr.DataArray(
    np.zeros((len(data["time"]), len(values_to_count))),
    coords={"time": data["time"], "values": values_to_count},
    dims=["time", "values"],
)

# Loop through each value and count occurrences in each year
for time in data.time:
    year = time.values.astype("datetime64[Y]")
    one_year_data = data.sel(time=time)
    count = one_year_data.mangroves.groupby(one_year_data.mangroves).count()
    for i, v in enumerate(values_to_count):
        if v not in count.mangroves:
            # Add the missing count to the array
            count_array.loc[{"time": time, "values": v}] = 0
        else:
            # Add the count to the array
            val = count.sel(mangroves=v)
            count_array.loc[{"time": time, "values": v}] = val * 10 / 10000

# Rename the count variable
count_array = count_array.rename("count")
count_array

In [None]:
count_array.plot.line(x="time", hue="values")

In [None]:
first_year = count_array.time.min().values.astype("datetime64[Y]")
last_year = count_array.time.max().values.astype("datetime64[Y]")

data = data.where(data.mangroves >= 0)

change = (data.sel(time=last_year) - data.sel(time=first_year))
change

In [None]:
change.mangroves.plot.imshow(cmap="RdBu")

In [None]:
change.where(change.mangroves!=0).mangroves.odc.explore(
    cmap="RdBu",
    tiles="https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}",
    attr="ESRI WorldImagery",
)

In [None]:
# for time in data.time:
#     year = time.values.astype("datetime64[Y]")
#     one_year_data = data.sel(time=time)
#     one_year_data.mangroves.odc.write_cog(f"mangroves_{year}.tif")