In [1]:
import cv2
import numpy as np
from scipy import ndimage as ndi
from skimage.feature import peak_local_max
from skimage.segmentation import watershed
import matplotlib.pyplot as plt

from PIL import Image
import glob
import os

In [2]:
def watershed_segmentation(img, show_image = False):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    filtro = cv2.pyrMeanShiftFiltering(img, 20, 40)
    gray = cv2.cvtColor(filtro, cv2.COLOR_BGR2GRAY)
    _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)

    if show_image:
        contornos, _ = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        buracos = []
        for con in contornos:
            area = cv2.contourArea(con)
            if area < 1000:
                buracos.append(con)
        cv2.drawContours(thresh, buracos, -1, 255, -1)
        dist = ndi.distance_transform_edt(thresh)
        dist_visual = dist.copy()
        local_max = peak_local_max(dist, indices=False, min_distance=20, labels=thresh)

        markers = ndi.label(local_max, structure=np.ones((3, 3)))[0]

        labels = watershed(-dist, markers, mask=thresh)
        titulos = ['Original image', 'Binary Image', 'Distance Transform', 'Watershed']
        imagens = [img, thresh, dist_visual, labels]
        fig = plt.gcf()
        fig.set_size_inches(16, 12)  
        for i in range(4):
            plt.subplot(2,2,i+1)
            if (i == 3):
                cmap = "jet"
            else:
                cmap = "gray"
            plt.imshow(imagens[i], cmap)
            plt.title(titulos[i]) 
            plt.xticks([]),plt.yticks([])     
        plt.show()
        
    return thresh

In [3]:
def get_img(path_str):
    # reads at RGB
    return np.asarray(Image.open(path_str))

def read_true_mask(img_path_str):
    return cv2.cvtColor(cv2.imread(img_path_str),cv2.COLOR_RGB2GRAY)
    
def find_metrics(seg_mask_img, org_mask_img):
    f1, iou, pixacc = 0,0,0 
    intersection = np.sum( np.logical_and(seg_mask_img, org_mask_img) )
    union = np.sum ( np.logical_or(seg_mask_img, org_mask_img) )
    cnt_seg = np.count_nonzero(seg_mask_img)
    cnt_org = np.count_nonzero(org_mask_img)
    h,w = seg_mask_img.shape
    cnt_tot = h*w 
    cnt_true = intersection
    cnt_false = cnt_tot - union 

    f1 = 2 * intersection / (cnt_seg + cnt_org) 
    iou = intersection / union
    pixacc = (cnt_true + cnt_false) / cnt_tot

    return f1, iou, pixacc

def print_all(seg_mask_img, org_mask_img):
    f1, iou, pixacc = find_metrics(seg_mask_img, org_mask_img)
    print("f1 = ", f1, " iou = ", iou, " pixacc = ", pixacc)


In [4]:
img_path = './final_dataset/images'
mask_path = './final_dataset/masks'
out_path = './final_dataset/masks_watershed/'

imgs = glob.glob(os.path.join(img_path, "*.png"))
masks = glob.glob(os.path.join(mask_path, "*.png")) 

imgs = sorted(imgs)
masks = sorted(masks)

num_images = len(masks)

# generate masks and save to file

In [5]:

f1_scores, iou_scores, pixacc_scores = [], [], []


for i in range(0, num_images):
    img = get_img(imgs[i])
    seg_mask_img = watershed_segmentation(img, show_image=False)
    cv2.imwrite(out_path + 'img_' + str(i).zfill(4) + '.png',  seg_mask_img)

    org_mask_img = read_true_mask(masks[i])
    f1, iou, pixacc = find_metrics(seg_mask_img, org_mask_img)
    f1_scores.append(f1)
    iou_scores.append(iou)
    pixacc_scores.append(pixacc)

    if (i+1)%10 == 0 : 
        print("Processed image", str(i+1))


Processed image 10
Processed image 20
Processed image 30
Processed image 40
Processed image 50
Processed image 60
Processed image 70
Processed image 80
Processed image 90
Processed image 100
Processed image 110
Processed image 120
Processed image 130
Processed image 140
Processed image 150
Processed image 160
Processed image 170
Processed image 180
Processed image 190
Processed image 200
Processed image 210
Processed image 220
Processed image 230
Processed image 240
Processed image 250
Processed image 260
Processed image 270
Processed image 280
Processed image 290
Processed image 300
Processed image 310
Processed image 320
Processed image 330
Processed image 340
Processed image 350
Processed image 360
Processed image 370
Processed image 380
Processed image 390
Processed image 400
Processed image 410
Processed image 420
Processed image 430
Processed image 440
Processed image 450
Processed image 460
Processed image 470
Processed image 480
Processed image 490
Processed image 500
Processed

In [6]:
num_processed = len(f1_scores)
print("Number of processed images", num_processed)

Number of processed images 1000


In [7]:
sum_f1, sum_iou, sum_pixacc = 0,0,0 

for i in range(0, num_processed):
    sum_f1 += f1_scores[i]
    sum_iou += iou_scores[i]
    sum_pixacc += pixacc_scores[i]

print("Average F1 score = ", str(sum_f1 / num_processed))
print("Average IOU score = ", str(sum_iou / num_processed))
print("Average Pixacc score = ", str(sum_pixacc / num_processed))

Average F1 score =  0.3370689030117877
Average IOU score =  0.2408698741597864
Average Pixacc score =  0.7134477636718747
