In [None]:
# General imports
import numpy as np
import ipywidgets as ipw
from IPython.display import display, clear_output, HTML
import re
import urllib.parse
import zipfile

# AiiDA imports.
%load_ext aiida
%aiida
from aiida import common

# Local imports.
from surfaces_tools.utils import spm
from surfaces_tools.widgets import series_plotter
from surfaces_tools.helpers import HART_2_EV

In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

In [None]:
local_ref_index = None
cp2k_calc = None
orb_calc = None

def load_pk(b):
    global cp2k_calc, orb_calc, new_version , workcalc
    global local_ref_index

    new_version = False
    workcalc = load_node(pk=pk_select.value)
    orb_calc = spm.get_calc_by_label(workcalc, 'orb')
    try:        
        cp2k_calc = spm.get_calc_by_label(workcalc, 'scf_diag')        
    except AssertionError:
        try:
            dft_out_params = workcalc.outputs.dft_output_parameters.get_dict()  
            new_version = True          
        except:        
            print("Incorrect pk.")
            return
    
    geom_info.value = spm.get_slab_calc_info(workcalc.inputs.structure)

    try:
        spm_params = workcalc.inputs.stm_params
    except common.NotExistentAttributeError:
        spm_params = workcalc.inputs.spm_params
    
    n_homo_inttext.value = max([int(spm_params['--n_homo']) - 2, 1])
    n_lumo_inttext.value = max([int(spm_params['--n_lumo']) - 2, 1])
    
    # Information about the calculation.
    with misc_info:
        clear_output()
    
    dft_inp_params = dict(workcalc.inputs['dft_params'])
    if not new_version:
        dft_out_params = dict(cp2k_calc.outputs.output_parameters)
    
    with misc_info:
        if dft_inp_params['uks']:
            print("UKS multiplicity %d" % dft_inp_params['multiplicity'])
        else:
            print("RKS")
        
        if 'charge' in dft_inp_params:
            print('Charge %d' % dft_inp_params['charge'])
        else:
            print('Charge 0')
        
        if 'init_nel_spin1' in dft_out_params:
            print("Number of alpha (s0) electrons: %d" % dft_out_params['init_nel_spin1'])
            print("Number of beta (s1) electrons:  %d" % dft_out_params['init_nel_spin2'])
        
        print("Energy [au]: %.6f" % (dft_out_params['energy']))
        print("Energy [eV]: %.6f" % (dft_out_params['energy'] * HART_2_EV))
        
        if '--p_tip_ratios' in dict(spm_params):
            p_tip_ratio = spm_params['--p_tip_ratios']
            print("Tip p-wave contrib: %.2f" % p_tip_ratio)
        
    # Ionization potential, if it's there.
    with orb_calc.outputs.retrieved.open('_scheduler-stdout.txt') as std_out_file:
        std_out = std_out_file.read()
    matches = re.findall("IONIZATION POTENIAL \(eV\): ([\d\.\d]+)", std_out)
    if len(matches) > 0:
        with misc_info:
            print("Ionization potential: %.4f eV" % float(matches[0]))
    
    # Load data.
    with orb_calc.outputs.retrieved.open('orb.npz',mode='rb') as npz_handle:
        loaded_data = np.load(npz_handle.name, allow_pickle=True)

    s0_orb_general_info = loaded_data['s0_orb_general_info'][()]
    s0_orb_series_info = loaded_data['s0_orb_series_info']
    s0_orb_series_data = loaded_data['s0_orb_series_data']
    
    series_plotter_inst.add_series_collection(s0_orb_general_info, s0_orb_series_info, s0_orb_series_data)
    
    ref_index = s0_orb_general_info['homo']
    
    if 's1_orb_general_info' in loaded_data.files:
        
        s1_orb_general_info = loaded_data['s1_orb_general_info'][()]
        s1_orb_series_info = loaded_data['s1_orb_series_info']
        s1_orb_series_data = loaded_data['s1_orb_series_data']
        
        series_plotter_inst.add_series_collection(s1_orb_general_info, s1_orb_series_info, s1_orb_series_data)
        
        ref_index = int(0.5 * (ref_index + s1_orb_general_info['homo']))
    
    series_plotter_inst.setup_added_collections(workcalc.pk)
    
    wfn_kit_button.disabled = False
    
    local_ref_index = np.where(s0_orb_general_info['orb_indexes'] == ref_index)
    local_ref_index = local_ref_index[0][0]
    

style = {'description_width': '50px'}
layout = {'width': '70%'}
    
pk_select = ipw.IntText(value=0, description='pk', style=style, layout=layout)

load_pk_btn = ipw.Button(description='Load pk', style=style, layout=layout)
load_pk_btn.on_click(load_pk)

geom_info = ipw.HTML()

display(ipw.HBox([ipw.VBox([pk_select, load_pk_btn]), geom_info]))

misc_info = ipw.Output()
display(misc_info)

# Orbital images

In [None]:
def selected_orbital_indexes():
    n_homo = n_homo_inttext.value
    n_lumo = n_lumo_inttext.value
    
    i_start = local_ref_index - n_homo + 1
    i_start = 0 if i_start < 0 else i_start
    
    i_end = local_ref_index + n_lumo + 1
    i_end = 0 if i_end < 0 else i_end
    
    return np.arange(i_start, i_end)

In [None]:
style = {'description_width': '80px'}
layout = {'width': '40%'}

series_plotter_inst = series_plotter.SeriesPlotter(
    select_indexes_function = selected_orbital_indexes,
    zip_prepend = 'orbs'
)

n_homo_inttext = ipw.IntText(
                        description='num HOMO',
                        min=0,
                        max=100,
                        value=10,
                        style=style, layout=layout)
n_lumo_inttext = ipw.IntText(
                        description='num LUMO',
                        min=0,
                        max=100,
                        value=10,
                        style=style, layout=layout)

n_orb_select = ipw.HBox([n_homo_inttext, n_lumo_inttext],
                        style=style, layout={'width': '60%'})


display(series_plotter_inst.selector_widget, n_orb_select,
        series_plotter_inst.plot_btn, series_plotter_inst.clear_btn, series_plotter_inst.plot_output)

# Export
**Image zip** exports the currently selected orbital images in png, txt and IGOR pro formats.

**Cube creation kit** creates an archive containing all necessary ingredients to generate the Kohn-Sham orbital cube files with the `cube_from_wfn.py` script available from https://github.com/nanotech-empa/cp2k-spm-tools.

In [None]:
display(ipw.HBox([series_plotter_inst.zip_btn, series_plotter_inst.zip_progress]), series_plotter_inst.link_out)

In [None]:
def create_wfn_zip(b):
    wfn_kit_button.disabled=True
    ! mkdir -p tmp
    label = "cube-kit-pk%d" % int(pk_select.value)
    cube_kit_name = label + ".zip"
    zipf = zipfile.ZipFile('tmp/%s'%cube_kit_name, 'w', zipfile.ZIP_DEFLATED)
    if new_version:
        fd = workcalc.outputs['retrieved']
    else:
        fd = cp2k_calc.outputs['retrieved']
    for fn in ['BASIS_MOLOPT', 'aiida.inp', 'aiida.out',  'aiida.coords.xyz', 'aiida-RESTART.wfn']:
        zipf.write(fd.open(fn).name, arcname=label + '/' + fn)
    
    run_script_path = "/home/aiida/apps/scanning_probe/orb/misc/run_cube_from_wfn.sh"
    zipf.write(run_script_path, arcname=label + '/' +"run_cube_from_wfn.sh")
    zipf.close()
    with wfn_kit_output:
        display(HTML('<a href="tmp/%s" target="_blank">download zip</a>' %cube_kit_name))
        
wfn_kit_button = ipw.Button(description='Cube creation kit', disabled=True)
wfn_kit_button.on_click(create_wfn_zip)

wfn_kit_output = ipw.Output()

display(wfn_kit_button, wfn_kit_output)

In [None]:
def clear_tmp(b):
    ! rm -rf tmp && mkdir tmp
    with series_plotter_inst.link_out:
        clear_output()
    series_plotter_inst.zip_progress.value = 0.0
    
    with wfn_kit_output:
        clear_output()
        
    if series_plotter_inst.series is not None:
        series_plotter_inst.zip_btn.disabled = False
        wfn_kit_button.disabled = False
    
clear_tmp_btn = ipw.Button(description='clear tmp')
clear_tmp_btn.on_click(clear_tmp)
display(clear_tmp_btn)

In [None]:
# Load the URL after everything is set up.
try:
    url = urllib.parse.urlsplit(jupyter_notebook_url)
    pk_select.value = urllib.parse.parse_qs(url.query)['pk'][0]
    load_pk(0)
except:
    pass