In [None]:
from aiida import load_dbenv, is_dbenv_loaded
from aiida.backends import settings
if not is_dbenv_loaded():
    load_dbenv(profile=settings.AIIDADB_PROFILE)

from aiida.orm import load_node
from aiida.orm.querybuilder import QueryBuilder
from aiida.orm.calculation.work import WorkCalculation
from aiida.orm.calculation.job import JobCalculation

import numpy as np
import scipy.constants as const
import ipywidgets as ipw
from IPython.display import display, clear_output, HTML
import re
import gzip
import matplotlib.pyplot as plt
from collections import OrderedDict
import urlparse
import io
import zipfile
import StringIO

import matplotlib.pyplot as plt

from apps.scanning_probe import common
from apps.scanning_probe import igor

In [None]:
colormaps = ['gist_heat', 'seismic']

stm_series = None

e_arr = None
de = None
fwhm = None
isovalues = None
heights = None

extent = None
figure_xy_ratio = None

def load_pk(b):
    global stm_series
    global e_arr, de, fwhm, isovalues, heights
    global extent, figure_xy_ratio
    try:
        workcalc = load_node(pk=pk_select.value)
        stm_calc = common.get_calc_by_label(workcalc, 'stm')
    except:
        print("Incorrect pk.")
        return
    
    fwhm = float(stm_calc.inp.parameters.dict['--fwhm'])
    geom_info.value = common.get_slab_calc_info(workcalc)
    
    ### ----------------------------------------------------
    ### Load data
    loaded_data = np.load(stm_calc.out.retrieved.get_abs_path('stm.npz'))
    
    isovalues = loaded_data['isovalues']
    heights = loaded_data['heights']
    e_arr = loaded_data['e_arr']
    x_arr = loaded_data['x_arr'] * 0.529177
    y_arr = loaded_data['y_arr'] * 0.529177

    sts_cc = loaded_data['cc_sts']
    stm_cc = loaded_data['cc_stm'].astype(np.float32)
    sts_ch = loaded_data['ch_sts']
    stm_ch = loaded_data['ch_stm']
    ### ----------------------------------------------------
    ### Create series
    stm_series = {}
    
    for i_iv, iv in enumerate(isovalues):
        stm_series["cc-stm, isov=%.0e" % iv] = stm_cc[i_iv, :, :, :]
        stm_series["cc-sts, isov=%.0e" % iv] = sts_cc[i_iv, :, :, :]
    for i_h, h in enumerate(heights):
        stm_series["ch-stm, h=%.1f" % h] = stm_ch[i_h, :, :, :]
        stm_series["ch-sts, h=%.1f" % h] = sts_ch[i_h, :, :, :]

    extent = [np.min(x_arr), np.max(x_arr), np.min(y_arr), np.max(y_arr)]
    de = (np.max(e_arr) - np.min(e_arr)) / (len(e_arr)-1)

    figure_xy_ratio = (np.max(x_arr)-np.min(x_arr)) / (np.max(y_arr)-np.min(y_arr))
    
    setup_stm_elements()
    setup_stm_single_elements()
    
    disc_zip_btn.disabled = False
    cont_zip_btn.disabled = False

style = {'description_width': '50px'}
layout = {'width': '70%'}
    
pk_select = ipw.IntText(value=0, description='pk', style=style, layout=layout)

load_pk_btn = ipw.Button(description='Load pk', style=style, layout=layout)
load_pk_btn.on_click(load_pk)

geom_info = ipw.HTML()

display(ipw.HBox([ipw.VBox([pk_select, load_pk_btn]), geom_info]))

# Scanning tunneling microscopy

In [None]:
import matplotlib

class FormatScalarFormatter(matplotlib.ticker.ScalarFormatter):
    def __init__(self, fformat="%1.1f", offset=True, mathText=True):
        self.fformat = fformat
        matplotlib.ticker.ScalarFormatter.__init__(self,useOffset=offset,
                                                        useMathText=mathText)
    def _set_format(self, vmin, vmax):
        self.format = self.fformat
        if self._useMathText:
            self.format = '$%s$' % matplotlib.ticker._mathdefault(self.format)

def make_plot(fig, ax, data, title=None, title_size=None, center0=False, vmin=None, vmax=None, cmap='gist_heat', noadd=False):
    if center0:
        data_amax = np.max(np.abs(data))
        im = ax.imshow(data.T, origin='lower', cmap=cmap, interpolation='bicubic', extent=extent, vmin=-data_amax, vmax=data_amax)
    else:
        im = ax.imshow(data.T, origin='lower', cmap=cmap, interpolation='bicubic', extent=extent, vmin=vmin, vmax=vmax)
    
    if noadd:
        ax.set_xticks([])
        ax.set_yticks([])
    else:
        ax.set_xlabel(r"x ($\AA$)")
        ax.set_ylabel(r"y ($\AA$)")
        if 1e-3 < np.max(data) < 1e3:
            cb = fig.colorbar(im, ax=ax)
        else:
            cb = fig.colorbar(im, ax=ax, format=FormatScalarFormatter("%.1f"))
        cb.formatter.set_powerlimits((-2, 2))
        cb.update_ticks()
    ax.set_title(title)
    if title_size:
        ax.title.set_fontsize(title_size)
    ax.axis('scaled')

    
def make_series_plot(fig, data, voltages):
    for i_bias, bias in enumerate(voltages):
        ax = plt.subplot(1, len(voltages), i_bias+1)
        i_e = np.abs(e_arr - bias).argmin()
        make_plot(fig, ax, data[:, :, i_e], title="V=%.2f"%bias, title_size=22, cmap='gist_heat', noadd=True)

# Series

In [None]:
def remove_from_tuple(tup, index):
    tmp_list = list(tup)
    del tmp_list[index]
    return tuple(tmp_list)

def remove_line_row(b, elem_list, selections_vbox):
    rm_btn_list = [elem[3] for elem in elem_list]
    rm_index = rm_btn_list.index(b)
    del elem_list[rm_index]
    selections_vbox.children = remove_from_tuple(selections_vbox.children, rm_index)

def add_selection_row(b, elem_list, selections_vbox):

    drop_full_series = ipw.Dropdown(description="series", options=sorted(stm_series.keys()),
        style = {'description_width': 'auto'})
    drop_cmap = ipw.Dropdown(description="colormap", options=colormaps,
        style = {'description_width': 'auto'})
    norm_check = ipw.Checkbox(
        value=False,
        description='normalize',
        disabled=False,
        style = {'description_width': 'auto'},
        layout=ipw.Layout(width='auto')
    )
    rm_btn = ipw.Button(description='x', layout=ipw.Layout(width='30px'))
    rm_btn.on_click(lambda b: remove_line_row(b, elem_list, selections_vbox))
    
    elements = [drop_full_series, drop_cmap, norm_check, rm_btn]
    element_widths = ['210px', '210px', '120px', '35px']
    boxed_row = ipw.HBox([ipw.HBox([row_el], layout=ipw.Layout(border='0.1px solid', width=row_w)) for row_el, row_w in zip(elements, element_widths)])
    
    elem_list.append(elements)
    selections_vbox.children += (boxed_row, )

In [None]:
def setup_stm_elements():
    
    add_selection_row(None, elem_list, selections_vbox)
    
    default_voltages = [-2.0, -1.5, -1.0, -0.5, -0.1, 0.1, 0.5, 1.0, 1.5, 2.0]
    # filter based on energy limits
    default_voltages = [v for v in default_voltages if v >= np.min(e_arr) and v <= np.max(e_arr)]
    voltages_text.value = " ".join([str(v) for v in default_voltages])
    
    energy_range_slider.min = np.min(e_arr)
    energy_range_slider.max = np.max(e_arr)
    energy_range_slider.step = de
    energy_range_slider.value = (np.min(e_arr), np.max(e_arr))

def make_discrete_plot(): 
    voltages = np.array(voltages_text.value.split(), dtype=float)
    filtered_voltages = []
    for v in voltages:
        if v >= np.min(e_arr) and v <= np.max(e_arr):
            filtered_voltages.append(v)
        else:
            print("Voltage %.2f out of range, skipping" % v)
            
    fig_y_size = 5
    fig = plt.figure(figsize=(fig_y_size*figure_xy_ratio*len(filtered_voltages), fig_y_size*len(elem_list)))   
    for i_ser in range(len(elem_list)):
        series_label = elem_list[i_ser][0].value
        cmap = elem_list[i_ser][1].value
        # Normalization is skipped for discrete case !!!
        normalize = elem_list[i_ser][2].value
        data = stm_series[series_label]
        for i_bias, bias in enumerate(filtered_voltages):
            ax = plt.subplot(len(elem_list), len(filtered_voltages), i_ser*len(filtered_voltages) + i_bias + 1)
            i_e = np.abs(e_arr - bias).argmin()
            make_plot(fig, ax, data[:, :, i_e], title="V=%.2f"%bias, title_size=22, cmap=cmap, noadd=True)
    return fig
            

def plot_discrete_series(b):
    with discrete_output:
        fig = make_discrete_plot()
        plt.show()
        
def plot_full_series(b):
    
    fig_y = 4
    fig_y_in_px = 0.8*fig_y*matplotlib.rcParams['figure.dpi']

    num_series = len(elem_list)
    
    box_layout = ipw.Layout(overflow_x='scroll',
                    border='3px solid black',
                    width='100%',
                    height='%dpx' % (fig_y_in_px*num_series + 70),
                    display='inline-flex',
                    flex_flow='column wrap',
                    align_items='flex-start')
    
    plot_hbox = ipw.Box(layout=box_layout)
    continuous_output.children += (plot_hbox, )
    
    min_e, max_e = energy_range_slider.value
    ie_1 = np.abs(e_arr - min_e).argmin()
    ie_2 = np.abs(e_arr - max_e).argmin()+1
    
    plot_hbox.children = ()
    for i_e in range(ie_1, ie_2):
        plot_out = ipw.Output()
        plot_hbox.children += (plot_out, )
        with plot_out:
            fig = plt.figure(figsize=(fig_y*figure_xy_ratio, fig_y*num_series))
            
            for i_ser in range(len(elem_list)):
                
                series_label = elem_list[i_ser][0].value
                cmap = elem_list[i_ser][1].value
                normalize = elem_list[i_ser][2].value

                title = '%s, E=%.2f eV'%(series_label, e_arr[i_e])
                data = stm_series[series_label]
                
                ax = plt.subplot(len(elem_list), 1, i_ser+1)

                if normalize:
                    min_val = np.min(data[:, :, ie_1:ie_2])
                    max_val = np.max(data[:, :, ie_1:ie_2])
                    make_plot(fig, ax, data[:, :, i_e], title=title,
                              vmin=min_val, vmax=max_val, cmap=cmap, noadd=True)
                else:
                    make_plot(fig, ax, data[:, :, i_e], title=title, cmap=cmap, noadd=True)
                    
            plt.show()
            
def on_full_clear(b):
    continuous_output.children = ()
    with discrete_output:
        clear_output()
    

In [None]:
elem_list = []
selections_vbox = ipw.VBox([])

add_row_btn = ipw.Button(description='Add series row')
add_row_btn.on_click(lambda b: add_selection_row(b, elem_list, selections_vbox))

style = {'description_width': '80px'}
layout = {'width': '40%'}

### -----------------------------------------------
### Plot discrete

disc_plot_btn = ipw.Button(description='plot discrete')
disc_plot_btn.on_click(plot_discrete_series)

voltages_text = ipw.Text(description='voltages (V)', value='',
                        style=style, layout={'width': '80%'})

disc_plot_hbox = ipw.HBox([voltages_text, disc_plot_btn],
                         style=style, layout={'width': '60%'})

discrete_output = ipw.Output()
### -----------------------------------------------
### Plot continuous

cont_plot_btn = ipw.Button(description='plot continuous')
cont_plot_btn.on_click(plot_full_series)

energy_range_slider = ipw.FloatRangeSlider(
    value=[0.0, 0.0],
    min=0.0,
    max=0.0,
    step=0.1,
    description='energy range',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
    style=style, layout={'width': '80%'}
)

cont_plot_hbox = ipw.HBox([energy_range_slider, cont_plot_btn],
                        style=style, layout={'width': '60%'})

continuous_output = ipw.VBox()
### -----------------------------------------------


full_clear_btn = ipw.Button(description='clear')
full_clear_btn.on_click(on_full_clear)


display(add_row_btn, selections_vbox, disc_plot_hbox, cont_plot_hbox, full_clear_btn, discrete_output, continuous_output)

## Single

In [None]:
def setup_stm_single_elements():
    drop_stm_series_singl.options=sorted(stm_series.keys())
    
    voltage_slider.min = np.min(e_arr)
    voltage_slider.max = np.max(e_arr)
    voltage_slider.step = de
    voltage_slider.value = np.min(e_arr)

def make_single_plot(voltage, data_label, cmap):
    title = data_label + ", v=%.1f"%voltage
    data = stm_series[data_label]
    i_e = np.abs(e_arr - voltage).argmin()
    fig_y_size = 6
    fig = plt.figure(figsize=(fig_y_size*figure_xy_ratio+1.0, fig_y_size))
    ax = plt.gca()
    make_plot(fig, ax, data[:, :, i_e],title=title, cmap=cmap)
    return fig

def plot_stm(c):
    if drop_stm_series_singl.value != None: 
        with stm_plot_out:
            clear_output()
            cmap = drop_singl_cmap.value
            fig = make_single_plot(voltage_slider.value, drop_stm_series_singl.value, cmap)
            plt.show()

drop_stm_series_singl = ipw.Dropdown(description="series", options=[])
drop_singl_cmap = ipw.Dropdown(description="colormap", options=colormaps)

voltage_slider = ipw.FloatSlider(
    value=0.0,
    min=0.0,
    max=0.0,
    step=0.1,
    description='voltage (V)',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
)

single_plot_btn = ipw.Button(description='plot')
single_plot_btn.on_click(plot_stm)

stm_plot_out = ipw.Output()

display(drop_stm_series_singl, drop_singl_cmap, voltage_slider, single_plot_btn, stm_plot_out)

# Export
Export either the currently selected discrete or continuous series. The raw data is included. In the discrete case, each plot is also exported individually and they are additionally exported in IGOR format.

In [None]:
def create_zip_link(figure_method, zip_progress, html_link_out, filename):
    
    zip_buffer = io.BytesIO()
    with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED, False) as zip_file:
        figure_method(zip_file, zip_progress)
    
    ! mkdir -p tmp
    
    with open('tmp/'+filename, 'wb') as f:
        f.write(zip_buffer.getvalue())
    
    with html_link_out:
        display(HTML('<a href="tmp/%s" target="_blank">download zip</a>' % filename))
    

def create_disc_zip_content(zip_file, zip_progress):
    voltages = np.array(voltages_text.value.split(), dtype=float)
    for i_v in range(len(voltages)-1, -1):
        if voltages[i_v] < np.min(e_arr) or voltages[i_v] > np.max(e_arr):
            del voltages[i_v]
    
    total_pics = len(voltages)*len(elem_list) + 1
    
     # the total image
    imgdata = StringIO.StringIO()
    fig = make_discrete_plot()
    fig.savefig(imgdata, format='png', dpi=200, bbox_inches='tight')
    zip_file.writestr("all.png", imgdata.getvalue())
    plt.close()
    zip_progress.value += 1.0/float(total_pics-1)

    # individuals
    for i_s in range(len(elem_list)):
        series_label = elem_list[i_s][0].value
        cmap = elem_list[i_s][1].value
        series_name = series_label.lower().replace(" ", '_').replace("=", '')
        for i_v in range(len(voltages)):
            bias = voltages[i_v]
            plot_name = series_name + "_%dv%+.2f" % (i_v, bias)
            imgdata = StringIO.StringIO()
            fig = make_single_plot(bias, series_label, cmap)
            fig.savefig(imgdata, format='png', dpi=200, bbox_inches='tight')
            zip_file.writestr(plot_name+".png", imgdata.getvalue())
            plt.close()
            
            # ---------------------------------------------------
            # Add raw data to the zip
            header = "xlim=(%.2f, %.2f), ylim=(%.2f, %.2f)" % (extent[0], extent[1],
                                                               extent[2], extent[3])
            txtdata = StringIO.StringIO()
            i_e = np.abs(e_arr - bias).argmin()
            np.savetxt(txtdata, stm_series[series_label][:, :, i_e], header=header, fmt="%.3e")
            zip_file.writestr("txt/"+plot_name+".txt", txtdata.getvalue())
            # ---------------------------------------------------
            
            # ---------------------------------------------------
            # Add IGOR format to zip
            igorwave = igor.Wave2d(
                    data=stm_series[series_label][:, :, i_e],
                    xmin=extent[0],
                    xmax=extent[1],
                    xlabel='x [Angstroms]',
                    ymin=extent[2],
                    ymax=extent[3],
                    ylabel='y [Angstroms]',
            )
            zip_file.writestr("itx/"+plot_name+".itx", str(igorwave))
            # ---------------------------------------------------

            zip_progress.value += 1.0/float(total_pics-1)
            
def create_cont_zip_content(zip_file, zip_progress):
    
    fig_y = 4
    
    min_e, max_e = energy_range_slider.value
    ie_1 = np.abs(e_arr - min_e).argmin()
    ie_2 = np.abs(e_arr - max_e).argmin()+1
    
    total_pics = len(elem_list)*(ie_2-ie_1)

    for i_ser in range(len(elem_list)):
        
        series_label = elem_list[i_ser][0].value
        cmap = elem_list[i_ser][1].value
        normalize = elem_list[i_ser][2].value
        series_name = series_label.lower().replace(" ", '_').replace("=", '')
        
        for i_e in range(ie_1, ie_2):
            en = e_arr[i_e]

            title = '%s, E=%.2f eV'%(series_label, e_arr[i_e])
            data = stm_series[series_label]
            
            plot_name = "%s_%de%.2f" % (series_name, i_e-ie_1, en)
            imgdata = StringIO.StringIO()
            fig = plt.figure(figsize=(fig_y*figure_xy_ratio, fig_y))
            ax = plt.gca()
            if normalize:
                min_val = np.min(data[:, :, ie_1:ie_2])
                max_val = np.max(data[:, :, ie_1:ie_2])
                make_plot(fig, ax, data[:, :, i_e], title=title,
                          vmin=min_val, vmax=max_val, cmap=cmap, noadd=True)
            else:
                make_plot(fig, ax, data[:, :, i_e], title=title, cmap=cmap, noadd=True)
            fig.savefig(imgdata, format='png', dpi=200, bbox_inches='tight')
            zip_file.writestr(plot_name+".png", imgdata.getvalue())
            plt.close()
            
            # ---------------------------------------------------
            # Add raw data to the zip
            header = "xlim=(%.2f, %.2f), ylim=(%.2f, %.2f)" % (extent[0], extent[1],
                                                               extent[2], extent[3])
            txtdata = StringIO.StringIO()
            np.savetxt(txtdata, data[:, :, i_e], header=header, fmt="%.3e")
            zip_file.writestr("txt/"+plot_name+".txt", txtdata.getvalue())
            # ---------------------------------------------------

            zip_progress.value += 1.0/float(total_pics-1)
    

def create_disc_zip_link(b):
    disc_zip_btn.disabled = True
    create_zip_link(create_disc_zip_content, disc_zip_progress, disc_link_out,
                    "stm_disc_%d.zip"%pk_select.value)

def create_cont_zip_link(b):
    cont_zip_btn.disabled = True
    e1, e2 = energy_range_slider.value
    create_zip_link(create_cont_zip_content, cont_zip_progress, cont_link_out,
                    "stm_cont_%d_e%.1f_%.1f.zip"% (pk_select.value, e1, e2))
    
disc_zip_btn = ipw.Button(description='Discrete zip', disabled=True)
disc_zip_btn.on_click(create_disc_zip_link)

disc_zip_progress = ipw.FloatProgress(
        value=0,
        min=0,
        max=1.0,
        description='progress:',
        bar_style='info',
        orientation='horizontal'
    )

disc_link_out = ipw.Output()
display(ipw.HBox([disc_zip_btn, disc_zip_progress]), disc_link_out)

cont_zip_btn = ipw.Button(description='Continuous zip', disabled=True)
cont_zip_btn.on_click(create_cont_zip_link)

cont_zip_progress = ipw.FloatProgress(
        value=0,
        min=0,
        max=1.0,
        description='progress:',
        bar_style='info',
        orientation='horizontal'
    )

cont_link_out = ipw.Output()
display(ipw.HBox([cont_zip_btn, cont_zip_progress]), cont_link_out)

def clear_tmp(b):
    ! rm -rf tmp && mkdir tmp
    with disc_link_out:
        clear_output()
    with cont_link_out:
        clear_output()
    disc_zip_progress.value = 0.0
    cont_zip_progress.value = 0.0
    if stm_series is not None:
        disc_zip_btn.disabled = False
        cont_zip_btn.disabled = False
    
    
clear_tmp_btn = ipw.Button(description='clear tmp')
clear_tmp_btn.on_click(clear_tmp)
display(clear_tmp_btn)


In [None]:
### Load the URL after everything is set up ###
try:
    url = urlparse.urlsplit(jupyter_notebook_url)
    pk_select.value = urlparse.parse_qs(url.query)['pk'][0]
    load_pk(0)
except:
    pass