In [None]:
%aiida

In [None]:
from aiida_cp2k.calculations import Cp2kCalculation

from aiida.orm import StructureData
from aiida.orm import ArrayData
from aiida.engine import submit, run_get_node

from aiidalab_widgets_base import CodeDropdown, SubmitButtonWidget, StructureBrowserWidget
from aiidalab_widgets_base import ComputerDropdown

import ase
import ase.io
import numpy as np
import nglview
from copy import deepcopy
from pprint import pprint

# AiiDAlab imports.
import aiidalab_widgets_base as awb

# Custom imports.
from widgets.empa_viewer import EmpaStructureViewer
from widgets.ANALYZE_structure import StructureAnalyzer
from aiida_nanotech_empa.workflows.cp2k import cp2k_utils

import ipywidgets as ipw
from IPython.display import display, clear_output, HTML

from aiida.common.exceptions import MissingEntryPointError 
try:
    Cp2kStmWorkChain = WorkflowFactory('nanotech_empa.cp2k.orbitals')
except MissingEntryPointError as e:
    print("Entry point not found. Did you perhaps forget to set up the plugins under 'Setup codes'?")
    raise e

from apps.scanning_probe import common

# Select structure

In [None]:
# Structure selector.

empa_viewer = EmpaStructureViewer()
structure_selector = awb.StructureManagerWidget(
    viewer=empa_viewer,
    importers=[
        awb.StructureBrowserWidget(title="AiiDA database"),
    ],
    editors = [
    ],
    storable=False, node_class='StructureData')
display(structure_selector)   

# Select computer and codes

In [None]:
# Code selector
computational_resources = awb.ComputationalResourcesWidget(description="CP2K code:", input_plugin="cp2k")

style = {'description_width': '140px'}
layout = {'width': '70%'}


def on_computer_change(c):
    global cp2k_codes, eval_morb_codes, overlap_codes
    if computational_resources.value is not None:
        overlap_codes = common.comp_plugin_codes(computational_resources.value.computer.label, 'spm.stm')
        
        drop_stm.options = [c.label for c in overlap_codes]
    
    
computational_resources.observe(on_computer_change)

drop_stm = ipw.Dropdown(description="STM code")

on_computer_change(0)

elpa_check = ipw.Checkbox(
    value=True,
    description='use ELPA',
    disabled=False
)

display(computational_resources, drop_stm, elpa_check)

# DFT parameters

In [None]:
style = {'description_width': '80px'}
layout = {'width': '70%'}

def enable_spin(b):
    for w in [spin_up_text, spin_dw_text, vis_spin_button, multiplicity_text]:
        w.disabled = not uks_switch.value

def visualize_spin_guess(b):
    spin_up = [int(v)-1 for v in spin_up_text.value.split()]
    spin_dw = [int(v)-1 for v in spin_dw_text.value.split()]
    viewer_widget.reset()
    viewer_widget.highlight_atoms(spin_up, color='red', size=0.3, opacity=0.4)
    viewer_widget.highlight_atoms(spin_dw, color='blue', size=0.3, opacity=0.4)

uks_switch = ipw.ToggleButton(value=False,
                              description='Spin-polarized calculation',
                              style=style, layout={'width': '450px'})
uks_switch.observe(enable_spin, names='value')

spin_up_text = ipw.Text(placeholder='1 2 3',
                        description='Spin up',
                        disabled=True,
                        style=style, layout={'width': '370px'})
spin_dw_text = ipw.Text(placeholder='1 2 3',
                        description='Spin down',
                        disabled=True,
                        style=style, layout={'width': '370px'})
vis_spin_button = ipw.Button(description="Visualize",
                             disabled=True,
                             style = {'description_width': '0px'}, layout={'width': '75px'})
vis_spin_button.on_click(visualize_spin_guess)

multiplicity_text = ipw.IntText(value=1,
                           description='Multiplicity',
                           disabled=True,
                           style=style, layout={'width': '20%'})
sc_diag_check = ipw.Checkbox(
    value=False,
    description='self-consistent diagonalization',
    disabled=True
)
force_multiplicity_check = ipw.Checkbox(
    value=True,
    description='Fix multiplicity',
    disabled=True
)
protocol = ipw.Dropdown( 
            value="standard",
            options=[("Standard", "standard"), ("Low accuracy", "low_accuracy")],
            description="Protocol:",
            style={"description_width": "120px"},
        )
display(uks_switch, ipw.HBox([ipw.VBox([spin_up_text, spin_dw_text]), vis_spin_button]), 
ipw.HBox([multiplicity_text,sc_diag_check,force_multiplicity_check]),protocol)

In [None]:
def enable_smearing(b):
    temperature_text.disabled = not smear_switch.value

smear_switch = ipw.ToggleButton(value=False,
                              description='Enable Fermi-Dirac smearing',
                              style=style, layout={'width': '450px'})
smear_switch.observe(enable_smearing, names='value')

temperature_text = ipw.FloatText(value=300.0,
                           description='Temperature [K]',
                           disabled=True,
                           style={'description_width': '100px'}, layout={'width': '20%'})


display(smear_switch, temperature_text)

# Scanning tunnelling microscopy parameters

In [None]:
style = {'description_width': '140px'}
layout = {'width': '50%'}
layout_small = {'width': '25%'}

elim_float_slider = ipw.FloatRangeSlider(
    value=[-2.0, 2.0],
    min=-4.0,
    max=4.0,
    step=0.1,
    description='Emin, Emax (eV):',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
    style=style, layout=layout)

de_floattext = ipw.BoundedFloatText(
                        description='dE (eV)',
                        min=0.01,
                        max=1.00,
                        step=0.01,
                        value=0.04,
                        style=style, layout=layout_small)

fwhms_text = ipw.Text(
                  description='FWHMs (eV)',
                  value='0.08',
                  style=style, layout=layout)

extrap_plane_floattext = ipw.BoundedFloatText(
                        description='Extrap plane (ang)',
                        min=1.0,
                        max=10.0,
                        step=0.1,
                        value=4.0,
                        style=style, layout=layout_small)

const_height_text = ipw.Text(description='Const. H (ang)',
                              value='4.0 6.0',
                              style=style, layout=layout)

const_current_text = ipw.Text(description='Const. cur. (isoval)',
                              value='1e-7',
                              style=style, layout=layout)

ptip_floattext = ipw.BoundedFloatText(
                        description='p tip ratio',
                        min=0.0,
                        max=1.0,
                        step=0.01,
                        value=0.0,
                        style=style, layout=layout_small)

display(elim_float_slider, de_floattext, fwhms_text, extrap_plane_floattext, const_height_text, const_current_text, ptip_floattext)

# Resources

In [None]:
# Resources estimation.
MAX_NODES=96

structure_analyzer = StructureAnalyzer()

def update_resources(_):

    if not (structure_selector.structure and computational_resources.value):
        node_estimate_message.message = """<span style="color:red"> Error:</span> Can't estimate resources: both structure and code must be selected."""
        return

   
    
    structure_analyzer.structure = substructure
    calctype = "slab" 
    resources = cp2k_utils.get_nodes(
        atoms=structure_selector.structure,
        calctype=calctype,
        computer=computational_resources.value.computer,
        max_nodes=MAX_NODES,
        uks=uks_switch.value
    )
    slab.nodes_widget.value = resources[0]
    slab.cpus_per_node_widget.value = resources[1]

estimate_nodes_button = ipw.Button(description="Estimate resources", button_style='warning')
estimate_nodes_button.on_click(update_resources)
node_estimate_message = awb.utils.StatusHTML()

In [None]:
# Resources widgets.
STYLE = {'description_width': '100px'}
nodes_widget = ipw.IntText(
    description="# Nodes",
    value=1,
    min=1,
    style=STYLE
    )
cpus_per_node_widget = ipw.IntText(
    description="# CPUs per node",
    value=1,
    min=1,
    style=STYLE
    )
num_cores_per_mpiproc_widget = ipw.IntText(
    description="# thredas",
    value=1,
    min=0,
    style=STYLE
)    
run_time_widget = ipw.IntText(
    description="Runtime (mins)",
    value=1440,
    min=0,
    style=STYLE
)

In [None]:
res=ipw.VBox([ipw.HTML("<b>System</b>"),
                nodes_widget,
                cpus_per_node_widget,
                num_cores_per_mpiproc_widget,
                run_time_widget])
display(ipw.VBox([res,estimate_nodes_button]))

# Submission

In [None]:
# Description for the calculation (try to read from the structure creator)

text_calc_description = ipw.Text(description='Description:', layout={'width': '45%'})
display(text_calc_description)

In [None]:
submit_out = ipw.Output()
def get_builder():
    with submit_out:
        clear_output()
        if structure_selector.structure is None:
            print("Please select a structure.")
            return
        if not computer_drop.selected_computer:
            print("Please select a computer.")
            return
        
        
        dft_params_dict = {
            'elpa_switch':     elpa_check.value,
            'sc_diag':           sc_diag_check.value,
            'protocol':          protocol.value,
            'force_multiplicity':          force_multiplicity_check.value,
            'periodic':         'XYZ',
            'uks':             uks_switch.value,
        }
        if uks_switch.value:
            dft_params_dict['spin_up_guess'] = [int(v)-1 for v in spin_up_text.value.split()]
            dft_params_dict['spin_dw_guess'] = [int(v)-1 for v in spin_dw_text.value.split()]
            dft_params_dict['multiplicity']  = multiplicity_text.value
            dft_params_dict['force_multiplicity']  = force_multiplicity_check.value
            
        if smear_switch.value:
            dft_params_dict['smear_t'] = temperature_text.value            
        
        dft_params = Dict(dict=dft_params_dict)
        
        struct = structure_selector.structure
        
        extrap_plane = extrap_plane_floattext.value
        max_height = max([float(h) for h in const_height_text.value.split()])
        extrap_extent = max([max_height - extrap_plane, 5.0])
        
        # Evaluation region in z
        z_min = 'n-2.0_C' if 'C' in struct.symbols else 'p-4.0'
        z_max = 'p{:.1f}'.format(extrap_plane)
        
        parent_dir = "parent_calc_folder/"
        
        energy_range_str = "%.2f %.2f %.3f" % (
            elim_float_slider.value[0], elim_float_slider.value[1], de_floattext.value
        )

        stm_params = Dict(dict={
            '--cp2k_input_file':    parent_dir+'aiida.inp',
            '--basis_set_file':     parent_dir+'BASIS_MOLOPT',
            '--xyz_file':           parent_dir+'aiida.coords.xyz',
            '--wfn_file':           parent_dir+'aiida-RESTART.wfn',
            '--hartree_file':       parent_dir+'aiida-HART-v_hartree-1_0.cube',
            '--output_file':        'stm.npz',
            '--eval_region':        ['G', 'G', 'G', 'G', z_min, z_max],
            '--dx':                 '0.15',
            '--eval_cutoff':        '16.0',
            '--extrap_extent':      str(extrap_extent),
            '--energy_range':       energy_range_str.split(),
            '--heights':            const_height_text.value.split(),
            '--isovalues':          const_current_text.value.split(),
            '--fwhms':              fwhms_text.value.split(),
            '--p_tip_ratios':       ptip_floattext.value,
        })
        
        cp2k_code = cp2k_codes[drop_cp2k.index]
        stm_code = stm_codes[drop_stm.index]
        
        ## Try to access the restart-wfn file ##
        wfn_file_path = ""
        selected_comp = cp2k_code.get_remote_computer()
        #try:
        #    wfn_file_path = common.find_struct_wf(struct, selected_comp)
        #except:
        #    wfn_file_path = ""
        #if wfn_file_path == "":
        #    print("Info: didn't find any accessible .wfn file.")
            
        node = submit(
            Cp2kStmWorkChain,
            cp2k_code=cp2k_code,
            structure=struct,
            wfn_file_path=Str(wfn_file_path),
            dft_params=dft_params,
            stm_code=stm_code,
            stm_params=stm_params,
            metadata={'description': text_calc_description.value,
                        'label': 'Cp2kStmWorkChain',
            }
        )
        
        # set calculation version; also used to determine post-processing
        #node.set_extra("version", 0)
        builder = Cp2kStmWorkChain.get_builder()
        builder.cp2k_code=cp2k_code
        builder.structure=StructureData(ase=struct)
        builder.wfn_file_path=Str(wfn_file_path)
        builder.dft_params=dft_params
        builder.stm_code=stm_code
        builder.stm_params=stm_params

        # Resources.
        builder.options = Dict(dict=
            {
                "max_wallclock_seconds": run_time_widget.value * 60,
                "resources": {
                    "num_machines": nodes_widget.value,
                    "num_mpiprocs_per_machine": cpus_per_node_widget.value,
                    "num_cores_per_mpiproc": num_cores_per_mpiproc_widget.value,
                }
            })

btn_submit = awb.SubmitButtonWidget(Cp2kStmWorkChain, input_dictionary_function=get_builder)
display(btn_submit, submit_out)