## Script: Plot Segmentation Using Otsu Thresholding

This code will take early season orthomosaic images of Canola fields and provide a column segmentation and a plot segmentation

### Import Required Modules

In [None]:
import skimage.util as util
import skimage.filters as filt
import skimage.morphology as morph
import skimage.measure as meas
import skimage.color as color
import skimage.exposure as exposure
import skimage.io as io
import skimage.segmentation as seg
import scipy.ndimage.morphology as morph_sc
import os as os
import numpy as np
import matplotlib.pyplot as plt
import cv2 as cv
import skimage.transform as transform
import skimage.restoration as restoration

%matplotlib inline

### Column Segmentation Algorithm

In [None]:
def segment_columns(I):
    '''
    Find column segmentation of the given image
    :param I: canola orthomosaic image to segment
    :return: Logical image where True pixels represent columns.
    '''
    
    # calculate green proportion of the image
    green_proportion = I[:, :, 1] / (np.sum(I, axis=2) + .0000001)
    
    
    # determine otsu threshold
    thresh = filt.threshold_otsu(green_proportion)

    
    # use the threshold to segment the image
    thresholded_image = green_proportion > thresh

    
    # post process the thresholded image
    processed_image = morph.remove_small_holes(thresholded_image, min_size=512)
    processed_image = morph.remove_small_objects(processed_image, min_size=1024)
    processed_image = morph_sc.binary_fill_holes(processed_image)

    processed_image = morph.erosion(processed_image, selem=morph.disk(4))


    # find six largest regions from the processed image
    label_image = meas.label(processed_image)
    props = meas.regionprops(label_image)

    areas = []

    for region in props:
        areas.append(region['area'])
 
    arr = np.sort(np.array(areas))
    
    if arr.shape[0] < 6:
        return ('Error! Found less than six regions.', '')
    
    largest_areas = arr[-6:]
    
    
    column_segment = np.zeros_like(processed_image, dtype=np.uint)

    for region in props:
        if region['area'] in largest_areas:
            column_segment[label_image == region.label] = 1
    
    
    # post process the image to increase region areas to counter the previous large erosion
    all_columns_segmented_image = morph.dilation(column_segment, selem=morph.disk(1))

    return ("success", all_columns_segmented_image.astype('bool'))

### Plot Segmentation Algorithm

In [None]:
def segment_plots(all_columns_segmented_image):
    '''
    Find plot segmentation from column segmentated image
    :param all_columns_segmented_image: A logical image where columns are marked as foreground with True
    :return: Logical image where True pixels represent plots.
    '''
    
    label_image = meas.label(all_columns_segmented_image)
    props = meas.regionprops(label_image)

    plot_each_column = 39
    all_plots_segmented_image = np.zeros_like(all_columns_segmented_image, dtype=np.uint8)
    
    # traverse each detected columns
    for region in props:
        column_segmented_image = np.zeros_like(all_columns_segmented_image, dtype=np.uint8)
        column_segmented_image[label_image == region.label] = 1

        # determine plot segmentation for each column
        plots = detect_plots_inside_column(region.bbox, image, column_segmented_image, plot_each_column)

        # join the segmented plots
        all_plots_segmented_image = np.logical_or(all_plots_segmented_image, plots)
    
    return all_plots_segmented_image



def count_plots(arr, threshold, binary_image, r1, r2, c1, c2):
    divider_list = []
    
    last_opening = -1
    last_closing = -1
    
    i = 0
    for point in arr:
        if point < threshold:
            if last_opening == -1:
                last_opening = i
        else:
            if last_opening != -1:
                divider_list.append((last_opening, i-1))
                last_opening = -1
        i+=1

    divider_list = np.array(divider_list)

    for point in divider_list:
        if point[1]-point[0] > 1:
            binary_image[r1 + point[0]:r1 + point[1]+1,c1:c2] = 0
    
    binary_image = morph.binary_opening(binary_image)
    
    label, cnt = meas.label(binary_image, return_num=True)
    props = meas.regionprops(label)
    
    areas = []
    for region in props:
        areas.append(region['area'])
    
    avg = np.average(np.array(areas))
    
    binary_image_cleaned = np.zeros_like(binary_image, dtype=np.uint)
 
    cnt = 0
    for region in props:
        if region['area'] > avg*0.05:
            cnt+=1
            binary_image_cleaned[region.label == label] = 1
    
    return (cnt, binary_image_cleaned)



def binary_search_threshold(arr, plot_number, binary_image, r1, r2, c1, c2):
    low = np.min(arr)
    high = np.max(arr)
    
    while low <= high:
        mid = (low + high)/2
        cnt, plot_image = count_plots(arr, mid, binary_image.copy(), r1, r2, c1, c2)
        
        if cnt > plot_number:
            high = mid
        elif cnt < plot_number:
            low = mid
        else:
            break
            
    return plot_image



def detect_plots_inside_column(bbox, original_image, column_segmented_image, plot_each_column):
    '''
    Find plot segmentation of a given column image
    :param bbox: Boundary box of the column
    :param original_image: Original input image
    :param column_segmented_image: Columns segmented image from the original image
    :param plot_each_column: Possible plots in each column
    :return: Logical image where True pixels represent plots in a column.
    '''
    
    minr, minc, maxr, maxc = bbox
    
    column_cropped = original_image[minr:maxr, minc:maxc]
    gray_image = color.rgb2gray(column_cropped)
    
    sum_of_row_pixels = np.sum(gray_image, axis=1)
    
    sum_sliding_window = []
    
    length = len(sum_of_row_pixels)

    i=0
    for row in sum_of_row_pixels:
        if i + 1 < length:
            sum_sliding_window.append(sum_of_row_pixels[i] + sum_of_row_pixels[i+1])
        i+=1

    sum_sliding_window = np.array(sum_sliding_window)
    plot_image = binary_search_threshold(sum_sliding_window, plot_each_column, column_segmented_image, minr, maxr, minc, maxc)
    
    return plot_image
    

### Implement Measures

In [None]:
def dice_coefficient(bwA, bwG):
    '''
    Dice coefficient between two binary images
    :param bwA: a binary (dtype='bool') image
    :param bwG: a binary (dtype='bool') image
    :return: the Dice coefficient between them
    '''
    intersection = np.logical_and(bwA, bwG)

    return 2.0*np.sum(intersection) / (np.sum(bwA) + np.sum(bwG))

### Helper Functions

In [None]:
def display_segmentation(image, segmentation, title):
    '''
    Display a segmentation superimposed on the given image
    :param image: image on which to superimpose the segmentation
    :param segmentation: a binary (dtype='bool') segmentation
    :param title: a string to show as figure title
    '''
    # create label image from the segmentation
    label = morph.label(segmentation)
    
    # mark the boundaries
    marked_image = seg.mark_boundaries(image, label, color=(1,.5,0))
    
    plt.figure(figsize=(40, 40))
    plt.imshow(marked_image)
    plt.title(title)
    plt.show()
    
    
    
def get_rgb_image(image): 
    '''
    Get RGB image from RGBA image if applicable
    :param image: RGB or RGBA image
    :return: RGB image of the given image
    '''
    if len(image.shape) > 2 and image.shape[2] == 4:
        return color.rgba2rgb(image)
    
    return image
    

### Validation Driver

This code will go through a set of sample input orthomosaics and segment plots and columns for those images. After that it will compute DSC for each type of segmentations.

The general approach for each image will be:
* Load the image and its ground truth images for plots and columns
* Segment the columns and return a binary image where columns regions are marked with 1
* Compute the DSC from the segmented image (column) and the ground truth image (column)
* Display the original image with the segmentation superimposed on top.

* Segment the plots using the binary image received from earlier column segmentation step and return a binary image where plot regions are marked with 1
* Compute the DSC from the segmented image (plots) and the ground truth image (plots)
* Display the original image with the segmentation superimposed on top.

In [None]:
# define image paths
images_path = os.path.join('../../', 'Dataset/Orthomosaics/images/')
column_masks_path = os.path.join('../../', 'Dataset/Orthomosaics/column_masks/')
plot_masks_path = os.path.join('../../', 'Dataset/Orthomosaics/plot_masks/')

column_dsc = []
plot_dsc = []

for root, dirs, files in os.walk(images_path):
    for filename in files:
        # ignore files that are not PNG files.
        if filename[-4:] != '.png':
            continue
        
        i+=1
        # read the images and convert to rgb image if rgba
        image = get_rgb_image(io.imread(images_path + filename))
        column_mask = get_rgb_image(io.imread(column_masks_path + filename))
        plot_mask = get_rgb_image(io.imread(plot_masks_path + filename))

        
        # resize the images with a width of 1980 and height is calculated based on the original image
        (height, width, z) = image.shape
    
        r_width = 1980

        image = transform.resize(image, (int(height * r_width/width), r_width))
        plot_mask = transform.resize(plot_mask, (int(height * r_width/width), r_width))
        column_mask = transform.resize(column_mask, (plot_mask.shape[0], plot_mask.shape[1]))

        
        # convert to logical image
        plot_mask = color.rgb2gray(plot_mask).astype('bool')
        column_mask = color.rgb2gray(column_mask).astype('bool')
        
        # determine column segmentation
        status, columns_segmentation = segment_columns(image)
        
        # check if column segmentation is successful
        if status != 'success':
            print(status)
            continue
            
        display_segmentation(image, columns_segmentation, "Coloumn Segmentation Marked on Original Image")
        dsc = dice_coefficient(column_mask, columns_segmentation)
       
        column_dsc.append(dsc)
        
        print("For " + filename)
        print("Computed DSC of Column Segmentation: " + str(dsc))
       
        # determine plots segmentation using columns segmented image
        plots_segmentation = segment_plots(columns_segmentation)
        display_segmentation(image, plots_segmentation, "Plots Segmentation Marked on Original Image")
        dsc = dice_coefficient(plot_mask, plots_segmentation)
        
        plot_dsc.append(dsc)
        
        print("Computed DSC of Plot Segmentation: " + str(dsc))

plot_dsc = np.array(plot_dsc)        
column_dsc = np.array(column_dsc)

print("\n\nComputed Mean DSC for column segmentation: " + str(np.mean(column_dsc)))
print("Computed Mean DSC for plot segmentation: " + str(np.mean(plot_dsc)))

In [None]:
import time
start = time.time()
a = range(1000000000)
b = []
for i in a:
    b.append(i*2)
end = time.time()
print(end - start)