# General imports

In [None]:
import numpy as np
import geopandas as gpd
import pprint

from geograypher.cameras.derived_cameras import MetashapeCameraSet
from geograypher.cameras.segmentor import SegmentorPhotogrammetryCameraSet
from geograypher.meshes import TexturedPhotogrammetryMesh
from geograypher.predictors.derived_segmentors import LookUpSegmentor
from geograypher.utils.prediction_metrics import compute_and_show_cf, compute_comprehensive_metrics
from geograypher.utils.indexing import find_argmax_nonzero_value
from geograypher.utils.visualization import show_segmentation_labels
from geograypher.constants import (
    EXAMPLE_CAMERAS_FILENAME,
    EXAMPLE_MESH_FILENAME,
    EXAMPLE_IMAGE_FOLDER,
    EXAMPLE_LABELS_FILENAME,
    EXAMPLE_PREDICTED_LABELS_FOLDER,
    EXAMPLE_DTM_FILE,
    EXAMPLE_AGGREGATED_FACE_LABELS_FILE,
    EXAMPLE_PREDICTED_VECTOR_LABELS_FILE,
    EXAMPLE_IDS_TO_LABELS,
    EXAMPLE_LABEL_COLUMN_NAME,
    TEN_CLASS_VIS_KWARGS,
)

# Processing parameters

In [None]:
# Skip re-computing the aggregation and use a saved version
USE_CACHED_AGGREGATION = False
# Processing parameters
HEIGHT_ABOVE_GROUND_THRESH = 2  # Height above the DTM to consider not ground
MESH_DOWNSAMPLE_TARGET = 0.25  # Downsample the mesh to this fraction
AGGREGATE_IMAGE_SCALE = 0.25  # Aggregate images at this scale resolution
BUFFER_RADIUS_METER = 50  # Include cameras within this radius of labeled points
MESH_VIS_KWARGS = TEN_CLASS_VIS_KWARGS

LABEL_COLUMN_NAME = EXAMPLE_LABEL_COLUMN_NAME
IDS_TO_LABELS = EXAMPLE_IDS_TO_LABELS
CAMERAS_FILENAME = EXAMPLE_CAMERAS_FILENAME
MESH_FILENAME = EXAMPLE_MESH_FILENAME
IMAGE_FOLDER = EXAMPLE_IMAGE_FOLDER
LABELS_FILENAME = EXAMPLE_LABELS_FILENAME
PREDICTED_IMAGE_LABELS_FOLDER = EXAMPLE_PREDICTED_LABELS_FOLDER
DTM_FILE = EXAMPLE_DTM_FILE
AGGREGATED_FACE_LABELS_FILE = EXAMPLE_AGGREGATED_FACE_LABELS_FILE
PREDICTED_VECTOR_LABELS_FILE = EXAMPLE_PREDICTED_VECTOR_LABELS_FILE

# Load the mesh

In [None]:
mesh = TexturedPhotogrammetryMesh(
    MESH_FILENAME,
    transform_filename=EXAMPLE_CAMERAS_FILENAME,
    IDs_to_labels=IDS_TO_LABELS,
)

# Load the camera set and subsample

In [None]:
# Create camera set
camera_set = MetashapeCameraSet(CAMERAS_FILENAME, IMAGE_FOLDER)
# Extract cameras near the training data
camera_set = camera_set.get_subset_ROI(
    ROI=LABELS_FILENAME, buffer_radius=BUFFER_RADIUS_METER
)

In [None]:
camera_set.vis()

In [None]:
mesh.vis(camera_set=camera_set, force_xvfb=True)

# Show the per-image predictions

In [None]:
show_segmentation_labels(label_folder=PREDICTED_IMAGE_LABELS_FOLDER, image_folder=IMAGE_FOLDER, IDs_to_labels=IDS_TO_LABELS)

In [None]:
segmentor = LookUpSegmentor(
    base_folder=IMAGE_FOLDER,
    lookup_folder=PREDICTED_IMAGE_LABELS_FOLDER,
    num_classes=len(mesh.get_label_names()),
)

segmentor_camera_set = SegmentorPhotogrammetryCameraSet(
    camera_set, segmentor=segmentor
)

In [None]:
if USE_CACHED_AGGREGATION:
    aggregated_face_labels = np.load(AGGREGATED_FACE_LABELS_FILE)
else:
    aggregated_face_labels, _ = mesh.aggregate_projected_images(
        segmentor_camera_set,
        aggregate_img_scale=AGGREGATE_IMAGE_SCALE,
    )
    np.save(AGGREGATED_FACE_LABELS_FILE, aggregated_face_labels)

In [None]:
predicted_face_classes = find_argmax_nonzero_value(aggregated_face_labels, keepdims=True)
predicted_face_classes = mesh.label_ground_class(
    labels=predicted_face_classes,
    height_above_ground_threshold=HEIGHT_ABOVE_GROUND_THRESH,
    DTM_file=DTM_FILE,
    ground_ID=np.nan,
    set_mesh_texture=False,
)

# Show the projected and aggregated face predictions

In [None]:
mesh.vis(vis_scalars=predicted_face_classes)

# Use the mesh predictions to generate per-polygon labels

In [None]:
# Load the data
polygons = gpd.read_file(LABELS_FILENAME)
# Assign a label to each polygon using the mesh faces that overlap with it
predicted_polygon_labels = mesh.label_polygons(
    face_labels=predicted_face_classes,
    polygons=polygons
)

# Compute prediction accuracy

In [None]:
# Extract the ground truth classes
ground_truth_labeling = polygons[LABEL_COLUMN_NAME]
# Get all the possible classes, in case you wanted to compare across sites where only
# a subset of all classes were present at one
# Drop the ground class because no polygons are labeled that
all_classes = list(IDS_TO_LABELS.values())[:-1]
# Compute the confusion matrix
cf_matrix, _, accuracy = compute_and_show_cf(
    pred_labels=predicted_polygon_labels,
    gt_labels=ground_truth_labeling,
    labels=all_classes
)
print(f"Accuracy was {accuracy}")
# Compute more detailed metrics from the confusion matrix
comprehensive_metrics = compute_comprehensive_metrics(
    cf_matrix=cf_matrix,
    class_names=all_classes
)
# Format and print the dict
pp = pprint.PrettyPrinter(indent=2)
print("Comprehensive metrics:")
print(pp.pprint(comprehensive_metrics))