In [None]:
from wildfire.data_types import *
from wildfire.dataset_utils import *

import matplotlib.pyplot as plt
from einops import reduce
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.axes import Axes
import cartopy.crs as ccrs
import cartopy.feature as cfeature


projection = ccrs.PlateCarree()
fig, ax = plt.subplots(1, 2, subplot_kw={"projection": projection},)
aoi2 = [config.aoi[0], config.aoi[2], config.aoi[1], config.aoi[3]]


full = config.patch_size
third = config.patch_size_km
default_val = np.full((full, full), 0, dtype=np.int16)
cs = config.cell_size
@dataclass
class Item:
    mask: np.ndarray
    bbox: list[float]

def get_items(h5_path: str):
    h5 = h5py.File(h5_path, "r")
    dates = list(h5["num_fire_pixels_by_day"].keys())
    by_xy = np.zeros((300, 300), dtype=np.int32)
    for date in dates:
        num_fire = h5["num_fire_pixels_by_day"][date]
        yy, xx = num_fire.shape
        by_xy[:yy, :xx] += num_fire
    for x,y, min_lon, min_lat in H5Grid.get_cells():
        if by_xy[y, x] == 0:
            continue
        bbox = [min_lon, min_lon + cs, min_lat, min_lat + cs]
        yield Item(by_xy[y:y+1, x:x+1], bbox)
        continue
        xy = H5Grid.get_cell_path(x, y)
        def get_cell(date, t):
            cell = h5_get_nested(h5, ["cells", xy, date, t, "fire_mask"])
            if cell is None:
                return default_val
            return cell[0]
        cur = default_val[...]
        for date in dates:
            day = get_cell(date, "day")
            night = get_cell(date, "night")
            
            this = np.maximum(day, night)
            cur = np.maximum(cur, this)
        
        cur = reduce(cur, "(h 3) (w 3) -> h w", "max", h=third, w=third)
        fire = cur >= 7
        mask = np.ma.masked_where(fire == 0, cur).astype(np.float32)
        bbox = [min_lon, min_lon + cs, min_lat, min_lat + cs]
        print(x, y)
        yield Item(mask, bbox)

def plot_ax(ax: Axes, items: list[Item], title: str):

    ax.set_extent(aoi2, crs=projection)
    ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.BORDERS, linestyle=":")
    ax.add_feature(cfeature.LAND, facecolor="lightgray")
    for item in items:
        im = ax.imshow(
            item.mask, cmap="YlOrRd", vmin=0, vmax=5000,
            extent=item.bbox, transform=projection,
            alpha=0.8
        )
    ax.set_title(title)
    return im

items_train = list(get_items(h5_get_path("viirs", False)))
items_test = list(get_items(h5_get_path("viirs", True)))

im = plot_ax(ax[0], items_train, "2019-2020")
im = plot_ax(ax[1], items_test, "2023-2024")
fig.tight_layout(rect=[0, 0, 0.88, 1])
cax = fig.add_axes([0.88, 0.3, 0.02, 0.4])
path = os.path.join(config.root_path, "figures", "fires_2019-2024.pdf")
print(path)
plt.colorbar(im, cax=cax, label="Number of fire pixels")
fig.savefig(path, dpi=150)



