In [None]:
%matplotlib inline
import matplotlib.image as mpimg
import numpy as np
import matplotlib.pyplot as plt
import os,sys
from scipy import ndimage
from PIL import Image
from sklearn.cluster import KMeans
from sklearn import linear_model
from sklearn import svm
from sklearn import preprocessing as prp

In [None]:
from helpers_img import *
from Post_processing import *

In [None]:
# Helper functions
def load_image(infilename):
    data = mpimg.imread(infilename)
    return data

def img_float_to_uint8(img):
    rimg = img - np.min(img)
    rimg = (rimg / np.max(rimg) * 255).round().astype(np.uint8)
    return rimg

# Concatenate an image and its groundtruth
def concatenate_images(img, gt_img):
    nChannels = len(gt_img.shape)
    w = gt_img.shape[0]
    h = gt_img.shape[1]
    if nChannels == 3:
        cimg = np.concatenate((img, gt_img), axis=1)
    else:
        gt_img_3c = np.zeros((w, h, 3), dtype=np.uint8)
        gt_img8 = img_float_to_uint8(gt_img)          
        gt_img_3c[:,:,0] = gt_img8
        gt_img_3c[:,:,1] = gt_img8
        gt_img_3c[:,:,2] = gt_img8
        img8 = img_float_to_uint8(img)
        cimg = np.concatenate((img8, gt_img_3c), axis=1)
    return cimg

def img_crop(im, w, h):
    list_patches = []
    imgwidth = im.shape[0]
    imgheight = im.shape[1]
    is_2d = len(im.shape) < 3
    for i in range(0,imgheight,h):
        for j in range(0,imgwidth,w):
            if is_2d:
                im_patch = im[j:j+w, i:i+h]
            else:
                im_patch = im[j:j+w, i:i+h, :]
            list_patches.append(im_patch)
    return list_patches

def value_to_class(v):
    df = np.sum(v)
    if df > foreground_threshold:
        return 1
    else:
        return 0
    
# Convert array of labels to an image

def label_to_img(imgwidth, imgheight, w, h, labels):
    im = np.zeros([imgwidth, imgheight])
    idx = 0
    for i in range(0,imgheight,h):
        for j in range(0,imgwidth,w):
            im[j:j+w, i:i+h] = labels[idx]
            idx = idx + 1
    return im

def make_img_overlay(img, predicted_img):
    w = img.shape[0]
    h = img.shape[1]
    color_mask = np.zeros((w, h, 3), dtype=np.uint8)
    color_mask[:,:,0] = predicted_img*255

    img8 = img_float_to_uint8(img)
    background = Image.fromarray(img8, 'RGB').convert("RGBA")
    overlay = Image.fromarray(color_mask, 'RGB').convert("RGBA")
    new_img = Image.blend(background, overlay, 0.2)
    return new_img

In [None]:
def rotation(orig, gts):
    ks=[90,180,270]
    rotated=[ndimage.rotate(img,k) for img in orig for k in ks]
    gt_rotated=[ndimage.rotate(gt_img,k) for gt_img in gts for k in ks]
    orig=orig+rotated
    gts=gts+gt_rotated
    print(len(orig))
    print(len(gts))
    return orig,gts

def add_gray_dimension(img):
    out=np.dot(img[...,:3], [0.299, 0.587, 0.114])
    shape_one=[out.shape[0], out.shape[1], 1]
    out = np.reshape(out, shape_one)
    return out

def add_laplacian(img):
    laplbew=ndimage.filters.laplace(add_gray_dimension(img))
    lapl=ndimage.filters.laplace(img)
    return laplbew,lapl

def add_sobel(img):
    sx = ndimage.sobel(img, axis=0, mode='constant')
    sy = ndimage.sobel(img, axis=1, mode='constant')
    sob = np.hypot(sx, sy)
    return sob

def add_segment(im):
    n = 10
    l = 256
    im = ndimage.gaussian_filter(im, sigma=l/(4.*n))
    mask = (im > im.mean()).astype(np.float)
    mask += 0.1 * im
    img = mask + 0.2*np.random.randn(*mask.shape)
    hist, bin_edges = np.histogram(img, bins=60)
    bin_centers = 0.5*(bin_edges[:-1] + bin_edges[1:])
    binary_img = img > 0.5
    open_img = ndimage.binary_opening(binary_img)
    # Remove small black hole
    close_img = ndimage.binary_closing(open_img)
    close_img=add_gray_dimension(close_img)
    return close_img

In [None]:
def compute_F1(Y,Z):
    TN = 0
    FP = 0
    FN = 0
    TP = 0
    matrix = []
    for i in range(len(Y)):
        if (round(Y[i])==0) & (Z[i]==0):
            TN = TN + 1
        elif (round(Y[i])==1) & (Z[i]==0):
            FN = FN + 1   
        elif (round(Y[i])==1) & (Z[i]==1):
            TP = TP + 1  
        else:
            FP = FP + 1

    precision = TP/(TP+FP)
    recall = TP/(TP+FN)
    return 2*precision*recall / (precision+recall) 

In [None]:
def add_features(img):
    gray_img = add_gray_dimension(img)
    sob = add_sobel(img)
    lapbew,lap=add_laplacian(img)
    seg=add_segment(img)
    img = np.concatenate((img, gray_img), axis = 2)
    img = np.concatenate((img, sob), axis = 2)
    img = np.concatenate((img, lapbew), axis = 2)
    img = np.concatenate((img, lap), axis = 2)    
    img = np.concatenate((img, seg), axis = 2)
    return img

In [None]:
def extract_features(img):
    #gray_img = add_gray_dimension(img)
    #sob = add_sobel(img)
    #lapbew,lap=add_laplacian(img)
    #seg=add_segment(img)
    #img = np.concatenate((img, gray_img), axis = 2)
    #img = np.concatenate((img, sob), axis = 2)
    #img = np.concatenate((img, lapbew), axis = 2)
    #img = np.concatenate((img, lap), axis = 2)    
    #img = np.concatenate((img, seg), axis = 2)
    
    feat_m = np.mean(img, axis=(0,1))
    feat_v = np.var(img, axis=(0,1))
    feat = np.append(feat_m, feat_v)
    poly = prp.PolynomialFeatures(3)
    feat = poly.fit_transform(feat.reshape(1,-1))
    feat = feat.reshape(-1,)
    return feat

In [None]:
def post_processing(label,threshold,size_min,verbarg,horbarg):
    label = complete_lines(label,threshold)
    label = remove_isolated_connected_component(label,size_min)
    label = clean_garbage_vert(label,verbarg)
    label = clean_garbage_hor(label,horbarg)
    label = remove_isolated_connected_component(label,size_min)
    return label

In [None]:
# Obtain train set
# Loaded a set of images
root_dir = "training/"
image_dir = root_dir + "images/"
files = os.listdir(image_dir)
n = min(75, len(files)) # Load maximum 20 images
imgs = [load_image(image_dir + files[i]) for i in range(n)]
gt_dir = root_dir + "groundtruth/"
gt_imgs = [load_image(gt_dir + files[i]) for i in range(n)]


imgs,gt_imgs = rotation(imgs,gt_imgs)
imgs_augm=[add_features(imgs[i]) for i in range(len(imgs))]

In [None]:
foreground_threshold = 0.25 # percentage of pixels > 1 required to assign a foreground label to a patch
patch_size = 16 # each patch is 16*16 pixels


img_patches = [img_crop(imgs_augm[i], patch_size, patch_size) for i in range(len(imgs_augm))]
gt_patches = [img_crop(gt_imgs[i], patch_size, patch_size) for i in range(len(gt_imgs))]
img_patches = np.asarray([img_patches[i][j] 
                          for i in range(len(img_patches)) 
                          for j in range(len(img_patches[i]))])
gt_patches =  np.asarray([gt_patches[i][j] 
                          for i in range(len(gt_patches)) 
                          for j in range(len(gt_patches[i]))])

In [None]:
X = np.asarray([ extract_features(img_patches[i]) for i in range(len(img_patches))])
Y = np.asarray([value_to_class(np.mean(gt_patches[i])) for i in range(len(gt_patches))])

In [None]:
logreg = linear_model.LogisticRegression(C=1e5, class_weight="balanced")
logreg.fit(X, Y)
Z=logreg.predict(X)
print('F1_score = ' + str(compute_F1(Y, Z)))

In [None]:
# Obtain test set
root = "training/"
image_dir = root + "images/"
files = os.listdir(image_dir)
imgs_te = [load_image(image_dir + files[i]) for i in np.arange(n+1,len(files))]
gt_dir = root_dir + "groundtruth/"
gt_imgs_te = [load_image(gt_dir + files[i]) for i in np.arange(n+1,len(files))]


imgs_te_aug=[add_features(imgs_te[i]) for i in range(len(imgs_te))]

In [None]:
img_patches_te = [img_crop(imgs_te_aug[i], patch_size, patch_size) for i in range(len(imgs_te_aug))]
gt_patches_te = [img_crop(gt_imgs_te[i], patch_size, patch_size) for i in range(len(gt_imgs_te))]
img_patches_te = np.asarray([img_patches_te[i][j] 
                             for i in range(len(img_patches_te)) 
                             for j in range(len(img_patches_te[i]))])
gt_patches_te =  np.asarray([gt_patches_te[i][j] 
                             for i in range(len(gt_patches_te)) 
                             for j in range(len(gt_patches_te[i]))])

In [None]:
X_te = np.asarray([extract_features(img_patches_te[i]) for i in range(len(img_patches_te))])
Y_te = np.asarray([value_to_class(np.mean(gt_patches_te[i])) for i in range(len(gt_patches_te))])

Z_te = logreg.predict(X_te)
print('F1_score = ' + str(compute_F1(Y_te, Z_te)))

In [None]:
Z_pp=[]
for i in range(len(gt_patches_te)):
    Z_pp = Z_pp + post_processing(Z_te[i*625:(i+1)*625],18,9,3,3)

print('F1_score = ' + str(compute_F1(Y_te, Z_pp)))

In [None]:
img_idx=7
w = gt_imgs_te[0].shape[0]
h = gt_imgs_te[0].shape[1]
predicted_im = label_to_img(w, h, patch_size, patch_size, Z_te[img_idx*625:(img_idx+1)*625])
#print(imgs_te[img_idx].shape)
cimg = concatenate_images(imgs_te[img_idx], predicted_im)
fig1 = plt.figure(figsize=(10, 10)) # create a figure with the default size 
plt.imshow(cimg, cmap='Greys_r')

new_img = make_img_overlay(imgs_te[img_idx], predicted_im)

plt.imshow(new_img)

In [None]:
print('F1_score = ' + str(compute_F1(Y_te[img_idx*625:(img_idx+1)*625], Z_te[img_idx*625:(img_idx+1)*625])))

In [None]:
result_logistic = Z_te[img_idx*625:(img_idx+1)*625]

## END

In [None]:
def complete_lines(label,threshold):
    ''' The function controls for each non-road square its neighbors. 
        If they are classified as ROAD with a certain pattern, the considered square is labeled as ROAD.
        
        INPUT: List of patches, Vector of label (SAME ORDER)
        OUTPUT: New patches, New Vector of label'''
    
    # Create a matrix of label
    label = np.array(label)
    label_per_line = int(np.sqrt(label.shape))
    matrix_label = label.reshape((label_per_line, label_per_line),order='F')
    
    # Column with less then 4 zeros are considered as ROAD
    #threshold = 16
    matrix_label[:,np.where(matrix_label.sum(axis=0)>=threshold)[0]] = 1
  
    
    # Rows with less then 4 zeros are considered as ROAD
    #threshold = 16
    matrix_label[np.where(matrix_label.sum(axis=1)>=threshold)[0],:] = 1
  
    # Create the list
    list_label = (matrix_label.T).tolist()
    # Flatten the lists
    label = [y for x in list_label for y in x]
    return label

In [None]:
new_label = complete_lines(result_logistic,18)

# DISPLAY THE IMAGE
mask_res = label_to_img(400, 400, 16, 16, new_label)
image_plot = make_img_overlay(imgs_te[img_idx], mask_res)
plt.figure(figsize=(10, 10))
plt.imshow(image_plot)

In [None]:
def remove_isolated_connected_component(label,size_min):
    
    # Create a matrix of label
    label = np.array(label)
    label_per_line = int(np.sqrt(label.shape))
    matrix_label = label.reshape((label_per_line, label_per_line),order='F')
    
    # now identify the objects and remove those above a threshold
    Zlabeled,Nlabels = ndimage.measurements.label(matrix_label)
    label_size = [(Zlabeled == label).sum() for label in range(Nlabels + 1)]
    
    # now remove the labels
    for label,size in enumerate(label_size):
        if size < size_min:
            matrix_label[Zlabeled == label] = 0
    
    # Create the list
    list_label = (matrix_label.T).tolist()
    # Flatten the lists
    label = [y for x in list_label for y in x]
    
    return label


In [None]:
new_label2 = remove_isolated_connected_component(new_label,9)

# DISPLAY THE IMAGE
mask_res = label_to_img(400, 400, 16, 16, new_label2)
image_plot = make_img_overlay(imgs_te[img_idx], mask_res)
plt.figure(figsize=(10, 10))
plt.imshow(image_plot)

In [None]:
def complete_lines_almostfull(label):
    ''' The function controls for each non-road square its neighbors. 
        If they are classified as ROAD with a certain pattern, the considered square is labeled as ROAD.
        
        INPUT: List of labels
        OUTPUT: New list of labels'''
    
    max_zeros = 3
    
    # Create a matrix of label
    label = np.array(label)
    label_per_line = int(np.sqrt(label.shape))
    matrix_label = label.reshape((label_per_line, label_per_line),order='F')
    
    # Fix columns
    rows,columns = matrix_label.shape
    for column in range(columns):
        count = 0
        start = 0
        end = 0
        for row in range(rows):
            if (matrix_label[row,column] == 1) and (start ==0):
                start = 1
            elif (matrix_label[row,column] == 1) and (start ==1) and (count>0):
                end = 1
            elif (matrix_label[row,column] == 0) and (start ==1) and (end==0):
                count = count + 1
            
            if end ==1:
                if count < max_zeros:
                    matrix_label[row-count:row,column] = 1
                start = 1
                end = 0
                count = 0
    
    # Fix rows
    for row in range(rows):
        count = 0
        start = 0
        end = 0
        for column in range(columns):
            if (matrix_label[row,column] == 1) and (start ==0):
                start = 1
            elif (matrix_label[row,column] == 1) and (start ==1) and (count>0):
                end = 1
            elif (matrix_label[row,column] == 0) and (start ==1) and (end==0):
                count = count + 1
            
            if end ==1:
                if count < max_zeros:
                    matrix_label[row,column-count:column] = 1
                start = 1
                end = 0
                count = 0
    
    
    # Create the list
    list_label = (matrix_label.T).tolist()
    # Flatten the lists
    label = [y for x in list_label for y in x]
    
    return label

In [None]:
#new_label3 = complete_lines_almostfull(new_label2)

# DISPLAY THE IMAGE
#mask_res = label_to_img(400, 400, 16, 16, new_label3)
#image_plot = make_img_overlay(image, mask_res)
#plt.figure(figsize=(10, 10))
#plt.imshow(image_plot)

In [None]:
def clean_garbage_vert(label,max_distance):
    
    # Create a matrix of label
    label = np.array(label)
    label_per_line = int(np.sqrt(label.shape))
    matrix_label = label.reshape((label_per_line, label_per_line),order='F')
    
    # Column with all one values
    full_columns = np.where(matrix_label.sum(axis=0) == 25)[0]
    for column in full_columns:   
        if (column < max_distance) and (matrix_label[:,column+1].sum(axis=0) < 25):
            count = matrix_label[:,column+1:column+max_distance+1].sum(axis=1)
            for k in range(count.shape[0]):
                if count[k] < max_distance:
                    matrix_label[k,column+1:column+max_distance] = 0
        
        elif (column > 25 - max_distance) and (matrix_label[:,column-1].sum(axis=0) < 25):
            count = matrix_label[:,column-max_distance:column].sum(axis=1)
            for k in range(count.shape[0]):
                if count[k] < max_distance:
                    matrix_label[k,column-max_distance:column] = 0
        
        elif (column >= max_distance) and (column <= 25 - max_distance):
            if matrix_label[:,column+1].sum(axis=0) < 25:
                count = matrix_label[:,column+1:column+max_distance+1].sum(axis=1)
                for k in range(count.shape[0]):
                    if count[k] < max_distance:
                        matrix_label[k,column+1:column+max_distance] = 0
        
            if matrix_label[:,column-1].sum(axis=0) < 25:            
                count = matrix_label[:,column-max_distance:column].sum(axis=1)
                for k in range(count.shape[0]):
                    if count[k] < max_distance:
                        matrix_label[k,column-max_distance:column] = 0
        
    # Rows with less then 4 zeros are considered as ROAD
    #threshold = 16
    #matrix_label[np.where(matrix_label.sum(axis=1)>=threshold)[0],:] = 1
  
    # Create the list
    list_label = (matrix_label.T).tolist()
    # Flatten the lists
    label = [y for x in list_label for y in x]
    return label    

In [None]:
new_label4 = clean_garbage_vert(new_label2,3)

# DISPLAY THE IMAGE
mask_res = label_to_img(400, 400, 16, 16, new_label4)
image_plot = make_img_overlay(imgs_te[img_idx], mask_res)
plt.figure(figsize=(10, 10))
plt.imshow(image_plot)

In [None]:
def clean_garbage_hor(label,max_distance):
    
    # Create a matrix of label
    label = np.array(label)
    label_per_line = int(np.sqrt(label.shape))
    matrix_label = label.reshape((label_per_line, label_per_line),order='F')
    
    # Column with all one values
    full_rows = np.where(matrix_label.sum(axis=1) == 25)[0]
    for row in full_rows:   
        if (row < max_distance) and (matrix_label[row+1,:].sum() < 25):
            count = matrix_label[row+1:row+max_distance+1,:].sum(axis=0)
            for k in range(count.shape[0]):
                if count[k] < max_distance:
                    matrix_label[row+1:row+max_distance,k] = 0
        
        elif (row > 25 - max_distance) and (matrix_label[row-1,:].sum() < 25):
            count = matrix_label[row-max_distance:row,:].sum(axis=0)
            for k in range(count.shape[0]):
                if count[k] < max_distance:
                    matrix_label[row-max_distance:row,k] = 0
        
        elif (row >= max_distance) and (row <= 25 - max_distance):
            if matrix_label[row+1,:].sum() < 25:
                count = matrix_label[row+1:row+max_distance+1,:].sum(axis=0)
                for k in range(count.shape[0]):
                    if count[k] < max_distance:
                        matrix_label[row+1:row+max_distance,k] = 0
        
            if matrix_label[row-1,:].sum() < 25:            
                count = matrix_label[row-max_distance:row,:].sum(axis=0)
                for k in range(count.shape[0]):
                    if count[k] < max_distance:
                        matrix_label[row-max_distance:row,k] = 0
        
    # Rows with less then 4 zeros are considered as ROAD
    #threshold = 16
    #matrix_label[np.where(matrix_label.sum(axis=1)>=threshold)[0],:] = 1
  
    # Create the list
    list_label = (matrix_label.T).tolist()
    # Flatten the lists
    label = [y for x in list_label for y in x]
    return label 

In [None]:
new_label5 = clean_garbage_hor(new_label4,3)

# DISPLAY THE IMAGE
mask_res = label_to_img(400, 400, 16, 16, new_label5)
image_plot = make_img_overlay(imgs_te[img_idx], mask_res)
plt.figure(figsize=(10, 10))
plt.imshow(image_plot)

In [None]:
new_label6 = remove_isolated_connected_component(new_label5,9)

# DISPLAY THE IMAGE
mask_res = label_to_img(400, 400, 16, 16, new_label6)
image_plot = make_img_overlay(imgs_te[img_idx], mask_res)
plt.figure(figsize=(10, 10))
plt.imshow(image_plot)

In [None]:
def post_processing(label,threshold,size_min,verbarg,horbarg):
    label = complete_lines(label,threshold)
    label = remove_isolated_connected_component(label,size_min)
    label = clean_garbage_vert(label,verbarg)
    label = clean_garbage_hor(label,horbarg)
    label = remove_isolated_connected_component(label,size_min)
    return label

In [None]:
def calcul_F1(mask, prediction):   
    '''compute the F1 error'''
    TN = 0
    FP = 0
    FN = 0
    TP = 0
    

    for i in range(mask.shape[0]):
        for j in range(mask.shape[1]):
            if (round(mask[i,j])==0) & (prediction[i,j]==0):
                TN = TN + 1
            elif (round(mask[i,j])==1) & (prediction[i,j]==0):
                FN = FN + 1   
            elif (round(mask[i,j])==1) & (prediction[i,j]==1):
                TP = TP + 1  
            else:
                FP = FP + 1
    
    F1_score = 0
    if (TP+FP > 0) and (TP+FN > 0) and (TP>0):
        precision = TP/(TP+FP)
        recall = TP/(TP+FN)
        if (precision+recall)>0:
            F1_score = 2*precision*recall / (precision+recall)
    #else:
        #print('Something goes wrong...')
    return F1_score

In [None]:
# Create mask
prevision = label_to_img(400, 400, 16, 16, new_label5)
F1=calcul_F1(gt_imgs_te[img_idx],prevision)

In [None]:
print(F1)

In [None]:
def create_model(img,gt,n_cluster,patch_size):
    
    # Extract patches from input images
    img_patches = img_crop(img, patch_size, patch_size)
    gt_patches = img_crop(gt, patch_size, patch_size)
    
    # Compute features for each image patch
    foreground_threshold = 0.25 # percentage of pixels > 1 required to assign a foreground label to a patch

    # Create X and Y
    Y = np.asarray([value_to_class(np.mean(gt_patches[i])) for i in range(len(gt_patches))])
    X = np.asarray([extract_features(img_patches[i]) for i in range(len(img_patches))])
    
    model = KMeans(n_clusters=n_cluster, random_state=2, init = 'k-means++', n_init = 20).fit(X)
    
    labels = model.labels_
    clusters_total = np.zeros(n_cluster)
    tot = np.zeros(n_cluster)
    for i in range(len(labels)):
        clusters_total[labels[i]] = clusters_total[labels[i]] + Y[i] 
        tot[labels[i]] = tot[labels[i]] + 1
        
    clusters_label = 1*(np.divide(clusters_total,tot)>=0.3)
        
    return model,clusters_label

In [None]:
def assign_label(patch,models):
    X = np.asarray(extract_features(patch)).reshape(1,-1)
    list_label=[]
    for model in models:
        cluster_chosen = model[0].predict(X)
        list_label.append(model[1][cluster_chosen])
    
    label = 1*(np.mean(list_label)>=0.3)
    
    return label

In [None]:
def calculate_accuracy(img,gt,models):
    
    # Extract patches from input images
    img_patches = img_crop(img, patch_size, patch_size)
    gt_patches = img_crop(gt, patch_size, patch_size)

    # Create X and Y
    Y = np.asarray([value_to_class(np.mean(gt_patches[i])) for i in range(len(gt_patches))])
    label = ([assign_label(img_patches[i],models) for i in range(len(img_patches))])
    
    # Calculate F1 score
    F1 = compute_F1(Y, label)
    return F1,label 

In [None]:
# CALCULATE THE MODELS
k=10
#patch_size = 20 # each patch is 16*16 pixels
models = []
for img,gt in zip(imgs,gt_imgs):
    models.append(create_model(img,gt,k,patch_size))

In [None]:
# TEST 1 IMAGE
#img_idx = 1
img_test = imgs_te[img_idx]
gt_test = gt_imgs_te[img_idx]
F1_score,result_k = calculate_accuracy(img_test,gt_test,models)

In [None]:
# PRINT THE RESULT
print('F1 score = ' + str(F1_score))

In [None]:
# DISPLAY THE IMAGE
mask_res = label_to_img(400, 400, 16, 16, label)
image_plot = make_img_overlay(img_test, mask_res)
plt.figure(figsize=(10, 10))
plt.imshow(image_plot)

In [None]:
new_result = complete_lines(result_k,17)

# DISPLAY THE IMAGE
mask_res = label_to_img(400, 400, 16, 16, new_result)
image_plot = make_img_overlay(imgs_te[img_idx], mask_res)
plt.figure(figsize=(10, 10))
plt.imshow(image_plot)

In [None]:
new_result2 = remove_isolated_connected_component(new_result,9)

# DISPLAY THE IMAGE
mask_res = label_to_img(400, 400, 16, 16, new_result2)
image_plot = make_img_overlay(imgs_te[img_idx], mask_res)
plt.figure(figsize=(10, 10))
plt.imshow(image_plot)

In [None]:
new_result4 = clean_garbage_vert(new_result2,2)

# DISPLAY THE IMAGE
mask_res = label_to_img(400, 400, 16, 16, new_result4)
image_plot = make_img_overlay(imgs_te[img_idx], mask_res)
plt.figure(figsize=(10, 10))
plt.imshow(image_plot)

In [None]:
new_result5 = clean_garbage_hor(new_result4,2)

# DISPLAY THE IMAGE
mask_res = label_to_img(400, 400, 16, 16, new_result5)
image_plot = make_img_overlay(imgs_te[img_idx], mask_res)
plt.figure(figsize=(10, 10))
plt.imshow(image_plot)

In [None]:
new_result6 = remove_isolated_connected_component(new_result5,4)

# DISPLAY THE IMAGE
mask_res = label_to_img(400, 400, 16, 16, new_result6)
image_plot = make_img_overlay(imgs_te[img_idx], mask_res)
plt.figure(figsize=(10, 10))
plt.imshow(image_plot)