In [None]:
import io
import sys

import time
import traceback
import functools

import PIL.Image
import PIL.ImageDraw
import numpy as np
import skimage.draw
import requests

from ipywidgets import widgets
from matplotlib import pyplot as plt

# scikit-image==0.16.2
console = widgets.HTML()
%matplotlib tk

In [None]:
def load_image(url):
    image_bytes = None
    if url.startswith('http://') or url.startswith('https://'):
        image_bytes = requests.get(url).content
    else:
        with open(url, 'rb') as f:
            image_bytes = f.read()
    return PIL.Image.open(io.BytesIO(image_bytes)).convert('RGB')


def log_msg(msg):
    console.value = f"{msg}<hr>" + console.value

    
def log_errors(f):
    @functools.wraps(f)
    def wrapper(*args,**kw):
        try:
            return f(*args,**kw)
        except Exception:
            log_msg(sys.exc_info()[1:])
            raise
    return wrapper


In [None]:
class MplMaskEditor(object):
    
    def __init__(self, img, seg_mask=None, brush_size=5):        
        # display brush size selector as widget
        self.brush_size_selector = widgets.IntSlider(value=brush_size, min=1, max=100)
        display(self.brush_size_selector)
        
        fig, ax = plt.subplots(1,1, figsize=(8,8))
        self.ax = ax
        self.fig = fig
        self.img = np.array(img, dtype=np.uint8)
        self.shape = [np.int32(_) for _ in self.img.shape]
        
        self.mask_draw = np.zeros(np.product(self.shape)).reshape(self.shape).astype(np.uint8)
        self.draw_val = None
        
        # bind methods to canvas
        self.fig.canvas.mpl_connect('button_press_event', self._on_click)
        self.fig.canvas.mpl_connect('button_release_event', self._on_release)
        self.fig.canvas.mpl_connect('motion_notify_event', self._on_motion)

        self.draw_timer = self.fig.canvas.new_timer(interval=50)
        self.draw_timer.add_callback(self._draw)
        self.draw_timer.start()
        
        self.draw_request = True
        self._draw()
        
    
    @property
    def brush_size(self):
        return self.brush_size_selector.value
    
    def draw(self):
        self.draw_request = True
        
    @log_errors
    def _draw(self):
        if self.draw_request:
            self.draw_request = False            
            self.ax.clear()
            self.ax.axis('off')
            self.ax.imshow(self.img)
            self.ax.imshow(self.mask_draw, alpha=0.5)
            plt.draw()

    def draw_mask_point(self, x, y, val):
        if (x is not None) and (y is not None) and (val is not None):
            for xi, yi in zip(*skimage.draw.circle(x, y, self.brush_size)):
                if xi < self.img.shape[1] and yi < self.img.shape[0]:
                    self.mask_draw[yi, xi] = np.uint8(val )
    
    @property
    def draw_value(self):
        return 255
    
    @property
    def clear_value(self):
        return 0
    
    @log_errors
    def _on_click(self, event):
        if event.inaxes != self.ax:
            return
        
        x, y = event.xdata, event.ydata
        
        # left click
        if event.button == 1:
            self.draw_val = self.draw_value

        # right click
        elif event.button == 3:
            self.draw_val = self.clear_value

        if self.draw_val is not None:
            self.draw_mask_point(x, y, self.draw_val)
            self.draw()
            
            
    @log_errors
    def _on_release(self, event):
        self.draw_val = None
    
    @log_errors
    def _on_motion(self, event):
        if event.inaxes != self.ax:
            return
        x, y = event.xdata, event.ydata  
        if self.draw_val is not None:
            self.draw_mask_point(x, y, self.draw_val)
            self.draw()
            
    def save_mask_as_png(self, path=None, invert=False, custom_max=255):
        draw_mask = self.mask_draw.copy()
        
        if invert:
            h = draw_mask == self.draw_value
            l = draw_mask == self.clear_value
            
            draw_mask[h] = self.clear_value
            draw_mask[l] = self.draw_value
        
        # replace max value (255) with custom value
        draw_mask[draw_mask == self.draw_value] = custom_max
        
        if path is None:
            path = f"{time.time()}.png"
        img = PIL.Image.fromarray(draw_mask)
        img.save(path, mode="L")
    


In [None]:
img = load_image("/home/mike/Downloads/Isolated_white_t-shirt_front.png")
mpl_editor = MplMaskEditor(img, brush_size=30)

In [None]:
mpl_editor.save_mask_as_png(path="04.png", invert=False)