# Necessary imports

In [None]:
import os 
import sys
import datetime 
import numpy as np
import shapely
import matplotlib.pyplot as plt
import geopandas as gpd
import pandas as pd
import osgeo
from osgeo import gdal, gdal_array, osr, ogr
from tifffile import imwrite
from scipy.spatial.distance import euclidean
from fastdtw import fastdtw, dtw
from skimage.color import rgb2gray
from skimage.filters import sobel
from skimage.segmentation import slic, quickshift, watershed, felzenszwalb
from skimage.segmentation import mark_boundaries
from skimage.metrics import (adapted_rand_error, variation_of_information)
from skimage.measure import label
from skimage.util import img_as_float
from sentinelhub import MimeType, CRS, BBox, SentinelHubRequest, SentinelHubDownloadClient, \
    DataCollection, bbox_to_dimensions, DownloadRequest, SHConfig

In [None]:
# Function for plotting images
def plot_image(image, factor=1.0, clip_range=None, **kwargs):
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(15, 15))
    if clip_range is not None:
        ax.imshow(np.clip(image * factor, *clip_range), **kwargs)
    else:
        ax.imshow(image * factor, **kwargs)
    ax.set_xticks([])
    ax.set_yticks([])

In [None]:
# Write your credentials here if you haven't already put them into config.json
CLIENT_ID = ''
CLIENT_SECRET = ''

config = SHConfig()
if CLIENT_ID and CLIENT_SECRET:
    config.sh_client_id = CLIENT_ID
    config.sh_client_secret = CLIENT_SECRET
    
if not config.sh_client_id or not config.sh_client_secret:
    print("Warning! To use Process API, please provide the credentials (OAuth client ID and client secret).")

# Choose the location based on WGS 84 coordinates 

#### Easiest way to choose the location is http://bboxfinder.com/

In [None]:
bbox_coords_wgs84 = [14.323528,50.094928,14.400433,50.139091]
bbox = BBox(bbox=bbox_coords_wgs84, crs='EPSG:4326')
resolution = 10
bbox_size = bbox_to_dimensions(bbox, resolution=resolution)

# RGB image generation used as a background

In [None]:
evalscript_true_color = """
    //VERSION=3

    function setup() {
        return {
            input: [{
                bands: ["B02", "B03", "B04"]
            }],
            output: {
                bands: 3
            }
        };
    }

    function evaluatePixel(sample) {
        return [3.5*sample.B04, 3.5*sample.B03, 3.5*sample.B02];
    }
"""
request_true_color = SentinelHubRequest(
    evalscript=evalscript_true_color,
    input_data=[
        SentinelHubRequest.input_data(
            data_collection=DataCollection.SENTINEL2_L2A,
            time_interval=('2020-04-01', '2020-12-31'),
            mosaicking_order='leastCC'
        )
    ],
    responses=[
        SentinelHubRequest.output_response('default', MimeType.PNG)
    ],
    bbox=bbox,
    size=bbox_size,
    config=config
)
rgb_tiff = request_true_color.get_data()[0]
plot_image(rgb_tiff/255)

# Definition of the time for request of the images

In [None]:
start = datetime.datetime(2021,1,1)
end = datetime.datetime(2021,11,30)
n_chunks = 23 #Number of images requested, always add +1
tdelta = (end - start) / n_chunks
edges = [(start + i*tdelta).date().isoformat() for i in range(n_chunks)]
slots = [(edges[i], edges[i+1]) for i in range(len(edges)-1)]

print('Monthly time windows:\n')
for slot in slots:
    print(slot)

# Evaluation script to request NDVI images along with cloud masks

In [None]:
evalscript_ndvi = """
    //VERSION=3
    function setup() {
        return {
            input: [{
                bands: ["B04", "B08", "CLM"],
                units: "DN"}],
            output: {
                bands: 2,
                sampleType: SampleType.FLOAT32}};}
    function evaluatePixel(sample) {
      let ndvi = (sample.B08 - sample.B04) / (sample.B08 + sample.B04)
      return [ ndvi, sample.CLM ]}
"""

In [None]:
def get_ndvi(time_interval):
    return SentinelHubRequest(
        evalscript=evalscript_ndvi,
        input_data=[
            SentinelHubRequest.input_data(
                data_collection=DataCollection.SENTINEL2_L2A,
                time_interval=time_interval,
                mosaicking_order='leastCC'
            )
        ],
        responses=[
            SentinelHubRequest.output_response('default', MimeType.TIFF)
        ],
        bbox=bbox,
        size=bbox_size,
        config=config
    )

# This might take a few seconds to minutes based on the area size and number of images requested

In [None]:
# create a list of requests
list_of_requests = [get_ndvi(slot) for slot in slots]
list_of_requests = [request.download_list[0] for request in list_of_requests]

# download data with multiple threads
data_ndvi = SentinelHubDownloadClient(config=config).download(list_of_requests, max_threads=5)

# Usage of the clouds mask to remove cloudy pixels

In [None]:
c = 0
while c < n_chunks-1:
    [N,C]=np.dsplit(data_ndvi[c],data_ndvi[c].shape[-1])
    data_ndvi[c]=(1-C)*N
    data_ndvi[c] = np.array(data_ndvi[c])
    data_ndvi[c][data_ndvi[c] == 0] = 'nan'
    c=c+1
stack=np.dstack((data_ndvi))

# Time to plot the images without cloud pixels

In [None]:
ncols = 4
nrows = 9 #This number is to be set based on the number of the images
aspect_ratio = bbox_size[0] / bbox_size[1]
subplot_kw = {'xticks': [], 'yticks': [], 'frame_on': False}

fig, axs = plt.subplots(ncols=ncols, nrows=nrows, figsize=(5 * ncols * aspect_ratio, 5 * nrows),
                        subplot_kw=subplot_kw)

for idx, image in enumerate(data_ndvi):
    ax = axs[idx // ncols][idx % ncols]
    ax.imshow(np.clip(image * 1, 2.5/255, 1))
    ax.set_title(f'{slots[idx][0]}  -  {slots[idx][1]}', fontsize=10)
    
plt.tight_layout()

# This function removes the images based on the maximum cloudness being set

In [None]:
max_cloudiness = 20

numrows = len(stack)
numcols = len(stack[0])
x=0
stay = []
xslots=[]

for x in range(0,n_chunks-1):
    cloudiness = np.isnan(data_ndvi[x]).sum()*100/(numcols * numrows)
    if cloudiness < max_cloudiness:
        stay.append(x)
        xslots.append(slots[x])
    else:
        print(f'The image from {slots[x]} contains {cloudiness}% of clouds and is therefore removed')

x=2
stacknew = np.dstack((data_ndvi[stay[0]],data_ndvi[stay[1]]))
while x<=len(stay)-1:
    stacknew = np.dstack((stacknew,data_ndvi[stay[x]]))
    x=x+1
stack = np.swapaxes(stacknew,0,2)
stack = np.swapaxes(stack,1,2)

# Visualization after removal of too cloudy images

In [None]:
ncols = 4
nrows = int((len(stack)/ ncols)+1)
aspect_ratio = bbox_size[0] / bbox_size[1]
subplot_kw = {'xticks': [], 'yticks': [], 'frame_on': False}

fig, axs = plt.subplots(ncols=ncols, nrows=nrows, figsize=(5 * ncols * aspect_ratio, 5 * nrows),
                        subplot_kw=subplot_kw)

for idx, image in enumerate(stack):
    ax = axs[idx // ncols][idx % ncols]
    ax.imshow(np.clip(image * 1, 2.5/255, 1))
    ax.set_title(f'{xslots[idx][0]}  -  {xslots[idx][1]}', fontsize=10)
    
plt.tight_layout()

# Interpolation and extrapolation of the pixels with 'nan' values

In [None]:
x=0
y=0
numrows = len(stack[0,:])
numcols = len(stack[0,0,:])

for x in range(0, numrows-1):
    for y in range(0, numcols-1):
        if np.isnan(stack[:,x,y]).any():
            stack[:,x,y]=pd.Series(stack[:,x,y]).interpolate(method = 'polynomial', order = 2,limit_area = None)
            stack[:,x,y]=pd.Series(stack[:,x,y]).interpolate(limit_area = None, limit_direction = 'forward')# extrapolation
            stack[:,x,y]=pd.Series(stack[:,x,y]).interpolate(limit_area = None, limit_direction = 'backward')# extrapolation

# Plotting the images after interpolation & extrapolation

In [None]:
ncols = 4
nrows = int(((len(stack)+1)/ ncols)+1)
aspect_ratio = bbox_size[0] / bbox_size[1]
subplot_kw = {'xticks': [], 'yticks': [], 'frame_on': False}

fig, axs = plt.subplots(ncols=ncols, nrows=nrows, figsize=(5 * ncols * aspect_ratio, 5 * nrows),
                        subplot_kw=subplot_kw)

for idx, image in enumerate(stack):
    ax = axs[idx // ncols][idx % ncols]
    ax.imshow(np.clip(image * 1, 2.5/255, 1))
    ax.set_title(f'{xslots[idx][0]}  -  {xslots[idx][1]}', fontsize=10)
    
plt.tight_layout()

# Putting the images into the multichannel image/ 3D array

In [None]:
stackx = np.swapaxes(stack,0,2)
stackx = np.swapaxes(stackx,0,1) 

# Creation of stacked image from 3 images to run other segmentation examples

In [None]:
# x, y and z represent the images to be used in other segmentations
x = int(len(stackx[1,1,:])/4)
y = int(len(stackx[1,1,:])/2)
z = int(len(stackx[1,1,:])/(4)*3)
imgq = np.dstack((stackx[:,:,x],stackx[:,:,y],stackx[:,:,z]))
imgq = np.nan_to_num(imgq, copy=True, nan=0.0, posinf=None, neginf=None)
imgq = imgq.astype('double')

# Segmentation based on one image, mainly for comparison and validation purpose

In [None]:
def image_segmentation (input_stack, max_distance, minimum_area):

    begin_time = datetime.datetime.now()
    input_stack=np.nan_to_num(input_stack, copy=True, nan=0.0, posinf=None, neginf=None)
    numrows = len(input_stack)
    numcols = len(input_stack[0])
    seg_output = np.ones((numrows, numcols))
    input_raster = np.zeros((numrows, numcols))
    x=0
    y=0
    n=1

    for x in range(0, numrows-1):
        for y in range(0, numcols-1):
            # raster s hodnotami NDVI
            current = input_stack[x,y]
            left = input_stack[x,y-1]
            upright = input_stack[x-1,y+1]
            right = input_stack[x,y+1] 
            rightbot = input_stack[x+1,y+1] 
            bot = input_stack[x+1,y] 
            leftbot = input_stack[x+1,y-1] 

            rupright  = abs(current - upright)
            rright  = abs(current - right)
            rrightbot = abs(current - rightbot)
            rbot = abs(current - bot)
            rleftbot = abs(current - leftbot)

            listn = [rupright, rright, rrightbot, rbot, rleftbot]
            maxcorr = max(listn)
            if maxcorr < max_distance: 
                
                op_current = seg_output[x,y]
                op_upright = seg_output[x-1,y+1]
                op_right = seg_output[x,y+1] 
                op_rightbot = seg_output[x+1,y+1] 
                op_bot = seg_output[x+1,y] 
                op_leftbot = seg_output[x+1,y-1]

                r_current = input_raster[x,y]
                r_upright = input_raster[x-1,y+1]
                r_right = input_raster[x,y+1] 
                r_rightbot = input_raster[x+1,y+1] 
                r_bot = input_raster[x+1,y] 
                r_leftbot = input_raster[x+1,y-1] 

                if seg_output[x,y] == 1:
                    n=n+1
                    seg_output[x,y] = n


                if rright < max_distance:
                    input_raster[x,y+1] = rright
                    if seg_output[x,y+1] != 1: 
                        np.place(seg_output, seg_output == seg_output[x,y+1], seg_output[x,y])
                        input_raster[x,y+1] = rright
                    else: 
                        seg_output[x,y+1] = seg_output[x,y]
                        input_raster[x,y+1] = rright

                if rrightbot < max_distance:
                    input_raster[x+1,y+1] = rrightbot
                    seg_output[x+1,y+1] = seg_output[x,y]

                if rbot < max_distance:
                    if (input_raster[x+1,y] != 0):
                        if (rbot > input_raster[x+1,y]):
                            input_raster[x+1,y] = rbot
                            seg_output[x+1,y] = seg_output[x,y]
                    else:
                        input_raster[x+1,y] = rbot
                        seg_output[x+1,y] = seg_output[x,y]      

                if rleftbot < max_distance:
                    if (input_raster[x+1,y-1] != 0):
                        if (rleftbot > input_raster[x+1,y-1] ):
                            input_raster[x+1,y-1] = rleftbot
                            seg_output[x+1,y-1] = seg_output[x,y]
                    else:
                        input_raster[x+1,y-1] = rleftbot
                        seg_output[x+1,y-1] = seg_output[x,y]    

                if rupright < max_distance:
                    if seg_output[x-1,y+1] != 1:
                        np.place(seg_output, seg_output == seg_output[x-1,y+1], seg_output[x,y])
                        input_raster[x-1,y+1] = rupright

    img = seg_output
    img[np.sum(img==img[:,None],axis=1)<minimum_area] = 1
    img[np.where(img <2)] = 0
    img[np.where(img >1)] = 1
    print(datetime.datetime.now() - begin_time)
    plot_image(mark_boundaries(rgb_tiff/255, seg_output.astype('int')))
    return img

### Example of how to run segmentation based on one image

In [None]:
one_image_example = image_segmentation(stackx[:,:,5], 0.15,  5)

# Definition of the segmentation function based on Euclidean Distance

### Much faster then DTW, but also less accurate

In [None]:
def euclidean_segmentation (input_stack, ed, minimum_area):

    begin_time = datetime.datetime.now()
    input_stack=np.nan_to_num(input_stack, copy=True, nan=0.0, posinf=None, neginf=None)
    numrows = len(input_stack)
    numcols = len(input_stack[0])
    seg_output = np.ones((numrows, numcols))
    input_raster = np.zeros((numrows, numcols))
    x=0
    y=0
    n=1

    for x in range(0, numrows-1):
        for y in range(0, numcols-1):
            current = input_stack[x,y,:]
            left = input_stack[x,y-1,:]
            upright = input_stack[x-1,y+1,:]
            right = input_stack[x,y+1,:] 
            rightbot = input_stack[x+1,y+1,:] 
            bot = input_stack[x+1,y,:] 
            leftbot = input_stack[x+1,y-1,:] 

            rupright  = sum(abs(current - upright))
            rright  = sum(abs(current - right))
            rrightbot = sum(abs(current - rightbot))
            rbot = sum(abs(current - bot))
            rleftbot = sum(abs(current - leftbot))

            listn = [rupright, rright, rrightbot, rbot, rleftbot]
            maxcorr = max(listn)
            if maxcorr < ed: 

                op_current = seg_output[x,y]
                op_upright = seg_output[x-1,y+1]
                op_right = seg_output[x,y+1] 
                op_rightbot = seg_output[x+1,y+1] 
                op_bot = seg_output[x+1,y] 
                op_leftbot = seg_output[x+1,y-1]

                r_current = input_raster[x,y]
                r_upright = input_raster[x-1,y+1]
                r_right = input_raster[x,y+1] 
                r_rightbot = input_raster[x+1,y+1] 
                r_bot = input_raster[x+1,y] 
                r_leftbot = input_raster[x+1,y-1] 

                if seg_output[x,y] == 1:
                    n=n+1
                    seg_output[x,y] = n

                if rright < ed:
                    input_raster[x,y+1] = rright
                    if seg_output[x,y+1] != 1: 
                        np.place(seg_output, seg_output == seg_output[x,y+1], seg_output[x,y])
                        input_raster[x,y+1] = rright
                    else: 
                        seg_output[x,y+1] = seg_output[x,y]
                        input_raster[x,y+1] = rright

                if rrightbot < ed:
                    input_raster[x+1,y+1] = rrightbot
                    seg_output[x+1,y+1] = seg_output[x,y]

                if rbot < ed:
                    if (input_raster[x+1,y] != 0):
                        if (rbot > input_raster[x+1,y]):
                            input_raster[x+1,y] = rbot
                            seg_output[x+1,y] = seg_output[x,y]
                    else:
                        input_raster[x+1,y] = rbot
                        seg_output[x+1,y] = seg_output[x,y]         

                if rleftbot < ed:
                    if (input_raster[x+1,y-1] != 0):
                        if (rleftbot > input_raster[x+1,y-1] ):
                            input_raster[x+1,y-1] = rleftbot
                            seg_output[x+1,y-1] = seg_output[x,y]
                    else:
                        input_raster[x+1,y-1] = rleftbot
                        seg_output[x+1,y-1] = seg_output[x,y]         

                if rupright < ed:
                    if seg_output[x-1,y+1] != 1:
                        np.place(seg_output, seg_output == seg_output[x-1,y+1], seg_output[x,y])
                        input_raster[x-1,y+1] = rupright

    img = seg_output
    img[np.sum(img==img[:,None],axis=1)<minimum_area] = 1
    img[np.where(img <2)] = 0
    img[np.where(img >1)] = 1
    print(datetime.datetime.now() - begin_time)
    plot_image(mark_boundaries(rgb_tiff/255, seg_output.astype('int')))
    return img

## Example of function to run, input stack is set, along with maximum distance for segmentation, last attribute is minimal size of one segment

In [None]:
ed_segmentation_example = euclidean_segmentation(stackx, 0.35,  5)

# Segmentation based on DTW using any number of images

In [None]:
def DTW_segmentation (input_stack, dtw, minimum_area):

    begin_time = datetime.datetime.now()
    input_stack=np.nan_to_num(input_stack, copy=True, nan=0.0, posinf=None, neginf=None)
    numrows = len(input_stack)
    numcols = len(input_stack[0])
    seg_output = np.ones((numrows, numcols))
    input_raster = np.zeros((numrows, numcols))
    x=0
    y=0
    n=1

    for x in range(0, numrows-1):
        for y in range(0, numcols-1):
            current = input_stack[x,y,:]
            left = input_stack[x,y-1,:]
            upright = input_stack[x-1,y+1,:]
            right = input_stack[x,y+1,:] 
            rightbot = input_stack[x+1,y+1,:] 
            bot = input_stack[x+1,y,:] 
            leftbot = input_stack[x+1,y-1,:] 

            rupright, path  = fastdtw(current, upright)
            rright, path  = fastdtw(current, right)
            rrightbot, path  = fastdtw(current, rightbot)
            rbot, path  = fastdtw(current, bot)
            rleftbot, path  = fastdtw(current, leftbot)

            listn = [rupright, rright, rrightbot, rbot, rleftbot]
            maxcorr = max(listn)
            if maxcorr < dtw: 

                op_current = seg_output[x,y]
                op_upright = seg_output[x-1,y+1]
                op_right = seg_output[x,y+1] 
                op_rightbot = seg_output[x+1,y+1] 
                op_bot = seg_output[x+1,y] 
                op_leftbot = seg_output[x+1,y-1]

                r_current = input_raster[x,y]
                r_upright = input_raster[x-1,y+1]
                r_right = input_raster[x,y+1] 
                r_rightbot = input_raster[x+1,y+1] 
                r_bot = input_raster[x+1,y] 
                r_leftbot = input_raster[x+1,y-1] 

                if seg_output[x,y] == 1:
                    n=n+1
                    seg_output[x,y] = n

                if rright < dtw:
                    input_raster[x,y+1] = rright
                    if seg_output[x,y+1] != 1: 
                        np.place(seg_output, seg_output == seg_output[x,y+1], seg_output[x,y])
                        input_raster[x,y+1] = rright
                    else: 
                        seg_output[x,y+1] = seg_output[x,y]
                        input_raster[x,y+1] = rright

                if rrightbot < dtw:
                    input_raster[x+1,y+1] = rrightbot
                    seg_output[x+1,y+1] = seg_output[x,y]

                if rbot < dtw:
                    if (input_raster[x+1,y] != 0):
                        if (rbot > input_raster[x+1,y]):
                            input_raster[x+1,y] = rbot
                            seg_output[x+1,y] = seg_output[x,y]
                    else:
                        input_raster[x+1,y] = rbot
                        seg_output[x+1,y] = seg_output[x,y]

                if rleftbot < dtw:
                    if (input_raster[x+1,y-1] != 0):
                        if (rleftbot > input_raster[x+1,y-1] ):
                            input_raster[x+1,y-1] = rleftbot
                            seg_output[x+1,y-1] = seg_output[x,y]
                    else:
                        input_raster[x+1,y-1] = rleftbot
                        seg_output[x+1,y-1] = seg_output[x,y]

                if rupright < dtw:
                    if seg_output[x-1,y+1] != 1:
                        np.place(seg_output, seg_output == seg_output[x-1,y+1], seg_output[x,y])
                        input_raster[x-1,y+1] = rupright

    img = seg_output
    img[np.sum(img==img[:,None],axis=1)<minimum_area] = 1
    img[np.where(img <2)] = 0
    img[np.where(img >1)] = 1
    print(datetime.datetime.now() - begin_time)
    plot_image(mark_boundaries(rgb_tiff/255, seg_output.astype('int')))
    return img

In [None]:
DTW_segmentation_example = DTW_segmentation(stackx, 0.4,  5)

# Quickshift Segmentation (based on three images maximum)

In [None]:
img = imgq.astype('double')
segments_quick = quickshift(imgq, kernel_size=4, max_dist=30, ratio=0.8, sigma=1)
plot_image(mark_boundaries(rgb_tiff/255, segments_quick))
print(f'Quickshift number of segments: {len(np.unique(segments_quick))}')

# Watershed segmentation (based on three images maximum)

In [None]:
img = stackx[:,:,5].astype('double')
gradient = sobel(rgb2gray(img))
segments_watershed = watershed(gradient, markers=400, compactness=0.000006)
plot_image(mark_boundaries(rgb_tiff/255, segments_watershed))
print(f'Quickshift number of segments: {len(np.unique(segments_watershed))}')

# SLIC segmentation (based on three images maximum)

In [None]:
img = rgb_tiff.astype('float')
segments_slic = slic(img, n_segments=500, compactness=1000)
plot_image(mark_boundaries(rgb_tiff/255, segments_slic))
print(f'Quickshift number of segments: {len(np.unique(segments_slic))}')

# Felzenszwalb segmentation, works on multiple images as well, fastest, well working method

In [None]:
img = stackx.astype('double')
segments_fz = felzenszwalb(img, scale=1500, sigma=0, min_size=5)
plot_image(mark_boundaries(rgb_tiff/255, segments_fz))
print(f'Felzenszwalb number of segments: {len(np.unique(segments_fz))}')

# Import of ground truth data, needs to be in format of binary raster

In [None]:
ds = gdal.Open("final_raster.tif")
im_true = np.array(ds.GetRasterBand(1).ReadAsArray()).astype('int')
plot_image(mark_boundaries(rgb_tiff/255, im_true))

# Validation of set segmentations

In [None]:
image = rgb_tiff

init_ls = np.zeros(image.shape, dtype=np.int8)
init_ls[10:-10, 10:-10] = 1
im_test1=DTW_segmentation_example.astype('int')
im_test2=ed_segmentation_example.astype('int')
im_test3=one_image_example.astype('int')


method_names = ['DTW segmentation', 'Euclidean Distance segmentation','Segmentation based on one image']

precision_list = []
recall_list = []
split_list = []
merge_list = []
for name, im_test in zip(method_names, [im_test1, im_test2, im_test3]):
    error, precision, recall = adapted_rand_error(im_true, im_test)
    splits, merges = variation_of_information(im_true, im_test)
    split_list.append(splits)
    merge_list.append(merges)
    precision_list.append(precision)
    recall_list.append(recall)
    print(f"\n## Method: {name}")
    print(f"Adapted Rand error: {error}")
    print(f"Adapted Rand precision: {precision}")
    print(f"Adapted Rand recall: {recall}")
    print(f"False Splits: {splits}")
    print(f"False Merges: {merges}")
    
short_method_names = ['DTW', 'ED', '1Im']

fig, axes = plt.subplots(2, 3, figsize=(9, 6), constrained_layout=True)
ax = axes.ravel()

ax[0].scatter(merge_list, split_list)
for i, txt in enumerate(short_method_names):
    ax[0].annotate(txt, (merge_list[i], split_list[i]),
                   verticalalignment='center')
ax[0].set_xlabel('False Merges (bits)')
ax[0].set_ylabel('False Splits (bits)')
ax[0].set_title('Split Variation of Information')

ax[1].scatter(precision_list, recall_list)
for i, txt in enumerate(short_method_names):
    ax[1].annotate(txt, (precision_list[i], recall_list[i]),
                   verticalalignment='center')
ax[1].set_xlabel('Precision')
ax[1].set_ylabel('Recall')
ax[1].set_title('Adapted Rand precision vs. recall')
ax[1].set_xlim(0, 1)
ax[1].set_ylim(0, 1)

ax[2].imshow(mark_boundaries(image, im_true))
ax[2].set_title('True Segmentation')
ax[2].set_axis_off()

ax[3].imshow(mark_boundaries(image, im_test1))
ax[3].set_title('DTW Segmentation')
ax[3].set_axis_off()

ax[4].imshow(mark_boundaries(image, im_test2))
ax[4].set_title('Euclidean Distance Segmentation')
ax[4].set_axis_off()

ax[5].imshow(mark_boundaries(image, im_test3))
ax[5].set_title('One Image Segmentation')
ax[5].set_axis_off()

plt.show()

# Export of the chosen segmentation output

In [None]:
def export_output (input, output_name): 
    xmin,ymin,xmax,ymax = bbox_coords_wgs84
    nrows,ncols = np.shape(input)
    xres = (xmax-xmin)/float(ncols)
    yres = (ymax-ymin)/float(nrows)
    geotransform=(xmin,xres,0,ymax,0, -yres)   

    output_raster = gdal.GetDriverByName('GTiff').Create((output_name),ncols, nrows, 1 ,gdal.GDT_Float32)  
    output_raster.SetGeoTransform(geotransform)  
    srs = osr.SpatialReference()                 
    srs.ImportFromEPSG(4326)                     
    output_raster.SetProjection( srs.ExportToWkt() )   
    output_raster.GetRasterBand(1).WriteArray(input)  
    output_raster.FlushCache()

# Example how to run the export

In [None]:
export_output (DTW_segmentation_example, 'dtw_segmentation_example.tif')