In [1]:
from pathlib import Path
import tensorflow as tf

# get a list of jpg files in directory 'test_images'
img_folder = Path('test_images').rglob('*.jpg')
files = [x for x in img_folder]

# load saved model
model = tf.keras.models.load_model('models/unet_model.h5')

In [2]:
def create_mask(input_image):
    '''Produce predicted mask for input image.'''
    
    # get current image's size
    w, h, _ = input_image.shape
    
    # prepare image for model's prediction
    image_m = tf.image.resize(input_image, (128, 128))     # resize image
    image_m = tf.cast(image_m, tf.float32) / 255.0         # normalize image
    image_m = image_m[None,:,:] 
    
    # get predicted mask that has highest score
    pred_mask = tf.argmax(model.predict(image_m), axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    
    # resize image back to original size
    img_mask = tf.image.resize(pred_mask[0], (w, h))
    
    return img_mask

In [3]:
from skimage import segmentation
from skimage.future import graph

def _weight_mean_color(graph, src, dst, n):
  diff = graph.nodes[dst]['mean color'] - graph.nodes[n]['mean color']
  diff = np.linalg.norm(diff)
  return {'weight': diff}

def merge_mean_color(graph, src, dst):
  graph.nodes[dst]['total color'] += graph.nodes[src]['total color']
  graph.nodes[dst]['pixel count'] += graph.nodes[src]['pixel count']
  graph.nodes[dst]['mean color'] = (graph.nodes[dst]['total color'] /
  graph.nodes[dst]['pixel count'])

In [4]:
# RAG Method
def segment_image(img, compactness=60, thresh=80, n_segments=200):
    labels = segmentation.slic(img, compactness=compactness, n_segments=n_segments, start_label=1)
    g = graph.rag_mean_color(img, labels)
    labels2 = graph.merge_hierarchical(labels, g, thresh=thresh, rag_copy=False, 
                                       in_place_merge=True, merge_func=merge_mean_color, 
                                       weight_func=_weight_mean_color)
    
    return labels2

In [5]:
from skimage.io import imread
import cv2
import numpy as np

methods = ['Input Image', 'Unet Model', 'Edge-Based Method', 'Otsu Method', 
           'Threshold Method', 'Region Adjacency Graph']
images = []       # store a list of input image and the corresponding segmented images
names = []        # a list of original image names
indexes = []

# produce segmented images for all input images
for i, img in enumerate(files):
    temp = []  # store images for each pet
    names.append(str(img).split('\\')[-1].split('.')[0])   # get image's name
    
    # save input image to a list
    image = np.array(imread(img))
    temp.append(image)
    
    # get predicted mask for input image
    temp.append(create_mask(image))
    
    # image processing for Otsu & Edge-based methods
    image_otsu = cv2.imread(str(img), cv2.IMREAD_COLOR)
    image_otsu = cv2.cvtColor(image_otsu, cv2.COLOR_BGR2GRAY)
    scale_percent = 50
    width = int(image_otsu.shape[1] * scale_percent / 100)
    height = int(image_otsu.shape[0] * scale_percent / 100)
    dim = (width, height)
    image_otsu = cv2.GaussianBlur(image_otsu, (3, 3), 0)            
    image_otsu = cv2.resize(image_otsu, dim)
    
    # edge-based method
    temp.append(cv2.Canny(image_otsu,100,200))
    
    # Otsu method
    otsu_threshold, otsout = cv2.threshold(image_otsu, 0, 255, 
                                           cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    otsout = cv2.normalize(otsout, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX)
    temp.append(otsout)
    
    # Threshold method
    _, th = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY)
    temp.append(th)
    
    # RAG method
    temp.append(segment_image(image))
    
    # add input image and segmented images to the images list
    images.append(temp)
    indexes.append(i)

In [6]:
import matplotlib.pyplot as plt
from skimage import color

def show_images(menu_value, checkbox_value):
    # get image's index
    index = names.index(menu_value)
    
    # set figure's size
    _ = plt.figure(figsize=(12, 10))
    
    # show input images and all segmented images
    for i, title in enumerate(methods):
        # set image's position
        plt.subplot(2, 3, i+1)
        
        # plot image
        if checkbox_value:
            if title == 'Threshold Method':
                image = color.rgb2gray(images[index][i])
                plt.imshow(image, 'gray')
            else:
                plt.imshow(images[index][i], 'gray')
        else:
            plt.imshow(images[index][i])
        
        plt.axis('off')                 # turn off axis
        plt.title(title, fontsize=16, color='blue')   # set title
    
    # show plot
    plt.show()

In [7]:
def on_button_clicked(_):
    # "linking function with output"
    with out:
        # what happens when we press the button
        clear_output()
        if len(menu.value) > 0:
            show_images(menu.value, checkbox.value)

In [8]:
# some handy functions to use along widgets
from IPython.display import display, Markdown, clear_output
import ipywidgets as widgets

# build menu select box for selecting image
menu = widgets.Dropdown(options = [''] + names, values = indexes, description='Select image:')

checkbox = widgets.Checkbox(value=False, description='Grayscale', disabled=False,)

# build submit button
button = widgets.Button(description='Perform Segmenation')

# placeholder for displaying images 
out = widgets.Output()

# linking button and function together using a button's method
button.on_click(on_button_clicked)

In [9]:
# show image select box and submit button
widgets.VBox([menu, checkbox, button, out])

VBox(children=(Dropdown(description='Select image:', options=('', 'cat', 'dog1', 'dog2'), value=''), Checkbox(…