In [1]:
import numpy as np
import os
import open3d as o3d

class Run_alternately_error(Exception):
    def __str__(self):
        return "Refinement and tidy-up must be run alternately."


class Refinement_error(Exception):
    def __str__(self):
        return "The number of points in all points was different before refinement. Something went wrong during refinement. Not overwriting all_points.npy..."


class Refiner:
    def __init__(self, plot_path):
        self.plot_path = plot_path
        self.all_points = np.load(f"{plot_path}/all_points.npy")
        self.all_points_len = len(self.all_points)
        self.removed_points = np.load(f"{plot_path}/removed_points.npy")
        self.mapping = dict()
        self.todo = np.load(f"{plot_path}/todo.npy")
        self.todo = list(self.todo)
        self.color_palette = np.array([[128, 0, 0],
                                       [170, 110, 40],
                                       [128, 128, 0],
                                       [0, 128, 128],
                                       [0, 0, 128],
                                       [0, 0, 0],
                                       [230, 25, 75],
                                       [245, 130, 48],
                                       [255, 225, 25],
                                       [210, 245, 60],
                                       [60, 180, 75],
                                       [70, 240, 240],
                                       [0, 130, 200],
                                       [145, 30, 180],
                                       [240, 50, 230],
                                       [250, 190, 212],
                                       [255, 215, 180],
                                       [255, 250, 200],
                                       [170, 255, 195],
                                       [220, 190, 255],
                                       [255, 255, 255]])

        self.color_palette = self.color_palette[[8, 6, 10, 12, 5, 0, 2, 3, 1, 4, 7, 9, 11, 13, 15, 14, 16, 17, 18, 19], :]
        self.color_palette = self.color_palette / 255


    def refine_tree(self, tree_number):

        # raise error if tidy-up was not done
        if not len(self.mapping) == 0:
            raise Run_alternately_error

        # select all trees that are nearby target tree
        tree_indices = self.all_points[:, -1] == tree_number
        tree = self.all_points[tree_indices]

        min_x = np.min(tree[:, 0])
        max_x = np.max(tree[:, 0])

        min_y = np.min(tree[:, 1])
        max_y = np.max(tree[:, 1])

        temp_index1 = self.all_points[:, 0] >= (min_x - 1)
        temp_index2 = self.all_points[:, 0] <= (max_x + 1)
        temp_index3 = self.all_points[:, 1] >= (min_y - 1)
        temp_index4 = self.all_points[:, 1] <= (max_y + 1)

        chunk_index = temp_index1 * temp_index2 * temp_index3 * temp_index4
        chunk = self.all_points[chunk_index]
        trees_in_chunk = np.unique(chunk[:, -1])
        trees_in_chunk = trees_in_chunk.astype("int")
        trees_in_chunk = trees_in_chunk[np.logical_and(trees_in_chunk != tree_number, trees_in_chunk != 9999)]

        #  cut out non_tree_chunk from all_points and save in ply file
        non_tree_index = self.all_points[:, -1] == 9999
        non_tree_chunk_index = np.logical_and(non_tree_index, chunk_index)
        non_tree_chunk = self.all_points[non_tree_chunk_index]
        self.all_points = self.all_points[np.logical_not(non_tree_chunk_index)]
        non_tree_chunk = non_tree_chunk[:, :-1]

        non_tree_chunk_pcd = o3d.geometry.PointCloud()
        non_tree_chunk_pcd.points = o3d.utility.Vector3dVector(non_tree_chunk)
        non_tree_chunk_pcd.colors = o3d.utility.Vector3dVector(np.ones((len(non_tree_chunk), 3)) * self.color_palette[0])
        o3d.io.write_point_cloud(f"{self.plot_path}/refine/0.ply", non_tree_chunk_pcd)
        self.mapping[0] = 9999

        # cut out target tree from all_points and save in ply file
        main_index = self.all_points[:, -1] == tree_number
        main_tree = self.all_points[main_index]
        self.all_points = self.all_points[np.logical_not(main_index)]
        main_tree = main_tree[:, :-1]

        main_tree_pcd = o3d.geometry.PointCloud()
        main_tree_pcd.points = o3d.utility.Vector3dVector(main_tree)
        main_tree_pcd.colors = o3d.utility.Vector3dVector(np.ones((len(main_tree), 3)) * self.color_palette[1])
        o3d.io.write_point_cloud(f"{self.plot_path}/refine/1.ply", main_tree_pcd)
        self.mapping[1] = tree_number

        #  cut out all surrounding trees from all_points and save in ply file
        for i, j in enumerate(trees_in_chunk):
            temp_index = self.all_points[:, -1] == j
            temp_tree = self.all_points[temp_index]
            self.all_points = self.all_points[np.logical_not(temp_index)]
            temp_tree = temp_tree[:, :-1]

            temp_tree_pcd = o3d.geometry.PointCloud()
            temp_tree_pcd.points = o3d.utility.Vector3dVector(temp_tree)
            temp_tree_pcd.colors = o3d.utility.Vector3dVector(np.ones((len(temp_tree), 3)) * self.color_palette[i+2])
            o3d.io.write_point_cloud(f"{self.plot_path}/refine/{i+2}.ply", temp_tree_pcd)
            self.mapping[i+2] = j


    def tidy_up(self):

        # raise error if refine_tree was not run before
        if len(self.mapping) == 0:
            raise Run_alternately_error

        # add refined trees to all_points
        refined_points = np.empty((0, 4))

        for key in self.mapping:
            if os.path.exists(f"{self.plot_path}/refine/{key}.ply"):
                temp = o3d.io.read_point_cloud(f"{self.plot_path}/refine/{key}.ply")
                temp = np.asarray(temp.points)
                np.around(temp, 2, temp)
                temp = np.hstack([temp, np.ones((len(temp), 1)) * self.mapping[key]])
                refined_points = np.vstack([refined_points, temp])

        # if new tree was created, add it to all points
        if os.path.exists(f"{self.plot_path}/refine/new.ply"):
            temp = o3d.io.read_point_cloud(f"{self.plot_path}/refine/new.ply")
            temp = np.asarray(temp.points)
            np.around(temp, 2, temp)
            temp = np.hstack([temp, np.ones((len(temp), 1)) * len(temp)])
            refined_points = np.vstack([refined_points, temp])

        if os.path.exists(f"{self.plot_path}/refine/remove.ply"):
            removed_points = o3d.io.read_point_cloud(f"{self.plot_path}/refine/remove.ply")
            removed_points = np.asarray(removed_points.points)
            np.around(removed_points, 2, removed_points)
        else:
            removed_points = np.empty((0))

        # save all_points if everything was done correctly
        if (len(self.all_points) + len(removed_points) + len(refined_points)) == self.all_points_len:
            self.all_points = np.vstack([self.all_points, refined_points])
            np.save(f"{self.plot_path}/all_points.npy", self.all_points)

            if os.path.exists(f"{self.plot_path}/refine/remove.ply"):
                self.removed_points = np.vstack([self.removed_points, removed_points])
                np.save(f"{self.plot_path}/removed_points.npy", self.removed_points)
                os.remove(f"{self.plot_path}/refine/remove.ply")

            for key in self.mapping:
                if os.path.exists(f"{self.plot_path}/refine/{key}.ply"):
                    os.remove(f"{self.plot_path}/refine/{key}.ply")

            if os.path.exists(f"{self.plot_path}/refine/new.ply"):
                os.remove(f"{self.plot_path}/refine/new.ply")                
                
            if self.mapping[1] in self.todo:
                self.todo.remove(self.mapping[1])
            
            np.save(f"{self.plot_path}/todo.npy", np.array(self.todo))
                
            self.mapping = dict()
        else:
            raise Refinement_error


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [3]:
# concatenate and save all_points1 and all_points2
all_points1 = np.load("refine_data/all_points1.npy")
all_points2 = np.load("refine_data/all_points2.npy")
all_points = np.vstack([all_points1, all_points2])
np.save("refine_data/all_points.npy", all_points)

In [4]:
# define paths
plot_path = "C:/Users/jonat/Documents/Studium/Angewandte Statistik/4.Semester/MA/repos/tree_learning_refinement/refine_data"

In [5]:
# define refiner object
refiner = Refiner(plot_path=plot_path)

In [6]:
refiner.todo[0]

12

In [7]:
refiner.refine_tree(11)

In [8]:
refiner.tidy_up()

In [9]:
# triple check that points have correct length

# current length
print(len(refiner.all_points) + len(np.load("C:/Users/jonat/Documents/Studium/Angewandte Statistik/4.Semester/MA/repos/tree_learning_data/refine_data/removed_points.npy")))
# original length
print(88240807)

88240807
88240807


In [10]:
# before pushing
os.remove("refine_data/all_points.npy")