In [None]:
%aiida

In [None]:
from aiida_cp2k.calculations import Cp2kCalculation

from aiida.orm import ArrayData
from aiida.engine import submit, run_get_node

from aiidalab_widgets_base import CodeDropdown, SubmitButtonWidget, StructureBrowserWidget
from aiidalab_widgets_base import ComputerDropdown

import ase
import ase.io
import numpy as np
import nglview
from copy import deepcopy
from pprint import pprint

import ipywidgets as ipw
from IPython.display import display, clear_output, HTML

from aiida.common.exceptions import MissingEntryPointError 
try:
    from apps.scanning_probe.stm.stm_workchain import STMWorkChain
except MissingEntryPointError as e:
    print("Entry point not found. Did you perhaps forget to set up the plugins under 'Setup codes'?")
    raise e
    

from apps.scanning_probe import common
from apps.scanning_probe.metadata_widget import MetadataWidget

from apps.scanning_probe.viewer_details import ViewerDetails

from apps.scanning_probe import analyze_structure

# Select structure

In [None]:
atoms = None
slab_analyzed = None

def on_struct_change(c):
    global atoms, slab_analyzed
    structure = struct_browser.results.value
    if structure:
        atoms = structure.get_ase()
        atoms.pbc = [1, 1, 1]
        
        slab_analyzed = analyze_structure.analyze(atoms)
        viewer_widget.setup(atoms, slab_analyzed)
        
        #cell_text.value = " ".join([str(c) for c in np.diag(atoms.cell)])
        
        if structure.creator is not None:
            text_calc_description.value = structure.creator.description

    
struct_browser = StructureBrowserWidget()
struct_browser.results.observe(on_struct_change, names='value')

viewer_widget = ViewerDetails()

display(ipw.VBox([struct_browser, viewer_widget]))

In [None]:
# Description for the calculation (try to read from the structure creator)

text_calc_description = ipw.Text(description='Description:', layout={'width': '45%'})
display(text_calc_description)

# Select computer and codes

In [None]:
computer_drop = ComputerDropdown()

def on_computer_change(c):
    global cp2k_codes, stm_codes
    if computer_drop.selected_computer is not None:
        cp2k_codes = common.comp_plugin_codes(computer_drop.selected_computer.label, 'cp2k')
        stm_codes = common.comp_plugin_codes(computer_drop.selected_computer.label, 'spm.stm')

        drop_cp2k.options = [c.label for c in cp2k_codes]
        drop_stm.options = [c.label for c in stm_codes]

    
computer_drop._dropdown.observe(on_computer_change)

drop_cp2k = ipw.Dropdown(description="Cp2k code")

drop_stm = ipw.Dropdown(description="STM code")

on_computer_change(0)

elpa_check = ipw.Checkbox(
    value=True,
    description='use ELPA',
    disabled=False
)

display(computer_drop, drop_cp2k, drop_stm, elpa_check)

# DFT parameters

In [None]:
style = {'description_width': '80px'}
layout = {'width': '70%'}

def enable_spin(b):
    for w in [spin_up_text, spin_dw_text, vis_spin_button, multiplicity_text]:
        w.disabled = not uks_switch.value

def visualize_spin_guess(b):
    spin_up = [int(v)-1 for v in spin_up_text.value.split()]
    spin_dw = [int(v)-1 for v in spin_dw_text.value.split()]
    viewer_widget.reset()
    viewer_widget.highlight_atoms(spin_up, color='red', size=0.3, opacity=0.4)
    viewer_widget.highlight_atoms(spin_dw, color='blue', size=0.3, opacity=0.4)

uks_switch = ipw.ToggleButton(value=False,
                              description='Spin-polarized calculation',
                              style=style, layout={'width': '450px'})
uks_switch.observe(enable_spin, names='value')

spin_up_text = ipw.Text(placeholder='1 2 3',
                        description='Spin up',
                        disabled=True,
                        style=style, layout={'width': '370px'})
spin_dw_text = ipw.Text(placeholder='1 2 3',
                        description='Spin down',
                        disabled=True,
                        style=style, layout={'width': '370px'})
vis_spin_button = ipw.Button(description="Visualize",
                             disabled=True,
                             style = {'description_width': '0px'}, layout={'width': '75px'})
vis_spin_button.on_click(visualize_spin_guess)

multiplicity_text = ipw.IntText(value=1,
                           description='Multiplicity',
                           disabled=True,
                           style=style, layout={'width': '20%'})


display(uks_switch, ipw.HBox([ipw.VBox([spin_up_text, spin_dw_text]), vis_spin_button]), multiplicity_text)

# Scanning tunnelling microscopy parameters

In [None]:
style = {'description_width': '140px'}
layout = {'width': '50%'}
layout_small = {'width': '25%'}

elim_float_slider = ipw.FloatRangeSlider(
    value=[-2.0, 2.0],
    min=-4.0,
    max=4.0,
    step=0.1,
    description='Emin, Emax (eV):',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
    style=style, layout=layout)

de_floattext = ipw.BoundedFloatText(
                        description='dE (eV)',
                        min=0.01,
                        max=1.00,
                        step=0.01,
                        value=0.04,
                        style=style, layout=layout_small)

fwhms_text = ipw.Text(
                  description='FWHMs (eV)',
                  value='0.08',
                  style=style, layout=layout)

extrap_plane_floattext = ipw.BoundedFloatText(
                        description='Extrap plane (ang)',
                        min=1.0,
                        max=10.0,
                        step=0.1,
                        value=5.0,
                        style=style, layout=layout_small)

const_height_text = ipw.Text(description='Const. H (ang)',
                              value='4.0 6.0',
                              style=style, layout=layout)

const_current_text = ipw.Text(description='Const. cur. (isoval)',
                              value='1e-7',
                              style=style, layout=layout)

display(elim_float_slider, de_floattext, fwhms_text, extrap_plane_floattext, const_height_text, const_current_text)

In [None]:
def on_submit(b):
    with submit_out:
        clear_output()
        if not struct_browser.results.value:
            print("Please select a structure.")
            return
        if not computer_drop.selected_computer:
            print("Please select a computer.")
            return
        
        
        dft_params_dict = {
            'mgrid_cutoff':    600,
            'elpa_switch':     elpa_check.value,
            'cell':            list(np.diag(viewer_widget.atoms.cell)),
            'uks':             uks_switch.value,
        }
        if uks_switch.value:
            dft_params_dict['spin_up_guess'] = [int(v)-1 for v in spin_up_text.value.split()]
            dft_params_dict['spin_dw_guess'] = [int(v)-1 for v in spin_dw_text.value.split()]
            dft_params_dict['multiplicity']  = multiplicity_text.value
            
        
        dft_params = Dict(dict=dft_params_dict)
        
        struct = struct_browser.results.value
        struct_ase = struct.get_ase()
        
        extrap_plane = extrap_plane_floattext.value
        max_height = max([float(h) for h in const_height_text.value.split()])
        extrap_extent = max([max_height - extrap_plane, 5.0])
        
        # Evaluation region in z
        z_min = 'n-2.0_C' if 'C' in struct_ase.symbols else 'p-4.0'
        z_max = 'p{:.1f}'.format(extrap_plane)
        
        parent_dir = "parent_calc_folder/"
        
        energy_range_str = "%.2f %.2f %.3f" % (
            elim_float_slider.value[0], elim_float_slider.value[1], de_floattext.value
        )

        stm_params = Dict(dict={
            '--cp2k_input_file':    parent_dir+'aiida.inp',
            '--basis_set_file':     parent_dir+'BASIS_MOLOPT',
            '--xyz_file':           parent_dir+'geom.xyz',
            '--wfn_file':           parent_dir+'aiida-RESTART.wfn',
            '--hartree_file':       parent_dir+'aiida-HART-v_hartree-1_0.cube',
            '--output_file':        'stm.npz',
            '--eval_region':        ['G', 'G', 'G', 'G', z_min, z_max],
            '--dx':                 '0.15',
            '--eval_cutoff':        '14.0',
            '--extrap_extent':      str(extrap_extent),
            '--energy_range':       energy_range_str.split(),
            '--heights':            const_height_text.value.split(),
            '--isovalues':          const_current_text.value.split(),
            '--fwhms':              fwhms_text.value.split(),
        })
        
        cp2k_code = cp2k_codes[drop_cp2k.index]
        stm_code = stm_codes[drop_stm.index]
        
        ## Try to access the restart-wfn file ##
        selected_comp = cp2k_code.get_remote_computer()
        try:
            wfn_file_path = common.find_struct_wf(struct, selected_comp)
        except:
            wfn_file_path = ""
        if wfn_file_path == "":
            print("Info: didn't find any accessible .wfn file.")
            
        node = submit(
            STMWorkChain,
            cp2k_code=cp2k_code,
            structure=struct,
            wfn_file_path=Str(wfn_file_path),
            dft_params=dft_params,
            stm_code=stm_code,
            stm_params=stm_params,
            metadata={'description': text_calc_description.value}
        )
        
        # set calculation version; also used to determine post-processing
        node.set_extra("version", 0)
        
        print()
        print("Submitted:")
        print(node)

btn_submit = ipw.Button(description="Submit")
btn_submit.on_click(on_submit)
submit_out = ipw.Output()
display(btn_submit, submit_out)