In [None]:
import json
import os

import leafmap.foliumap as leafmap
import matplotlib.pyplot as plt
import pystac_client
import s3fs
import xarray as xr
from dask.distributed import Client as DaskClient
from dotenv import load_dotenv
from odc.stac import load
from planetary_computer import sign_url
from shapely.geometry import shape

In [None]:
load_dotenv()

In [None]:
dask_client = DaskClient()
dask_client

In [None]:
m = leafmap.Map(draw_export=True)
m.add_basemap("ESA WorldCover 2021")
m.add_legend(builtin_legend="ESA_WorldCover")
m.add_vector("aoi.geojson")
m

In [None]:
with open("aoi.geojson") as file:
    area_of_interest = json.load(file)
    area_of_interest = area_of_interest["features"][0]["geometry"]
    geom = shape(area_of_interest)
    bbox = list(geom.bounds)

In [None]:
catalog = pystac_client.Client.open(
    "https://planetarycomputer.microsoft.com/api/stac/v1/"
)
collection = "modis-09Q1-061"
bbox = bbox
start_date = "2024-07"
end_date = "2025-07-18"

In [None]:
search = catalog.search(
    collections=[collection],
    bbox=bbox,
    datetime=f"{start_date}/{end_date}",
)
items = search.item_collection()

In [None]:
len(items)

In [None]:
data = load(
    items,
    bands=["red", "nir08", "sur_refl_qc_250m"],
    bbox=bbox,
    chunks={"x": 2048, "y": 2048},
    groupby="solar_day",
    patch_url=sign_url,
)

In [None]:
data

In [None]:
# Mask out low-quality pixels
mask = data.sur_refl_qc_250m.where((data.sur_refl_qc_250m & 0b11110000) == 0)

data = data.where(mask)
data = data.drop_vars("sur_refl_qc_250m")

In [None]:
ndvi = (data.nir08 - data.red) / (data.nir08 + data.red)
data["ndvi"] = ndvi.clip(-1, 1)
data = data.drop_vars(["red", "nir08"])

In [None]:
data

In [None]:
data = data.compute()

In [None]:
data.to_zarr("data/ndvi.zarr", mode="w", consolidated=True)

In [None]:
# Start here if data is already downloaded
data = xr.open_zarr("data/ndvi.zarr", chunks={"time": -1, "x": 2048, "y": 2048})
ndvi = data.ndvi

In [None]:
ndvi

In [None]:
ndvi_time_series = ndvi.mean(dim=["x", "y"])

In [None]:
plt.plot(
    ndvi_time_series["time"].values,
    ndvi_time_series.values,
    marker=".",
)

In [None]:
ndvi.isel(time=slice(0, 6)).plot(col="time", col_wrap=3, cmap="viridis")

In [None]:
filled = ndvi.interpolate_na("time", method="linear").bfill("time").ffill("time")

In [None]:
smoothed = filled.rolling(time=3, center=True).mean()

In [None]:
# start and end dates are nan after rolling average
trimmed = smoothed.isel(time=slice(1, -1))

In [None]:
mean = ndvi.mean(dim=["x", "y"])
smoothed_mean = smoothed.mean(dim=["x", "y"])

plt.scatter(
    mean["time"].values,
    mean.values,
    s=10,
    color="orange",
    label="Original Data",
)
plt.plot(
    smoothed_mean["time"].values,
    smoothed_mean.values,
    marker=".",
    label="Rolling average",
)
plt.title("Mean Value Over Time")
plt.xlabel("Time")
plt.ylabel("Mean Value")
plt.grid(True)
plt.legend()

In [None]:
trimmed.to_zarr("data/ndvi_processed.zarr", mode="w", consolidated=True)

In [None]:
bucket_name = os.environ["S3_BUCKET"]
s3_path = f"{bucket_name}/ndvi_processed.zarr"
fs = s3fs.S3FileSystem()
fs.put("./data/ndvi_processed.zarr/", s3_path, recursive=True)

In [None]:
fs.put("last_end_date.txt", bucket_name)
fs.put("aoi.geojson", bucket_name)