In [59]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import shapely.ops as so
import shapely.geometry as sg
%matplotlib widget

# TODO: Merge grains
# TODO: Create grains
# TODO: Polish

In [60]:
class Grain(object):
    """ stuff """

    PROPS = {
        # 'alpha': 0.5,
    }
    HIGHLIGHT_PROPS = {
        # 'alpha': 0.7,
        'color': 'red'
    }

    def __init__(self, xy=[]):
        # Data
        self.centroid = 0
        self.id = 0
        self.xy = xy

        # Display
        self.ax = None
        self.assigned_color = None
        self.patch = None
        self.selected = False
    
    def make_patch(self, ax):
        self.ax = ax
        (self.patch,) = ax.fill(*self.xy, picker=True)
        self.assigned_color = self.patch.get_facecolor()
        return self.patch

    def select(self):
        self.selected = ~self.selected
        if self.selected:
            # props = self.HIGHLIGHT_PROPS
            color = 'red'
        else:
            # props = self.PROPS
            # props['color'] = self.assigned_color
            color = self.assigned_color
        self.patch.set_color(color)
        # self.ax.draw_artist(self.patch)

        return self.selected


class GrainPlot(object):
    """ stuff """

    HIGHLIGHT_PROPS = {
            'alpha': 0.7,
            'animated': True
        }
    NULL_XY = np.ndarray((0,2))

    def __init__(self, grains, img=None, label=None):
        self.fig = plt.figure(figsize=(6, 4))
        self.ax = self.fig.add_subplot(aspect='equal')
        self.canvas = self.fig.canvas

        self.cursor = patches.Circle((0, 0), radius=0.05, color='lime', visible=False)
        self.ax.add_patch(self.cursor)

        self.cids = []
        self.selected_grains = []

        self.grains = grains
        for grain in grains:
            grain.make_patch(self.ax)

    def unselect_grains(self):
        for grain in self.selected_grains:
            grain.select()
        self.selected_grains = []

    def unselect_all(self):
        self.set_cursor(False)
        self.unselect_grains()

    def set_cursor(self, xy):
        if xy:
            self.cursor.set_center(xy)
            self.cursor.set_visible(True)
        else:
            self.cursor.set_center((-1, -1))
            self.cursor.set_visible(False)

    # Manage grains ---
    def create_grain(self, event):
        """ Attempt to find and add grain at selected point """
        # put prompt
        # find grain around prompt
        # add grain to list
        # add grain to plot
        pass

    def delete_grains(self):
        """ Delete all selected grains """
        for grain in self.selected_grains:
            # Hide grain from plot
            grain.patch.remove()
            # grain.patch.set_visible(False)
            # Delete grain from list
            self.grains.remove(grain)
        self.selected_grains = []

    def merge_grains(self):
        """ Merge all selected grains """
        # Create merged grain using Shapely
        polys = [sg.Polygon(g.patch.get_path().vertices) for g in self.selected_grains]
        new_grain = Grain(so.unary_union(polys).exterior.xy)
        new_grain.make_patch(self.ax)
        self.grains.append(new_grain)
        # Clear old grains
        self.delete_grains()

    # Events ---
    def onclick(self, event):
        """ Handle clicking anywhere on plot """
        # Only process individual left-clicks when no grains selected
        if event.dblclick is True or event.button != 1 or len(self.selected_grains) > 0:
            return
        # Show cursor at selected point
        self.set_cursor((event.xdata, event.ydata))

    def onpick(self, event):
        """ Handle clicking on an existing grain """
        # Only process individual left-clicks
        if event.mouseevent.dblclick is True or event.mouseevent.button != 1:
            return
        # Hide cursor
        self.set_cursor(False)
        # Add selected grain to list
        for grain in self.grains:
            if event.artist is grain.patch:
                if grain.select():
                    self.selected_grains.append(grain)
                else:
                    self.selected_grains.remove(grain)
                break
    
    def onpress(self, event):
        """ Handle key presses """
        if event.key == 'c':
            self.create_grain()
        elif event.key == 'delete':
            self.delete_grains()
        elif event.key == 'm':
            self.merge_grains()
        elif event.key == 'escape':
            self.unselect_all()

    def activate(self):
        events = {
            'button_press_event': self.onclick,
            'pick_event': self.onpick,
            'key_press_event': self.onpress
        }
        for event, handler in events.items():
            self.cids.append(self.canvas.mpl_connect(event, handler))

    def deactivate(self):
        for cid in self.cids:
            self.canvas.mpl_disconnect(cid)
        self.cids = []


In [None]:
# Make grains
grains = []
centroids = [(0, 0), (0.5, 0.5), (1, 1)]
for c in centroids:
    x, y = c
    x = [x-0.5, x+0.5, x+0.5, x-0.5]
    y = [y+0.5, y+0.5, y-0.5, y-0.5]
    grains.append(Grain([x, y]))

# Plot grains
grain_plot = GrainPlot(grains)
grain_plot.activate()