**For General Tips: Check out the examples in the Matplotlib Gallery under 'Event Handling' and 'Widgets'**

# Matplotlib Layers
* Renderer: Internal, handles final 'on screen' and 'on disk' versions
* Canvas: Manages `Figure` & `Renderer`
* Transforms: Manages each of the four coordinate systems of the {'data', 'axes', 'figure', 'screen'}
* Artists: It's a middle layer, given a `Renderer`, it can draw itself
* Axes & Figures Methods: Creates `Artists` and adds to draw-tree
* pyplot: Creates `Artists` and adds to draw-tree
---
* UI Events

### Axes & Figures Methods
* Manages the draw tree
* Manages figure size, dpi, axis scales, etc.
* Provides a namespace for plotting functions (like ax.plot())

### Artists
* Everything in the figure is an `artist`, because it draws on the canvas. (Text, images, etc.)
* Responsive for translating internal state >> `Renderer` method calls
* Can be mutated and re-drawn with set commands.

### Canvas
* Holds `Figure` instance
* Knows how to make a *Renderer* instance at the proper size & DPI
* For GUI backends typically uses multiple inheritance

# UI Events on Canvas
![Canvas_UI](jupyter/IMG_5933.JPG)

----
# Demo Setup
I created a conda environment. Activate it via `source activate mpl-tutorial`. The slides say to run examples, I need to do `ipython --matplotlib=qt5`. I'll activate this Jupyter Notebook by sourcing mpl-tutorial first, and then launching the notebook.

----
# Demo: 00-explore.py
This demo shows how we can add events to give our terminal information when we click on stuff.

If you run this in your terminal, when you click on the produced plot, your terminal will tell you the values on where you clicked.

`source activate mpl-tutorial && python 00-explore.py`

In [8]:
# Adds interactive plotting to matplotlib
%matplotlib notebook

import matplotlib.pyplot as plt
import numpy as np

last_ev = None


def event_printer(event):
    """Helper function for exploring events.

    Prints all public attributes +
    """
    # capture the last event
    global last_ev
    last_ev = event
    for k, v in sorted(vars(event).items()):
        print(f'{k}: {v!r}')  # Python 3.6 Feature: It's a format with local vars
    print('-'*25)


th = np.linspace(0, 2*np.pi, 64)
fig, ax = plt.subplots()
# the `picker=5` kwarg turn on pick-events for this artist
ax.plot(th, np.sin(th), 'o-', picker=5)

# Shows location when button is clicked
cid = fig.canvas.mpl_connect('button_press_event', event_printer)
plt.show()
# fig.canvas.mpl_disconnect(cid)

# Shows location when button is released
cid = fig.canvas.mpl_connect('button_press_event', event_printer)

# Shows location of static cursor, when mouse wheel scrolls
cid = fig.canvas.mpl_connect('scroll_event', event_printer)

# Shows location of static cursor, when key is pressed or released
cid = fig.canvas.mpl_connect('key_press_event', event_printer)
cid = fig.canvas.mpl_connect('key_release_event', event_printer)

# Shows what Matplotlib object was selected when clicked
cid = fig.canvas.mpl_connect('pick_event', event_printer)

# Something about drawing data as a 'delta' of changes
# ln.figure.canvas.draw_idle()


# EXERCISE (10 - 15 minutes)
#
# play around with events interactively
#
#   - Try all 'active' events
#
#     ['button_press_event', 'button_release_event', 'scroll_event',
#      'key_press_event', 'key_release_event', 'pick_event']
#   - tweak the print line
#   - remove a callback
#   - add more than one callback to the canvas



<IPython.core.display.Javascript object>

----
# Demo: 01-callable.py
This demo shows how we can change stuff when we tap on keys. It doesn't show too much over 00-explore.py. 

If you choose to run it, run it as: 
`source activate mpl-tutorial && ipython --matplotlib=qt5`
`run 01-callable.py`

----
# Demo: 02-event_filter.py
This creates a plot where I drop the point when I click and Matplotlib connects the dots. Some Widji Lessons

* If I update a plot, I always need to redraw it by running `self.ln.figure.canvas.draw_idle()`. If I update the data, I need to also re-run `self.ln.set_data(self.xdata, self.ydata)`.
* event.key always returns a string, even if I press a numerical key.

In [12]:
import matplotlib.pyplot as plt
from itertools import cycle


class LineMaker:
    def __init__(self, ln):
        # stash the current data
        self.xdata = list(ln.get_xdata())
        self.ydata = list(ln.get_ydata())
        # stash the Line2D artist
        self.ln = ln
        self.color_cyle = cycle(['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728',
                                 '#9467bd', '#8c564b', '#e377c2', '#7f7f7f',
                                 '#bcbd22', '#17becf'])
        self.button_cid = ln.figure.canvas.mpl_connect('button_press_event',
                                                       self.on_button)
        self.key_cid = ln.figure.canvas.mpl_connect('key_press_event',
                                                    self.on_key)

    def on_button(self, event):
        """ This creates a point each time I click on the plot """
        # only consider events from the lines Axes
        if event.inaxes is not self.ln.axes:
            return

        # if not the left mouse button or a modifier key
        # is held down, bail
        if event.button != 1 or event.key is not None:
            print('key+button: {!r}+{!r}'.format(event.key, event.button))
            return

        # get the event location in data-space
        self.xdata.append(event.xdata)
        self.ydata.append(event.ydata)

        # update the artist data
        self.ln.set_data(self.xdata, self.ydata)

        # ask the GUI to re-draw the next time it can
        self.ln.figure.canvas.draw_idle()

    def on_key(self, event):
        """ This changes the plot's color if I type 'c'
            And I modified it to delete a point when I tap 'shift'"""
        # This is _super_ useful for debugging!
        # print(event.key)

        # if the key is c (any case)
        if event.key.lower() == 'c':
            # change the color
            self.ln.set_color(next(self.color_cyle))

            # ask the GUI to re-draw the next time it can
            self.ln.figure.canvas.draw_idle()
        
        # if the key is shift, delete last point
        if event.key == 'shift':
            self.xdata = self.xdata[:-1]
            self.ydata = self.ydata[:-1]
            self.ln.set_data(self.xdata, self.ydata)
            # ask the GUI to re-draw the next time it can
            self.ln.figure.canvas.draw_idle()
    
        # if numerical key, change width
        if event.key in ['0', '1', '2', '3', '4', '5', '6' '7', '8' '9']:
            self.ln.set_linewidth(int(event.key))
            self.ln.figure.canvas.draw_idle()

        # if escape key, remove line
        if event.key == 'escape':
            self.ln.set_linewidth(0)
            self.ln.figure.canvas.draw_idle()
            
fig, ax = plt.subplots()
ln, = ax.plot([], [], '-o')
line_maker = LineMaker(ln)
plt.show()

# EXERCISE (15 minutes)

# - modify to remove the closest point when key == 'shift'
# - change the line width for [1-9]
# - clear the line when event.key == 'escape'

<IPython.core.display.Javascript object>

---- 
# Demo: 03-temperature.py
We will be working with Central Park data and the code below quickly shows what it looks like

In [16]:
import sys

sys.path.append('03-temperature')

from w_helpers import load_data

t = load_data('central_park')
t.plot(y='T')

<IPython.core.display.Javascript object>

<matplotlib.axes._subplots.AxesSubplot at 0x117212978>

### 01-picking.py
This gives us a line plot with more aggregated values so that when we click a point, it gives us details from that day.

In [18]:
import matplotlib.pyplot as plt
from w_helpers import load_data, aggregate_by_day, extract_day_of_hourly, label_date

import uuid

datasource = 'central_park'

temperature = load_data(datasource)
temperature = temperature[temperature['year'] >= 2017]
temperature_daily = aggregate_by_day(temperature)


class RowPrinter:
    def __init__(self, ln, df, picker=10):
        ln.set_picker(picker)
        # we can use this to ID our line!
        self.uid = str(uuid.uuid4())
        ln.set_gid(self.uid)
        self.ln = ln  # Line object
        self.df = df  # Data frame object
        self.cid = None
        self.connect()

    def connect(self):
        self.remove()
        self.cid = ln.figure.canvas.mpl_connect('pick_event',
                                                self)

    def __call__(self, event):
        # ignore picks on not-our-artist
        if event.artist is not self.ln:
            return
        # for each hit index, print out the row
        for i in event.ind:
            print(self.df.iloc[i])

    def remove(self):
        if self.cid is not None:
            self.ln.figure.canvas.mpl_disconnect(self.cid)
            self.cid = None


fig, ax = plt.subplots(2, 1)
ln, = ax[0].plot('mean', '-o', data=temperature_daily)  # Will plot DF temperature_daily['mean']
ax[0].set_xlabel('Date [UTC]')
ax[0].set_ylabel('Air Temperature [℃]')
ax[0].set_title(f'{datasource} temperature')

rp = RowPrinter(ln, temperature_daily)

# Plots the day's values below
one_day = extract_day_of_hourly(temperature, 2017, 10, 27)
ln, = ax[1].plot('mean', '-o', data=one_day)

plt.show()

# EXERCISE
# - make the print out nicer looking

# - open a new window with plot of day temperature
#   - fig, ax = plt.subplots()
#   - one_day = extract_day_of_hourly(temperature, 2015, 10, 18)
# - make picking add a label with `label_date`

# - use `get_gid` to filter artists instead of `is not`


<IPython.core.display.Javascript object>

### 03-interactive-temperature.py
Here, I have three interactive plots which lets me drill down temperatures from years, to days, to hours.

In [19]:
import datetime as dt
import matplotlib.pyplot as plt
from cycler import cycler
from w_helpers import (load_data, aggregate_by_month, aggregate_by_day,
                       extract_day_of_hourly, extract_month_of_daily)


def setup_temperature_figure(**kwargs):
    """ Builds a 3 panel figure """
    fig, ax_lst = plt.subplots(3, 1, sharey=True, **kwargs)
    for ax in ax_lst:
        ax.set_ylabel('T [℃]')
        ax.grid(True)
    for ax, x_lab in zip(ax_lst, ['Date', 'days from start of month',
                                  'hours from midnight UTC']):
        ax.set_xlabel(x_lab)
    ax_lst[1].set_xlim(-1, 32)
    ax_lst[2].set_xlim(-1, 25)
    fig.tight_layout()
    return fig, ax_lst


def plot_aggregated_errorbar(ax, gb, label, picker=None, **kwargs):
    """ Creates an errorbar plot """
    kwargs.setdefault('capsize', 3)
    kwargs.setdefault('markersize', 5)
    eb = ax.errorbar(gb.index, 'mean',
                     yerr='std',
                     data=gb,
                     label=label,
                     picker=picker,
                     **kwargs)
    fill = ax.fill_between(gb.index, 'min', 'max', alpha=.5,
                           data=gb, color=eb[0].get_color())
    ax.legend()
    ax.figure.canvas.draw_idle()
    return eb, fill


class AggregatedTimeTrace:
    def __init__(self, hourly_data, label, yearly_ax, monthly_ax, daily_ax,
                 agg_by_day=None, agg_by_month=None, style_cycle=None):
        '''Class to manage 3-levels of aggregated temperature

        Parameters
        ----------
        hourly_data : DataFrame
            Tempreture measured hourly

        label : str
            The name of this data set_a

        yearly_ax : Axes
            The axes to plot 'year' scale data (aggregated by month) to

        monthly_ax : Axes
            The axes to plot 'month' scale data (aggregated by day) to

        daily_ax : Axes
            The axes to plot 'day' scale data (un-aggregated hourly) to

        agg_by_day : DataFrame, optional

            Data already aggregated by day.  This is just to save
            computation, will be computed if not provided.

        agg_by_month : DataFrame, optional

            Data already aggregated by month.  This is just to save
            computation, will be computed if not provided.

        style_cycle : Cycler, optional
            Style to use for plotting

        '''
        # data
        self.data_by_hour = hourly_data

        if agg_by_day is None:
            agg_by_day = aggregate_by_day(hourly_data)
        self.data_by_day = agg_by_day

        if agg_by_month is None:
            agg_by_month = aggregate_by_month(hourly_data)
        self.data_by_month = agg_by_month

        # style
        if style_cycle is None:
            style_cycle = ((cycler('marker', ['o', 's', '^', '*',
                                              'x', 'v', '8', 'D',
                                              'H', '<']) +
                            cycler('color',
                                   ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728',
                                    '#9467bd', '#8c564b', '#e377c2', '#7f7f7f',
                                    '#bcbd22', '#17becf'])))
        self.style_cycle = style_cycle()
        # axes
        self.yearly_ax = yearly_ax
        self.monthly_ax = monthly_ax
        self.daily_ax = daily_ax
        # name
        self.label = label
        # these will be used for book keeping
        self.daily_artists = {}
        self.daily_index = {}
        self.hourly_artiists = {}
        # artists
        self.yearly_art = plot_aggregated_errorbar(self.yearly_ax,
                                                   self.data_by_month,
                                                   self.label,
                                                   picker=5,
                                                   **next(self.style_cycle))

        # pick methods
        self.y_cid = self.yearly_ax.figure.canvas.mpl_connect(
            'pick_event', self._yearly_on_pick)
        self.y_cid = self.yearly_ax.figure.canvas.mpl_connect(
            'pick_event', self._monthly_on_pick)
        self.y_cid = self.yearly_ax.figure.canvas.mpl_connect(
            'pick_event', self._daily_on_pick)


    def _yearly_on_pick(self, event):
        ''' Process picks on 'year' scale axes '''
        # if not the right axes, bail
        if event.mouseevent.inaxes is not self.yearly_ax:
            return
        # make sure the artists we expect exists and we picked it
        if self.yearly_art is None or event.artist is not self.yearly_art[0][0]:
            return
        # loop over the points we hit and plot the 'month' scale data
        for i in event.ind:
            row = self.data_by_month.iloc[i]
            self._plot_T_by_day(int(row['year']), int(row['month']))


    def _plot_T_by_day(self, year, month):
        ''' Plots the monthly data '''
        # get the data we need
        df = extract_month_of_daily(self.data_by_day, year, month)
        # format the label
        label = '{:s}: {:04d}-{:02d}'.format(self.label, year, month)
        # if we have already plotted this, don't bother
        if label in self.daily_artists:
            return
        # plot the data
        eb, fill = plot_aggregated_errorbar(self.monthly_ax, df, label,
                                            picker=5, **next(self.style_cycle))
        # set the gid of the line (which is what will be picked) to label
        eb[0].set_gid(label)
        # stash the artists so we can remove them later
        self.daily_artists[label] = [eb, fill]
        # stash the dates associated with the points so we can use in
        # plotting later
        self.daily_index[label] = df['index']


    def _monthly_on_pick(self, event):
        ''' Process picks on 'month' scale axes '''
        # if we are not in the right axes, bail
        if event.mouseevent.inaxes is not self.monthly_ax:
            return
        # get the label from the picked aritst
        label = event.artist.get_gid()
        # if the shift key is held down, remove this data
        if event.mouseevent.key == 'shift':
            self.daily_index.pop(label, None)
            arts = self.daily_artists.pop(label, [])
            for art in arts:
                # work around a bug!
                if art in self.monthly_ax.containers:
                    self.monthly_ax.containers.remove(art)
                art.remove()
            # regenerate the legend
            self.monthly_ax.legend()
            # ask the GUI to redraw when convenient
            self.monthly_ax.figure.canvas.draw_idle()
            return
        # else, loop through the points we hit and plot the daily
        for i in event.ind:
            sel_date = self.daily_index[label][i]
            self._plot_T_by_hour(sel_date.year, sel_date.month, sel_date.day)


    def _plot_T_by_hour(self, year, month, day):
        ''' Plots the daily plot '''
        # get the hourly data for a single day
        df = extract_day_of_hourly(self.data_by_hour, year, month, day)
        # format the label
        label = '{:s}: {:04d}-{:02d}-{:02d}'.format(self.label, year, month, day)
        # A 'simple' plot
        self.daily_ax.plot('T', '-', picker=10, label=label, data=df,
                           **next(self.style_cycle))

        # Advances a day
        now = '{:04d}-{:02d}-{:02d}'.format(year, month, day)
        next_day = dt.datetime.strptime(now, '%Y-%m-%d') + dt.timedelta(days=1)
        df = extract_day_of_hourly(self.data_by_hour, next_day.year, next_day.month,
                                   next_day.day)
        label = next_day.strftime('{}: %Y-%m-%d'.format(self.label))
        self.daily_ax.plot('T', '-', picker=10, label=label, data=df,
                           **next(self.style_cycle))
 
        # Moves back a day
        now = '{:04d}-{:02d}-{:02d}'.format(year, month, day)
        next_day = dt.datetime.strptime(now, '%Y-%m-%d') - dt.timedelta(days=1)
        df = extract_day_of_hourly(self.data_by_hour, next_day.year, next_day.month,
                                   next_day.day)
        label = next_day.strftime('{}: %Y-%m-%d'.format(self.label))
        self.daily_ax.plot('T', '-', picker=10, label=label, data=df,
                           **next(self.style_cycle))
      
        # update the legend
        self.daily_ax.legend()
        # ask the GUI to redraw the next time it can
        self.daily_ax.figure.canvas.draw_idle()


    def _daily_on_pick(self, event):
        ''' Process picks on the daily plot to remove lines '''
        if event.mouseevent.inaxes is not self.daily_ax:
            return
        # grab the canvas
        canvas = event.artist.figure.canvas
        # remove the artist
        event.artist.remove()
        # update the legend
        self.daily_ax.legend()
        # redraw the canvas next time it is convenient
        canvas.draw_idle()


    def remove(self):
        for art in self.yearly_art:
            art.remove()
        self.yearly_art = None
        self.yearly_ax.figure.canvas.mpl_disconnect(self.cid)


temperature = load_data('central_park')
fig, (ax_by_month, ax_by_day, ax_by_hour) = setup_temperature_figure()
temperature_at = AggregatedTimeTrace(temperature, 'temperature',
                                     ax_by_month, ax_by_day, ax_by_hour)
fig.suptitle('Temperature')
plt.show()

# EXERCISE (15 minutes)
# - plot 3 day windows centered on picked day
# - cycle through min/max, std bands, and no bands on key stroke

<IPython.core.display.Javascript object>

----
# Demo: 05-spectral & 06-xrf
Domain specific with a lot of advanced logic.

In [23]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.widgets import SpanSelector
from matplotlib.colors import LogNorm
import numpy as np
import h5py

import sys
sys.path.append('05-spectral')


def plot_all_chan_spectrum(spectrum, bins, *, ax=None, **kwargs):

    def integrate_to_angles(spectrum, bins, lo, hi):
        """ Given an energy range, this sums the proper values """
        lo_ind, hi_ind = bins.searchsorted([lo, hi])
        return spectrum[lo_ind:hi_ind].sum(axis=0)

    if ax is None:
        fig, ax = plt.subplots(figsize=(13.5, 9.5))
    else:
        fig = ax.figure

    # Sets the axes on the top plot and lower right plot
    div = make_axes_locatable(ax)
    ax_r = div.append_axes('right', 2, pad=0.1, sharey=ax)
    ax_t = div.append_axes('top', 2, pad=0.1, sharex=ax)

    ax_r.yaxis.tick_right()
    ax_r.yaxis.set_label_position("right")
    ax_t.xaxis.tick_top()
    ax_t.xaxis.set_label_position("top")

    im = ax.imshow(spectrum, origin='lower', aspect='auto',
                   extent=(-.5, 383.5,
                           bins[0], bins[-1]),
                   norm=LogNorm())

    e_line, = ax_r.plot(spectrum.sum(axis=1), bins[:-1] + np.diff(bins))
    p_line, = ax_t.plot(spectrum.sum(axis=0))
    label = ax_t.annotate('[0, 70] kEv', (0, 1), (10, -10),
                          xycoords='axes fraction',
                          textcoords='offset pixels',
                          va='top', ha='left')

    def update(lo, hi):
        """ Extracts data, and plots it"""
        p_data = integrate_to_angles(spectrum, bins, lo, hi)
        p_line.set_ydata(p_data)
        ax_t.relim()
        ax_t.autoscale(axis='y')

        label.set_text(f'[{lo:.1f}, {hi:.1f}] keV')
        fig.canvas.draw_idle()

    # Implements the right box
    span = SpanSelector(ax_r, update, 'vertical', useblit=True,
                        rectprops={'alpha': .5, 'facecolor': 'red'},
                        span_stays=True)

    ax.set_xlabel('channel [#]')
    ax.set_ylabel('E [keV]')

    ax_t.set_xlabel('channel [#]')
    ax_t.set_ylabel('total counts')

    ax_r.set_ylabel('E [keV]')
    ax_r.set_xlabel('total counts')
    ax.set_xlim(-.5, 383.5)
    ax.set_ylim(bins[0], bins[-1])
    ax_r.set_xlim(xmin=0)

    return spectrum, bins, {'center': {'ax': ax, 'im': im},
                            'top': {'ax': ax_t, 'p_line': p_line},
                            'right': {'ax': ax_r, 'e_line': e_line,
                                      'span': span}}


with h5py.File('05-spectral/germ.h5', 'r') as fin:
    spectrum = fin['spectrum'][:]
    bins = fin['bins'][:]

ret = plot_all_chan_spectrum(spectrum, bins)
plt.show()


# Exercise (15 minutes)
# - add span selector to top axes to change curve in right axes


<IPython.core.display.Javascript object>

In [25]:
# https://drive.google.com/open?id=0B5vxvuZBEEfTRGdXZ2NXUjNKUUk

import h5py
import matplotlib.gridspec as gridspec
import matplotlib.widgets as mwidgets
from matplotlib import path
import numpy as np

import sys
sys.path.append('06-xrf')

# uncomment this to set the backend
# import matplotlib
# matplotlib.use('Qt4Agg')
import matplotlib.pyplot as plt


class XRFInteract(object):
    def __init__(self, counts, positions, fig=None, pos_order=None,
                 norm=None):

        if pos_order is None:
            pos_order = {'x': 0,
                         'y': 1}
        # extract x/y data
        self.x_pos = xpos = positions[pos_order['x']]
        self.y_pos = ypos = positions[pos_order['y']]
        self.points = np.transpose((xpos.ravel(), ypos.ravel()))
        # sort ouf the normalization
        if norm is None:
            norm = np.ones_like(self.x_pos)

        norm = np.atleast_3d(norm[:])
        self.counts = counts[:] / norm

        # compute values we will use for extents below
        dx = np.diff(xpos.mean(axis=0)).mean()
        dy = np.diff(ypos.mean(axis=1)).mean()
        left = xpos[:, 0].mean() - dx/2
        right = xpos[:, -1].mean() + dx/2
        top = ypos[0].mean() - dy/2
        bot = ypos[-1].mean() + dy/2

        # create a figure if we must
        if fig is None:
            import matplotlib.pyplot as plt
            fig = plt.figure(tight_layout=True)
        # clear the figure
        fig.clf()
        # set the window title (look at the tool bar)
        fig.canvas.set_window_title('XRF map')
        self.fig = fig
        # set up the figure layout
        gs = gridspec.GridSpec(2, 1, height_ratios=[4, 1])

        # set up the top panel (the map)
        self.ax_im = fig.add_subplot(gs[0, 0], gid='imgmap')
        self.ax_im.set_xlabel('x [?]')
        self.ax_im.set_ylabel('y [?]')
        self.ax_im.set_title(
            'shift-click to select pixel, '
            'alt-drag to draw region, '
            'right-click to reset')

        # set up the lower axes (the average spectrum of the ROI)
        self.ax_spec = fig.add_subplot(gs[1, 0], gid='spectrum')
        self.ax_spec.set_ylabel('counts [?]')
        self.ax_spec.set_xlabel('bin number')
        self.ax_spec.set_yscale('log')
        self.ax_spec.set_title('click-and-drag to select energy region')
        self._EROI_txt = self.ax_spec.annotate('ROI: all',
                                               xy=(0, 1),
                                               xytext=(0, 5),
                                               xycoords='axes fraction',
                                               textcoords='offset points')
        self._pixel_txt = self.ax_spec.annotate('map average',
                                                xy=(1, 1),
                                                xytext=(0, 5),
                                                xycoords='axes fraction',
                                                textcoords='offset points',
                                                ha='right')

        # show the initial image
        self.im = self.ax_im.imshow(self.counts[:, :, :].sum(axis=2),
                                    cmap='viridis',
                                    interpolation='nearest',
                                    extent=[left, right, bot, top]
                                    )
        # and colorbar
        self.cb = self.fig.colorbar(self.im, ax=self.ax_im)

        # and the ROI mask (overlay in red)
        self.mask = np.ones(self.x_pos.shape, dtype='bool')
        self.mask_im = self.ax_im.imshow(self._overlay_image,
                                         interpolation='nearest',
                                         extent=[left, right, bot, top],
                                         zorder=self.im.get_zorder())
        self.mask_im.mouseover = False  # do not consider for mouseover text

        # set up the spectrum, to start average everything
        self.spec, = self.ax_spec.plot(
            self.counts.mean(axis=(0, 1)),
            lw=2)

        # set up the selector widget for the specturm
        self.selector = mwidgets.SpanSelector(self.ax_spec,
                                              self._on_span,
                                              'horizontal',
                                              useblit=True, minspan=2,
                                              span_stays=True)
        # placeholder for the lasso selector
        self.lasso = None
        # hook up the mouse events for the XRF map
        self.cid = self.fig.canvas.mpl_connect('button_press_event',
                                               self._on_click)

    @property
    def _overlay_image(self):
        ret = np.zeros(self.mask.shape + (4,), dtype='uint8')
        if np.all(self.mask):
            return ret
        ret[:, :, 0] = 255
        ret[:, :, 3] = 100 * self.mask.astype('uint8')
        return ret

    def _on_click(self, event):
        # not in the right axes, bail
        ax = event.inaxes
        if ax is None or ax.get_gid() != 'imgmap':
            return
        # if right click, clear ROI
        if event.button == 3:
            return self._reset_spectrum()

        # if alt, start lasso
        if event.key == 'alt':
            return self._lasso_on_press(event)
        # if shift, select a pixel
        if event.key == 'shift':
            return self._pixel_select(event)

    def _reset_spectrum(self):
        self.mask = np.ones(self.x_pos.shape, dtype='bool')
        self.mask_im.set_data(self._overlay_image)
        new_y_data = self.counts.mean(axis=(0, 1))
        self.spec.set_ydata(new_y_data)
        self._pixel_txt.set_text('map average')
        self.ax_spec.relim()
        self.ax_spec.autoscale(True, axis='y')
        self.fig.canvas.draw_idle()

    def _pixel_select(self, event):

        x, y = event.xdata, event.ydata
        # get index by assuming even spacing
        # TODO use kdtree?
        diff = np.hypot((self.x_pos - x), (self.y_pos - y))
        y_ind, x_ind = np.unravel_index(np.argmin(diff), diff.shape)

        # get the spectrum for this point
        new_y_data = self.counts[y_ind, x_ind, :]
        self.mask = np.zeros(self.x_pos.shape, dtype='bool')
        self.mask[y_ind, x_ind] = True
        self.mask_im.set_data(self._overlay_image)
        self._pixel_txt.set_text(
            'pixel: [{:d}, {:d}] ({:.3g}, {:.3g})'.format(
                y_ind, x_ind,
                self.x_pos[y_ind, x_ind],
                self.y_pos[y_ind, x_ind]))

        self.spec.set_ydata(new_y_data)
        self.ax_spec.relim()
        self.ax_spec.autoscale(True, axis='y')
        self.fig.canvas.draw_idle()

    def _on_span(self, vmin, vmax):
        vmin, vmax = map(int, (vmin, vmax))
        new_image = self.counts[:, :, vmin:vmax].sum(axis=2)
        new_max = new_image.max()
        self._EROI_txt.set_text('ROI: {}:{}'.format(vmin, vmax))
        self.im.set_data(new_image)
        self.im.set_clim(0, new_max)
        self.fig.canvas.draw_idle()

    def _lasso_on_press(self, event):
        self.lasso = mwidgets.Lasso(event.inaxes, (event.xdata, event.ydata),
                                    self._lasso_call_back)

    def _lasso_call_back(self, verts):
        p = path.Path(verts)

        new_mask = p.contains_points(self.points).reshape(*self.x_pos.shape)
        self.mask = new_mask
        self.mask_im.set_data(self._overlay_image)
        new_y_data = self.counts[new_mask].mean(axis=0)
        self._pixel_txt.set_text('lasso mask')
        self.spec.set_ydata(new_y_data)
        self.ax_spec.relim()
        self.ax_spec.autoscale(True, axis='y')
        self.fig.canvas.draw_idle()


# def make_text_demo(inp='BNL', n_chan=1000):
#     '''Make some synthetic data
#     '''
#     from matplotlib.figure import Figure
#     from matplotlib.backends.backend_agg import FigureCanvas
#     fig = Figure()
#     canvas = FigureCanvas(fig)
#     canvas.draw()
#     im_shape = fig.canvas.get_width_height()[::-1] + (3,)
#     t = fig.text(.5, .5, '', fontsize=350, ha='center', va='center')
#     counts = np.random.rand(*(im_shape[:2] + (n_chan,)))
#     x = np.linspace(0, 1, n_chan)
#     for j, l in enumerate(inp):
#         t.set_text(l)
#         fig.canvas.draw()
#         im = np.fromstring(fig.canvas.tostring_rgb(),
#                            dtype=np.uint8).reshape(im_shape)
#         im = 255 - np.mean(im, axis=2, keepdims=True)
#         counts += (150 * im * np.exp(-500 * ((1+j)/(len(inp) + 1) - x)**2)
#                    .reshape(1, 1, -1))
#         del im
#
#     return counts
#
#
# counts = make_text_demo()
# N, M = counts.shape[:2]
# X, Y = np.meshgrid(range(M), range(N))
# pos = np.stack([.01*X + 100, .01*Y + 50])
#
# xrf = XRFInteract(counts, pos)

# to look at a data file
fn = '06-xrf/scan_3624.h5'
F = h5py.File(fn, 'r')
g = F['xrfmap']

xrf = XRFInteract(g['detsum']['counts'][:], g['positions']['pos'][:],
                  norm=g['scalers']['val'][:, :, 0])

# un comment out this line to use 'interacitve' mode
# plt.ion()
plt.show()


<IPython.core.display.Javascript object>