In [None]:
from pathlib import Path
from itertools import cycle
from datetime import datetime

import geopandas as gpd
from rasterio.plot import show
import rasterio
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import pandas as pd
import folium
from dateutil.relativedelta import relativedelta

from src.notebook_helper import calculate_zoom_level, contrast_stretch
from src.util import tif_paths, geojson_paths, create_config
from src.scripts.extract_grid_intermediates import extract_grid_intermediates

# Define Variables

In [None]:
CONFIG_FILE = Path("../run_config.yaml")
START_DATE = datetime(2024, 1, 1)
END_DATE = START_DATE + relativedelta(year=1)

# Create a config
CONFIG, SAVE_PATH = create_config(CONFIG_FILE, start_date=START_DATE)

# Inspect All Grids
Plot all the grids on a world map

In [None]:
# Get list of grid files
grid_paths = geojson_paths(CONFIG.grid_dir)

# Load grids
grids = []
for p in grid_paths:
    gdf = gpd.read_file(p)
    gdf["filename"] = p.stem
    grids.append(gdf)

# Combine into a single GeoDataFrame
global_gdf = gpd.GeoDataFrame(pd.concat(grids, ignore_index=True))

# Convert to local CRS for centroid calculations
local_gdf = global_gdf.to_crs(global_gdf.estimate_utm_crs())

In [None]:
# Step 1: Sort by centroid position
global_gdf["centroid_x"] = local_gdf.geometry.centroid.x
global_gdf["centroid_y"] = local_gdf.geometry.centroid.y
global_gdf = global_gdf.sort_values(by=["centroid_y", "centroid_x"], ascending=[True, True]).reset_index(drop=True)

# Step 2: Assign colors cyclically
# Generate a large color palette
colors = plt.get_cmap("tab20", 20)  # 20 distinct colors (use larger if needed)
color_palette = [colors(i) for i in range(colors.N)]
hex_colors = [matplotlib.colors.rgb2hex(c[:3]) for c in color_palette]  # Convert to HEX
color_cycle = cycle(hex_colors)

# Assign colors
global_gdf["color"] = [next(color_cycle) for _ in range(len(global_gdf))]

In [None]:
# Compute the bounding box of all polygons
minx, miny, maxx, maxy = global_gdf.total_bounds

# Calculate the center of the bounding box
center_lat = (miny + maxy) / 2
center_lon = (minx + maxx) / 2

# Calculate dynamic zoom level
zoom_level = calculate_zoom_level(global_gdf.total_bounds)

# Create the base map centered on the calculated location
base_map = folium.Map(location=[center_lat, center_lon], zoom_start=zoom_level, width=800, height=600)

# Add each GeoJSON file to the map
# Add polygons to the map
for _, row in global_gdf.iterrows():
    folium.GeoJson(
        row["geometry"],
        name=row.filename,
        # tooltip=folium.GeoJsonTooltip(fields=["name"], aliases=["Region:"]),
        popup=folium.Popup(row.filename, parse_html=True),
        style_function=lambda x, color=row["color"]: {
            "fillColor": color,
            "color": "black",
            "weight": 1,
            "fillOpacity": 0.5,
        },
    ).add_to(base_map)

# Display the map
base_map

# Inspect UDM extents compared to an AOI

### !!! Set the GRID_ID !!!

In [None]:
GRID_ID = "25059125"

In [None]:
# The results folder for a single grid
GRID_RESULTS_DIR = SAVE_PATH / GRID_ID

# Get the UDM paths for this grid
UDM_PATHS = geojson_paths(GRID_RESULTS_DIR)

# Load the grid geometry
GRID = gpd.read_file(CONFIG.grid_dir / f"{GRID_ID}.geojson")

In [None]:
# Load the UDM GeoJSON file
geojson_file = GRID_RESULTS_DIR / "search_geometries.geojson"
gdf = gpd.read_file(geojson_file)

zoom_level = calculate_zoom_level(GRID.total_bounds)
m = folium.Map(location=(GRID.centroid.iloc[0].y, GRID.centroid.iloc[0].x), zoom_start=zoom_level - 2)

# Add AOI to the map in blue
folium.GeoJson(
    GRID,
    name="AOI",
    style_function=lambda x: {
        "fillColor": "blue",
        "color": "blue",
        "weight": 2,
        "fillOpacity": 0.9,
    },
).add_to(m)

# Plot each polygon with a different color
for _, row in gdf.iterrows():
    folium.GeoJson(
        row["geometry"],
        style_function=lambda feature, color=next(color_cycle): {
            "fillColor": color,
            "color": color,
            "weight": 2,
            "fillOpacity": 0.05,
        },
    ).add_to(m)


# Add layer control
folium.LayerControl().add_to(m)

# Display the map
m

# Inspect downloaded grid images

### Extract the intermediates for visualization

In [None]:
extract_grid_intermediates(config_file=CONFIG_FILE, start_date=START_DATE, grid_id=GRID_ID)

## Visualize ALL Downloaded (and reprojected) UDMs
This is the list of UDMs that were considered for Asset download.

In [None]:
udm_dir = GRID_RESULTS_DIR / "udm_cropped"
udm_paths = tif_paths(udm_dir)
num_images = len(udm_paths)

cols = min(4, num_images)
rows = num_images // cols + num_images % cols
image_size = 5

fig, axes = plt.subplots(rows, cols, figsize=(image_size * cols, image_size * rows))
if cols == rows == 1:
    axes = [axes]
else:
    axes = axes.flatten()
for ax in axes:
    ax.axis("off")

for i, (udm_path, ax) in enumerate(zip(udm_paths, axes)):
    with rasterio.open(udm_path) as src:
        udm = (src.read(1) == 1).astype(np.uint8)

    show(~udm, ax=ax, cmap="binary", title=udm_path.stem)

plt.tight_layout()

## Visualize Downloaded Images (masked to AOI)

In [None]:
asset_dir = GRID_RESULTS_DIR / "files_asset_cropped"
udm_dir = GRID_RESULTS_DIR / "files_udm_cropped"
image_paths = tif_paths(asset_dir)
udm_paths = tif_paths(udm_dir)

cols = min(3, len(image_paths))
rows = len(image_paths) // cols + len(image_paths) % cols
image_size = 5

fig, axes = plt.subplots(rows, cols, figsize=(image_size * cols, image_size * rows))
if cols == rows == 1:
    axes = [axes]
else:
    axes = axes.flatten()
for ax in axes:
    ax.axis("off")

for i, (img_pth, udm_path, ax) in enumerate(zip(image_paths, udm_paths, axes)):
    with rasterio.open(img_pth) as src:
        img = src.read((7, 5, 3), masked=True)

    with rasterio.open(udm_path) as src:
        udm = (src.read(1) == 1).astype(np.uint8)

    img = contrast_stretch(img, p_high=97, p_low=2)
    show(img, ax=ax, title=img_pth.stem)

plt.tight_layout()

## Visualize coverage counts

In [None]:
cols = 2
rows = 1
image_size = 5

fig, axes = plt.subplots(rows, cols, figsize=(image_size * cols, image_size * rows))
if cols == rows == 1:
    axes = [axes]
else:
    axes = axes.flatten()
for ax in axes:
    ax.axis("off")

# Create counter
with rasterio.open(udm_paths[0]) as src:
    img = (src.read(1) == 1).astype(np.uint8)
    counter = np.zeros_like(img)

for file in udm_paths:
    with rasterio.open(file) as src:
        img = (src.read(1) == 1).astype(np.uint8)
        counter += img

print("Counts min:", np.unique(counter)[1], "max:", counter.max())

show(counter, cmap="inferno", title="Counts", ax=axes[0])
_ = show(counter.clip(0, 5), cmap="inferno", title="Clipped Counts", ax=axes[1])