In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

In [None]:
# AiiDA imports.
%aiida

# General imports.
import ipywidgets as ipw
from IPython.display import clear_output

# AiiDAlab imports.
import aiidalab_widgets_base as awb

# Custom imports.
from widgets.build_slab import BuildSlab
from widgets.computational_resources import ProcessResourcesWidget, ResourcesEstimatorWidget
from widgets.inputs import InputDetails
from widgets.empa_viewer import EmpaStructureViewer
from widgets.ANALYZE_structure import StructureAnalyzer
from widgets.import_cdxml import CdxmlUpload2GnrWidget

from apps.scanning_probe import common

In [None]:
Cp2kStmWorkChain = WorkflowFactory('nanotech_empa.cp2k.stm')
Cp2kHrstmWorkChain = WorkflowFactory('nanotech_empa.cp2k.hrstm')
Cp2kAfmWorkChain = WorkflowFactory('nanotech_empa.cp2k.afm')
Cp2kOrbitalsWorkChain = WorkflowFactory('nanotech_empa.cp2k.orbitals')

# Select structure

In [None]:
# Structure selector.

# Structure selector.
empa_viewer = EmpaStructureViewer()
build_slab = BuildSlab(title='Build slab')
ipw.dlink((empa_viewer, 'details'), (build_slab, 'details'))
ipw.dlink((empa_viewer, 'structure'), (build_slab, 'molecule'))

structure_selector = awb.StructureManagerWidget(
    viewer=empa_viewer,
    importers=[
        awb.StructureUploadWidget(title="Import from computer"),
        awb.StructureBrowserWidget(title="AiiDA database"),
        awb.OptimadeQueryWidget(embedded=True),
        awb.SmilesWidget(title="From SMILES"),
        CdxmlUpload2GnrWidget(title="CDXML"),
    ],
    editors = [
        awb.BasicStructureEditor(title="Edit structure"),
        build_slab
    ],
    storable=False, node_class='StructureData')
display(structure_selector)

# Code.
computational_resources = awb.ComputationalResourcesWidget(input_plugin='cp2k')
resources = ProcessResourcesWidget()

# Select SPM

In [None]:
drop_spm_type = ipw.Dropdown(description="SPM type",options=['STM','AFM','HRSTM','ORBITALS'],value='STM')
display(drop_spm_type)

# DFT parameters

In [None]:
style = {'description_width': '80px'}
layout = {'width': '70%'}

def enable_spin(b):
    for w in [spin_up_text, spin_dw_text, vis_spin_button, multiplicity_text]:
        w.disabled = not uks_switch.value

def visualize_spin_guess(b):
    spin_up = [int(v)-1 for v in spin_up_text.value.split()]
    spin_dw = [int(v)-1 for v in spin_dw_text.value.split()]
    viewer_widget.reset()
    viewer_widget.highlight_atoms(spin_up, color='red', size=0.3, opacity=0.4)
    viewer_widget.highlight_atoms(spin_dw, color='blue', size=0.3, opacity=0.4)

uks_switch = ipw.ToggleButton(value=False,
                              description='Spin-polarized calculation',
                              style=style, layout={'width': '450px'})
uks_switch.observe(enable_spin, names='value')

spin_up_text = ipw.Text(placeholder='1 2 3',
                        description='Spin up',
                        disabled=True,
                        style=style, layout={'width': '370px'})
spin_dw_text = ipw.Text(placeholder='1 2 3',
                        description='Spin down',
                        disabled=True,
                        style=style, layout={'width': '370px'})
vis_spin_button = ipw.Button(description="Visualize",
                             disabled=True,
                             style = {'description_width': '0px'}, layout={'width': '75px'})
vis_spin_button.on_click(visualize_spin_guess)

charge_text = ipw.IntText(
    value=0,
    description='charge',
    disabled=True,
    style=style, layout={'width': '15%'}
)

multiplicity_text = ipw.IntText(value=1,
                           description='Multiplicity',
                           disabled=True,
                           style=style, layout={'width': '20%'})
sc_diag_check = ipw.Checkbox(
    value=False,
    description='self-consistent diagonalization',
    disabled=True
)
force_multiplicity_check = ipw.Checkbox(
    value=True,
    description='Fix multiplicity',
    disabled=True
)
protocol = ipw.Dropdown( 
            value="standard",
            options=[("Standard", "standard"), ("Low accuracy", "low_accuracy"),("Debug", "debug")],
            description="Protocol:",
            style={"description_width": "120px"},
        )

elpa_check = ipw.Checkbox(
    value=True,
    description='use ELPA',
    disabled=False
)        
display(uks_switch, ipw.HBox([ipw.VBox([spin_up_text, spin_dw_text]), vis_spin_button]), 
ipw.HBox([multiplicity_text,sc_diag_check,force_multiplicity_check]),charge_text)

In [None]:
def enable_smearing(b):
    temperature_text.disabled = not smear_switch.value

smear_switch = ipw.ToggleButton(value=False,
                              description='Enable Fermi-Dirac smearing',
                              style=style, layout={'width': '450px'})
smear_switch.observe(enable_smearing, names='value')

temperature_text = ipw.FloatText(value=300.0,
                           description='Temperature [K]',
                           disabled=True,
                           style={'description_width': '100px'}, layout={'width': '20%'})


display(smear_switch, temperature_text)

# SPM parameters

In [None]:
style = {'description_width': '140px'}
layout = {'width': '50%'}
layout_small = {'width': '25%'}

spm_parameters_box=ipw.VBox([])
# STM
elim_float_slider = ipw.FloatRangeSlider(
    value=[-2.0, 2.0],
    min=-4.0,
    max=4.0,
    step=0.1,
    description='Emin, Emax (eV):',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
    style=style, layout=layout)

de_floattext = ipw.BoundedFloatText(
                        description='dE (eV)',
                        min=0.01,
                        max=1.00,
                        step=0.01,
                        value=0.04,
                        style=style, layout=layout_small)

fwhms_text = ipw.Text(
                  description='FWHMs (eV)',
                  value='0.08',
                  style=style, layout=layout)

extrap_plane_floattext = ipw.BoundedFloatText(
                        description='Extrap plane (ang)',
                        min=1.0,
                        max=10.0,
                        step=0.1,
                        value=4.0,
                        style=style, layout=layout_small)

const_height_text = ipw.Text(description='Const. H (ang)',
                              value='4.0 6.0',
                              style=style, layout=layout)

const_current_text = ipw.Text(description='Const. cur. (isoval)',
                              value='1e-7',
                              style=style, layout=layout)

ptip_floattext = ipw.BoundedFloatText(
                        description='p tip ratio',
                        min=0.0,
                        max=1.0,
                        step=0.01,
                        value=0.0,
                        style=style, layout=layout_small)



In [None]:
# ORBITALS
n_homo_inttext = ipw.IntText(
                        description='num HOMO',
                        min=0,
                        max=100,
                        value=10,
                        style=style, layout=layout_small)
n_lumo_inttext = ipw.IntText(
                        description='num LUMO',
                        min=0,
                        max=100,
                        value=10,
                        style=style, layout=layout_small)

heights_text = ipw.Text(description='Heights (ang)',
                              value='3.0 5.0',
                              style=style, layout=layout)

isovals_text = ipw.Text(description='Isovalues',
                              value='1e-7',
                              style=style, layout=layout)

fwhms_text = ipw.Text(description='FWHMs (eV)',
                              value='0.04',
                              style=style, layout=layout)

extrap_plane_floattext = ipw.BoundedFloatText(
                        description='Extrap plane (ang)',
                        min=1.0,
                        max=10.0,
                        step=0.1,
                        value=4.0,
                        style=style, layout=layout_small)

ptip_floattext = ipw.BoundedFloatText(
                        description='p tip ratio',
                        min=0.0,
                        max=1.0,
                        step=0.01,
                        value=0.0,
                        style=style, layout=layout_small)


# Submission

In [None]:
# Description for the calculation (try to read from the structure creator)

text_calc_description = ipw.Text(description='Description:', layout={'width': '45%'})
display(text_calc_description)

In [None]:
submit_out = ipw.Output()
def get_builder():
    with submit_out:
        clear_output()
        if structure_selector.structure is None:
            print("Please select a structure.")
            return
        if computational_resources.value is  None:
            print("Please select a computer.")
            return
        
        
        dft_params_dict = {
            'elpa_switch':     elpa_check.value,
            'sc_diag':           sc_diag_check.value,
            'protocol':          protocol.value,
            'force_multiplicity':          force_multiplicity_check.value,
            'periodic':         'XYZ',
            'uks':             uks_switch.value,
        }
        if uks_switch.value:
            dft_params_dict['spin_up_guess'] = [int(v)-1 for v in spin_up_text.value.split()]
            dft_params_dict['spin_dw_guess'] = [int(v)-1 for v in spin_dw_text.value.split()]
            dft_params_dict['multiplicity']  = multiplicity_text.value
            dft_params_dict['force_multiplicity']  = force_multiplicity_check.value
            
        if smear_switch.value:
            dft_params_dict['smear_t'] = temperature_text.value            
        
        dft_params = Dict(dict=dft_params_dict)
        
        struct = structure_selector.structure
        
        extrap_plane = extrap_plane_floattext.value
        max_height = max([float(h) for h in const_height_text.value.split()])
        extrap_extent = max([max_height - extrap_plane, 5.0])
        
        # Evaluation region in z
        z_min = 'n-2.0_C' if 'C' in struct.symbols else 'p-4.0'
        z_max = 'p{:.1f}'.format(extrap_plane)
        
        parent_dir = "parent_calc_folder/"
        
        energy_range_str = "%.2f %.2f %.3f" % (
            elim_float_slider.value[0], elim_float_slider.value[1], de_floattext.value
        )

        spm_params = Dict(dict={
            '--cp2k_input_file':    parent_dir+'aiida.inp',
            '--basis_set_file':     parent_dir+'BASIS_MOLOPT',
            '--xyz_file':           parent_dir+'aiida.coords.xyz',
            '--wfn_file':           parent_dir+'aiida-RESTART.wfn',
            '--hartree_file':       parent_dir+'aiida-HART-v_hartree-1_0.cube',
            '--output_file':        'stm.npz',
            '--eval_region':        ['G', 'G', 'G', 'G', z_min, z_max],
            '--dx':                 '0.15',
            '--eval_cutoff':        '16.0',
            '--extrap_extent':      str(extrap_extent),
            '--energy_range':       energy_range_str.split(),
            '--heights':            const_height_text.value.split(),
            '--isovalues':          const_current_text.value.split(),
            '--fwhms':              fwhms_text.value.split(),
            '--p_tip_ratios':       ptip_floattext.value,
        })
        
        cp2k_code = computational_resources.value
        spm_code = spm_codes[drop_spm.index]
        
        ## Try to access the restart-wfn file ##
        wfn_file_path = ""
        #selected_comp = cp2k_code.computer
        #try:
        #    wfn_file_path = common.find_struct_wf(struct, selected_comp)
        #except:
        #    wfn_file_path = ""
        #if wfn_file_path == "":
        #    print("Info: didn't find any accessible .wfn file.")
                    
        # set calculation version; also used to determine post-processing
        #node.set_extra("version", 0)
        builder = Cp2kStmWorkChain.get_builder()
        builder.cp2k_code=cp2k_code
        builder.structure=structure_selector.structure_node
        builder.wfn_file_path=Str(wfn_file_path)
        builder.dft_params=dft_params
        builder.spm_code=spm_code
        builder.spm_params=spm_params

        # Resources.
        builder.options = {
            "max_wallclock_seconds": run_time_widget.value * 60,
            "resources": {
                "num_machines": nodes_widget.value,
                "num_mpiprocs_per_machine": cpus_per_node_widget.value,
                "num_cores_per_mpiproc": num_cores_per_mpiproc_widget.value
                }
                }
        return builder

btn_submit = awb.SubmitButtonWidget(Cp2kStmWorkChain, input_dictionary_function=get_builder)
display(btn_submit, submit_out)

In [None]:
# Code selector
computational_resources = awb.ComputationalResourcesWidget(description="CP2K code:", input_plugin="cp2k")

style = {'description_width': '140px'}
layout = {'width': '70%'}

spm_code_box=ipw.HBox([])
def on_computer_change(c):
    global  spm_codes
    if computational_resources.value is not None:            
        if drop_spm_type.value == 'STM':
            spm_codes = common.comp_plugin_codes(computational_resources.value.computer.label, 'spm.stm')
            drop_spm.options = [c.label for c in spm_codes]
            spm_parameters_box.children=[elim_float_slider, de_floattext, 
            fwhms_text, extrap_plane_floattext, const_height_text, const_current_text, ptip_floattext]
            spm_code_box.children=[drop_spm]
        elif drop_spm_type.value == 'ORBITALS':
            charge_text.disabled=False
            spm_parameters_box.children  =[n_homo_inttext, n_lumo_inttext, heights_text,
            isovals_text, fwhms_text, extrap_plane_floattext, ptip_floattext]
            spm_codes = common.comp_plugin_codes(computational_resources.value.computer.label, 'spm.stm')
            drop_spm.options = [c.label for c in spm_codes]
            spm_code_box.children=[drop_spm]
        elif drop_spm_type.value == 'AFM':
            afm_codes = common.comp_plugin_codes(computational_resources.value.computer.label, 'spm.afm')
            drop_pp.options = [(c.label, c) for c in afm_codes if "_pp" in c.label]
            drop_2pp.options = [(c.label, c) for c in afm_codes if "_2pp" in c.label]
            spm_code_box.children=[drop_pp,drop_2pp]
            charge_text.disabled=True
        else:
            spm_codes = common.comp_plugin_codes(computational_resources.value.computer.label, 'spm.hrstm')
            afm_codes = common.comp_plugin_codes(computational_resources.value.computer.label, 'spm.afm')
            drop_spm.options = [c.label for c in spm_codes]
            drop_pp.options = [(c.label, c) for c in afm_codes if "_pp" in c.label]
            spm_code_box.children=[drop_spm,drop_pp]
            charge_text.disabled=True      
    
drop_spm_type.observe(on_computer_change)    
computational_resources.observe(on_computer_change)

drop_spm = ipw.Dropdown(description="SPM code")
drop_pp = ipw.Dropdown(description="AFM PP code")
drop_2pp = ipw.Dropdown(description="AFM 2PP code")

on_computer_change(0)


In [None]:
# Resources estimation.
resources_estimation = ResourcesEstimatorWidget()
resources_estimation.link_to_resources_widget(resources)
ipw.dlink((empa_viewer, 'details'), (resources_estimation, 'details'))
ipw.dlink((uks_switch, 'value'), (resources_estimation, 'uks'))
_ = ipw.dlink((computational_resources, 'value'), (resources_estimation, 'selected_code'))

In [None]:
display(ipw.VBox([spm_parameters_box,resources, resources_estimation,computational_resources,protocol,elpa_check,spm_code_box]), submit_out)