# Background
Forests are stressed globally by human activities, whether it’s logging, land-use change for agriculture and industry, or pollution and climate change. These forests store carbon, sustain biodiversity, and provide clean air and water. We will continue to rely on these ecosystems going forward but monitoring the state of forest resources is very challenging, given their extent (a third of the global land area) and their complex and dynamic nature.
Remote-sensing techniques, including the use of light detection and ranging (LiDAR) can increase the volume and speed of spatial information capture in forests. Capturing forest structures with LiDAR produces point clouds from which additional information, like the size and shape of a tree, can be extracted. But before information about a tree can be produced from this data, we need to differentiate between the trees in the point cloud.
Preexisting tree segmentation algorithms tend to split the data into voxels, calculate eigenvectors for each point in the voxels, and use this information to find surfaces in the point cloud that can be combined to form stems or branches. While exact techniques vary, a common requirement of these methods is sufficiently dense point clouds. With the insight that in a point cloud flattened to visualize just the x and y dimensions, vertical tree stems would produce dense clusters, We produced an algorithm that could segment trees from single LiDAR scans, by applying a density threshold to detect stems, and clustering points from the bottom up.

# Equipment

To gather the data we used an ouster OS1-64 portable LiDAR scanner, which captures points in 64 vertical channels covering a 45 degree angle vertically and 360 degrees horizontally. The maximum range is 120m and the precision is close to plusminus 1mm.
The Ouster SDK provides code for connecting the sensor to a laptop and recording scans.
To scan outdoors we needed an external power source and a tripod, raising the sensor to 1.3m in height.
We captured 10 scans in 6 plots at 2 locations, and 7 of these were segmented with the algorithm.

# Algorithm

Once the data is loaded as a pandas DataFrame, and an area within the point cloud selected for segmentation, the first step is to remove the terrain or ground points from the data. This loop selects square meter increments, finds the minimum Z-dimension value in the square, and removes the bottom 5cm slice of points above that minimum. The next square meter is selected and the and the process repeated until the ground points have been removed. The results are displayed and the data split into stems (up to 2m in height) and crowns (above 2m in height).
The next step is to assign a density value to every point in the stem point cloud. The scipy.spatial KDTree class produces a list of neighbours within a set radius from each point. The length of this list is appended for each point as a density value in the dataframe. This loop asks the user to input a minimum density threshold, under which points will be removed. KMeans is then applied with a user specified number of clusters, segmenting the stems.
Some small functions were defined to implement the clustering algorithm and the later tree metric calculations. The, closest_node function, for a given point, finds the nearest point in a separate array or list. This is used later for clustering points in the canopy to the appropriate stems. The furthest_node function finds the greatest distance between points in an array or list and is used to calculate DBH and Crown width for individual trees.
The final step is to cluster the crowns to the appropriate stems. This is done with the following loop, which takes a 5cm thick horizontal slice at a 2m height, then using the closest_node function, clusters each point in the slice to the nearest stem. If a point is more than 1m from any stem, it is not clustered. The next slice from 2.05 to 2.10 m height is taken and clustered, and this is repeated until the maximum z-dimension value is reached and all points are clustered or discarded. Any objects below 2m in height are removed automatically and the final results are displayed.

# Outcomes

Our method succeeded in segmenting the majority of trees in the eight plots tested and outperformed an existing algorithm. The algorithm performed better in stands with fewer non-tree objects and less understory vegetation, and individual trees were segmented better in the sparse data than in the dense data (which is the opposite of 3D Forest). Calculations of individual tree metrics can still be improved, and the algorithm as a whole could be improved by testing different radii for KDTree, using DBSCAN rather an KMeans in some cases, adding a top to bottom clustering pass or adjusting the 1m limit beyond which points are discarded, and differentiating between foliage and wood, possibly using reflectivity and/or infrared values. Finally, with a big enough sample size I would love to apply machine learning techniques to predict attributes such as tree and understory species from the point cloud.

In [None]:
#import libraries used in segmentation algorithm
import pandas as pd

import re, seaborn as sns
import plotly.graph_objects as go
import numpy as np
import math

from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import ListedColormap

import scipy.spatial as spatial
from sklearn.cluster import KMeans

from numpy import random
from scipy.spatial import distance

from IPython.display import clear_output

#function to find the closest point in the dataset to a new point by euclidean distance
def closest_node(node, nodes):
    closest_index = distance.cdist([node], nodes).argmin()
    closest_distance = distance.cdist([node], nodes).min()
    
    return closest_index, closest_distance

#function to find the furthest point in the dataset to another point by euclidean distance
def furthest_node(nodes):
    furthest_distance = distance.cdist(nodes, nodes).max()
    
    return furthest_distance

#function to roundup values
def round_up(n, nearest_value):
    multiplier = 1 / nearest_value
    return math.ceil(n * multiplier) / multiplier

#segmentation application function
def segment_interact():
    
    #import the LiDAR data csv file
    correct_file = 'n'
    
    while correct_file.lower()[0] != 'y':
        scan_file = input('File Name: ')
        scan = pd.read_csv(scan_file, low_memory=False)
        scan = scan[scan[' X (mm)'] != 0]
        
        #check if correct scan has been imported
        print(scan.head())
        correct_file = input('Is this the correct file? (y/n): ')
    
    #show 2d plot of the scan
    sns.set(rc = {'figure.figsize':(10,10)})
    sns.scatterplot(data=scan, x=" X (mm)", y=" Y (mm)", s=1)
    plt.show()
    
    #select plot area to be segmented or use entire scan
    select_area = input('Select a smaller plot within the scan? (y/n): ')
    
    if select_area.lower()[0] == 'y':
        
        adjust_plot = 'y'
        
        while adjust_plot.lower()[0] != 'n':
        
            set_x_min = int(input('Select a minimum X value (mm): '))
            set_x_max = int(input('Select a maximum X value (mm): '))
            set_y_min = int(input('Select a minimum Y value (mm): '))
            set_y_max = int(input('Select a maximum Y value (mm): '))

            scan_plot = scan[(scan[' X (mm)'] >= set_x_min) & (scan[' X (mm)'] <= set_x_max) & 
                             (scan[' Y (mm)'] >= set_y_min) & (scan[' Y (mm)'] <= set_y_max)]
            
            #show 2d plot of the plot
            sns.scatterplot(data=scan_plot, x=" X (mm)", y=" Y (mm)", s=1)
            plt.show()
            
            #check if any further adjustments to plot boundaries are needed
            adjust_plot = input('Would you like to further adjust the plot boundaries? (y/n): ')
        
    else:
        scan_plot = scan
        set_x_min = round_up(scan[' X (mm)'].min(), 1000)
        set_x_max = round_up(scan[' X (mm)'].max(), 1000)
        set_y_min = round_up(scan[' Y (mm)'].min(), 1000)
        set_y_max = round_up(scan[' Y (mm)'].max(), 1000)
        
        
    #remove groundpoints (all points within 5cm of lowest point for each square meter of the plot)
    current_x_min = set_x_min
    current_y_min = set_y_min
    ground_removed = scan_plot[0:0]

    while current_y_min != set_y_max:
        while current_x_min != set_x_max:
            current_area = scan_plot[(scan_plot[' X (mm)'] >= current_x_min) & (scan_plot[' X (mm)'] < current_x_min + 1000) & 
                                     (scan_plot[' Y (mm)'] >= current_y_min) & (scan_plot[' Y (mm)'] <= current_y_min + 1000)]
            square_z_min = current_area[' Z (mm)'].min()
            current_area = current_area[(current_area[' Z (mm)'] >= square_z_min + 50)]
            ground_removed = pd.concat([ground_removed, current_area])
            current_x_min += 1000

        current_y_min += 1000
        current_x_min = set_x_min
    
    #clear previous output and plot data with ground points removed
    clear_output(wait=True)
    sns.scatterplot(data=ground_removed, x=" X (mm)", y=" Y (mm)", s=10)
    plt.show()
    
    fig = plt.figure(figsize=(10,10))
    ax = Axes3D(fig, auto_add_to_figure=False)
    fig.add_axes(ax)
    sc = ax.scatter(ground_removed[' X (mm)'], ground_removed[' Y (mm)'], ground_removed[' Z (mm)'], s=1, marker='o', alpha=1)
    plt.show()
    
    #begin segmentation of trees in the plot
    print('The ground points (i.e. terrain surface) have also been removed from the data')
    continue_segmentation = input('Proceed with segmentation of this data? (y/n): ')
    
    if continue_segmentation.lower()[0] != 'y':
        print('Rerun the function to start again.')
        
    else:
        #divide the data into below 2m (stem) and above 2m (crown) dataframes
        z_min = ground_removed[' Z (mm)'].min()
        z_max = round_up(ground_removed[' Z (mm)'].max(), 100)
        
        stems = ground_removed[ground_removed[' Z (mm)'] <= z_min + 2000]
        crown = ground_removed[ground_removed[' Z (mm)'] > z_min + 2000]
        
        #determine point density in x,y dimensions for all points
        xy_array = stems[[" X (mm)", " Y (mm)"]].to_numpy() 

        tree = spatial.KDTree(np.array(xy_array))
        radius = 250

        neighbors = tree.query_ball_tree(tree, radius)
        len_list = [len(lst) for lst in neighbors]

        stems = stems.assign(DENSITY=len_list)
        
        #plot stem dataset
        clear_output(wait=True)
        sns.scatterplot(data=stems, x=" X (mm)", y=" Y (mm)", s=10)
        plt.show()
        
        fig = plt.figure(figsize=(10,10))
        ax = Axes3D(fig, auto_add_to_figure=False)
        fig.add_axes(ax)
        sc = ax.scatter(stems[' X (mm)'], stems[' Y (mm)'], stems[' Z (mm)'], s=1, marker='o', alpha=1)
        plt.show()
        
        #Remove non-stem points from the dataset by selecting a minimum point density
        mean_stem_density = stems['DENSITY'].mean()
        print(f'The mean density for points in the 0 to 2 m height range is {mean_stem_density} points per 25 cm radius')
        print('We recommend choosing an initial minimum density approximately equal to this value.')
        
        stems_discovered = 'y'
        
        while stems_discovered.lower()[0] != 'n':
            min_density = int(input('Enter a minimum density: '))
            stems_min_density = stems[stems['DENSITY'] >= min_density]
            clear_output(wait=True)
            sns.scatterplot(data=stems_min_density, x=" X (mm)", y=" Y (mm)", s=10)
            plt.show()
            
            fig = plt.figure(figsize=(10,10))
            ax = Axes3D(fig, auto_add_to_figure=False)
            fig.add_axes(ax)
            sc = ax.scatter(stems_min_density[' X (mm)'], stems_min_density[' Y (mm)'], stems_min_density[' Z (mm)'], s=1, marker='o', alpha=1)
            plt.show()
            
            print('Stems should now appear as clear clusters, matching the number of trees to be segmented.')
            print('If this is not the case, the minimum density may need to be increased or decreased')
            stems_discovered = input('Would you like to adjust the minimum density? (y/n): ')
        
        #apply k-means to cluster stems
        clusters = int(input('Enter the number of stems to be clustered: '))
        
        stem_array = stems_min_density[[" X (mm)", " Y (mm)"]].to_numpy() 

        kmeans = KMeans(clusters, random_state=0).fit(stem_array).predict(stem_array)
            
        clear_output(wait=True)
        plt.scatter(stem_array[:,0], stem_array[:,1], c=kmeans, s=40, cmap='viridis')
        plt.show()
        print(f'Number of stems: {clusters}')
        
        stems_clustered = stems_min_density.assign(CLUSTER=kmeans)
        stems_clustered = stems_clustered.drop('DENSITY', axis=1)
        
        #segmentation algorithm to add crowns to stems
        #first pass ascending
        current_z_min = z_min + 2000
        current_z_max = z_min + 2050

        while current_z_max <= z_max:

            next_slice_df = crown[(crown[' Z (mm)'] > current_z_min) & (crown[' Z (mm)'] <= current_z_max)]

            clusters = stems_clustered[[" X (mm)", " Y (mm)", " Z (mm)"]].to_numpy()
            next_slice = next_slice_df[[" X (mm)", " Y (mm)", " Z (mm)"]].to_numpy()

            cluster_list = []

            for point in next_slice:
                closest_index, closest_distance = closest_node(point, clusters)

                if closest_distance <= 1000:
                    cluster_list.append(stems_clustered.iloc[closest_index].CLUSTER)
                else:
                    cluster_list.append(-1)

            next_slice_df = next_slice_df.assign(CLUSTER=cluster_list)
            stems_clustered = pd.concat([stems_clustered, next_slice_df], axis=0)

            current_z_min += 50
            current_z_max += 50
        
        #separation of clustered and unclustered points
        stems_clustered_stemsonly = stems_clustered[stems_clustered.CLUSTER != -1]
        
        #automatic removal of objects < 2m in height
        trees = np.sort(stems_clustered_stemsonly['CLUSTER'].unique()).tolist()
        position = 0
        count = 0
        
        while position != len(trees):
            current_tree = stems_clustered_stemsonly[stems_clustered_stemsonly['CLUSTER'] == position]
            z_min = current_tree[' Z (mm)'].min()
            z_max = current_tree[' Z (mm)'].max()
            
            if (z_max - z_min) < 2000:
                stems_clustered_stemsonly = stems_clustered_stemsonly[stems_clustered_stemsonly['CLUSTER'] != position]
                count += 1
                
            position += 1
        
        print(f'{count} objects were removed from the dataset due to insufficient height.')
        
        #plot clustered trees
        fig = plt.figure(figsize=(10,10))
        ax = Axes3D(fig, auto_add_to_figure=False)
        fig.add_axes(ax)

        sc = ax.scatter(stems_clustered_stemsonly[' X (mm)'], stems_clustered_stemsonly[' Y (mm)'], stems_clustered_stemsonly[' Z (mm)'], 
                        c=stems_clustered_stemsonly['CLUSTER'], s=10, marker='o', alpha=1, cmap='viridis')
        ax.view_init(0, 180)
        plt.show()

        fig = go.Figure()

        for tree in stems_clustered_stemsonly['CLUSTER'].unique().tolist():
            tree_df = stems_clustered_stemsonly[stems_clustered_stemsonly['CLUSTER'] == tree]
            fig.add_trace(go.Mesh3d(x = tree_df[' X (mm)'], y = tree_df[' Y (mm)'], 
                                    z = tree_df[' Z (mm)'], opacity=0.5))

        fig.update_yaxes(
            scaleanchor = "x",
            scaleratio = 1,
          )

        fig.show()
        
    
    print('Segmentation Complete')
    segmented_trees = stems_clustered_stemsonly
    
    return segmented_trees

#Review individual tree point clouds and metrics
def individual_tree_metrics(segmented_df):
    #setup new dataframe to be edited and select first tree for review   
    metric_df = segmented_df
    metric_df['CLUSTER'] = metric_df['CLUSTER'].astype(int)
    
    trees = np.sort(metric_df['CLUSTER'].unique()).tolist()
    print(f'The following is a list of all segmented trees in the dataset: {trees}')
    print(f'There are a total of {len(trees)} trees')
    position = 0
    selected_tree = trees[position]
    
    continue_review = 'y'
    
    tree_lst = []
    width_lst = []
    height_lst = []
    dbh_lst = []
    
    #Measurement of Tree Metrics and Review Process
    while continue_review.lower()[0] != 'n' and position != len(trees):
        
        current_tree = metric_df[metric_df['CLUSTER'] == selected_tree]
        
        fig = plt.figure(figsize=(10,10))
        ax = Axes3D(fig, auto_add_to_figure=False)
        fig.add_axes(ax)

        sc = ax.scatter(current_tree[' X (mm)'], current_tree[' Y (mm)'], current_tree[' Z (mm)'], s=10, marker='o', alpha=1)
        ax.view_init(0, 180)
        plt.show()
        
        #Calculate and Print Crown Width
        print(f'The current tree is tree {selected_tree}')
        xy_points = current_tree[[" X (mm)", " Y (mm)"]].to_numpy()
        crown_width = furthest_node(xy_points)
        print(f'Crown Width for Tree {selected_tree}: {crown_width}')

        #Calculate and Print Tree Height
        z_min = current_tree[' Z (mm)'].min()
        z_max = current_tree[' Z (mm)'].max()
        print(f'Tree Height for Tree {selected_tree}: {z_max - z_min}')
        
        #Calculate and Print DBH
        z_values = current_tree[" Z (mm)"].to_numpy()
        bh = z_min + 1300
        closest_index = (abs(z_values - bh)).argmin()
        bh_point = int(current_tree.iloc[closest_index][" Z (mm)"])
        dbh_range = current_tree[(current_tree[" Z (mm)"] >= bh_point - 10) & (current_tree[" Z (mm)"] <= bh_point + 10)]
        dbh_xy = dbh_range[[" X (mm)", " Y (mm)"]].to_numpy()
        dbh = furthest_node(dbh_xy)
        print(f'DBH for Tree {selected_tree}: {dbh}')
       
        #differentiate between object that do or do not meet minimum dbh requirements
        if (z_max - z_min) < 2000 or dbh < 50:
            print('This tree is either below 2m tall or has a DBH below 5cm.')
            print('This may not be a tree or it may be at least partially occluded.')
            
            delete_tree = input('Would you like to remove this tree from the data? (y/n): ')
            
            #choose to remove tree or not
            if delete_tree.lower()[0] == 'y':
                metric_df = metric_df[metric_df['CLUSTER'] != selected_tree]
                print('Tree removed from data.')
            else:
                tree_lst.append(selected_tree)
                width_lst.append(crown_width)
                height_lst.append(z_max - z_min)
                dbh_lst.append(dbh)
        
        else:
            print('This tree appears to meet the minimum requirements for height and DBH.')
            delete_tree = input('Would you like to remove this tree from the data? (y/n): ')
            
            #choose to remove tree or not
            if delete_tree.lower()[0] == 'y':
                metric_df = metric_df[metric_df['CLUSTER'] != selected_tree]
                print('Tree removed from data.')
            else:
                tree_lst.append(selected_tree)
                width_lst.append(crown_width)
                height_lst.append(z_max - z_min)
                dbh_lst.append(dbh)
        
        #move to next tree
        continue_review = input('Would you like to observe the next tree? (y/n): ')
        clear_output(wait=True)
        position += 1
        if position != len(trees):
            selected_tree = trees[position]
    
    #Plot final metric results
    fig = plt.figure(figsize=(10,10))
    ax = Axes3D(fig, auto_add_to_figure=False)
    fig.add_axes(ax)

    sc = ax.scatter(metric_df[' X (mm)'], metric_df[' Y (mm)'], metric_df[' Z (mm)'], c=metric_df['CLUSTER'], 
                    s=10, marker='o', alpha=1, cmap='viridis')
    ax.view_init(0, 180)
    plt.show()
    

    fig = go.Figure()

    for tree in metric_df['CLUSTER'].unique().tolist():
        tree_df = metric_df[metric_df['CLUSTER'] == tree]
        fig.add_trace(go.Mesh3d(x = tree_df[' X (mm)'], y = tree_df[' Y (mm)'], z = tree_df[' Z (mm)'], opacity=0.5))

    fig.update_yaxes(
        scaleanchor = "x",
        scaleratio = 1,
      )

    fig.show()
        
    print('Review of individual tree metrics complete')
    metric_table = pd.DataFrame(list(zip(tree_lst, width_lst, height_lst, dbh_lst)), columns =['Tree', 'Crown Width', 'Height', 'DBH'])
    
    return metric_df, metric_table

In [None]:
#Run segmentation algorithm
segmented_trees = segment_interact()

In [None]:
#Review segmented trees
metric_df, metric_table = individual_tree_metrics(segmented_trees)

In [None]:
#Save segmentation results as .csv
print(metric_df)
plot_name = 'for_accuracy'
file_name = f'selected_segmented_trees_{plot_name}'
metric_df.to_csv(file_name)

In [None]:
#Save metric results and .csv 
print(metric_table)
plot_name = '3dtrees'
file_name = f'segmented_tree_metrics_{plot_name}'
metric_df.to_csv(file_name)