# Flatten tiles

# table of content
1) [Flattening a tile](#flattening-a-tile)
2) [Visualize the removed points](#visualize-the-removed-points)

### Dependencies and general utils

In [None]:
# pip install scipy

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import open3d as o3d
import laspy
# import pdal
import json
import scipy
import copy
import pickle
from tqdm import tqdm
from scipy.spatial import cKDTree

### Flattening a tile

#### example

In [None]:
def remove_duplicates(laz_file):
    # Find pairs of points
    coords = np.round(np.vstack((laz_file.x, laz_file.y, laz_file.z)),2).T
    tree_B = cKDTree(coords)
    pairs = tree_B.query_pairs(1e-2)

    # Create the mask with dupplicates
    mask = [True for i in range(len(coords))]
    for pair in pairs:
        mask[pair[1]] = False

    # Remove the dupplicates from the file
    laz_file.points = laz_file.points[mask]

In [None]:
tile_src = r"..\data\flattening_testing\color_grp_full_tile_128.laz"
tile_src = r"..\data\flattening_corrections\test\color_grp_full_tile_128.laz"
laz = laspy.read(tile_src)
print(len(laz))
remove_duplicates(laz)
print(len(laz))
laz.write(tile_src)
points = np.vstack((laz.x, laz.y, laz.z)).T
points_flatten = copy.deepcopy(points)
points_interpolated = copy.deepcopy(points)
print(list(laz.point_format.dimension_names))
print(points.shape)

In [None]:
grid_size=10
# Divide into tiles and find local minimums
#   _Create grid and find min Z in each cell
x_min, y_min = np.min(points[:, :2], axis=0)
x_max, y_max = np.max(points[:, :2], axis=0)

x_bins = np.append(np.arange(x_min, x_max, grid_size), x_max)
y_bins = np.append(np.arange(y_min, y_max, grid_size), y_max)

grid = {i:{j:[] for j in range(y_bins.size - 1)} for i in range(x_bins.size -1)}
for _, (px, py, pz) in tqdm(enumerate(points), total=len(points)):
    xbin = np.clip(0, (px - x_min) // grid_size, x_bins.size - 1)
    ybin = np.clip(0, (py - y_min) // grid_size, y_bins.size - 1)
    grid[xbin][ybin].append((px, py, pz))


In [None]:
# Create grid_min
grid_used = np.zeros((x_bins.size - 1, y_bins.size - 1))
lst_grid_min = []
lst_grid_min_pos = []
for x in grid.keys():
    for y in grid[x].keys():
        if np.array(grid[x][y]).shape[0] > 0:
            grid_used[x, y] = 1
            # print(np.argmin(np.array(grid[x][y])[:,2]))
            lst_grid_min.append(np.min(np.array(grid[x][y])[:,2]))
            arg_min = np.argmin(np.array(grid[x][y])[:,2])
            lst_grid_min_pos.append(np.array(grid[x][y])[arg_min,0:2])
        else:
            grid_used[x, y] = 0
print(grid_used)
arr_grid_min_pos = np.vstack(lst_grid_min_pos)
print(arr_grid_min_pos.shape)

In [None]:
# Interpolate
points_xy = np.array(points)[:,0:2]
interpolated_min_z = scipy.interpolate.griddata(arr_grid_min_pos, np.array(lst_grid_min), points_xy, method="cubic", fill_value=-1)

mask_valid = np.array([x != -1 for x in list(interpolated_min_z)])
points_interpolated = points_interpolated[mask_valid]
points_interpolated[:, 2] = interpolated_min_z[mask_valid]

print(f"Original number of points: {points.shape[0]}")
print(f"Interpollated number of points: {points_interpolated.shape[0]} ({int(points_interpolated.shape[0] / points.shape[0]*100)}%)")


In [None]:
save = False

In [None]:
# save mask
with open(tile_src.split('.laz')[0] + f"_mask_{grid_size}m.pcl", '+wb') as file:
    pickle.dump(mask_valid, file)

In [None]:
# save floor
filtered_points = {dim: getattr(laz, dim)[mask_valid] for dim in laz.point_format.dimension_names}
header = laspy.LasHeader(point_format=laz.header.point_format, version=laz.header.version)
new_las = laspy.LasData(header)

#   _Assign filtered and modified data
for dim, values in filtered_points.items():
    setattr(new_las, dim, values)
# new_las.xyz = points_interpolated
setattr(new_las, 'x', points_interpolated[:,0])
setattr(new_las, 'y', points_interpolated[:,1])
setattr(new_las, 'z', points_interpolated[:,2])

print(len(new_las))
#   _Save new file
new_las.write(tile_src.split('.laz')[0] + f"_floor_{grid_size}m.laz")
print("Saved file: ", tile_src.split('.laz')[0] + f"_floor_{grid_size}m.laz")


# if save:
#     pcd.points = o3d.utility.Vector3dVector(points_interpolated)
#     o3d.io.write_point_cloud(tile_src.split('.pcd')[0] + f"_floor_{grid_size}m.pcd", pcd, write_ascii=True)

In [None]:
test_laz = laspy.read(tile_src.split('.laz')[0] + f"_floor_{grid_size}m.laz")
print(len(test_laz))

In [None]:
# Flatten
points_flatten = points_flatten[mask_valid]
points_flatten[:,2] = points_flatten[:,2] - points_interpolated[:,2]
# points_flatten[:,2] = np.clip(0, points_flatten[:,2] - points_interpolated[:,2], np.inf)

filtered_points = {dim: getattr(laz, dim)[mask_valid] for dim in laz.point_format.dimension_names}
header = laspy.LasHeader(point_format=laz.header.point_format, version=laz.header.version)
new_las = laspy.LasData(header)

#   _Assign filtered and modified data
for dim, values in filtered_points.items():
    setattr(new_las, dim, values)
# new_las.xyz = points_flatten
setattr(new_las, 'x', points_flatten[:,0])
setattr(new_las, 'y', points_flatten[:,1])
setattr(new_las, 'z', points_flatten[:,2])

#   _Save new file
new_las.write(tile_src.split('.laz')[0] + f"_flatten_{grid_size}m.laz")
print("Saved file: ", tile_src.split('.laz')[0] + f"_flatten_{grid_size}m.laz")

# if save:
#     pcd.points = o3d.utility.Vector3dVector(points_flatten)
    
#     o3d.io.write_point_cloud(tile_src.split('.pcd')[0] + f"_flatten_{grid_size}m.pcd", pcd, write_ascii=True)


### Visualize the removed points

In [None]:
# Find removed points
removed_points = copy.deepcopy(points)
removed_points = removed_points[~mask_valid]

In [None]:
# Visualize removed points
pcd_new = o3d.geometry.PointCloud()
pcd_new.points = o3d.utility.Vector3dVector(removed_points)
o3d.visualization.draw_geometries([pcd_new])

In [None]:
src_flatten_tile = r"D:\PDM_repo\Github\PDM\data\flattening_corrections\test\flatten\color_grp_full_tile_128_flatten_10m.laz"
las = laspy.read(src_flatten_tile)
print(len(las))

In [None]:
# Check for NaNs in the XYZ coordinates
xyz = las.xyz  # Returns a (N, 3) array of floats
has_nan_xyz = np.isnan(xyz).any()

print("Contains NaN in XYZ:", has_nan_xyz)
for dim in las.point_format.dimension_names:
    print("looking at dim: ", dim)
    data = getattr(las, dim)
    if isinstance(data, np.ndarray) and np.issubdtype(data.dtype, np.floating):
        if np.isnan(data).any():
            print(f"NaNs found in dimension: {dim}")