In [None]:
# default_exp jupyter_plot

# jupyter_plot

> Create real-time plots in Jupyter Notebooks.

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
# exports
import IPython
import matplotlib.pyplot as plt

try:
    from lrcurve.plot_learning_curve import PlotLearningCurve
except:
    from lrcurve.plot_learning_curve import PlotLearningCurve

class ProgressPlot(PlotLearningCurve):
    """Real-time progress plots for Jupyter notebooks.
    
    """
    def __init__(self,
                 plot_names=['plot'],
                 line_names=['line-1'],
                 line_colors=None,
                 x_lim=[None, None],
                 y_lim=[None, None],
                 x_label='iteration',
                 x_iterator=True,
                 height = None,
                 width = 600,
                 display_fn=IPython.display.display,
                 debug=False
                ):

        self.width = width
        self.height = height
        self.display_fn = display_fn
        self.debug = debug

        self._plot_is_setup = False
        self._plots = plot_names
        self.line_names = line_names
        self.line_colors = line_colors
        self.x_lim = x_lim
        self.y_lim = y_lim
        self.x_label = x_label

        self.iterator=0
        
        if not line_colors:
            mpl_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
            self.line_colors = [mpl_colors[i%len(mpl_colors)] for i in range(len(line_names))]
        else:
            self.line_colors=line_colors
        
        if x_iterator:
            self.update = self._update_with_iter
        else:
            self.update = self._update_with_x
            
        self._setup_plot()
    
        
    def _update_with_iter(self, y):
        self._update_with_x(self.iterator, y)
        self.iterator+=1
        
    def _update_with_x(self, x, y):
        y = self._parse_y(y)
        self.append(x, y)
        self.draw()
    
    def _parse_y(self, y):
        if isinstance(y, dict):
            return y
        elif isinstance(y, list):
            return self._y_list_to_dict(y)
        elif isinstance(y, (int, float)):
            return self._y_scalar_to_dict(y)
        else:
            raise ValueError('Not supported data type for update. Should be one of dict/list/float.')
    
    def _y_list_to_dict(self, y):
        if not (len(y)==len(self._plots)):
            raise ValueError('Number of plot updates not equal to number of plots!')
        if not all(isinstance(yi, list) for yi in y):
            raise ValueError('Line updates not of type list!')
        if not all(len(yi)==len(self.line_names) for yi in y):
            raise ValueError('Number of line update values not equal to number of lines!')
        
        y_dict = {plot: {line: y_ij for line, y_ij in zip(self.line_names, y_i)} for plot, y_i in zip(self._plots, y)}
        return y_dict
    
    def _y_scalar_to_dict(self, y):
        if not (len(self._plots)==1 and len(self.line_names)==1):
            raise ValueError('Can only update with int/float with one plot and one line.')
            
        y_dict = {self._plots[0]: {self.line_names[0]: y}}
        return y_dict

    def _setup_plot(self):
        line_config = {name: {'name': name, 'color': color} for name, color in zip(self.line_names, self.line_colors)}
        facet_config = {name: {'name': name, 'limit': self.y_lim} for name in self._plots}
        xaxis_config = {'name': self.x_label, 'limit': self.x_lim}
        
        super().__init__(height=self.height, width=self.width,
                         line_config=line_config,
                         facet_config=facet_config,
                         xaxis_config=xaxis_config,
                         display_fn=self.display_fn,
                         debug=self.debug)
        
    

In [None]:
import numpy as np

In [None]:
pp = ProgressPlot()
for i in range(1000):
    pp.update(np.log10(i+1))
pp.finalize()

In [None]:
pp = ProgressPlot(line_names=['lin', 'log', 'cos', 'sin'], x_lim=[0, 1000], y_lim=[-1,4])
for i in range(1000):
    pp.update([[i/250, np.log10(i+1), np.cos(i/100), np.sin(i/100)]])
pp.finalize()

In [None]:
pp = ProgressPlot(plot_names=['cos', 'sin'], line_names=['data'], x_lim=[0, 1000], y_lim=[-1,1])
for i in range(1000):
    pp.update([[np.cos(i/100)], [np.sin(i/100)]])
pp.finalize()

In [None]:
pp = ProgressPlot(x_iterator=False)
for i in range(1000):
    pp.update(10*i, i/100)
pp.finalize()