In [None]:
########## IMPORTS ##########

# common
import rasterio
import numpy as np
import matplotlib.pyplot as plt

# test & display
import itertools
import traceback
from pprint import pprint
import importlib # allows updating the functions by rerunning this import cell (otherwise the functions remain stored and frozen in the cache)

# Imports from src
import os
import sys

PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))  # add the project root directory to the PYTHONPATH
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

import src.main_multidate_test as main_test
importlib.reload(main_test)
from src.ground_truth import (
    shapefile_to_mask,
    confusion_from_masks,
    metrics_from_masks,
)       



In [None]:
########## DATA ##########


path_img1 = r"C:\Users\gbonlieu\Documents\herramienta\explore\multi_date\t6_pre2_gamma.tif"
list_pathpost = [r"C:\Users\gbonlieu\Documents\herramienta\explore\multi_date\t6_post1_gamma.tif", r"C:\Users\gbonlieu\Documents\herramienta\explore\multi_date\t6_post2_gamma.tif",
                 r"C:\Users\gbonlieu\Documents\herramienta\explore\multi_date\t6_post3_gamma.tif", r"C:\Users\gbonlieu\Documents\herramienta\explore\multi_date\t6_post4_gamma.tif", r"C:\Users\gbonlieu\Documents\herramienta\explore\multi_date\t6_post5_gamma.tif"]
shp_path = r"C:\Users\gbonlieu\Documents\herramienta\change_detection_tool\data\preprocessed\t6\t6_grd_truth\t6_grd_truth_pre_11122017.shp"
out_path = r"C:\Users\gbonlieu\Documents\codepythonoutil\outil_detection_changement\ouputs\output_vrac\t3bis_out\main_out.tif"
ref_raster_path = path_img1

truth = shapefile_to_mask(shp_path=shp_path, ref_raster_path=ref_raster_path) # out_path=out_path)
nb_whites, nb_blacks = truth.sum(), truth.size - truth.sum()
print (f"Number of whites: {nb_whites}\nNumber of blacks: {nb_blacks}")
fig, ax = plt.subplots(figsize=(6, 6))
im = ax.imshow(truth,
               cmap="gray",
               interpolation="nearest",  # pas de lissage
               vmin=0, vmax=1)           # valeurs 0/1 bien séparées
ax.xaxis.tick_top()
ax.set_title("Truth")
plt.show()

# fig.colorbar(im, ax=ax, label="Pixel Values")


In [None]:
########## SINGLE TEST & DISPLAY ##########

"""
Small test script to run main_dtod_test() on two rasters and print the result.
"""


def test_main_multidate_display(path_img1: str, list_pathpost: str, shp_path: str,
                      n: int, k: float = 1.0, closings: bool = False,
                      p: int = 27, d: float = 0.5, a: int = 3000, 
                      out_path_multidate: str | None = None):
    """
    Runs main_dtod_test() with the given parameters, computes performance
    metrics using the ground-truth shapefile, and displays the
    result (binary mask) with the parameters & metrics as annotations.
    """

    # --- Compute predicted mask ---
    profile, multidate = main_test.main_multidate_test(path_img1, list_pathpost,
                          n, k=k, closings=closings,
                          p=p, d=d, a=a, out_path_multidate=out_path_multidate)

    # --- Load and rasterize ground-truth ---
    mask_ref = shapefile_to_mask(shp_path, ref_raster_path=path_img1)

    # --- Compute performance metrics ---
    mets = metrics_from_masks(mask_ref, multidate)

    # --- Prepare nice figure ---

    w, h = profile['width'], profile['height']
    fig, ax = plt.subplots(figsize=((10), (7)))

    im = ax.imshow(multidate, cmap="gray")

    # --- Title: parameters ---
    ax.set_title(
        f"Result mask\n\n"
        f"n={n}, k={k}, p={p}, d={d}, a={a}, closings={closings}",
        fontsize=11
    )

    # --- Text with metrics ---
    txt = (f"MCC={mets['mcc']:.3f}   " f"F1={mets['F1']:.3f}   " f"P={mets['precision']:.3f}   " f"R={mets['recall']:.3f}   " f"Acc={mets['accuracy']:.3f}" ) # :.3f to display only the first 3 decimal places

    ax.text(0.5, -0.08, txt, transform=ax.transAxes, ha="center", va="top", fontsize=10)
    ax.xaxis.tick_top()
    plt.show()
    
    print("Full profile :")
    pprint(profile)


# path_img1 = r"C:\Users\gbonlieu\Documents\codepythonoutil\outil_detection_changement\data\raw\testnotebooks\t5\t5_pre_8bands.tif"
# list_pathpost = r"C:\Users\gbonlieu\Documents\codepythonoutil\outil_detection_changement\data\raw\testnotebooks\t5\t5_post_8bands.tif"
# shp_path = r"C:\Users\gbonlieu\Documents\codepythonoutil\outil_detection_changement\data\raw\testnotebooks\t5\t5_grd_truth\t5_grd_truth_0709.shp"
# out_path_multidate = r"C:\Users\gbonlieu\Documents\codepythonoutil\outil_detection_changement\ouputs\output_vrac\t3bis_out\main_out.tif"

test_main_multidate_display(path_img1, list_pathpost, shp_path,
                  n=1, k=1.2, closings=False, p=30, d=0.5, a=3000)
