### Vis of some test
* Both viirs and modis (id=0 and id=2)


In [None]:
# Test that it works somewhat
import h5py
from matplotlib.axes import Axes
from wildfire.data_types import *
from wildfire.data_utils import *
from wildfire.training_utils import *

import matplotlib.pyplot as plt

def upsample(ndarray: np.ndarray, new_size, mode="bilinear"):
    # ensure 4D
    org_dim = len(ndarray.shape)
    while len(ndarray.shape) < 4:
        ndarray = np.expand_dims(ndarray, 0)
    res = torch.nn.functional.interpolate(
        torch.tensor(ndarray),
        size=new_size,
        mode=mode,
    ).numpy()
    while len(res.shape) > org_dim:
        res = res[0]
    return res

def find_median_run_dir(run_dir: str):
    paths = sorted(glob(f"{run_dir}_*/test_results.json"))
    metrics = [json_load(p) for p in paths]
    ious = [m["test/0.5/iou"] for m in metrics]
    median_idx = np.argsort(ious)[len(ious) // 2]
    return f"{run_dir}_{median_idx}", ious[median_idx]


class DS:
    def __init__(self, run_dir: str):
        self.metrics = json_load(run_dir + "/test_results.json")
        self.config = json_load(run_dir + "/config.json")
        print(self.metrics["test/0.5/iou"])
        self.ds = h5py.File(run_dir + "/tensors.h5", "r")
        sat_path = self.ds["test"].attrs["sat_path"]
        self.sat = h5py.File(sat_path, "r")
        self.sat_dates = list(self.sat["num_fire_pixels_by_day"].keys())

    def load_rgb(self, x, y, t):
        path = ["cells", H5Grid.get_cell_path(x, y), self.sat_dates[t], "day"]
        cell = h5_get_nested(self.sat, path)
        if cell is None: 
            return None
        if not self.config["is_modis"]:
            rgb = cell["hi"][...][[2,1, 0]].transpose(1, 2, 0)
            # rgb = cell["hi"][...][[3,1, 0]].transpose(1, 2, 0)
            rgb[rgb == config.uint16_no_data] = rgb.min()
            rgb = rgb.astype(np.float32)
        else:
            if "modis_250" not in cell:
                return None
            r = cell["modis_250"][...][0:1].astype(np.float32)
            gb = cell["modis_500"][...][[1, 0]].astype(np.float32)
            r[r == config.uint16_no_data] = r.min()
            gb[gb == config.uint16_no_data] = gb.min()
            gb = upsample(gb, r.shape[1:])
            rgb = np.concatenate([r, gb], axis=0)
            rgb = rgb.transpose(1, 2, 0)
        
        # maxx = rgb.reshape(-1, 3).max(axis=0)
        # rgb = rgb / maxx.reshape(1, 1, 3)
        rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min())
        return rgb

    def load_stats(self, ax: Axes,sample, size):
        next_fire_cls = sample["next_fire_cls"][0]
        loss_mask = sample["loss_mask"][0]
        pred_prob = sample["pred_prob"][0]
        gt_binary = next_fire_cls > 7
        pred_binary = np.where(pred_prob > 0.5, 1.0, 0.0)
        fn = (1 - pred_binary) * gt_binary
        fp = pred_binary * (1 - gt_binary)
        tp = pred_binary * gt_binary

        stats = [fn, fp, tp]
        colors = [(0.3, 0.3, 1), (1, 0.3, 0.3), (0.3, 1, 0.3)]
        for stat, color in zip(stats, colors):
            stat[~loss_mask] = 0
            stat = upsample(stat, size, "nearest")
            # Create RGB mask with alpha channel
            rgba = np.zeros((*stat.shape, 4))
            rgba[stat > 0] = (*color, 1.0) # Set color with 0.5 alpha transparency
            ax.imshow(rgba)

        


    def load_sample(self, ax: Axes, idx: int):
        epoch_key = list(self.ds["test"])[0]
        sample = h5_get_nested(self.ds, ["test", epoch_key, str(idx)])
        t, y, x = sample["tyx"][...]
        rgb = self.load_rgb(x, y, t + 1)
        if rgb is None:
            print("No rgb")
            return sample
        ax.imshow(rgb)
        self.load_stats(ax, sample, rgb.shape[:2])
        return sample


modis_dir = "/proj/cvl/users/x_juska/data/wildfire/runs/97"
viirs_dir = "/proj/cvl/users/x_juska/data/wildfire/runs/99"
modis_dir, modis_median_iou = find_median_run_dir(modis_dir)
viirs_dir, viirs_median_iou = find_median_run_dir(viirs_dir)
modis_run_id = os.path.basename(modis_dir)
viirs_run_id = os.path.basename(viirs_dir)
modis = DS(modis_dir)
viirs = DS(viirs_dir)

from wildfire.db import *
init_db(config.db_path)


samples = select(InfoSample, {"idx": 0, "iou": 0, "run_id": 0, "tp": 0, "next_num_fire": 0, "cur_num_fire": 0})
samples_modis = [s for s in samples if s["run_id"] == int(modis_run_id)]
samples_viirs = [s for s in samples if s["run_id"] == int(viirs_run_id)]

# Group samples by idx
idx_to_samples = defaultdict(list)
for s in samples_modis + samples_viirs:
    idx_to_samples[s["idx"]].append(s)

# Find samples where both modis and viirs are close to their median IOU
median = []
best = []
worst = []
for idx, samples in idx_to_samples.items():
    if len(samples) != 2:
        continue
    modis_sample, viirs_sample = samples
    modis_iou = modis_sample["iou"]
    viirs_iou = viirs_sample["iou"]
    modis_iou_diff = abs(modis_iou - modis_median_iou)
    viirs_iou_diff = abs(viirs_iou - viirs_median_iou)
    shared = (idx, modis_iou, viirs_iou)
    median.append((max(modis_iou_diff, viirs_iou_diff),*shared) )
    enough = (viirs_sample["next_num_fire"]) > 20
    best.append(((modis_iou + viirs_iou) * enough,*shared) )
    enough = enough and(viirs_sample["cur_num_fire"]) > 5
    enough = enough and max(modis_iou, viirs_iou) > 1.0
    w = 300 if not enough else 0
    worst.append(((modis_iou + viirs_iou) + w,*shared) )



median.sort()
best.sort(reverse=True)
worst.sort()

idxs = {
    3205, 905,
    1096, 1575,
    457, 338

}

print(f"Found {len(median)} matching samples")
cnt = 0
dir = os.path.join(config.root_path, "figures", "samples")
os.makedirs(dir, exist_ok=True)


for metric, idx, modis_iou, viirs_iou in best:
    if idx not in idxs:
        continue
    fig, axs = plt.subplots(1, 2)
    viirs.load_sample(axs[0], idx)
    print(idx)
    axs[0].set_title(f"VNP14 IoU={viirs_iou:.1f}")
    modis.load_sample(axs[1], idx)
    axs[1].set_title(f"MOD14 IoU={modis_iou:.1f}")
    fig.savefig(os.path.join(dir, f"{cnt}.pdf"), dpi=150)
    plt.show()
    cnt += 1
    if cnt > 20:
        break
