# Predictions corrections

### table of content
1) [Load samples](#load-samples)
2) [Metrics](#metrics)
3) [Performances computation](#performances-computation)

### Dependencies and general utils

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import open3d as o3d
import laspy
import pdal
import json
import pickle
from tqdm import tqdm
from scipy.spatial import cKDTree
from time import time

### Load samples

In [None]:
src_preds = r"..\data\flattening_corrections\predictions"
src_gt = r"..\data\flattening_corrections\gt"
src_floors = r"..\data\flattening_corrections\floors"
src_masks = r"..\data\flattening_corrections\masks"
src_originals = r"..\data\flattening_corrections\originals"
src_flatten = r"..\data\flattening_corrections\flatten"
src_results = r"..\data\flattening_corrections"

In [None]:
list_preds = {}
list_masks = {}
list_floors = {}
list_flatten = {}
list_floors = {}
list_gt = {}
list_originals = {}

samples_num = [128, 129, 160, 210, 311, 633]
tilling_num = [0, 1, 5, 10, 20]

for r, _, f in os.walk(src_preds):
    for num in samples_num:
        list_preds[num] = [os.path.join(r, file) for file in f if len(file.split(str(num))) > 1]
for r, _, f in os.walk(src_masks):
    for num in samples_num:
        list_masks[num] = [os.path.join(r, file) for file in f if len(file.split(str(num))) > 1]
for r, _, f in os.walk(src_floors):
    for num in samples_num:
        list_floors[num] = [os.path.join(r, file) for file in f if len(file.split(str(num))) > 1]
for r, _, f in os.walk(src_flatten):
    for num in samples_num:
        list_flatten[num] = [os.path.join(r, file) for file in f if len(file.split(str(num))) > 1]
for r, _, f in os.walk(src_originals):
    for num in samples_num:
        list_originals[num] = [os.path.join(r, file) for file in f if len(file.split(str(num))) > 1][0]
for r, _, f in os.walk(src_gt):
    for num in samples_num:
        list_gt[num] = [os.path.join(r, file) for file in f if len(file.split(str(num))) > 1][0]


### Metrics

In [None]:
def compute_panoptic_quality(gt_instances, pred_instances):
    """
    Computes Panoptic Quality (PQ), Segmentation Quality (SQ), and Recognition Quality (RQ).
    
    :param gt_instances: List of sets, each containing point indices for a ground truth instance.
    :param pred_instances: List of sets, each containing point indices for a predicted instance.
    :return: PQ, SQ, RQ
    """

    # gt_instances, pred_instances = get_segmentation(gt_instances, pred_instances)
    tp, fp, fn = 0, 0, 0
    iou_sum = 0

    # Match predicted instances to ground truth instances
    matched_gt = set()
    matched_pred = set()
    
    for i, gt in enumerate(gt_instances):
        best_iou = 0
        best_pred = None

        for j, pred in enumerate(pred_instances):
            iou = len(gt & pred) / len(gt | pred)  # IoU computation
            
            if iou > best_iou:
                best_iou = iou
                best_pred = j
        
        # Threshold for a valid match
        if best_iou > 0.5:
            matched_gt.add(i)
            matched_pred.add(best_pred)
            tp += 1
            iou_sum += best_iou
        else:
            fn += 1  # Unmatched ground truth instance
    
    fp = len(pred_instances) - len(matched_pred)  # Unmatched predictions

    RQ = tp / (tp + 0.5 * (fp + fn)) if (tp + 0.5 * (fp + fn)) > 0 else 0
    SQ = iou_sum / tp if tp > 0 else 0
    PQ = SQ * RQ

    return PQ, SQ, RQ, tp, fp, fn


def compute_mean_iou(y_true, y_pred, num_classes=2):
    """
    Computes mean Intersection over Union (mIoU).
    
    :param y_true: Ground truth labels (N,)
    :param y_pred: Predicted labels (N,)
    :param num_classes: Total number of classes
    :return: Mean IoU score
    """
    iou_list = []
    
    for c in range(num_classes):
        tp = np.sum((y_true == c) & (y_pred == c))
        fp = np.sum((y_true != c) & (y_pred == c))
        fn = np.sum((y_true == c) & (y_pred != c))
        
        iou = tp / (tp + fp + fn) if (tp + fp + fn) > 0 else 0
        iou_list.append(iou)

    return np.mean(iou_list)


def get_segmentation(instance_list, semantic_list):
    instances_format = []
    semantic_format = []
    # Computing instances
    for instance in set(instance_list):
        if instance == 0: continue
        list_points = [pos for pos, val in enumerate(instance_list) if val == instance]
        instances_format.append(set(list_points))

    # Computing semantic
    for semantic in set(semantic_list):
        list_points = [pos for pos, val in enumerate(semantic_list) if val == semantic]
        semantic_format.append(set(list_points))

    return instances_format, semantic_format


### Performances computation

#### Utils

In [None]:
def remove_duplicates(laz_file):
    # Find pairs of points
    coords = np.round(np.vstack((laz_file.x, laz_file.y, laz_file.z)),2).T
    tree_B = cKDTree(coords)
    pairs = tree_B.query_pairs(1e-2)

    # Create the mask with dupplicates
    mask = [True for i in range(len(coords))]
    for pair in pairs:
        mask[pair[1]] = False

    # Remove the dupplicates from the file
    laz_file.points = laz_file.points[mask]


def match_pointclouds(laz1, laz2):
    """Sort laz2 to match the order of laz1 without changing laz1's order.

    Args:
        laz1: laspy.LasData object (reference order)
        laz2: laspy.LasData object (to be sorted)
    
    Returns:
        laz2 sorted to match laz1
    """
    # Retrieve and round coordinates for robust matching
    coords_1 = np.round(np.vstack((laz1.x, laz1.y, laz1.z)), 2).T
    coords_2 = np.round(np.vstack((laz2.x, laz2.y, laz2.z)), 2).T

    # Verify laz2 is of the same size as laz1
    assert len(coords_2) == len(coords_1), "laz2 should be a subset of laz1"

    # Create a dictionary mapping from coordinates to indices
    coord_to_idx = {tuple(coord): idx for idx, coord in enumerate(coords_1)}

    # Find indices in laz1 that correspond to laz2
    matching_indices = []
    failed = 0
    for coord in coords_2:
        try:
            matching_indices.append(coord_to_idx[tuple(coord)])
        except Exception as e:
            failed += 1
    # print(f"Number of non-matching points: {failed}")

    matching_indices = np.array([coord_to_idx[tuple(coord)] for coord in coords_2])

    # Sort laz2 to match laz1
    sorted_indices = np.argsort(matching_indices)

    # Apply sorting to all attributes of laz2
    laz2.points = laz2.points[sorted_indices]

    return laz2  # Now sorted to match laz1


#### Computing on all samples

In [None]:
import copy
import pickle
metrics = ['PQ', 'SQ', 'RQ', 'mIoU', 'Recall', 'Precision']
metrics_res = {samp_num: np.zeros((len(samples_num), len(tilling_num))) for samp_num in samples_num}

for i, samp_num in tqdm(enumerate(samples_num), total=len(samples_num)):
    original_src = list_originals[samp_num]
    gt_src = list_gt[samp_num]
    laz_original = laspy.read(original_src)

    for j, tilling in enumerate(tilling_num):
        if tilling > 0:
            pred_src = [x for x in list_preds[samp_num] if len(x.split(f'{tilling}m')) > 1][0]
            floor_src = [x for x in list_floors[samp_num] if len(x.split(f'{tilling}m')) > 1][0]
            flatten_src = [x for x in list_flatten[samp_num] if len(x.split(f'{tilling}m')) > 1][0]
            mask_src = [x for x in list_masks[samp_num] if len(x.split(f'{tilling}m')) > 1][0]
        else:
            pred_src = [x for x in list_preds[samp_num] if len(os.path.basename(x).split('flatten')) == 1][0]
            flatten_src = original_src

        laz_pred = laspy.read(pred_src)
        laz_gt = laspy.read(gt_src)
        laz_flatten = laspy.read(flatten_src)

        laz_pred = match_pointclouds(laz_flatten, laz_pred)

        pred_coords = np.vstack((laz_pred.x, laz_pred.y, laz_pred.z)).T
        gt_coords = np.vstack((laz_gt.x, laz_gt.y, laz_gt.z)).T

        remove_duplicates(laz_gt)
        laz_gt = match_pointclouds(laz_original, laz_gt)

        # Crop groud truth
        if len(os.path.basename(pred_src).split('flatten')) > 1:
            # add floor to preds
            laz_floor = laspy.read(floor_src)
            floor_coords = np.vstack((laz_floor.x, laz_floor.y, laz_floor.z)).T
            pred_coords[:,2] = pred_coords[:,2] + floor_coords[:,2]
            setattr(laz_pred, 'x', pred_coords[:,0])
            setattr(laz_pred, 'y', pred_coords[:,1])
            setattr(laz_pred, 'z', pred_coords[:,2])

            # load mask
            with open(mask_src, 'rb') as infile:
                mask = pickle.load(infile)
            laz_gt.points = laz_gt.points[mask]
            laz_gt = match_pointclouds(laz_pred, laz_gt)

        # Compute metrics
        gt_instances = laz_gt.gt_instance_segmentation
        gt_semantic = laz_gt.gt_semantic_segmentation
        pred_instances = laz_pred.PredInstance
        pred_semantic = laz_pred.PredSemantic

        gt_instances_format, gt_semantic_format = get_segmentation(gt_instances, gt_semantic)
        pred_instances_format, pred_semantic_format = get_segmentation(pred_instances, pred_semantic)
        
        PQ, SQ, RQ, tp, fp, fn = compute_panoptic_quality(gt_instances_format, pred_instances_format)
        mean_iou = compute_mean_iou(gt_semantic, pred_semantic)
        metrics_res[samp_num][0,j] = PQ
        metrics_res[samp_num][1,j] = SQ
        metrics_res[samp_num][2,j] = RQ
        metrics_res[samp_num][3,j] = mean_iou
        metrics_res[samp_num][4,j] = round(tp/(tp + fn), 2) if tp + fn > 0 else 0
        metrics_res[samp_num][5,j] = round(tp/(tp + fp),2) if tp + fp > 0 else 0
    

In [None]:
for metric in metrics_res.keys():
    print(metric)
    print(metrics_res[metric])

### Visualization

In [None]:
dict_sampnum_to_terrain = {
    128: "light slope sample",
    129: "bushes sample",
    160: "slope empty sample",
    210: "flat empty sample",
    311: "flat sample",
    633: "heavy slope sample"
}
fig, axs = plt.subplots(2, 3, figsize=(15, 10))
for pos, samp_num in enumerate(samples_num):
    i = pos//3
    j = pos%3
    arr_res = metrics_res[samp_num]
    df_data = pd.DataFrame(
        data=arr_res,
        columns=tilling_num,
        index=metrics
    )
    max_mask = df_data.eq(df_data.max(axis=1), axis=0)
    annot_colors = np.full(df_data.shape, "black", dtype=object)  # Default text color
    annot_colors[max_mask] = "red"  # Set highest values to red

    sns.heatmap(data=df_data, cmap="crest", annot=True, ax=axs[i,j], fmt=".2f")
    axs[i,j].set_title(dict_sampnum_to_terrain[samp_num])

    # Draw only horizontal grid lines
    axs[i,j].hlines(np.arange(1, df_data.shape[0]), *axs[i,j].get_xlim(), color="white", linewidth=0.8)

    # Color the highest values in red
    for text, (k, l) in zip(axs[i,j].texts, np.ndindex(df_data.shape)):
        text.set_color('black')
        if max_mask.iat[k, l]:  
            text.set_fontsize(12)  # Set the highest value in each row to red
            text.set_fontweight('bold')  # Set the highest value in each row to red
        else:
            text.set_fontsize(10)  # Default color
    if i == 1:
        axs[i,j].set_xlabel('Grid size [m]')
    if j == 0:
        axs[i,j].set_ylabel('Metric [-]')
plt.suptitle("Flattening results for different grid sizes")
plt.tight_layout()