In [None]:
# General imports
import numpy as np
import ipywidgets as ipw
from IPython.display import display, clear_output
import urllib.parse

# AiiDA imports.
%load_ext aiida
%aiida
from aiida import common

# Local imports.

from surfaces_tools.widgets import series_plotter
from surfaces_tools.utils import spm
from surfaces_tools.helpers import HART_2_EV

In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

In [None]:
colormaps = ['gist_heat', 'seismic']

e_arr = None

def load_pk(b):
    global e_arr

    new_version = False
    workcalc = load_node(pk=pk_select.value)
    stm_calc = spm.get_calc_by_label(workcalc, 'stm')
    try:        
        cp2k_calc = spm.get_calc_by_label(workcalc, 'scf_diag')        
    except AssertionError:
        try:
            dft_out_params = workcalc.outputs.dft_output_parameters.get_dict()  
            new_version = True          
        except Exception as exc:        
            print("Incorrect pk.")
            print(exc)
            return
    
    
    geom_info.value = spm.get_slab_calc_info(workcalc.inputs.structure)
    
    # Information about the calculation.
    with misc_info:
        clear_output()
    
    dft_inp_params = dict(workcalc.inputs['dft_params'])
    if not new_version:
        dft_out_params = dict(cp2k_calc.outputs.output_parameters)
    
    with misc_info:
        if dft_inp_params['uks']:
            print(f"UKS multiplicity {dft_inp_params['multiplicity']}")
        else:
            print("RKS")
              
        print(f"Energy [au]: {dft_out_params['energy']:.6f}, [eV]: {dft_out_params['energy'] * HART_2_EV:.6f}")
    
        try:
            spm_params = workcalc.inputs.stm_params
        except common.NotExistentAttributeError:
            spm_params = workcalc.inputs.spm_params

        extrap_plane = float(spm_params['--eval_region'][-1][1:])
        print(f"Extrap. plane [ang]: {extrap_plane:.1f}")

        if '--p_tip_ratios' in dict(spm_params):
            p_tip_ratio = spm_params['--p_tip_ratios']

    
    ### Load data.
    with stm_calc.outputs.retrieved.open('stm.npz', mode='rb') as handle:
        loaded_data = np.load(handle.name, allow_pickle=True)
    stm_general_info = loaded_data['stm_general_info'][()]
    stm_series_info = loaded_data['stm_series_info']
    stm_series_data = loaded_data['stm_series_data']
    
    e_arr = stm_general_info['energies']
    
    series_plotter_inst.add_series_collection(stm_general_info, stm_series_info, stm_series_data)
    
    series_plotter_inst.setup_added_collections(workcalc.pk)
    
    setup_selection_elements()


style = {'description_width': '50px'}
layout = {'width': '70%'}
    
pk_select = ipw.IntText(value=0, description='pk', style=style, layout=layout)

load_pk_btn = ipw.Button(description='Load pk', style=style, layout=layout)
load_pk_btn.on_click(load_pk)

geom_info = ipw.HTML()

display(ipw.HBox([ipw.VBox([pk_select, load_pk_btn]), geom_info]))

misc_info = ipw.Output()
display(misc_info)

# Scanning tunneling microscopy

In [None]:
def selected_orbital_indexes():
    
    if tab.selected_index == 0:
        # Continuous selection.
        min_e, max_e = energy_range_slider.value
        ie_1 = np.abs(e_arr - min_e).argmin()
        ie_2 = np.abs(e_arr - max_e).argmin()+1
        indexes = np.arange(ie_1, ie_2)
        
    else:
        # Discrete selection.
        voltages = np.array(voltages_text.value.split(), dtype=float)
        filtered_voltages = []
        for v in voltages:
            if v >= np.min(e_arr) and v <= np.max(e_arr):
                filtered_voltages.append(v)
            else:
                print(f"Voltage {v:.2f} out of range, skipping.")
                
        indexes = []
        for i_bias, bias in enumerate(filtered_voltages):
            indexes.append(np.abs(e_arr - bias).argmin())

    return indexes

In [None]:
style = {'description_width': '120px'}
layout = {'width': '40%'}

series_plotter_inst = series_plotter.SeriesPlotter(
    select_indexes_function = selected_orbital_indexes,
    zip_prepend='stm'
)

# Select energies to plot.

energy_range_slider = ipw.FloatRangeSlider(
    value=[0.0, 0.0],
    min=0.0,
    max=0.0,
    step=0.1,
    description='energy range (eV)',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
    style=style, layout={'width': '90%'}
)

voltages_text = ipw.Text(description='energies (eV)', value='',
                        style=style, layout={'width': '90%'})

tab = ipw.Tab(layout={'width': '60%'})

tab.children = [
    energy_range_slider,
    voltages_text
]
tab.set_title(0, "Continuous selection")
tab.set_title(1, "Discrete selection")

display(series_plotter_inst.selector_widget, tab,
        series_plotter_inst.plot_btn, series_plotter_inst.clear_btn, series_plotter_inst.plot_output)

In [None]:
def setup_selection_elements():
    
    default_voltages = [-2.0, -1.5, -1.0, -0.5, -0.1, 0.1, 0.5, 1.0, 1.5, 2.0]
    
    # Filter based on energy limits.
    default_voltages = [v for v in default_voltages if v >= np.min(e_arr) and v <= np.max(e_arr)]
    voltages_text.value = " ".join([str(v) for v in default_voltages])
    
    energy_range_slider.min = np.min(e_arr)
    energy_range_slider.max = np.max(e_arr)
    energy_range_slider.step = e_arr[1] - e_arr[0]
    energy_range_slider.value = (np.min(e_arr), np.max(e_arr))

# Export
Export the currently selected series into a zip file. The raw data in plain txt and IGOR formats are included.

In [None]:
display(ipw.HBox([series_plotter_inst.zip_btn, series_plotter_inst.zip_progress]), series_plotter_inst.link_out)

In [None]:
def clear_tmp(b):
    ! rm -rf tmp && mkdir tmp
    with series_plotter_inst.link_out:
        clear_output()
    series_plotter_inst.zip_progress.value = 0.0
    
    if series_plotter_inst.series is not None:
        series_plotter_inst.zip_btn.disabled = False
    
clear_tmp_btn = ipw.Button(description='clear tmp')
clear_tmp_btn.on_click(clear_tmp)
display(clear_tmp_btn)

In [None]:
# Load the URL after everything is set up.
try:
    url = urllib.parse.urlsplit(jupyter_notebook_url)
    pk_select.value = urllib.parse.parse_qs(url.query)['pk'][0]
    load_pk(0)
except Exception as exc:
    print(exc)