# >>Disclaimer

Here I implement the widget as a class, so it can be used not only for sampling but for other methods (gradient descent etc.) and it can be easily controled within itself (without any arguments on the input, just GUI).

In [1]:
%run config_check.py

You are currently using this configuration -
Python:  3.6.4 |Anaconda, Inc.| (default, Jan 16 2018, 18:10:19) 
[GCC 7.2.0]
Numpy:  1.14.0
Scipy:  1.0.0
Matplotlib:  2.1.2
Pystan:  2.17.1.0
ipywidgets:  7.2.1

The code was tested on Python:  3.6.4 |Anaconda, Inc.| (default, Jan 16 2018, 18:10:19) [GCC 7.2.0], Numpy:  1.14.0, Scipy:  1.0.0, Matplotlib:  2.1.2, Pystan:  2.17.1.0, ipywidgets:  7.1.1


In [2]:
# <api>
import numpy as np
import bqplot
import ipywidgets
from ipywidgets import widgets, Layout
from IPython import display
from jupyter_cms.loader import load_notebook

In [3]:
# Load classes and functions from the previous parts

smpl_trgt = load_notebook('./Sampling_targets.ipynb')
smpl_mtd = load_notebook('./Sampling_methods.ipynb')

## Animation widget for 2D distributions

In [4]:
#<api>
class AnimationWidget(object):
    
    def __init__(self, target=None, method=None, cmap='Blues', rej_points=False):
        '''Create an animated widget for the given sampling method.
    
        Parameters
        ----------
        method : the given name of drawing (sampling) method
        target(Target) : a type of target        
        cmap : a colormap from {'Blues', 'Rosenblues'} 
        rej_points : to show or not the proposed(rejected) points for MHP. 
                     ATTENTION! It slows down the playback.
        '''
        # Set parametes to the default values if no were given
        self.target = target
        if (target is None):
            self.target = smpl_trgt.MultNorm()
        self.size = self.target.size
        self.cmap = cmap        
        
        self.method = method
        if (method is None):
            self.method = 'MH'
            self.drawing = smpl_mtd.MH(target=self.target, prop_step=50)
            self.N=2000
            self.maxN=5000
            self.title="Metropolis-Hastings sampling"     
               
        
        # If not a default method was chosen, set internal parameters for the right method
        self.chose_method()
        
        # Produce datapoints for the line (draw the samples)
        self.data = smpl_mtd.SamplesArray(self.drawing)
        if self.method=='HMCi':
            self.data.draw(int(self.N/self.drawing.Tau))
            self.data.reassembly_HMCi_samples()
        else:
            self.data.draw(self.N)        
        
        # Try to get the drawing (sampling) summary
        self.drawing_summary = ""
        if(self.drawing.__str__ is not object.__str__):
            self.drawing_summary = self.drawing.__str__()    
    
        #### Create the bqplot canvas
        self.create_canvas()

        #### Create control widgets
        
        # Data point widget
        self.sliderD = widgets.IntSlider(value=1, min=1, max=self.N, step=10, description='Samples', 
                                         continuous_update=True)
        self.sliderD.observe(self.on_sliderD_change, names='value')

        # Slider for the total amount of data poins, used to produce other number of points by redraw
        self.sliderN = widgets.IntSlider(value=self.N, min=10, max=self.maxN, step=10, description='Total')
        self.sliderN.observe(self.on_sliderN_change, names='value')  
        
        # Slider for the play step (the bigger the step, the quicker the play)
        self.sliderS = widgets.IntSlider(value=10, min=1, max=20, step=1, description='Play step', 
                                         continuous_update=True)
        self.sliderS.observe(self.on_sliderS_change, names='value')        

        # Play widget
        self.play = widgets.Play(interval=50, value=1, min=1, max=self.N, step=10, description="Press play")      
        self.link = widgets.jslink((self.play, 'value'), (self.sliderD, 'value')) 
        
        # Redraw button
        self.button = widgets.Button(description="Redraw")
        self.button.on_click(self.redraw)

        # Wdget for selection the method
        self.method_select = widgets.Dropdown(
            options={'Metropolis-Hastings': 'MH', 
                     'Metropolis-Hastings with proposals': 'MHP',
                     'Sice sampling': 'MSS', 
                     'Hamiltonian Monte-Carlo': 'HMC',
                     'HMC with integration steps': 'HMCi'}, 
            value='MH', description='Sampling method:', disabled=False)
        self.method_select.observe(self.on_method_select_change, names='value')
        
        # Wdget for selection the target
        self.target_select = widgets.Dropdown(
            options={'Multivariate normal': 'MN',
                     'Bimodal normal' : 'BN',
                     'Donut': 'Donut', 
                     'Banana (Rosenbrock function)': 'Banana'}, 
            value='MN', description='Target distribution:', disabled=False)
        self.target_select.observe(self.on_target_select_change, names='value')

        # Try to print drawing summary
        self.text = widgets.HTML(value=self.drawing_summary)  
        
        # Create figure
        self.rej_points = rej_points
        self.create_figure(rejected_points=self.rej_points)        
        self.ui = widgets.HBox([self.fig])
        self.add_spec_method_param() 
        
    

    def create_canvas(self):
        '''Create the bqplot canvas''' 
        
        self.x_sc = bqplot.LinearScale(min=-self.size, max=self.size)
        self.y_sc = bqplot.LinearScale(min=-self.size, max=self.size)
        
        # Chose the colormap
        if (self.cmap=='Blues'):
            self.col_sc = bqplot.ColorScale(scheme='Blues')
        elif (self.cmap=='Rosenblues'):
            self.col_sc = bqplot.ColorScale(colors=['#3182bd','#9ecae1','#deebf7'], max=10)

        self.ax_x = bqplot.Axis(scale=self.x_sc, orientation='horizontal', num_ticks=7)
        self.ax_y = bqplot.Axis(scale=self.y_sc, orientation='vertical', num_ticks=7)
        #self.ax_c = bqplot.ColorAxis(scale=self.col_sc)
        self.target_heat = bqplot.HeatMap(x=self.target.x, 
                                          y=self.target.y, 
                                          color=self.target.Z, 
                                          scales={'color': self.col_sc, 'x': self.x_sc, 'y': self.y_sc})
        self.line = bqplot.Lines(x=self.data.getX(self.N), 
                                 y=self.data.getY(self.N), 
                                 colors=['red'], opacities=[0.6],
                                 scales={'x': self.x_sc, 'y': self.y_sc}) 
        
        self.rejected_points = bqplot.Scatter(x=[], 
                                                  y=[], 
                                                  colors=['red'],
                                                  scales={'x': self.x_sc, 'y': self.y_sc}, 
                                                  default_opacities=[0.4], 
                                                  default_size=24,
                                             disabled=True)
        self.prop_line = bqplot.Lines(x=[], y=[], 
                                 colors=['green'], default_opacities=[0.4],
                                 scales={'x': self.x_sc, 'y': self.y_sc})
        self.traj_line = bqplot.Lines(x=[], y=[], 
                                 colors=['black'], default_opacities=[0.4],
                                 scales={'x': self.x_sc, 'y': self.y_sc})
        
            
    def create_figure(self, rejected_points=False):
        '''Create the bqplot figure'''
        
        if rejected_points:
            self.fig = bqplot.Figure(marks=[self.target_heat, self.line, self.rejected_points, 
                                            self.prop_line, self.traj_line], 
                                 axes=[self.ax_x, self.ax_y],
                                 title=self.title, 
                                 animation_duration=0,
                                 min_aspect_ratio=1,
                                 max_aspect_ratio=1,
                                 fig_margin={'top':60, 'bottom':60, 'left':60, 'right':60},
                                 background_style={'fill': 'white'},
                                 padding_x=0.0,
                                 padding_y=0.0)
        else:
            self.fig = bqplot.Figure(marks=[self.target_heat, self.line, self.prop_line, self.traj_line], 
                                 axes=[self.ax_x, self.ax_y],
                                 title=self.title, 
                                 animation_duration=0,
                                 min_aspect_ratio=1,
                                 max_aspect_ratio=1,
                                 fig_margin={'top':60, 'bottom':60, 'left':60, 'right':60},
                                 background_style={'fill': 'white'},
                                 padding_x=0.0,
                                 padding_y=0.0)
        self.fig.layout.height = '600px'
        self.fig.layout.width = '600px'
        
        
    def add_spec_method_param(self):
        # Add some widgets for specific methods parameters
        if (self.method=='MH'):
            self.sliderD.description = "Samples"
            self.sliderP = widgets.IntSlider(value=50, min=1, max=200, description='Proposal size')
            self.sliderP.observe(self.on_sliderP_change, names='value')
            self.ui = widgets.HBox([self.fig, 
                                    widgets.VBox([self.method_select, self.target_select, 
                                                  widgets.HBox([self.button, self.play]), 
                                                  self.sliderD, self.sliderN, self.sliderS, 
                                                  self.sliderP, self.text])])
        elif (self.method=='MHP'):
            self.sliderD.description = "Samples"
            self.sliderP = widgets.IntSlider(value=50, min=1, max=200, description='Proposal size')
            self.sliderP.observe(self.on_sliderP_change, names='value')
            self.ui = widgets.HBox([self.fig, 
                                    widgets.VBox([self.method_select, self.target_select, 
                                                  widgets.HBox([self.button, self.play]), 
                                                  self.sliderD, self.sliderN, self.sliderS, 
                                                  self.sliderP, self.text])])
        elif (self.method=='MSS'):
            self.sliderD.description = "Samples"
            self.sliderW = widgets.FloatSlider(value=1.0, min=0.1, max=10.0, description='w')
            self.sliderW.observe(self.on_sliderW_change, names='value')
            self.toggle_button = widgets.ToggleButtons(options=['Straight', 'Skew'], description='Directions:', 
                                                       tooltips=['[[1,0], [0,1]]', '[[1,1], [1,-1]]'])
            self.toggle_button.observe(self.on_toggle_button, names='value')
            self.ui = widgets.HBox([self.fig, 
                                    widgets.VBox([self.method_select, self.target_select, 
                                                  widgets.HBox([self.button, self.play]), 
                                                  self.sliderD, self.sliderN, self.sliderS, 
                                                  self.sliderW, self.toggle_button])])        
        elif (self.method=='HMC'):
            self.sliderD.description = "Samples"
            self.sliderTau = widgets.IntSlider(value=42, min=1, max=100, description='Tau')
            self.sliderTau.observe(self.on_sliderTau_change, names='value')
            self.sliderDtau = widgets.FloatLogSlider(value=0.04, base=10, min=-3, max=0, step=0.1, 
                                                     description='dtau')
            self.sliderDtau.observe(self.on_sliderDtau_change, names='value')
            self.ui = widgets.HBox([self.fig, 
                                    widgets.VBox([self.method_select, self.target_select, 
                                                  widgets.HBox([self.button, self.play]), 
                                                  self.sliderD, self.sliderN, self.sliderS, 
                                                  self.sliderTau, self.sliderDtau])]) 
        elif (self.method=='HMCi'):
            self.sliderD.description = "Integr. steps"
            self.sliderTau = widgets.IntSlider(value=42, min=1, max=100, description='Tau')
            self.sliderTau.observe(self.on_sliderTau_change, names='value')
            self.sliderDtau = widgets.FloatLogSlider(value=0.04, base=10, min=-3, max=0, step=0.1, 
                                                     description='dtau')
            self.sliderDtau.observe(self.on_sliderDtau_change, names='value')
            self.ui = widgets.HBox([self.fig, 
                                    widgets.VBox([self.method_select, self.target_select, 
                                                  widgets.HBox([self.button, self.play]), 
                                                  self.sliderD, self.sliderN, self.sliderS, 
                                                  self.sliderTau, self.sliderDtau, self.text])])         
    
    
    def chose_method(self):
        # Chose the drawing (sampling) method
        if (self.method=='MH'):
            self.drawing = smpl_mtd.MH(target=self.target, prop_step=50)
            self.N=2000
            self.maxN=5000
            self.title="Metropolis-Hastings sampling" 
        elif (self.method=='MHP'):
            self.drawing = smpl_mtd.MHP(target=self.target, prop_step=50)
            self.N=2000
            self.maxN=5000
            self.title="Metropolis-Hastings sampling with proposals" 
        elif (self.method=='MSS'):
            self.drawing = smpl_mtd.MSS(target=self.target, direct='Straight')
            self.N=100
            self.maxN=500
            self.title="Slice sampling"
        elif (self.method=='HMC'):
            self.drawing = smpl_mtd.HMC(target=self.target)
            self.N=100
            self.maxN=500
            self.title="Hamiltonian Monte-Carlo sampling" 
        elif (self.method=='HMCi'):
            self.drawing = smpl_mtd.HMCi(target=self.target)
            self.N=2100
            self.maxN=5000
            self.title="HMC with integration steps"
        
            
    def show(self):
        # Show the widget
        display.display(self.ui)
        
        
    def redraw(self, button):
        '''Action, when Redraw button is clicked.'''       
        
        self.N = self.sliderN.value
        
        # Redraw the samples        
        self.drawing.reset_start()
        self.drawing.reset_sampling()
        if self.method=='HMCi':
            self.data.draw(int(self.N/self.drawing.Tau))
            self.data.reassembly_HMCi_samples()
        else:
            self.data.draw(self.N)
        
        # Reset sliders parameters
        self.sliderD.max = self.N
        self.sliderD.value = self.N
        self.sliderD.step = self.sliderS.value
        self.play.max = self.N
        self.play.value = self.N
        self.play.step = self.sliderS.value
        
        # Try to print some information about samles, if it exists
        if(self.drawing.__str__ is not object.__str__):
            self.text.value = self.drawing.__str__()
            
        # Plot the new line
        self.line.x = self.data.getX(self.N)
        self.line.y = self.data.getY(self.N)
        last_acceptedX = self.data.getX(self.N)[-1] 
        last_acceptedY = self.data.getY(self.N)[-1]
            
        # Plot the rejected points / integration steps
        if self.data.prop:
            self.rejected_points.x = self.data.getXall(self.N)
            self.rejected_points.y = self.data.getYall(self.N)
            last_proposedX = self.data.getXall(self.N)[-1] 
            last_proposedY = self.data.getYall(self.N)[-1]
            self.prop_line.x = [last_acceptedX, last_proposedX]
            self.prop_line.y = [last_acceptedY, last_proposedY]          
        
        
        
    def on_sliderD_change(self, change):
        num = self.sliderD.value
        self.line.x = self.data.getX(num)
        self.line.y = self.data.getY(num)
        last_acceptedX = self.data.getX(num)[-1] 
        last_acceptedY = self.data.getY(num)[-1]
        if self.data.prop:
            self.rejected_points.x = self.data.getXall(num)
            self.rejected_points.y = self.data.getYall(num)
            last_proposedX = self.data.getXall(num)[-1] 
            last_proposedY = self.data.getYall(num)[-1]
            self.prop_line.x = [last_acceptedX, last_proposedX]
            self.prop_line.y = [last_acceptedY, last_proposedY]
            if (self.method=='HMCi'):
                if (num <= 1):
                    self.traj_line.x, self.traj_line.y = [],[]
                else:
                    self.traj_line.x, self.traj_line.y = self.calculate_trajectory(num)
        
        
    def on_sliderN_change(self, change):        
        self.N = self.sliderN.value
        
    def on_sliderS_change(self, change):        
        self.sliderD.step = self.sliderS.value
        self.play.step = self.sliderS.value
    
    def on_sliderP_change(self, change):       
        self.drawing.prop_step = self.sliderP.value
        
    def on_sliderW_change(self, change):       
        self.drawing.w = self.sliderW.value
        
    def on_toggle_button(self, change):
        self.drawing.direct = change['new']
        
    def on_sliderTau_change(self, change):        
        self.drawing.Tau = self.sliderTau.value
        
    def on_sliderDtau_change(self, change):       
        self.drawing.dtau = self.sliderDtau.value    
    
    def on_target_select_change(self, change):       
        if (self.target_select.value=='MN'):
            self.target = smpl_trgt.MultNorm()
            self.cmap = 'Blues'
        elif (self.target_select.value=='BN'):
            self.target = smpl_trgt.BimodMultNorm()
            self.cmap = 'Blues'
        elif (self.target_select.value=='Donut'):
            self.target = smpl_trgt.Donut()
            self.cmap = 'Blues'
        elif (self.target_select.value=='Banana'):
            self.target = smpl_trgt.Rosenbrock()
            self.cmap = 'Rosenblues'           
        self.size = self.target.size        
        self.on_method_select_change(self.method_select.value)
    
    def on_method_select_change(self, change):        
        
        # Refresh widget settings        
        self.method = self.method_select.value
        display.clear_output(wait=True)
        self.create_canvas()
        self.chose_method()
        self.data = smpl_mtd.SamplesArray(self.drawing)        
        
        # Refresh the dashboard
        self.create_figure(rejected_points=self.rej_points)        
        self.add_spec_method_param()
                
        # BUG!!! After changing from MSS or HMC back to MH, N sets to the previous value of maxN
        # instead of the new value of N
        self.sliderN.value = self.N
        self.sliderN.max = self.maxN 
        
        # Clear the lines, it looks better this way
        self.line.x = np.array([])
        self.line.y = np.array([])       
        
        self.show()
        self.redraw(self.button) 

    def calculate_trajectory(self, position):
        '''Calculates the integration trajectory (steps) for drawing in HMCi.'''
        acceptedX = self.data.getX(position) 
        acceptedY = self.data.getY(position)        
        last_acceptedX = acceptedX[-1] 
        last_acceptedY = acceptedY[-1]
        i = 1
        proposedX = self.data.getXall(position)
        proposedY = self.data.getYall(position)
        last_proposedX = proposedX[-i]
        last_proposedY = proposedY[-i]
        trajX = []
        trajY = []
        array_size = len(proposedX)
        
        while ((last_acceptedX != proposedX[-i]) and (last_acceptedY != proposedY[-i]) and (i<array_size)):
            trajX.append(proposedX[-i])
            trajY.append(proposedY[-i])
            i += 1
        return [trajX, trajY]  

In [5]:
aw = AnimationWidget(rej_points=False)
aw.show()

HBox(children=(Figure(axes=[Axis(num_ticks=7, scale=LinearScale(max=3.0, min=-3.0)), Axis(num_ticks=7, orienta…