In [None]:

%load_ext aiida
%aiida

import numpy as np

import ipywidgets as ipw
from IPython.display import display, clear_output, HTML

import urllib.parse
import io
import zipfile
import matplotlib
import matplotlib.pyplot as plt

from utils import spm

In [None]:
colormaps = ['afmhot', 'binary', 'gist_gray', 'gist_heat', 'seismic']

current = None

fwhm = None
heights = None
voltages = None
heightOptions = None

extent = None
figure_xy_ratio = None

def load_pk(b):
    global current
    global fwhm, heights, voltages
    global extent, figure_xy_ratio
    global heightOptions
    try:
        workcalc = load_node(pk=pk_select.value)
        hrstm_calc = spm.get_calc_by_label(workcalc, 'hrstm')
    except:
        print("Incorrect pk.")
        return
    
    fwhm = float(hrstm_calc.inputs.parameters.dict['--fwhm_sam'])
    geom_info.value = spm.get_slab_calc_info(workcalc.inputs.structure)
    ase_geom = workcalc.inputs.structure.get_ase()
    
    # Load data.
    with hrstm_calc.outputs.retrieved.open('hrstm_meta.npy',mode='rb') as handle:
        meta_data = np.load(handle, allow_pickle=True).item()
    dimGrid = meta_data['dimGrid']
    lVec = meta_data['lVec']
    heights = [np.round(lVec[0,2]+lVec[3,2]/dimGrid[-1]*idx-np.max(ase_geom.get_positions()[:,2]),1) 
               for idx in range(dimGrid[-1])]
    heightOptions = {}
    for hIdx, height in enumerate(heights):
        heightOptions["h={:}".format(height)] = hIdx
    voltages = np.array(meta_data['voltages'])
    dimShape = dimGrid[:-1]+(len(heights),len(voltages),)
    try:
        with hrstm_calc.outputs.retrieved.open('hrstm.npz',mode='rb') as handle:
            current = np.abs(np.load(handle)['arr_0'].reshape(dimShape))
    except OSError:
        with hrstm_calc.outputs.retrieved.open('hrstm.npy',mode='rb') as handle:
            current = np.abs(np.load(handle).reshape(dimShape))

    extent = [lVec[0,0], lVec[1,0], lVec[0,1], lVec[2,1]]
    figure_xy_ratio = (lVec[1,0]-lVec[0,0]) / (lVec[2,1]-lVec[0,1])
    
    setup_hrstm_elements()
    setup_hrstm_single_elements()
    
    disc_zip_btn.disabled = False
    cont_zip_btn.disabled = False
    

style = {'description_width': '50px'}
layout = {'width': '70%'}
    
pk_select = ipw.IntText(value=0, description='pk', style=style, layout=layout)

load_pk_btn = ipw.Button(description='Load pk', style=style, layout=layout)
load_pk_btn.on_click(load_pk)

geom_info = ipw.HTML()

display(ipw.HBox([ipw.VBox([pk_select, load_pk_btn]), geom_info]))

# High-resolution scanning tunneling microscopy

In [None]:


def make_plot(fig, ax, data, title=None, title_size=None, center0=False, vmin=None, vmax=None, cmap='gist_heat', noadd=False):
    if center0:
        data_amax = np.max(np.abs(data))
        im = ax.imshow(data.T, origin='lower', cmap=cmap, interpolation='bicubic', extent=extent, vmin=-data_amax, vmax=data_amax)
    else:
        im = ax.imshow(data.T, origin='lower', cmap=cmap, interpolation='bicubic', extent=extent, vmin=vmin, vmax=vmax)
    
    if noadd:
        ax.set_xticks([])
        ax.set_yticks([])
    else:
        ax.set_xlabel(r"x ($\AA$)")
        ax.set_ylabel(r"y ($\AA$)")
        #if 1e-3 < np.max(data) < 1e3:
        #    cb = fig.colorbar(im, ax=ax)
        #else:
        #    cb = fig.colorbar(im, ax=ax, format=FormatScalarFormatter("%.1f"))
        cb = fig.colorbar(im, ax=ax)
        cb.formatter.set_powerlimits((-2, 2))
        cb.update_ticks()
    ax.set_title(title)
    if title_size:
        ax.title.set_fontsize(title_size)
    ax.axis('scaled')

# Series

In [None]:
def remove_from_tuple(tup, index):
    tmp_list = list(tup)
    del tmp_list[index]
    return tuple(tmp_list)

def remove_line_row(b, elem_list, selections_vbox):
    rm_btn_list = [elem[2] for elem in elem_list]
    rm_index = rm_btn_list.index(b)
    del elem_list[rm_index]
    selections_vbox.children = remove_from_tuple(selections_vbox.children, rm_index)

def add_selection_row(b, elem_list, selections_vbox):

    # Series: 
    drop_full_series = ipw.Dropdown(description="height", options=sorted(heightOptions.keys()),
        style = {'description_width': 'auto'})
    drop_cmap = ipw.Dropdown(description="colormap", options=colormaps,
        style = {'description_width': 'auto'})
    rm_btn = ipw.Button(description='x', layout=ipw.Layout(width='30px'))
    rm_btn.on_click(lambda b: remove_line_row(b, elem_list, selections_vbox))
    
    elements = [drop_full_series, drop_cmap, rm_btn]
    element_widths = ['180px', '240px', '35px']
    boxed_row = ipw.HBox([ipw.HBox([row_el], layout=ipw.Layout(border='0.1px solid', width=row_w)) for row_el, row_w in zip(elements, element_widths)])
    
    elem_list.append(elements)
    selections_vbox.children += (boxed_row, )

In [None]:
def setup_hrstm_elements():
    
    add_selection_row(None, elem_list, selections_vbox)
    
    default_biases = [-1.0, -0.5, -0.1, 0.1, 0.5, 1.0]
    # filter based on energy limits
    default_biases = [v for v in default_biases if v >= np.min(voltages) and v <= np.max(voltages)]
    biases_text.value = " ".join([str(v) for v in default_biases])
    
    energy_range_slider.min = np.min(voltages)
    energy_range_slider.max = np.max(voltages)
    energy_range_slider.step = voltages[1]-voltages[0]
    energy_range_slider.value = (np.min(voltages), np.max(voltages))

def make_discrete_plot(): 
    biases = np.array(biases_text.value.split(), dtype=float)
    filtered_biases = []
    for v in biases:
        if v >= np.min(voltages) and v <= np.max(voltages):
            filtered_biases.append(v)
        else:
            print("Voltage %.2f out of range, skipping" % v)
            
    fig_y_size = 5
    fig = plt.figure(figsize=(fig_y_size*figure_xy_ratio*len(filtered_biases), fig_y_size*len(elem_list)))   
    for i_ser in range(len(elem_list)):
        # TODO this gets the height, not the index!
        hIdx = heightOptions[elem_list[i_ser][0].value]
        cmap = elem_list[i_ser][1].value
        data = current[:,:,hIdx]
        for biasIdx, bias in enumerate(biases):
            ax = plt.subplot(len(elem_list), len(biases), i_ser*len(biases) + biasIdx + 1)
            vIdx = np.argmin(np.abs(voltages - bias))
            make_plot(fig, ax, data[:, :, vIdx], title='h=%.1f Ang, E=%.2f eV'%(heights[hIdx],bias), title_size=22, cmap=cmap, noadd=True)
    return fig
            

def plot_discrete_series(b):
    with discrete_output:
        fig = make_discrete_plot()
        plt.show()
        

def plot_full_series(b):
    
    fig_y = 4
    fig_y_in_px = 0.8*fig_y*matplotlib.rcParams['figure.dpi']

    num_series = len(elem_list)
    
    box_layout = ipw.Layout(overflow_x='scroll',
                    border='3px solid black',
                    width='100%',
                    height='%dpx' % (fig_y_in_px*num_series + 70),
                    display='inline-flex',
                    flex_flow='column wrap',
                    align_items='flex-start')
    
    plot_hbox = ipw.Box(layout=box_layout)
    continuous_output.children += (plot_hbox, )
    
    min_e, max_e = energy_range_slider.value
    ie_1 = np.abs(voltages - min_e).argmin()
    ie_2 = np.abs(voltages - max_e).argmin()+1
    
    plot_hbox.children = ()
    for i_e in range(ie_1, ie_2):
        plot_out = ipw.Output()
        plot_hbox.children += (plot_out, )
        with plot_out:
            fig = plt.figure(figsize=(fig_y*figure_xy_ratio, fig_y*num_series))
            
            for i_ser in range(len(elem_list)):
                
                hIdx = heightOptions[elem_list[i_ser][0].value]
                cmap = elem_list[i_ser][1].value

                title = 'h=%.1f Ang, E=%.2f eV'%(heights[hIdx], voltages[i_e])
                data = current[:,:,hIdx]
                
                ax = plt.subplot(len(elem_list), 1, i_ser+1)

                make_plot(fig, ax, data[:, :, i_e], title=title, cmap=cmap, noadd=True)
                    
            plt.show()
            
def on_full_clear(b):
    continuous_output.children = ()
    with discrete_output:
        clear_output()
    

In [None]:
elem_list = []
selections_vbox = ipw.VBox([])

add_row_btn = ipw.Button(description='Add series row')
add_row_btn.on_click(lambda b: add_selection_row(b, elem_list, selections_vbox))

style = {'description_width': '80px'}
layout = {'width': '40%'}

# Plot discrete.
disc_plot_btn = ipw.Button(description='plot discrete')
disc_plot_btn.on_click(plot_discrete_series)

biases_text = ipw.Text(description='voltages (V)', value='',
                        style=style, layout={'width': '80%'})
disc_plot_hbox = ipw.HBox([biases_text, disc_plot_btn],
                         style=style, layout={'width': '60%'})
discrete_output = ipw.Output()


# Plot continuous.
cont_plot_btn = ipw.Button(description='plot continuous')
cont_plot_btn.on_click(plot_full_series)

energy_range_slider = ipw.FloatRangeSlider(
    value=[0.0, 0.0],
    min=0.0,
    max=0.0,
    step=0.1,
    description='energy range',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
    style=style, layout={'width': '80%'}
)

cont_plot_hbox = ipw.HBox([energy_range_slider, cont_plot_btn],
                        style=style, layout={'width': '60%'})

continuous_output = ipw.VBox()
full_clear_btn = ipw.Button(description='clear')
full_clear_btn.on_click(on_full_clear)


display(add_row_btn, selections_vbox, disc_plot_hbox, cont_plot_hbox, full_clear_btn, discrete_output, continuous_output)

## Single

In [None]:
def setup_hrstm_single_elements():
    drop_hrstm_height_singl.options=sorted(heightOptions.keys())
    
    bias_slider.min = np.min(voltages)
    bias_slider.max = np.max(voltages)
    bias_slider.step = voltages[1]-voltages[0]
    bias_slider.value = np.min(voltages)

def make_single_plot(voltage, height, cmap):
    title = height + ", v=%.1f"%voltage
    data = current[:,:,heightOptions[height]]
    vIdx = np.abs(voltages - voltage).argmin()
    fig_y_size = 6
    fig = plt.figure(figsize=(fig_y_size*figure_xy_ratio+1.0, fig_y_size))
    ax = plt.gca()
    make_plot(fig, ax, data[:, :, vIdx],title=title, cmap=cmap)
    return fig

def plot_hrstm(c):
    if drop_hrstm_height_singl.value != None: 
        with hrstm_plot_out:
            clear_output()
            cmap = drop_singl_cmap.value
            fig = make_single_plot(bias_slider.value, drop_hrstm_height_singl.value, cmap)
            plt.show()

drop_hrstm_height_singl = ipw.Dropdown(description="heights", options=[])
drop_singl_cmap = ipw.Dropdown(description="colormap", options=colormaps)

bias_slider = ipw.FloatSlider(
    value=0.0,
    min=0.0,
    max=0.0,
    step=0.1,
    description='voltage (V)',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
)

single_plot_btn = ipw.Button(description='plot')
single_plot_btn.on_click(plot_hrstm)

hrstm_plot_out = ipw.Output()

display(drop_hrstm_height_singl, drop_singl_cmap, bias_slider, single_plot_btn, hrstm_plot_out)

# Export
Export either the currently selected discrete or continuous series.

In [None]:
def create_zip_link(figure_method, zip_progress, html_link_out, filename):
    
    zip_buffer = io.BytesIO()
    with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED, False) as zip_file:
        figure_method(zip_file, zip_progress)
    
    ! mkdir -p tmp
    
    with open('tmp/'+filename, 'wb') as f:
        f.write(zip_buffer.getvalue())
    
    with html_link_out:
        display(HTML('<a href="tmp/%s" target="_blank">download zip</a>' % filename))
    

def create_disc_zip_content(zip_file, zip_progress):
    biases = np.array(biases_text.value.split(), dtype=float)
    for i_v in range(len(biases)-1, -1):
        if biases[i_v] < np.min(voltages) or biases[i_v] > np.max(voltages):
            del biases[i_v]
    
    total_pics = len(biases)*len(elem_list) + 1
    
     # The total image.
    imgdata = io.BytesIO()
    fig = make_discrete_plot()
    fig.savefig(imgdata, format='png', dpi=200, bbox_inches='tight')
    zip_file.writestr("all.png", imgdata.getvalue())
    plt.close()
    zip_progress.value += 1.0/float(total_pics-1)

    # Individuals.
    for i_s in range(len(elem_list)):
        height = elem_list[i_s][0].value
        cmap = elem_list[i_s][1].value
        series_name = "hrstm_" + height
        for i_v in range(len(biases)):
            bias = biases[i_v]
            plot_name = series_name + "_%dv%+.2f" % (i_v, bias)
            imgdata = io.BytesIO()
            fig = make_single_plot(bias, height, cmap)
            fig.savefig(imgdata, format='png', dpi=200, bbox_inches='tight')
            zip_file.writestr(plot_name+".png", imgdata.getvalue())
            plt.close()
            zip_progress.value += 1.0/float(total_pics-1)
            
def create_cont_zip_content(zip_file, zip_progress):
    
    fig_y = 4
    
    min_e, max_e = energy_range_slider.value
    ie_1 = np.abs(voltages - min_e).argmin()
    ie_2 = np.abs(voltages - max_e).argmin()+1
    
    total_pics = len(elem_list)*(ie_2-ie_1)

    for i_ser in range(len(elem_list)):
        
        height = elem_list[i_ser][0].value
        cmap = elem_list[i_ser][1].value
        series_name = "hrstm_" + height
        
        for i_e in range(ie_1, ie_2):
            en = voltages[i_e]

            title = '%s, E=%.2f eV'%(height, voltages[i_e])
            data = current[:,:,heightOptions[height]]
            
            plot_name = "%s_%de%.2f" % (series_name, i_e-ie_1, en)
            imgdata = io.BytesIO()
            fig = plt.figure(figsize=(fig_y*figure_xy_ratio, fig_y))
            ax = plt.gca()
            make_plot(fig, ax, data[:, :, i_e], title=title, cmap=cmap, noadd=True)
            fig.savefig(imgdata, format='png', dpi=200, bbox_inches='tight')
            zip_file.writestr(plot_name+".png", imgdata.getvalue())
            plt.close()
            zip_progress.value += 1.0/float(total_pics-1)
    

def create_disc_zip_link(b):
    disc_zip_btn.disabled = True
    create_zip_link(create_disc_zip_content, disc_zip_progress, disc_link_out,
                    "hrstm_disc_%d.zip"%pk_select.value)

def create_cont_zip_link(b):
    cont_zip_btn.disabled = True
    e1, e2 = energy_range_slider.value
    create_zip_link(create_cont_zip_content, cont_zip_progress, cont_link_out,
                    "hrstm_cont_%d_e%.1f_%.1f.zip"% (pk_select.value, e1, e2))
    
disc_zip_btn = ipw.Button(description='Discrete zip', disabled=True)
disc_zip_btn.on_click(create_disc_zip_link)

disc_zip_progress = ipw.FloatProgress(
        value=0,
        min=0,
        max=1.0,
        description='progress:',
        bar_style='info',
        orientation='horizontal'
    )

disc_link_out = ipw.Output()
display(ipw.HBox([disc_zip_btn, disc_zip_progress]), disc_link_out)

cont_zip_btn = ipw.Button(description='Continuous zip', disabled=True)
cont_zip_btn.on_click(create_cont_zip_link)

cont_zip_progress = ipw.FloatProgress(
        value=0,
        min=0,
        max=1.0,
        description='progress:',
        bar_style='info',
        orientation='horizontal'
    )

cont_link_out = ipw.Output()
display(ipw.HBox([cont_zip_btn, cont_zip_progress]), cont_link_out)

def clear_tmp(b):
    ! rm -rf tmp && mkdir tmp
    with disc_link_out:
        clear_output()
    with cont_link_out:
        clear_output()
    disc_zip_progress.value = 0.0
    cont_zip_progress.value = 0.0
    if current is not None:
        disc_zip_btn.disabled = False
        cont_zip_btn.disabled = False
    
    
clear_tmp_btn = ipw.Button(description='clear tmp')
clear_tmp_btn.on_click(clear_tmp)
display(clear_tmp_btn)


In [None]:
# Load the URL after everything is set up.
try:
    url = urllib.parse.urlsplit(jupyter_notebook_url)
    pk_select.value = urllib.parse.parse_qs(url.query)['pk'][0]
    load_pk(0)
except:
    pass