In [None]:
from pystac_client import Client
from odc.stac import load

from sklearn.ensemble import AdaBoostRegressor

from odc.algo import mask_cleanup

import geopandas as gpd
import pandas as pd
import numpy as np
import xarray as xr

import folium

In [None]:
in_points = "tuvalu_20.gpkg"

gdf = gpd.read_file(in_points)

In [None]:
less_than_50 = gdf[gdf["depth"] > -50]

sample = less_than_50.sample(25000)
second_sample = less_than_50.sample(25000)
sample.explore(column="depth", cmap="Blues_r")

In [None]:
catalog = "https://earth-search.aws.element84.com/v1"
client = Client.open(catalog)

# Get extents of gdf
bbox = list(gdf.to_crs("epsg:4326").total_bounds)

# Expand the bbox slightly
buffer = 0.01
bbox[0] = bbox[0] - buffer
bbox[1] = bbox[1] - buffer
bbox[2] = bbox[2] + buffer
bbox[3] = bbox[3] + buffer

items = client.search(
    collections=["sentinel-2-c1-l2a"],
    bbox=bbox,
    datetime="2024-01/2024-09",
    query={"eo:cloud_cover": {"lt": 30}},
).item_collection()

print(f"Found {len(items)} items")

In [None]:
data = load(
    items,
    chunks={},
    bbox=bbox,
    groupby="solar_day",
    measurements=[
        "red",
        "green",
        "blue",
        "nir",
        "nir09",
        "swir16",
        "swir22",
        "coastal",
        "rededge1",
        "rededge2",
        "rededge3",
        "scl"
    ]
)

# nodata, cloud shadow, medium cloud, high cloud
mask_flags = [1, 3, 8, 9]
cloud_mask = ~data.scl.isin(mask_flags)
masked = data.where(cloud_mask).drop_vars("scl")

scaled = (masked.where(masked != 0) * 0.0001).clip(0, 1)

scaled = scaled.compute()
scaled

In [None]:
# Add some indices including NDVI, MNDWI, NDWI
scaled["ndvi"] = (scaled.nir - scaled.red) / (scaled.nir + scaled.red)
scaled["ndwi"] = (scaled.green - scaled.nir) / (scaled.green + scaled.nir)
scaled["mndwi"] = (scaled.green - scaled.swir16) / (scaled.green + scaled.swir16)

# Create a single median
median = scaled.median("time")

In [None]:
scaled[["red", "green", "blue"]].to_array().plot.imshow(col="time", col_wrap=2, vmin=0, vmax=0.2)

In [None]:
median[["red", "green", "blue"]].to_array().plot.imshow(size=6, vmin=0, vmax=0.2)

In [None]:
reprojected = sample.to_crs(median.odc.crs)

# Convert the geodataframe to an xarray
pts_da = sample.assign(x=reprojected.geometry.x, y=reprojected.geometry.y).to_xarray()

# Extract values from the EO data onto the points xarray, and convert back to pandas
pt_values_i = (
    median.sel(pts_da[["x", "y"]], method="nearest").squeeze().compute().to_pandas()
)

In [None]:
training_array = pd.concat([sample, pt_values_i], axis=1)
training_array = training_array.drop(
    columns=[
        "y",
        "x",
        "spatial_ref",
        "geometry",
        "index_left"
    ]
)
# Drop rows where there are any NaNs
training_array = training_array.dropna()

training_array.head()

In [None]:
training_data = np.array(training_array)[:, 1:]
values = np.array(training_array)[:, 0]

In [None]:
regr = AdaBoostRegressor()

model = regr.fit(training_data, values)

In [None]:
predictions = []

for i in range(len(scaled.time)):
    one_time = scaled.isel(time=i)

    # Replace nans with -9999
    one_time = one_time.fillna(-9999)

    stacked_arrays = one_time.to_array().stack(dims=["y", "x"]).transpose()

    p = model.predict(stacked_arrays)
    array = p.reshape(len(masked.y), len(masked.x))
    predictions.append(xr.DataArray(
        array, coords={"x": masked.x, "y": masked.y}, dims=["y", "x"]
    ))

print(f"Completed predicting {len(scaled.time)} time slices")

In [None]:
# Combine predictions into an xarray
predicted = xr.concat(predictions, dim=scaled.time).to_dataset(name="depth")
predicted = predicted.where(cloud_mask)

In [None]:
predicted.depth.plot.imshow(col="time", col_wrap=2, cmap="viridis")

In [None]:
average = predicted.depth.median("time")  # Use mean or median

In [None]:
# Closer to 1 is better
model.score(training_data, values)

In [None]:
coords = (bbox[1] + bbox[3])/2, (bbox[0] + bbox[2])/2
m = folium.Map(location=coords, zoom_start=12, layer_control=True)

visual = median.odc.to_rgba(["red", "green", "blue"], vmin=0, vmax=0.3)
visual.odc.add_to(m, name="RGB")

predicted.isel(time=0).depth.odc.add_to(m, name="Depth", cmap="Blues_r")
average.odc.add_to(m, name="Average Depth", cmap="Blues_r")

# Layer control
folium.LayerControl().add_to(m)

m

In [None]:
second_sample_reprojected = second_sample.to_crs(median.odc.crs)
pts_da = second_sample.assign(x=second_sample_reprojected.geometry.x, y=second_sample_reprojected.geometry.y).to_xarray()

# Extract values from the EO data onto the points xarray, and convert back to pandas
compare_depths = (
    average.sel(pts_da[["x", "y"]], method="nearest").squeeze().compute().to_pandas()
).rename("depth_computed")

appended = pd.concat([second_sample, compare_depths], axis=1)

appended["error"] = appended.depth - appended.depth_computed

appended["error"].hist(bins=20)

In [None]:
water = ((median.mndwi + median.ndwi) > 0)
water_filtered = mask_cleanup(water, [("opening", 5)])
water_filtered.plot.imshow(size=6)

In [None]:
final = average.where(water_filtered)
final.plot.imshow()

In [None]:
final.odc.write_cog("tuvalu_depth_test2.tif", overwrite=True)