In [1]:
from IPython.display import HTML
display(HTML("<head><link rel='stylesheet' type='text/css' href='../custom.css'></head>"))
display(HTML("<style>.container { width:100% !important; }</style>"))

In [2]:
from bqplot import *
import bqplot as bq
import bqplot.marks as bqm
import bqplot.scales as bqs
import bqplot.axes as bqa

import ipywidgets as widgets
from ipywidgets import Layout

import matplotlib as mpl
import matplotlib.colors as mcolors  
import matplotlib.pyplot as plt      

from matplotlib import rc 

from numpy import array, float64

In [3]:
def get_limits(figure):
    
    data_mark = figure.marks[0]
    
    if data_mark.selected == None or len(data_mark.selected) == 0:
    
        max_x = max(data_mark.x)
        min_x = min(data_mark.x)

        max_y = max(data_mark.y)
        min_y = min(data_mark.y)
        
    else:   
        max_x = max(np.take(data_mark.x, data_mark.selected))
        min_x = min(np.take(data_mark.x, data_mark.selected))

        max_y = max(np.take(data_mark.y, data_mark.selected))
        min_y = min(np.take(data_mark.y, data_mark.selected))
        
        return max_x, min_x, max_y, min_y
    
    for mark in figure.marks:

        if len(mark.x) > 0 and len(mark.y) > 0:
            if max(mark.x) > max_x:
                max_x = max(mark.x)

            if min(mark.x) < min_x:
                min_x = min(mark.x)

            if max(mark.y) > max_y:
                max_y = max(mark.y)

            if min(mark.y) < min_y:
                min_y = min(mark.y)   
    
    return max_x, min_x, max_y, min_y

In [4]:
def bqplot_to_matplotlib(fig):
    
    plt.ioff()

    axis_x = fig.axes[0]
    axis_y = fig.axes[1]


    plt.rc('text', usetex=True)
    mpl.rcParams['errorbar.capsize'] = 3
    colors_ = list(mcolors.TABLEAU_COLORS.values())


    mpl_fig, ax0 = plt.subplots(nrows=1, ncols=1, figsize=(16, 12), );
    plt.subplots_adjust(left=0.25, bottom=0.25, right=0.9, top=None, wspace=0.0, hspace=0.0)
    ax0.tick_params(axis='both', labelsize=30, pad=10, length=12)

    labels = []

    x_max, x_min, y_max, y_min  = axis_x.scale.max, axis_x.scale.min, axis_y.scale.max, axis_y.scale.min#get_limits(fig)


    if type(axis_x.scale) == bqs.LogScale:

        ax0.set_xscale('log')
        ax0.set_xlim(x_min/1.2, x_max*1.2)

    elif type(axis_x.scale) == bqs.LinearScale:

        data_width = x_max - x_min
        ax0.set_xlim(x_min-data_width*0.1, x_max+data_width*0.1)

    if type(axis_y.scale) == bqs.LogScale:

        ax0.set_yscale('log')
        ax0.set_ylim(y_min/1.2, y_max*1.2)

    elif type(axis_y.scale) == bqs.LinearScale:

        data_height = y_max - y_min
        ax0.set_ylim(y_min-data_height*0.1, y_max+data_height*0.1)

    for mark in fig.marks:
        
        if mark.selected == None:

            data_x = mark.x
            data_y = mark.y

        else:

            if len(mark.selected) < 3:

                data_x = mark.x
                data_y = mark.y

            else:

                data_x = np.take(mark.x, mark.selected)
                data_y = np.take(mark.y, mark.selected)    

        if len(fig.title.split()) > 0: 
            ax0.set_title(r'\textrm{%s}' % fig.title, size=35, pad=20);
        ax0.set_xlabel(r'\textrm{%s}' % axis_x.label, size=30, labelpad=15);
        ax0.set_ylabel(r'\textrm{%s}' % axis_y.label, size=30, labelpad=15);

        labels = mark.labels
        label = ''
        
        if labels is not None:
            if len(labels)>0:
                label = labels[0]
        
        
        if type(mark) == bqm.Scatter:

            if mark.default_opacities == None:
                default_opacity = 1.0
            elif len(mark.default_opacities) == 0:
                default_opacity = 1.0
            else:
                 default_opacity = mark.default_opacities[0]
            
            if mark.display_legend:
            
                ax0.scatter(data_x,
                            data_y,
                            label=r'\textrm{%s}' %  label,
                            color=mark.colors[0],
                            s=40*mark.stroke_width,
                            alpha=default_opacity);
            
            else:
                
                ax0.scatter(data_x,
                            data_y,
                            color=mark.colors[0],
                            s=40*mark.stroke_width,
                            alpha=default_opacity);
                
            if mark.names is not None:
                if len(mark.names)>0:
                    for i in range(len(mark.names)):
                        ax0.annotate(mark.names[i], (mark.x[i], mark.y[i]), size=30)

        elif type(mark) == bqm.Lines:

            if mark.opacities == None:
                opacity = 1.0
            elif len(mark.opacities) == 0:
                opacity = 1.0
            else:
                opacity = mark.opacities[0]

            if mark.display_legend:
            
                ax0.plot(mark.x,
                         mark.y,
                         label=r'\textrm{%s}' %  label,
                         color=mark.colors[0],
                         linewidth=mark.stroke_width,
                         alpha=opacity
                         );
            else:
                
                ax0.plot(mark.x,
                         mark.y,
                         color=mark.colors[0],
                         linewidth=mark.stroke_width,
                         alpha=opacity
                         );

        if mark.display_legend:
            
            handles, labels = plt.gca().get_legend_handles_labels()
            order = [0]

            plt.legend([handles[idx] for idx in order],[labels[idx] for idx in order])
            ax0.legend(bbox_transform=plt.gcf().transFigure,
                  bbox_to_anchor=(0.90, 0.90),
                  loc='upper left',
                  ncol=1,
                  borderaxespad=1.5,
                  frameon=False,
                  fontsize=25);

        if axis_x.grid_lines != 'none':
                ax0.grid(color='grey', alpha=0.5, linewidth=0.5)

    return mpl_fig

In [5]:
def refresh_mpl_figure(a):
    
    plt.close('all')
    mpl_out.clear_output()
    
    bqp_fig.title = parse_text(bqp_fig.title)

    with mpl_out:
            mpl_fig = bqplot_to_matplotlib(bqp_fig)
            display(mpl_fig)

In [6]:
def parse_text(string):
    
    chars = ['#', '$', '%', '&', '~', '_', '^', '\\', '{', '}', '\(', '\)', '\[', '\]']
    
    for char in chars:
    
        if char in string:

            return "$"+string+"$"
            
    return string

In [7]:
def rename_labels(change):
    bqp_fig.title = figure_title.value
    bqp_fig.axes[0].label = label_x.value
    bqp_fig.axes[1].label = label_y.value

In [8]:
def show_grid(change):
    
    if show_grid_checkbox.value:
        bqp_fig.axes[0].grid_lines = 'dashed'
        bqp_fig.axes[1].grid_lines = 'dashed'
        
    else:
        bqp_fig.axes[0].grid_lines = 'none'
        bqp_fig.axes[1].grid_lines = 'none'

In [9]:
def change_data_name(change):
    obj = change.owner
    
    for i in range(len(middle_block.children[0].children)):
        if middle_block.children[i].children[0].children[0] is obj:
            break
            
    bqp_fig.marks[i].name = parse_text(obj.value)

In [10]:
def change_data_style(change):
    obj = change.owner
    
    for i in range(len(middle_block.children[0].children)):
        if middle_block.children[0].children[i].children[1].children[0] is obj:
            break
               
    mark = bqp_fig.marks[i]
    marks = [m for m in bqp_fig.marks]
    
    scales = {'x': bqp_fig.axes[0].scale,
              'y': bqp_fig.axes[1].scale
             }
    
    if obj.value == "Scatter":
        marks[i] = bqm.Scatter(
                   x = mark.x, 
                   y = mark.y, 
                   scales = scales, 
                   enable_move = False,
                   restrict_x = False,
                   restrict_y = False,
                   selected_style={'opacity': '1'},
                   unselected_style={'opacity': '0.2'},
                   selected = None,
                   colors = mark.colors,
                   name = mark.name
               )
    

    
    elif obj.value == "Lines":
        
        x_values = sorted(mark.x)
        y_values = [y for _,y in sorted(zip(mark.x, mark.y))] #Sort values of Y depending on the order of X
        
        marks[i] = bqm.Lines(
                    x = x_values, 
                    y = y_values, 
                    scales = scales, 
                    visible = True,
                    colors = mark.colors,
                    name = mark.name
                )
        
    bqp_fig.marks = marks

In [11]:
def change_data_width(change):
    
    obj = change.owner
    
    for i in range(len(middle_block.children[0].children)):
        if middle_block.children[0].children[i].children[2].children[0] is obj:
            break
            
    bqp_fig.marks[i].stroke_width = obj.value

In [12]:
def change_data_opacity(change):
    
    obj = change.owner
    
    for i in range(len(middle_block.children[0].children)):
        if middle_block.children[0].children[i].children[3].children[0] is obj:
            break
            
    bqp_fig.marks[i].opacities = [obj.value]
    bqp_fig.marks[i].default_opacities = [obj.value]

In [13]:
def change_data_color(change):
    
    obj = change.owner

    for i in range(len(middle_block.children[0].children)):
        if middle_block.children[0].children[i].children[4].children[0] is obj:
            break
        
    bqp_fig.marks[i].colors = [obj.value]

In [14]:
def change_data_legend(change):
    
    obj = change.owner

    for i in range(len(middle_block.children[0].children)):
        if middle_block.children[0].children[i].children[6].children[0] is obj:
            break
    
    bqp_fig.marks[i].labels = [obj.value]

In [15]:
def change_show_legend(change):
    
    obj = change.owner

    for i in range(len(middle_block.children[0].children)):
        if middle_block.children[0].children[i].children[7].children[0] is obj:
            break
            
    bqp_fig.marks[i].display_legend = obj.value

In [16]:
def save_PNG(a):
    try:
        export_message.value ="<b>Creating PNG file...</b>"
        
        save_filename = filename_text.value
        
        if len(save_filename.split()) == 0:
            save_filename = 'png_file'

        if '.png' not in save_filename:
            save_filename = save_filename + '.png'
        
        plt.savefig(save_filename, format='png', dpi=300, bbox_inches="tight");
        export_message.value = "<span style='color:green'><b>PNG file '" + save_filename + "' successfully created, download it here:</b></span><form action=" + save_filename + " target='_blank'><button type=''submit''>Download PNG</button></form>"
        filename_text.value = ''

    except:
        export_message.value = "<span style='color:red'><b>An error has occurred</b></span>"
        filename_text.value = ''

In [17]:
def save_PDF(a):
    try:
        export_message.value ="<b>Creating PDF file...</b>"
        
        save_filename = filename_text.value
        
        if len(save_filename.split()) == 0:
            save_filename = 'pdf_file'

        if '.pdf' not in save_filename:
            save_filename = save_filename + '.pdf'
        
        plt.savefig(save_filename, format='pdf', dpi=300, bbox_inches="tight");
        export_message.value = "<span style='color:green'><b>PDF file '" + save_filename + "' successfully created, download it here:</b></span><form action=" + save_filename + " target='_blank'><button type=''submit''>Download PDF</button></form>"
        filename_text.value = ''

    except:
        export_message.value = "<span style='color:red'><b>An error has occurred</b></span>"
        filename_text.value = ''

In [18]:
def clean_marks(figure):
    
    new_marks = []
    
    for mark in figure.marks:
        if mark.visible and len(mark.x)>0:
            new_marks.append(mark)
    
    figure.marks = new_marks
    return figure

In [3]:
#load the figure from %store
fig_str = None
%store -r fig_str
bqp_fig = clean_marks(eval(fig_str))

#defining the 'blocks' of the interface
main_block = widgets.VBox([])

head_block = widgets.VBox([], layout = widgets.Layout(margin='20px 20px 20px 20px', align_items = 'center'))
body_block = widgets.HBox([])
footer_block = widgets.VBox([], layout = widgets.Layout(margin='20px 20px 20px 20px', align_items = 'center'))

left_block = widgets.VBox([], layout = widgets.Layout(width='35%', align_items = 'center'))
middle_block = widgets.VBox([], layout = widgets.Layout(width='20%',margin='50px 0 0 0'))
right_block = widgets.VBox([], layout = widgets.Layout(width='45%', align_items = 'center'))


#head_block's children
figure_title = widgets.Text(
    value= bqp_fig.title,
    placeholder='',
    description='Title:',
    disabled=False
)

figure_title.observe(rename_labels, 'value')

label_x = widgets.Text(
    value= bqp_fig.axes[0].label,
    placeholder='',
    description='Axis x:',
    disabled=False
)

label_x.observe(rename_labels, 'value')

label_y = widgets.Text(
    value= bqp_fig.axes[1].label,
    placeholder='',
    description='Axis y:',
    disabled=False
)

label_y.observe(rename_labels, 'value')

show_grid_checkbox = widgets.Checkbox(
    value=False,
    description='Show grid',
    disabled=False
)

show_grid_checkbox.observe(show_grid, 'value')

head_block.children = (figure_title,
                      widgets.HBox([label_x, label_y], layout=widgets.Layout(margin='10px 0 0 0')),
                        show_grid_checkbox
                      )
#left_block's children
left_block.children = (bqp_fig,)

#middle_block's children

style_tabs = widgets.Tab()

count = 0

for mark in bqp_fig.marks:

    if mark.visible and len(mark.x) > 0:     
    
        data_name_label = widgets.HTML(
            value="<h4>"+ mark.name +"</h4>",
        )
    
        #data_name_text = widgets.Text(
        #        value= mark.name
        #    )

        #data_name_text.observe(change_data_name, 'value')

        data_style_dropdown = widgets.Dropdown(
                options = ['Scatter', 'Lines'],
                description='',
                value = str(type(mark))[21:-2]
            )

        data_style_dropdown.observe(change_data_style, 'value')

        width_slider = widgets.FloatSlider(
                value=mark.stroke_width,
                min=0.0,
                max=8.0,
                step=0.2,
                description='Width:',
                disabled=False,
                continuous_update=False,
                orientation='horizontal',
                readout=True,
                readout_format='.2f',
                layout = widgets.Layout(width = '100%')
            )

        width_slider.observe(change_data_width, 'value')

        opacity_slider = widgets.FloatSlider(
                value=2.0,
                min=0.0,
                max=1.0,
                step=0.1,
                description='Opacity:',
                disabled=False,
                continuous_update=False,
                orientation='horizontal',
                readout=True,
                readout_format='.2f',
                layout = widgets.Layout(width = '100%')
            )

        opacity_slider.observe(change_data_opacity, 'value')

        data_color = widgets.ColorPicker(
                concise=False,
                description='',
                value=mark.colors[0],
                disabled=False
            )

        data_color.observe(change_data_color, 'value')
        
        
        data_legend_label = widgets.HTML(
            value="<h5>Legend:</h5>",
        )
        
        labels = mark.labels
        label = ''
        
        if labels is not None:
            if len(labels)>0:
                label = labels[0]
        
        data_legend_text = widgets.Text(
                value= label
            )
        
        data_legend_text.observe(change_data_legend, 'value')
        
        show_legend_checkbox = widgets.Checkbox(
            value=True,
            description='Show in legend',
            disabled=False
        )
        
        show_legend_checkbox.observe(change_show_legend, 'value')
        
        style_tabs.children = style_tabs.children + (widgets.VBox([
                                                            widgets.HBox([data_name_label]),
                                                            widgets.HBox([data_style_dropdown]),
                                                            widgets.HBox([width_slider,]),
                                                            widgets.HBox([opacity_slider,]),
                                                            widgets.HBox([data_color,]),
                                                            widgets.HBox([data_legend_label,]),
                                                            widgets.HBox([data_legend_text,]),
                                                            widgets.HBox([show_legend_checkbox,]),
                                                            widgets.HTML(value = "<hr>",),
                                                        ]),
                                                        )
        style_tabs.set_title(count, mark.name)
        count = count + 1
        
middle_block.children = (style_tabs,)
        
        
#right_block's children

mpl_out = widgets.Output()


refresh_mpl_figure_button = widgets.Button(
        description='Refresh figure',
        disabled=False,
        button_style='',
        tooltip='',
    )

refresh_mpl_figure_button.on_click(refresh_mpl_figure)

filename_text = widgets.Text(
    value='',
    placeholder="Type file's name (optional)",
    description='',
    disabled=False
)

export_pdf_button = widgets.Button(
        description='Export to PDF',
        disabled=False,
        button_style='',
        tooltip='Click me',
    )

export_pdf_button.on_click(save_PDF)

export_png_button = widgets.Button(
        description='Export to PNG',
        disabled=False,
        button_style='',
        tooltip='',
    )

export_png_button.on_click(save_PNG)

export_message = widgets.HTML(
    value = "",
)

right_block.children = (mpl_out,
                        refresh_mpl_figure_button,
                        filename_text,
                        widgets.HBox([export_pdf_button, export_png_button]),
                        export_message
                       )

body_block.children = (left_block,
                       middle_block,
                       right_block
                      )

#footer_block

latex_error_message = widgets.HTML(
    value = "<b style='color: green; margin-left: 15px;'>If you see an error message in the right panel Latex hasn't been able to parse your text. Please review your title/legend/axes.</b>",
)

footer_block.children = (widgets.HBox([latex_error_message]),
                        
                        )

main_block.children = (head_block,
                        body_block,
                      footer_block)


main_block

NameError: name 'clean_marks' is not defined