# Visualization of training samples

# table of content
1) [Path to files to process](#path-to-files-to-process)
2) [Visualiization of tiles from above](#visualization-of-tiles-from-above)
3) [Analysis of training pickles](#analysis-of-training-pickles)
4) [Show cylinders on full tiles](#show-cylinders-on-full-tiles)
5) [Extract statistics](#extract-statistics)

### Dependencies and general utils

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
import seaborn as sns
import open3d as o3d
import laspy
import pickle
import pdal
import json
from tqdm import tqdm

### Path to files to process

In [None]:
src_tiles = r"D:\PDM_repo\Github\PDM\results\trainings\20250427_140314_test\pseudo_labels"
src_pickles = r"D:\PDM_repo\Github\PDM\data\dataset_tiles_100m\training_samples\loop2"

### Visualization of tiles from above

In [None]:
# loading
tile_src = "../data/training_sample_visualization/color_grp_full_tile_311_out_gt.laz"
tile = laspy.read(tile_src)
print(list(tile.point_format.dimension_names))

#### utils

In [None]:
def plot_las_top_view(las_path, point_size=0.5, color_by='z'):
    """
    Load a LAS/LAZ file and generate a top-down image (XY plane) colored by height or intensity.

    Parameters:
    - las_path: str, path to the .las or .laz file
    - point_size: float, size of each point in the scatter plot
    - color_by: str, either 'z' or 'intensity', to color points
    """

    # Load LAS file
    las = laspy.read(las_path)

    # Get coordinates
    x = las.x
    y = las.y

    # Get values for coloring
    try:
        c = las.__getattr__(color_by)
    except Exception as e:
        print("Not working!\n", e)

    # Plot
    plt.figure(figsize=(10, 10))
    plt.scatter(x, y, c=c, s=point_size, cmap='viridis', marker='.')
    plt.xlabel("X")
    plt.ylabel("Y")
    plt.title(f"Top-Down View (colored by {color_by})")
    plt.colorbar(label=color_by)
    plt.axis("equal")
    plt.tight_layout()
    plt.show()

In [None]:
plot_las_top_view(tile_src, point_size=0.2, color_by='gt_semantic_segmentation')

### Analysis of training pickles

In [None]:
# get all centers
list_centers = []
files = [x for x in os.listdir(src_pickles) if x.endswith('pickle')]
for file in files:
    with open(os.path.join(src_pickles, file), 'rb') as in_file:
        training_pickle = pickle.load(in_file)
    for center in training_pickle:
        list_centers.append(tuple(center[0]))
print("Total number of centers: ", len(list_centers))
list_centers = set(list_centers)
print("Without duplicates: ", len(list_centers))

In [None]:
# associate centers to tiles:
tiles = [x for x in os.listdir(src_tiles) if x.endswith('laz')]
centers_to_tiles = {tile_src: [] for tile_src in tiles}
for tile_src in tiles:
    tile = laspy.read(os.path.join(src_tiles, tile_src))
    x_min = tile.x.min()
    x_max = tile.x.max()
    y_min = tile.y.min()
    y_max = tile.y.max()
    for center in list_centers:
        if center[0] > x_min and center[0] < x_max and center[1] > y_min and center[1] < y_max:
            centers_to_tiles[tile_src].append(center)

lst_full = {x:y for x,y in centers_to_tiles.items() if len(y) > 0}
lst_empty = {x:y for x,y in centers_to_tiles.items() if len(y) == 0}

print("Full tiles: ")
for x,y in lst_full.items():
    print(f"\t{x} ({len(y)} samples)")
print("Empty tiles: ")
for x in lst_empty.keys():
    print("\t", x)

### Show cylinders on full tiles

#### utils

In [None]:
def plot_las_plus_centers(las_path, lst_centers, title="", point_size=0.5, radius=8, color_by='z'):
    """
    Load a LAS/LAZ file and generate a top-down image (XY plane) colored by height or intensity.

    Parameters:
    - las_path: str, path to the .las or .laz file
    - point_size: float, size of each point in the scatter plot
    - color_by: str, either 'z' or 'intensity', to color points
    """

    # Load LAS file
    las = laspy.read(las_path)

    # Get coordinates
    x = las.x
    y = las.y

    # Get values for coloring
    try:
        c = las.__getattr__(color_by)
    except Exception as e:
        print("Not working!\n", e)

    # Plot
    fig, ax = plt.subplots(figsize=(10, 10))
    sc = ax.scatter(x, y, c=c, s=point_size, cmap='viridis', marker='.')
    plt.xlabel("X")
    plt.ylabel("Y")
    if title == "":
        plt.title(f"{os.path.basename(las_path)} (colored by {color_by})")
    else:
        plt.title(f"{title} (colored by {color_by})")
         
    plt.colorbar(sc, label=color_by)
    ax.set_aspect('equal')
    for (px, py) in lst_centers:
            circle = Circle((px, py), radius=radius, edgecolor='red', facecolor='none', linewidth=2)
            ax.add_patch(circle)
    plt.tight_layout()
    plt.show()

#### show centers

In [None]:
for tile_src in tiles:
    plot_las_plus_centers(
        os.path.join(src_tiles, tile_src), 
        centers_to_tiles[tile_src], 
        # title=tiles_to_type[tile_src],
        point_size=0.2, 
        color_by='classification',
        )

### Extract statistics

#### Utils

In [None]:
def extract_points_in_circle(center, las, lst_features, radius=8):
    """
    Extract indices and coordinates of points within given circles.

    Parameters:
    - x, y: np.ndarray, point coordinates (from las.x and las.y)
    - las: laspy.LasData, the LAS file object
    - centers: list of (x, y) tuples representing circle centers
    - radius: float, radius of the circle

    Returns:
    - list of np.ndarray, each containing the indices of points inside a circle
    """
    results = {}
    cx, cy = center
    dist = np.sqrt((las.x - cx)**2 + (las.y - cy)**2)
    mask = dist <= radius
    # results.append(np.where(mask)[0])  # Store indices
    # results = {}
    for feature in lst_features:
        results[feature] = las[feature][mask]

    return results

#### Extract

In [None]:
# prepare dict of stats:
list_of_stats = ["semantic_frac_ground", "semantic_frac_tree", "", "frac_grey", "frac_ground", "frac_tree"]
stats_tot = {x:[] for x in list_of_stats}
tile_test = laspy.read(os.path.join(src_tiles, tiles[0]))
lst_features = list(tile_test.point_format.dimension_names)
center_points = {x:[] for x in lst_features}
print("Centers:\n", list_centers)

# loop on centers:
for idx, center in tqdm(enumerate(list_centers), total=len(list_centers), desc="Extracting data"):
    # find corresponding tile
    src_corresponding_tile = ""
    for tile_name, lst_centers in centers_to_tiles.items():
        if center in lst_centers:
            src_corresponding_tile = os.path.join(src_tiles, tile_name)
            break
    if src_corresponding_tile == "":
        print("DID NOT FIND A CORRESPONDING TILE TO ", center)
        break

    # find matching points
    tile = laspy.read(src_corresponding_tile)
    results = extract_points_in_circle(center, tile, lst_features)
    for feature, vals in results.items():
        center_points[feature].append(vals)
    
# Computing stats
pseudo_labels_semantic = center_points['classification']
stats_tot['frac_grey'] = [len(x[x == 0])/len(x) for x in pseudo_labels_semantic if 0 in x]
stats_tot['frac_ground'] = [len(x[x == 1])/len(x) for x in pseudo_labels_semantic if 1 in x]
stats_tot['frac_tree'] = [len(x[x == 4])/len(x) for x in pseudo_labels_semantic if 4 in x]


    

    



In [None]:
# Showing distribution
fig, axs = plt.subplots(3,2, figsize=(12, 15))
axs = axs.flatten()
lst_metrics = ['frac_grey', 'frac_ground', 'frac_tree']
for idx, metric in enumerate(lst_metrics):
    sns.histplot(stats_tot[metric], bins=10, binrange=(0.0, 1.0), ax=axs[2*idx])
    sns.boxplot(stats_tot[metric], ax=axs[2*idx + 1])
    axs[2*idx].set_title(f"Histogram of {metric}")
    axs[2*idx + 1].set_title(f"Boxplot of {metric}")