In [None]:
########## IMPORTS ##########

# common
import rasterio
import numpy as np
import matplotlib.pyplot as plt

# test & display
import itertools
import traceback
from pprint import pprint
import importlib # allows updating the functions by rerunning this import cell (otherwise the functions remain stored and frozen in the cache)

# Imports from src
import os
import sys

PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))  # add the project root directory to the PYTHONPATH
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

import src.main_dtod_test as main_test
importlib.reload(main_test)
from src.ground_truth import (
    shapefile_to_mask,
    confusion_from_masks,
    metrics_from_masks,
)       

# NB: itertools, traceback, pprint, importlib, os and sys are already included in the standard library 

In [None]:
########## DATA ##########


path_img1 = r"C:\Users\guigu\Documents\pro_asus\vigisar\data\preprocessed\zta5\zta5_pre_sigma.tif"
path_img2 = r"C:\Users\guigu\Documents\pro_asus\vigisar\data\preprocessed\zta5\zta5_post_sigma.tif"
shp_path = r"C:\Users\guigu\Documents\pro_asus\vigisar\data\preprocessed\zta5\zta5_grd_truth\zta5_grd_truth.shp"
out_path = r"C:\Users\gbonlieu\Documents\codepythonoutil\outil_detection_changement\ouputs\output_vrac\t3bis_out\main_out.tif"
ref_raster_path = path_img1

truth = shapefile_to_mask(shp_path=shp_path, ref_raster_path=ref_raster_path) # out_path=out_path)
nb_whites, nb_blacks = truth.sum(), truth.size - truth.sum()
print (f"Number of whites: {nb_whites}\nNumber of blacks: {nb_blacks}")
fig, ax = plt.subplots(figsize=(6, 6))
im = ax.imshow(truth,
               cmap="gray",
               interpolation="nearest",  # pas de lissage
               vmin=0, vmax=1)           # valeurs 0/1 bien séparées
ax.set_title("Truth")
plt.show()

# fig.colorbar(im, ax=ax, label="Pixel Values")


In [None]:
########## SINGLE TEST & DISPLAY ##########

"""
Small test script to run main_dtod_test() on two rasters and print the result.
"""


def test_main_display(path_img1: str, path_img2: str, shp_path: str,
                      n: int, k: float = 1.0, closings: bool = False,
                      p: int = 27, d: float = 0.5, a: int = 3000, 
                      out_path: str | None = None):
    """
    Runs main_dtod_test() with the given parameters, computes performance
    metrics using the ground-truth shapefile, and displays the
    result (binary mask) with the parameters & metrics as annotations.
    """

    # --- Compute predicted mask ---
    profile, filt = main_test.main_dtod_test(path_img1, path_img2,
                          n, k=k, closings=closings,
                          p=p, d=d, a=a, out_path=out_path)

    # --- Load and rasterize ground-truth ---
    mask_ref = shapefile_to_mask(shp_path, ref_raster_path=path_img1)

    # --- Compute performance metrics ---
    mets = metrics_from_masks(mask_ref, filt)

    # --- Prepare nice figure ---

    w, h = profile['width'], profile['height']
    fig, ax = plt.subplots(figsize=((10), (7)))

    im = ax.imshow(filt, cmap="gray")

    # --- Title: parameters ---
    ax.set_title(
        f"Result mask\n\n"
        f"n={n}, k={k}, p={p}, d={d}, a={a}, closings={closings}",
        fontsize=11
    )

    # --- Text with metrics (including Kappa) ---
    txt = (
        f"MCC={mets['mcc']:.3f}   " 
        f"Kappa={mets['kappa']:.3f}   " 
        f"F1={mets['F1']:.3f}   " 
        f"P={mets['precision']:.3f}   " 
        f"R={mets['recall']:.3f}   " 
        f"Acc={mets['accuracy']:.3f}"
    )

    ax.text(0.5, -0.08, txt, transform=ax.transAxes, ha="center", va="top", fontsize=10)

    plt.show()
    print("Full profile :")
    pprint(profile)


# path_img1 = r"C:\Users\gbonlieu\Documents\codepythonoutil\outil_detection_changement\data\raw\testnotebooks\t5\t5_pre_8bands.tif"
# path_img2 = r"C:\Users\gbonlieu\Documents\codepythonoutil\outil_detection_changement\data\raw\testnotebooks\t5\t5_post_8bands.tif"
# shp_path = r"C:\Users\gbonlieu\Documents\codepythonoutil\outil_detection_changement\data\raw\testnotebooks\t5\t5_grd_truth\t5_grd_truth_0709.shp"
# out_path = r"C:\Users\gbonlieu\Documents\codepythonoutil\outil_detection_changement\ouputs\output_vrac\t3bis_out\main_out.tif"

test_main_display(path_img1, path_img2, shp_path,
                  n=1, k=1.1, closings=False, p=30, d=0.5, a=3000)


In [None]:
########## FILTER PARAMETERS 2D TEST & DISPLAY  ##########



def test_parametres_2D(
    path_img1: str, path_img2: str,
    idx1: int,
    idx2: int,
    list1: list,
    list2: list,
    closings: bool = False,
    n: int = 1,
    k: float = 1.15,
    p: int = 30,
    d: float = 0.5,
    a: int = 3000,
    shp_path: str | None = None,
):
    """
    Explores all possible combinations of TWO parameters among (n, k, p, d, a).

    Variable parameters are selected using their indices:
    - 1 -> n (tile size / number of tiles)
    - 2 -> k (threshold multiplier)
    - 3 -> p (filter parameter p)
    - 4 -> d (filter parameter d)
    - 5 -> a (minimum area)

    idx1, idx2 : integers in [1,5] and must be different
    list1, list2 : lists of values tested for these two parameters

    Other non-explored parameters:
    - default values (n=… k=1.15, p=30, d=0.5, a=4000)
    - unless explicitly passed in arguments
    """

    # ---- Basic checks ----
    if idx1 == idx2:
        raise ValueError("idx1 and idx2 must be different (from 1 to 5).")

    for idx in (idx1, idx2):
        if idx < 1 or idx > 5:
            raise ValueError("Indices must be between 1 and 5.")

    # ---- Mapping index → parameter name ----
    index_to_name = {1: "n", 2: "k", 3: "p", 4: "d", 5: "a"}

    name1 = index_to_name[idx1]
    name2 = index_to_name[idx2]

    # ---- Base parameter values ----
    base_params = {"n": n, "k": k, "p": p, "d": d, "a": a}

    # ---- Ground truth mask (optional) ----
    mask_ref = None
    if shp_path is not None:
        mask_ref = shapefile_to_mask(shp_path=shp_path, ref_raster_path=path_img1)

    # ---- Read raster shape for figure geometry ----
    with rasterio.open(path_img1) as src:
        W = src.width
        H = src.height

    # ---- Build combination grid (param1, param2) ----
    combos = list(itertools.product(list1, list2))
    n_combos = len(combos)

    # ---- "Square-like" grid ----
    ncols = len(list2)
    nrows = len(list1)

    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * (W / 100), nrows * (H / 100)))
    axes = np.array(axes).reshape(nrows, ncols)  # to avoid errors when there is only one row or column, this forces axes to always be a 2D array with the correct shape

    # ---- Main loop over combinations ----
    for idx, (val1, val2) in enumerate(combos):  # enumerate() lets you iterate over a list while keeping the current index
        row = idx // ncols
        col = idx % ncols
        ax = axes[row, col]

        # start with base parameters
        params = base_params.copy()
        params[name1] = val1
        params[name2] = val2

        try:
            # call main algorithm
            result = main_test.main_dtod_test(path_img1, path_img2, n=params["n"], k=params["k"], closings=closings, p=params["p"], d=params["d"], a=params["a"])[1]

            im = ax.imshow(result, cmap="gray")

            # title with the two explored parameters    
            ax.set_title(f"{name1}={val1:.2f}, {name2}={val2:.2f}", fontsize=15)

            # metrics if ground truth provided
            if mask_ref is not None:
                # metrics_from_masks returns the 'kappa' key
                mets = metrics_from_masks(mask_ref, result)
                txt = (
                    f"MCC={mets['mcc']:.3f} "
                    f"Kappa={mets['kappa']:.3f}\n" # Added Kappa here
                    f"F1={mets['F1']:.3f} "
                    f"P={mets['precision']:.3f} "
                    f"R={mets['recall']:.3f} "
                    f"Acc={mets['accuracy']:.3f}"
                )
                ax.text(0.5, -0.05, txt, transform=ax.transAxes, ha="center", va="top", fontsize=11)

        except Exception as e:
            # display error message for this combination
            print(f"\n=== ERROR DURING PARAMETER COMBINATION {val1}, {val2} ===")
            print("Message :", e)
            import traceback
            traceback.print_exc() 

    # ---- Hide unused axes if grid not full ----
    for idx in range(n_combos, nrows * ncols):
        row = idx // ncols
        col = idx % ncols
        axes[row, col].axis("off")

    plt.tight_layout()
    plt.show()

# path_img1 = r"C:\Users\gbonlieu\Documents\codepythonoutil\outil_detection_changement\data\raw\testnotebooks\t3bis\t3bis_pre_varVV.tif"
# path_img2 = r"C:\Users\gbonlieu\Documents\codepythonoutil\outil_detection_changement\data\raw\testnotebooks\t3bis\t3bis_0208_varVV.tif"
# shp_path = r"C:\Users\gbonlieu\Documents\codepythonoutil\outil_detection_changement\data\raw\testnotebooks\t3bis\t3bis_grd_truth\t3bis_grdtruth_0204shp\t3bis_grdtruth_0204.shp"
list_n = [int(x) for x in np.linspace(1, 3, 3)]
list_k = [x for x in np.linspace(1, 1.4, 5)]
list_p = [round(x) for x in np.linspace(25, 35, 10)] 
list_d = [x for x in np.linspace(0.5, 0.5, 1)]
list_a = [round(x) for x in np.linspace(3000, 4000, 5)]


test_parametres_2D(
    path_img1,
    path_img2,
    idx1=2,        
    idx2=3,        
    list1=list_k,
    list2=list_p,
    closings=False,
    n=1,
    k=1.1,
    d=0.5,
    a=3000,
    shp_path=shp_path
)

In [None]:
########## FILTER PARAMETERS GRID SEARCH ##########


def grid_search_best_params(
    path_img1: str,
    path_img2: str,
    list_n: list[int],
    list_k: list[float],
    list_p: list[int],
    list_d: list[float],
    list_a: list[int],
    closing: bool = False,
    shp_path: str | None = None,
    top: int = 3,
    mcc_ref: float = 0.6
):
    """
    Explores all possible parameter combinations (n, k, p, d, a),
    evaluates metrics against ground truth, and:

      - returns the top 3 tuples for MCC, precision (with MCC>mcc_ref), recall (with MCC>mcc_ref)
      - displays a 3*3 figure: top 3 MCC, top 3 Precision, top 3 Recall

    Each subplot shows:
      - the binary output of main_dtod(...)
      - title: values of n, k, p, d, a
      - text under image: MCC, Kappa, F1, P, R, Acc
    """

    if shp_path is None:
        raise ValueError("shp_path cannot be None: ground truth is required ")

    # ----- 1) Ground truth rasterized on reference image -----
    mask_ref = shapefile_to_mask(shp_path=shp_path, ref_raster_path=path_img1)

    # ----- 2) Read raster to get dimensions -----
    with rasterio.open(path_img1) as src:
        w = src.width
        h = src.height

    # ----- 3) Build combinations -----
    all_combos = list(itertools.product(list_n, list_k, list_p, list_d, list_a))

    perf = []  # perf is a list of dict, each dict stores params, metrics, result

    # ----- 4) Loop over all combinations -----
    for (n, k, p, d, a) in all_combos:
        try:
            # main algorithm call
            result = main_test.main_dtod_test(path_img1, path_img2, n=n, k=k, closing=closing, p=p, d=d, a=a)[1]

            mets = metrics_from_masks(mask_ref, result)

            perf.append({
                "params": {"n": n, "k": k, "p": p, "d": d, "a": a},
                "metrics": mets,
                "result": result,
            })

        except Exception as e:
            # display error message for this combination
            print(f"[WARNING] Skipped combination (n={n}, k={k}, p={p}, d={d}, a={a}) "
                  f"due to error: {e}")
            continue

    if not perf:  # list perf empty
        raise RuntimeError("No valid combination produced a result.")

    # ----- 5) Sort by each metric -----
    def sort_by(metric_name: str):
        '''
        sorted sorts perf, it needs a key= to know what to compare on 
        lambda takes a dict r from perf and returns r["metrics"][metric_name] (in our case mcc or precision or recall)
        perf is thus sorted by values of metric_name (mcc, precision, recall)
        ''' 
        return sorted(perf, key=lambda r: r["metrics"][metric_name], reverse=True)
    
    sorted_mcc = sort_by("mcc")
    sorted_prec = sort_by("precision")
    sorted_rec = sort_by("recall")

    # keep top X
    top_mcc = sorted_mcc[:top]
    top_prec = [r for r in sorted_prec if r["metrics"]["mcc"] >= mcc_ref][:top]
    top_rec = [r for r in sorted_rec if r["metrics"]["mcc"] >= mcc_ref][:top]


    # ----- 6) Display figure (3×3: top MCC / top Precision / top Recall) -----
    fig, axes = plt.subplots(3, top, figsize=(top * (w / 100), 3 * (h / 100)))

    axes = np.atleast_2d(axes)  # ensure 2D even for top=1

    def show_row(row_idx: int, selected_results, row_title: str):
        '''
        displays the result of each line (line 1: mcc, line 2: precision, line 3: recall)
        selected_results will thus be top_mcc, top_prec, top_prec
        '''
        for i, res in enumerate(selected_results):  # enumerate() lets you iterate over a list while keeping the current index
            ax = axes[row_idx, i]
            img = res["result"]
            params = res["params"]
            mets = res["metrics"]

            ax.imshow(img, cmap="gray")

            ax.set_title(f"{row_title} #{i+1}\n"
                f"n={params['n']}, k={params['k']:.2f}, p={params['p']}, d={params['d']:.2f}, a={params['a']}",
                fontsize=15)

            txt = (
                f"MCC={mets['mcc']:.3f} "
                f"Kappa={mets['kappa']:.3f} "  
                f"F1={mets['F1']:.3f} "
                f"P={mets['precision']:.3f} "
                f"R={mets['recall']:.3f} "
                f"Acc={mets['accuracy']:.3f}"
            )

            ax.text(
                0.5,
                -0.08,
                txt,
                transform=ax.transAxes,
                ha="center",
                va="top",
                fontsize=15,
            )

        # hide empty axes
        for j in range(len(selected_results), top):
            axes[row_idx, j].axis("off")

    show_row(0, top_mcc, "Top MCC")
    show_row(1, top_prec, "Top Precision")
    show_row(2, top_rec, "Top Recall")

    plt.tight_layout()
    plt.show()


    # ----- 7) Returned values: best tuples -----
    return {"top_mcc": top_mcc, "top_precision": top_prec, "top_recall": top_rec}

# Tests

# path_img1 = r"C:\Users\gbonlieu\Documents\codepythonoutil\outil_detection_changement\data\raw\testnotebooks\t3bis\t3bis_pre_varVV.tif"
# path_img2 = r"C:\Users\gbonlieu\Documents\codepythonoutil\outil_detection_changement\data\raw\testnotebooks\t3bis\t3bis_0208_varVV.tif"
# shp_path = r"C:\Users\gbonlieu\Documents\codepythonoutil\outil_detection_changement\data\raw\testnotebooks\t3bis\t3bis_grd_truth\t3bis_grdtruth_0204shp\t3bis_grdtruth_0204.shp"
list_n = [int(x) for x in np.linspace(1, 1, 1)]
list_k = [x for x in np.linspace(0.9, 1.4, 6)]
list_p = [round(x) for x in np.linspace(30, 30, 1)] 
list_d = [x for x in np.linspace(0.2, 0.6, 9)]
list_a = [round(x) for x in np.linspace(3000, 3000, 1)]

best = grid_search_best_params(
    path_img1=path_img1,
    path_img2=path_img2,
    list_n=list_n,
    list_k=list_k,
    list_p=list_p,
    list_d=list_d,
    list_a=list_a,
    closing=False,
    shp_path=shp_path,
    top=3,
    mcc_ref=0.8
)

# Example : get the best tuple for MCC
best_mcc_1 = best["top_mcc"][0]
print("Best MCC :", best_mcc_1["metrics"]["mcc"])
print("Params :", best_mcc_1["params"])
