In [None]:
import json
import os

import leafmap.foliumap as leafmap
import matplotlib.pyplot as plt
import numpy as np
import pystac_client
import s3fs
import xarray as xr
from dask.distributed import Client as DaskClient
from dotenv import load_dotenv
from odc.stac import load
from planetary_computer import sign_url
from shapely.geometry import shape

In [None]:
load_dotenv()

In [None]:
dask_client = DaskClient()
dask_client

In [None]:
m = leafmap.Map(draw_export=True)
m.add_basemap("ESA WorldCover 2021")
m.add_legend(builtin_legend="ESA_WorldCover")
m.add_vector("aoi.geojson")
m

In [None]:
with open("aoi.geojson") as file:
    area_of_interest = json.load(file)
    area_of_interest = area_of_interest["features"][0]["geometry"]
    geom = shape(area_of_interest)
    bbox = list(geom.bounds)

In [None]:
catalog = pystac_client.Client.open(
    "https://planetarycomputer.microsoft.com/api/stac/v1/"
)
collection = "landsat-c2-l2"
bbox = bbox
start_date = "2023-08"
end_date = "2025-07"

In [None]:
search = catalog.search(
    collections=[collection],
    bbox=bbox,
    datetime=f"{start_date}/{end_date}",
)
items = search.item_collection()

In [None]:
len(items)

In [None]:
data = load(
    items,
    bands=["red", "nir08", "qa_pixel"],
    bbox=bbox,
    chunks={"x": 2048, "y": 2048},
    resolution=300,
    groupby="solar_day",
    patch_url=sign_url,
)

In [None]:
data

In [None]:
# Mask out nodata and cloud pixels
# Bit 3 is cloud shadow, bit 4 is cloud, and bit 0 is nodata
mask_bits = 0b00011001

mask = (data.qa_pixel & mask_bits) != 0

data = data.where(~mask, other=np.nan).drop_vars("qa_pixel")

In [None]:
ndvi = (data.nir08 - data.red) / (data.nir08 + data.red)
data["ndvi"] = ndvi.clip(-1, 1)
data = data.drop_vars(["red", "nir08"])

In [None]:
data = data.compute()

In [None]:
data.to_zarr("data/ndvi.zarr", mode="w", consolidated=True)

In [None]:
# Start here if data is already downloaded
data = xr.open_zarr("data/ndvi.zarr")
ndvi = data.ndvi

In [None]:
ndvi

In [None]:
ndvi_time_series = ndvi.mean(dim=["x", "y"])

In [None]:
plt.plot(
    ndvi_time_series["time"].values,
    ndvi_time_series.values,
    marker=".",
)

In [None]:
ndvi.isel(time=slice(0, 6)).plot(col="time", col_wrap=3, cmap="viridis")

In [None]:
eight_day = data.ndvi.resample(time="8D").max()

In [None]:
eight_day.isel(time=slice(0, 6)).plot(col="time", col_wrap=3, cmap="viridis")

In [None]:
eight_day = eight_day.chunk({"time": -1})
filled = eight_day.interpolate_na("time", method="linear").bfill("time").ffill("time")

In [None]:
filled.isel(time=slice(0, 6)).plot(col="time", col_wrap=3, cmap="viridis")

In [None]:
mean = ndvi.mean(dim=["x", "y"])
filled_mean = filled.mean(dim=["x", "y"])

plt.scatter(
    mean["time"].values,
    mean.values,
    s=10,
    color="orange",
    label="Original Data",
)
plt.plot(
    filled_mean["time"].values,
    filled_mean.values,
    marker=".",
    label="Resampled and filled",
)
plt.title("Mean Value Over Time")
plt.xlabel("Time")
plt.ylabel("Mean Value")
plt.grid(True)
plt.legend()

In [None]:
processed = xr.Dataset(
    {
        "ndvi_8d_raw": eight_day,
        "ndvi_8d_processed": filled,
    }
)

In [None]:
processed = processed.chunk({"time": 3, "x": 100, "y": 100})

In [None]:
processed

In [None]:
processed.to_zarr("data/ndvi_processed.zarr", mode="w", consolidated=True)

In [None]:
bucket_name = os.environ["S3_BUCKET"]
s3_path = f"{bucket_name}/ndvi_processed.zarr"
fs = s3fs.S3FileSystem()
fs.put("./data/ndvi_processed.zarr/", s3_path, recursive=True)

In [None]:
fs.put("aoi.geojson", bucket_name)