In [1]:
import os, sys
module_path = os.path.abspath(os.path.join('../'))
if module_path not in sys.path:
    sys.path.append(module_path)

%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from skimage import io
from skimage.color import rgb2grey
from skimage.filters import sobel
from skimage import filters

def imread_convert(f):
    return rgb2grey(io.imread(f))
out = widgets.Output()
from skimage.segmentation import flood
from lib.labelling import *
import glob

In [20]:
from IPython.display import clear_output
from matplotlib.widgets import LassoSelector
from matplotlib.path import Path

VALID_IMAGE_TYPES = ['jpeg', 'png', 'bmp', 'gif', 'jpg'] # same as supported by keras



class image_segmenter:
    def __init__(self, img_dir, classes, overlay_alpha=.5,figsize=(10,10)):
        """
        TODO allow for intializing with a shape instead of an image
        
        parameters
        ----------
        classes : Int or list
            Number of classes or a list of class names
        ensure_rgba : boolean
            whether to force the displayed image to have an alpha channel to enable transparent overlay
        """
#         if not os.path.exists(img_dir):
#             raise ValueError(f"{img_dir} is not a valid file path")
        
#         self.img_dir = img_dir
#         self.img_files = 
        self.img_dir = img_dir
        if not os.path.isdir(self.img_dir):
            raise ValueError(f"{img_dir} is not a folder")
        #ensure that there is a sibling directory named masks
        self.mask_dir = os.path.abspath(img_dir).rsplit('/', 1)[0] + '/masks'
        if not os.path.exists(self.mask_dir):
            os.mkdir(self.mask_dir)
        elif not os.path.isdir(self.mask_dir):
            raise ValueError(f'{self.mask_dir} already exists and is not a folder')

        self.image_paths = []
#         self.mask_paths = []
        for type_ in VALID_IMAGE_TYPES:
            self.image_paths += (glob.glob(self.img_dir.rstrip('/')+f'/*.{type_}'))
#             self.mask_paths += glob.glob(self.mask_dir+f'/*.{type_}')
        self.shape = None        
        
        plt.ioff() # see https://github.com/matplotlib/matplotlib/issues/17013
        self.fig = plt.figure(figsize=figsize)
        self.ax = self.fig.gca()
        lineprops = {'color': 'black', 'linewidth': 1, 'alpha': 0.8}
        self.lasso = LassoSelector(self.ax, self.onselect,lineprops=lineprops, button=1)
        self.lasso.set_visible(True)
        self.fig.canvas.mpl_connect('button_press_event', self.onclick)
        self.fig.canvas.mpl_connect('button_release_event', self._release)
        self.panhandler = panhandler(self.fig)
        
        self.new_image(0)

        # setup lasso stuff
        

        plt.ion()
        
        if isinstance(classes, int):
            classes = np.arange(classes)
        if len(classes)<=10:
            self.colors = 'tab10'
        elif len(classes)<=20:
            self.colors = 'tab20'
        else:
            raise ValueError(f'Currently only up to 20 classes are supported, you tried to use {len(classes)} classes')
        
        self.colors = plt.get_cmap(self.colors)(np.arange(len(classes)))[:,:3]
        
        self.class_dropdown = widgets.Dropdown(
                options=[(str(classes[i]), i) for i in range(len(classes))],
                value=0,
                description='Class:',
                disabled=False,
            )
        self.lasso_button = widgets.Button(
            description='lasso select',
            disabled=False,
            button_style='success', # 'success', 'info', 'warning', 'danger' or ''
            icon='mouse-pointer', # (FontAwesome names without the `fa-` prefix)
        )
        self.flood_button = widgets.Button(
            description='flood fill',
            disabled=False,
            button_style='', # 'success', 'info', 'warning', 'danger' or ''
            icon='fill-drip', # (FontAwesome names without the `fa-` prefix)
        )
        
        self.erase_check_box = widgets.Checkbox(
            value=False,
            description='Erase Mode',
            disabled=False,
            indent=False
        )
        
        self.reset_button = widgets.Button(
            description='reset',
            disabled=False,
            button_style='', # 'success', 'info', 'warning', 'danger' or ''
            icon='refresh', # (FontAwesome names without the `fa-` prefix)
        )
        self.reset_button.on_click(self.reset)
        def button_click(button):
            if button.description == 'flood fill':
                self.flood_button.button_style='success'
                self.lasso_button.button_style=''
                self.lasso.set_active(False)
            else:
                self.flood_button.button_style=''
                self.lasso_button.button_style='success'
                self.lasso.set_active(True)
        
        self.lasso_button.on_click(button_click)
        self.flood_button.on_click(button_click)
        self.overlay_alpha = overlay_alpha
    def new_image(self, img_idx):
        self.img = io.imread(self.image_paths[img_idx])
        self.img_idx = img_idx
        img_path = self.image_paths[self.img_idx]
        
        self.mask_path = self.mask_dir + f'/{os.path.basename(img_path)}'
        
        if self.img.shape != self.shape:
            self.shape = self.img.shape
            pix_x = np.arange(self.shape[0])
            pix_y = np.arange(self.shape[1])
            xv, yv = np.meshgrid(pix_y,pix_x)
            self.pix = np.vstack( (xv.flatten(), yv.flatten()) ).T
            self.displayed = self.ax.imshow(self.img)
            if os.path.exists(self.mask_path):
                self.class_mask = io.imread(self.mask_path)
            else:
                self.class_mask = -np.ones([self.shape[0],self.shape[1]],dtype=np.uint8)
        else:
            self.displayed.set_data(self.img)
            if os.path.exists(self.mask_path):
                self.class_mask = io.imread(self.mask_path)
                # should probs check that the first two dimensions are the same as the img
            else:
                self.class_mask[:,:] = -1

        #ensure that the _nav_stack is empty
        self.fig.canvas.toolbar._nav_stack.clear()
        #add the initial view to the stack so that the home button works.
        self.fig.canvas.toolbar.push_current()
        

    def _release(self, event):
        with out:
            self.panhandler.release(event)

    def reset(self,*args):
        self.displayed.set_data(self.img)
        self.class_mask[:,:] = -1
        self.fig.canvas.draw()

    def onclick(self, event):
        """
        handle clicking to remove already added stuff
        """
        if event.button == 1:
            if event.xdata is not None and not self.lasso.active:
                with out:
                    # transpose x and y bc imshow transposes
                    self.indices = flood(self.class_mask,(np.int(event.ydata), np.int(event.xdata)))
                    self.updateArray()
        elif event.button == 3:
            with out:
                self.panhandler.press(event)

    def updateArray(self):
        with out:
            array = self.displayed.get_array().data
            
            if self.erase_check_box.value:
                self.class_mask[self.indices] = -1
                array[self.indices] = self.img[self.indices]
            else:
                self.class_mask[self.indices] = self.class_dropdown.value
                # https://en.wikipedia.org/wiki/Alpha_compositing#Straight_versus_premultiplied           
                c_overlay = self.colors[self.class_dropdown.value]*255*self.overlay_alpha
                array[self.indices] = c_overlay + self.img[self.indices]*(1-self.overlay_alpha)
            self.displayed.set_data(array)
        self.ax.set_title(np.sum(array==1.1))
        
    def onselect(self,verts):
        self.verts = verts
        p = Path(verts)

        self.indices = p.contains_points(self.pix, radius=0).reshape(450,540)

        self.updateArray()
        self.fig.canvas.draw_idle()
        
    def render(self):
        layers = [widgets.HBox([self.lasso_button, self.flood_button])]
        layers.append(widgets.HBox([self.reset_button, self.class_dropdown,self.erase_check_box]))
        layers.append(self.fig.canvas)    
        return widgets.VBox(layers)
    def save_mask(self):
        io.imsave(self.mask_path, self.class_mask)
    def _ipython_display_(self):
        display(self.render())

out.clear_output()
plt.close('all')
tstimage = io.imread('test-image.jpg')
obj = image_segmenter('images', ['yeast','not yeast'])
# zoom_factory(obj.ax)
from sidecar import Sidecar
from ipywidgets import IntSlider
sc = Sidecar(title='Segmentation area')
sl = IntSlider(description='Some slider')
with sc:
    # force the _nav_stack to record the initial position so the home button works as expected
#     panhandler(obj.fig)

    zoom_factory(obj.ax)
    display(obj)
out

Output()

In [21]:
obj.save_mask()

  io.imsave(self.mask_path, self.class_mask)


In [5]:

img_dir = 'fold'
if not os.path.isdir(self.img_dir)
#ensure that there is a sibling directory named masks
mask_dir = os.path.abspath('fold/').rsplit('/', 1)[0] + '/masks'
if not os.path.exists(mask_dir):
    os.mkdir(mask_dir)
elif not os.path.isdir(mask_dir):
    raise ValueError(f'{mask_dir} exists and is not a folder')

images = []
masks = []
for type_ in VALID_IMAGE_TYPES:
    images += (glob.glob(img_dir.rstrip('/')+f'/*.{type_}'))
    masks += glob.glob(mask_dir+f'/*.{type_}')

In [59]:
masks

['/home/ian/Documents/AC295/AC295-final-project-JWI/notebooks/masks/test-image.jpg']

In [63]:
os.path.basename(masks[0])

'test-image.jpg'

In [33]:
os.path.basename('fold/fold2')

'fold2'

In [64]:
io.imsave('yikes.png',np.random.randn(100,100))



In [65]:
io.imread('yikes.png')

array([[181, 106, 121, ..., 118, 130, 102],
       [157, 154, 167, ..., 111, 154, 184],
       [108, 139,  75, ..., 224, 175, 143],
       ...,
       [104, 108,  78, ..., 125, 122, 138],
       [113, 155, 170, ..., 135, 184, 124],
       [164, 122,  79, ..., 124, 135, 140]], dtype=uint8)

In [31]:
io.imsave('yikes.png',obj.class_mask)#,quality=100)

  io.imsave('yikes.png',obj.class_mask)#,quality=100)


In [32]:
plt.imshow(io.imread('yikes.png'))

<matplotlib.image.AxesImage at 0x7fe45e0d0df0>

In [23]:
plt.imshow(io.imread('masks/test-image.jpg'))

<matplotlib.image.AxesImage at 0x7fe4743ad370>