In [None]:
__author__ = 'kgeorge2@gmail.com'

# common notebook utils
### koshy george, kgeorge2@gmail.com

### In addition to [jupyter](http://jupyter.org/), you need to have [ipywidgets](https://github.com/ipython/ipywidgets) installed 

The contribution of this notebook are
1. define a custom ipywidget called ProgrEssImageWidget
2. define a handy Plotter class who can accepts samples in various channels and constuct a plot


### ProgressImageWidget

<code>ProgressImageWidget</code> will display any image assigned to its <code>value</code>, if the image is a [datauri](https://en.wikipedia.org/wiki/Data_URI_scheme) string. 

<code>
p=common.utils.ProgressImageWidget()
display(p)
</code>


..., lots of stuff


Now if if you assign the value element of <code>p</code> to some new image content, the widget displayed by the display  call above, will now have the new image content.
<code>
p.value = new_image_content_as_png_dataurl
</code>

In [None]:

import ipywidgets as widgets
from traitlets import Unicode, validate
from IPython import display


class ProgressImageWidget(widgets.DOMWidget):
    """
      ipywidget class to display incremental progress of training as an image
    """
    _view_name = Unicode('ProgressImageView').tag(sync=True)
    _view_module = Unicode('progress_image').tag(sync=True)
    value = Unicode().tag(sync=True)

In [None]:
%%javascript
require.undef('progress_image');

define('progress_image', ["jupyter-js-widgets"], function(widgets) {

    // Define the HelloView
    var ProgressImageView = widgets.DOMWidgetView.extend({
        // Render the view.
        render: function() {
            this.$img = $('<img />')
                .appendTo(this.$el);
        },
        
        update: function() {
            this.$img.attr('src', this.model.get('value'));
            return ProgressImageView.__super__.update.apply(this);
        },
        events: {"change": "handle_value_change"},
        
        handle_value_change: function(event) {
            this.model.set('value', this.$img.src);
            this.touch();
        },
        
    });

    return {
        ProgressImageView : ProgressImageView 
    }
});


### Plotter

Now for constructing the [datauri](https://en.wikipedia.org/wiki/Data_URI_scheme), which shows the progress graph, we employ another helper class called <code>Plotter</code> defined in <code>common/utils.ipynb</code>. We can construct a plotter with <code>xlabel, ylabel</code> and <code>title</code> parameters. An instance of a potter class, will return a png-datauri when the <code>plotter.plot()</code> is called.

We can also add many channels to the plot. For each channel we must supply an upperboumd on the number of samples that will be added to the channel. See <code>Plotter.add_channel</code>. All the channels will be shown on the same plot.

We should add as many samples to each channel as we please and if you call <code>plotter.plot()</code>, a pong datauri containing the plot will be returned.

In [None]:
import numpy as np
import io, base64
import os
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties



class Plotter(object):
    """
      A utility class to plot training/test data
      add_channel: Add as many channels as you want
      add_sample: Add as many samples to any channel
      plot: will return a dataurl containing a single plot 
    """
    
    format='PNG'
    def __init__(self,  **kwds):
        #need to have these keywords for initialization
        assert(kwds.get('xlabel'))
        assert(kwds.get('ylabel'))
        assert(kwds.get('title'))
        self.__dict__.update(kwds)
        #initialize empty extents
        self.extents=[np.inf, -np.inf, np.inf, -np.inf]
        self.channels={}
        pass
    
    #num_samples == upper bound on the number of samples that can be added for this channel    
    def add_channel(self, num_samples=-1, **kwds):
        assert(kwds.get('channel_name'))
        assert(kwds.get('legend'))
        channel = self.channels.setdefault(kwds['channel_name'], {})
        channel['plot_x'] = np.zeros(num_samples, dtype=np.float32)
        channel['plot_y'] = np.zeros_like( channel['plot_x']  )
        channel['legend'] = kwds['legend']
        channel['next_sample_index'] = 0
    
    #add a sample to a channel
    def add_sample(self, x, y, channel_name=''):
        assert(channel_name)
        assert(self.channels.get(channel_name))
        channel = self.channels[channel_name]
        next_index = channel['next_sample_index']
        channel['plot_x'][next_index] = x
        channel['plot_y'][next_index] = y
        channel['next_sample_index'] += 1
        self.update_extents_(x, y)

        
    #internal routine to keep track of extents
    def update_extents_(self, x, y):
        self.extents[0 ] = 0 # min(x, self.extents[0])  
        self.extents[1 ] = max(x, self.extents[1])  
        self.extents[2 ] = 0 # min(y, self.extents[2])  
        self.extents[3 ] = 1 # max(y, self.extents[3])  

    def plot_core_(self):
        fontP = FontProperties()
        fontP.set_size('small')
        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)
        #plot each channel
        for k,v in self.channels.iteritems():
            next_sample_index = v['next_sample_index']
            ax.plot(v['plot_x'][0:next_sample_index], v['plot_y'][0:next_sample_index], label=k)
        plt.legend( loc='lower left', prop=fontP)
        ax.set_title(self.title)
        ax.set_xlabel(self.ylabel)
        ax.set_xlabel(self.xlabel)
        #return the plot as a dataurl
        buf = io.BytesIO()    
        fig.savefig(buf, format=Plotter.format)
        buf.seek(0)
        fig.clear()
        plt.close(fig)
        return buf
    
    #plot routune
    def plot(self):
        buf = self.plot_core_()
        dataurl = "data:image/" + Plotter.format + ";base64," + base64.b64encode(buf.read())
        return dataurl
    
    def plot_and_save_fig(self, savepath=''):
        assert(savepath)
        assert(os.path.splitext(savepath)[1].lower() == '.' + Plotter.format.lower())
        buf = self.plot_core_()
        with file(savepath, 'wb') as fp:
            fp.write(buf.read())
    
        