# Evaluate Detection and segmentation accuracy of TreeX

Steps required:
- loading the reference and prediction files
- make sure XYZ values are truncated after 4 decimal places
- sort both by XYZ ascending
- make sure instance ids start from -1 (ground) 0-... (trees)
- apply metrics calculation 

In [1]:
import laspy
from pointtree.evaluation import match_instances, instance_segmentation_metrics

In [2]:
# Load the reference and predicted point clouds
ref = laspy.read('data/manual.laz')
pred = laspy.read('data/treex.laz')

In [3]:
import numpy as np

In [4]:
def get_xyz(las:laspy.LasData, decimals:float=4) -> np.ndarray:
    """Get the xyz coordinates of a las file that are precise to 4 decimal places. 
    
    Truncates all invalid decimals.

    Args:
        las (laspy.BoundsLasData): point cloud data.
        decimals (float, optional): number of valid decimal places. Defaults to 4.

    Returns:
        numpy.ndarray: xyz coordinates of shape (N, 3)
    """

    scale = 10**decimals

    x = np.trunc(las.x * scale) / scale
    y = np.trunc(las.y * scale) / scale
    z = np.trunc(las.z * scale) / scale

    return np.vstack((x, y, z)).T

def get_xyz_sorted(las):
    """Sorts points in given point cloud data by xyz coordinates.

    Args:
        las (laspy.LasData): the point cloud data to sort.

    Returns:
        laspy.point.record.ScaleAwarePointRecord: the sorted points.
    """
    coords = get_xyz(las)  # shape (N, 3)

    # Get sorting indices by lexicographic ordering of X, then Y, then Z
    sorted_indices = np.lexsort((coords[:, 2], coords[:, 1], coords[:, 0]))

    # Reorder the entire point record array by these indices
    sorted_points = las.points[sorted_indices]

    # Create a new LasData object or overwrite the existing one
    return sorted_points

In [5]:
sref = get_xyz_sorted(ref)
spred = get_xyz_sorted(pred)

In [6]:
xyz_ref = get_xyz(sref)
xyz_ref

array([[-15.0486,   6.8463,  -4.6444],
       [-15.0476,   6.856 ,  -4.6444],
       [-15.0461,   6.544 ,  -4.6254],
       ...,
       [ 15.0382,  -6.5221,  -1.0231],
       [ 15.0394,  -6.7211,  -1.2054],
       [ 15.0486,  -6.5155,  -1.0131]])

In [7]:
xyz_pred = get_xyz(spred)
xyz_pred

array([[-15.0486,   6.8463,  -4.6444],
       [-15.0476,   6.856 ,  -4.6444],
       [-15.0461,   6.544 ,  -4.6254],
       ...,
       [ 15.0382,  -6.5221,  -1.0231],
       [ 15.0394,  -6.7211,  -1.2054],
       [ 15.0486,  -6.5155,  -1.0131]])

In [8]:
assert np.all(get_xyz(sref) == get_xyz(spred)), 'Warning: XYZ coordinates differ!'

In [11]:
def get_instance_ids(sorted_las):
    """Gets the instance ids of sorted point cloud points.

    Makes sure instance ids start with -1 (ground) and 0-n (instances).

    Args:
        sorted_las (_type_): sorted points.

    Returns:
        numpy.ndarray: _description_
    """
    return sorted_las.instance_id.copy() if sorted_las.instance_id.min() == -1 else (sorted_las.instance_id.astype(int) - 1).copy()

In [12]:
target = get_instance_ids(sref)
np.unique(target)

array([-1,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14])

In [13]:
prediction = get_instance_ids(spred)
np.unique(prediction)

array([-1,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10])

In [14]:
match_result = match_instances(xyz_ref, target, prediction, method='point2tree')

In [15]:
match_result

(array([12,  3, 10,  6,  9,  5, -1, -1, -1,  2,  0]),
 array([10, -1,  9,  1, -1,  5,  3, -1, -1,  4,  2, -1,  0, -1, -1]))

In [16]:
instance_segmentation_metrics(target, prediction, match_result[1])

({'MeanIoU': 0.43183935, 'MeanPrecision': 0.8110652, 'MeanRecall': 0.5325294},
     TargetID  PredictionID       IoU  Precision    Recall
 0          0            10  0.588079   0.588079  1.000000
 1          1            -1  0.000000        NaN  0.000000
 2          2             9  0.300468   0.300468  1.000000
 3          3             1  0.985558   0.987968  0.997531
 4          4            -1  0.000000        NaN  0.000000
 5          5             5  0.948783   0.954706  0.993503
 6          6             3  0.995094   0.995137  0.999956
 7          7            -1  0.000000        NaN  0.000000
 8          8            -1  0.000000        NaN  0.000000
 9          9             4  0.943296   0.945168  0.997906
 10        10             2  0.892163   0.892394  0.999710
 11        11            -1  0.000000        NaN  0.000000
 12        12             0  0.824148   0.824601  0.999334
 13        13            -1  0.000000        NaN  0.000000
 14        14            -1  0.00000

In [None]:
def instance_detection_metrics(match_result:tuple[np.ndarray, np.ndarray], decimals:int=7):
    """Calculates instance detection metrics (precision, recall, f1-score) from pointtree match instances results.

    Returns calculated metrics rounded to a given number of decimal places.

    Args:
        match_result (tuple[np.ndarray, np.ndarray]): pointtree match instances result.
        decimals (int, optional): decimal places the results are rounded to. Defaults to 7.

    Returns:
        dict: dictionary with precision, recall and f1-score rounded to the given amount of decimal places.
    """
    tp = np.count_nonzero(match_result[0] != -1) # predicted instances that match with a reference instance
    fp = np.count_nonzero(match_result[0] == -1) # predicted instances that do not match any reference instance
    fn = np.count_nonzero(match_result[1] == -1) # reference instances that do not match any predicted instance

    precision = tp / (tp + fp) # user's accuracy (1 - comission)
    recall = tp / (tp + fn) # producer's accuracy (1 - omission)
    f1 = 2 * (precision * recall) / (precision + recall) # harmonic mean

    return {
        'Precision': np.round(precision, decimals),
        'Recall': np.round(recall, decimals),
        'F1-Score': np.round(f1, decimals),
    }

In [18]:
def accuracy_metrics(target:np.ndarray, prediction:np.ndarray, match_result):
    """Combines the calculation of instance detection and instance segmentation metrics.

    Args:
        target (np.ndarray): ground truth instance id for each point.
        prediction (np.ndarray): predicted instance id for each point.
        match_result (tuple[np.ndarray, np.ndarray]): result fom instance matching.

    Returns:
        A tuple, that contains the following three element
            - a dictionary with detection accuracy metrics The dictionary contains
            the following keys: :code:`"Precision"`, :code:`"Recall"`, and :code:`"F1-Score"`.
            - A dictionary containing the segmentation metrics averaged over all instance pairs. The dictionary contains
            the following keys: :code:`"MeanIoU"`, :code:`"MeanPrecision"`, and :code:`"MeanRecall"`.
            - A pandas.DataFrame containing the segmentation metrics for each instance pair. The dataframe contains the
            following columns: :code:`"TargetID"`, :code:`"PredictionID"`, :code:`"IoU"`, :code:`"Precision"`,
            :code:`"Recall"`.
    """
    return instance_detection_metrics(match_result), *instance_segmentation_metrics(target, prediction, match_result[1], include_unmatched_instances=True)

In [19]:
accuracy_metrics(target, prediction, match_result)

({'Precision': 0.7272727, 'Recall': 0.5333333, 'F1-Score': 0.6153846},
 {'MeanIoU': 0.43183935, 'MeanPrecision': 0.8110652, 'MeanRecall': 0.5325294},
     TargetID  PredictionID       IoU  Precision    Recall
 0          0            10  0.588079   0.588079  1.000000
 1          1            -1  0.000000        NaN  0.000000
 2          2             9  0.300468   0.300468  1.000000
 3          3             1  0.985558   0.987968  0.997531
 4          4            -1  0.000000        NaN  0.000000
 5          5             5  0.948783   0.954706  0.993503
 6          6             3  0.995094   0.995137  0.999956
 7          7            -1  0.000000        NaN  0.000000
 8          8            -1  0.000000        NaN  0.000000
 9          9             4  0.943296   0.945168  0.997906
 10        10             2  0.892163   0.892394  0.999710
 11        11            -1  0.000000        NaN  0.000000
 12        12             0  0.824148   0.824601  0.999334
 13        13           