In [67]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
%matplotlib widget

In [68]:
class Grain(object):
    """ stuff """

    def __init__(self, xy=[]):
        # Data
        self.centroid = 0
        self.id = 0
        self.xy = xy

        # Display
        self.ax = None
        self.assigned_color = None
        self.patch = None
        self.selected = False
    
    def make_patch(self, ax):
        self.ax = ax
        self.patch, = ax.fill(*self.xy, picker=True)
        self.assigned_color = self.patch.get_facecolor()
        return self.patch

    def select(self, selected=True):
        color = 'red' if selected else self.assigned_color
        self.patch.set_color(color)
        self.selected = selected


class GrainPlot(object):
    """ stuff """

    HIGHLIGHT_PROPS = {
            'alpha': 0.7,
            'animated': True
        }
    NULL_XY = np.ndarray((0,2))

    def __init__(self, grains):
        # Plot
        self.fig, self.ax = plt.subplots()
        self.canvas = self.fig.canvas
        self.cids = []
        
        # Highlighters
        highlights = []
        for color in ['red', 'yellow']:
            highlights += self.ax.fill([], [], color=color, **self.HIGHLIGHT_PROPS)
        self.highlights = highlights

        # Grains
        self.grains = grains
        for grain in grains:
            grain.make_patch(self.ax)                
        
        # Init
        self.fig.tight_layout()
        plt.pause(0.1)
        self.bg = self.canvas.copy_from_bbox(self.fig.bbox)


    def highlight_grain(self, new_xy):
        # Clear highlights
        canvas = self.canvas
        canvas.restore_region(self.bg)

        # Update highlights
        xys = [h.get_xy() for h in self.highlights]
        if np.array_equal(new_xy, xys[0]):
            xys[0] = xys[1]
            xys[1] = self.NULL_XY
        elif np.array_equal(new_xy, xys[1]):
            xys[1] = self.NULL_XY
        else:
            xys[1] = xys[0]
            xys[0] = new_xy
        for highlight, xy in zip(self.highlights, xys):
            patch = highlight #.patch
            patch.set_xy(xy)
            self.ax.draw_artist(patch)
        
        # Update GUI
        canvas.blit(self.fig.bbox)
        canvas.flush_events()

    # Manage grains ---

    def add_grain(self, event):
        # put prompt
        # find grain around prompt
        # add grain to list
        # add grain to plot
        pass

    def delete_grain(self, target):
        # remove highlight
        highlight_grain(event.canvas, target.patch.get_xy())
        # delete grain from plot
        # target.patch.set_color('black')
        # delete grain from list
        # grains.remove(target.grain)
        # update GUI

    def merge_grains(self, grains: list):
        # reset highlights
        pass
        # create joined grain
        pass
        # add grain to list
        pass

    # Events ---

    def onclick(self, event):
        # Place a green dot
        # Remember the selected location in case we will add a grain here
        pass

    def onpick(self, event):
        """
        Handle clicking on an existing grain
        """
        # Only process individual left-clicks
        if event.mouseevent.dblclick is True or event.mouseevent.button != 1:
            return
        # Update highlights
        self.highlight_grain(event.artist.get_xy())
    
    def onpress(self, event):
        # merge, delete, or add a grain as appropriate
        if event.key == 'c':
            mouseevent = event.mouseevent
            self.create_grain(mouseevent.xdata, mouseevent.ydata)
        elif event.key == 'd':
            self.delete_grain(self.highlights[0])
        elif event.key == 'm':
            self.merge_grains(self.highlights)

    def activate(self):
        events = {
            # 'button_press_event': self.onclick,
            'pick_event': self.onpick,
            'key_press_event': self.onpress
        }
        for event, handler in events.items():
            self.cids.append(self.canvas.mpl_connect(event, handler))

    def deactivate(self):
        for cid in self.cids:
            self.canvas.mpl_disconnect(cid)
        self.cids = []


In [None]:
# Make grains
grains = []
centroids = [(0, 0), (0.5, 0.5), (1, 1)]
for c in centroids:
    x, y = c
    x = [x-0.5, x+0.5, x+0.5, x-0.5]
    y = [y+0.5, y+0.5, y-0.5, y-0.5]
    grains.append(Grain([x, y]))

# Plot grains
grain_plot = GrainPlot(grains)
grain_plot.activate()