In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

In [None]:
# AiiDA imports.
%load_ext aiida
%aiida
from aiida import orm, plugins

# General imports.
import ipywidgets as ipw
from IPython.display import clear_output

# AiiDAlab imports.
import aiidalab_widgets_base as awb

# Custom imports.
from surfaces_tools.widgets.computational_resources import ProcessResourcesWidget, ResourcesEstimatorWidget
from surfaces_tools.widgets.empa_viewer import EmpaStructureViewer
from surfaces_tools.utils import spm, wfn

In [None]:
Cp2kStmWorkChain = plugins.WorkflowFactory('nanotech_empa.cp2k.stm')
Cp2kHrstmWorkChain = plugins.WorkflowFactory('nanotech_empa.cp2k.hrstm')
Cp2kAfmWorkChain = plugins.WorkflowFactory('nanotech_empa.cp2k.afm')
Cp2kOrbitalsWorkChain = plugins.WorkflowFactory('nanotech_empa.cp2k.orbitals')

# Select structure

In [None]:
# Structure selector.

empa_viewer = EmpaStructureViewer()
structure_selector = awb.StructureManagerWidget(
    viewer=empa_viewer,
    importers=[awb.StructureUploadWidget(title="Import from computer"),
        awb.StructureBrowserWidget(title="AiiDA database"),
    ],
    editors = [
    ],
    storable=False, node_class='StructureData')
display(structure_selector)    

# Code.
resources = ProcessResourcesWidget()

# Select SPM

In [None]:
drop_spm_type = ipw.Dropdown(description="SPM type",options=['STM','AFM','HRSTM','ORBITALS'],value='STM')

# DFT parameters

In [None]:
style = {'description_width': '80px'}
layout = {'width': '70%'}

def enable_spin(b):
    for w in [spin_up_text, spin_dw_text,  multiplicity_text]:
        w.disabled = not uks_switch.value

uks_switch = ipw.ToggleButton(value=False,
                              description='Spin-polarized calculation',
                              style=style, layout={'width': '450px'})
uks_switch.observe(enable_spin, names='value')

spin_up_text = ipw.Text(placeholder='1 2 3',
                        description='Spin up',
                        disabled=True,
                        style=style, layout={'width': '370px'})
spin_dw_text = ipw.Text(placeholder='1 2 3',
                        description='Spin down',
                        disabled=True,
                        style=style, layout={'width': '370px'})

charge_text = ipw.IntText(
    value=0,
    description='charge',
    disabled=True,
    style=style, layout={'width': '15%'}
)

multiplicity_text = ipw.IntText(value=1,
                           description='Multiplicity',
                           disabled=True,
                           style=style, layout={'width': '20%'})
sc_diag_check = ipw.Checkbox(
    value=False,
    description='self-consistent diagonalization',
    disabled=True
)
force_multiplicity_check = ipw.Checkbox(
    value=True,
    description='Fix multiplicity',
    disabled=True
)
protocol = ipw.Dropdown( 
            value="standard",
            options=[("Standard", "standard"), ("Low accuracy", "low_accuracy"),("Debug", "debug")],
            description="Protocol:",
            style={"description_width": "120px"},
        )

elpa_check = ipw.Checkbox(
    value=True,
    description='use ELPA',
    disabled=False
)        
display(uks_switch, ipw.HBox([ipw.VBox([spin_up_text, spin_dw_text])]), 
ipw.HBox([multiplicity_text,sc_diag_check]),charge_text)

In [None]:
# Codes dropdown.
stm_code = awb.ComputationalResourcesWidget(description="STM code:", default_calc_job_plugin="nanotech_empa.stm")
hrstm_code = awb.ComputationalResourcesWidget(description="HRSTM code:", default_calc_job_plugin="nanotech_empa.hrstm")
ppafm_code = awb.ComputationalResourcesWidget(description="PPAFM code:", default_calc_job_plugin="nanotech_empa.afm")
twopp_code = awb.ComputationalResourcesWidget(description="2PP code:", default_calc_job_plugin="nanotech_empa.afm")

cp2k_code = awb.ComputationalResourcesWidget(description="CP2K code:", default_calc_job_plugin="cp2k")


In [None]:
def enable_smearing(b):
    temperature_text.disabled = not smear_switch.value
    force_multiplicity_check.disabled = not smear_switch.value

smear_switch = ipw.ToggleButton(value=False,
                              description='Enable Fermi-Dirac smearing',
                              style=style, layout={'width': '450px'})
smear_switch.observe(enable_smearing, names='value')

temperature_text = ipw.FloatText(value=150.0,
                           description='Temperature [K]',
                           disabled=True,
                           style={'description_width': '100px'}, layout={'width': '20%'})


display(smear_switch, ipw.HBox([temperature_text,force_multiplicity_check]))

# SPM parameters

In [None]:
spm_parameters_box = ipw.VBox([])
display(ipw.VBox([drop_spm_type,spm_parameters_box]))

In [None]:
style = {'description_width': 'initial'} #'140px'
layout = {'width': '50%'}
#style = {'description_width': '140px'}
#layout = {'width': '70%'} #'60%'}
layout_small = {'width': '25%'}

# STM
elim_float_slider = ipw.FloatRangeSlider(
    value=[-2.0, 2.0],
    min=-4.0,
    max=4.0,
    step=0.1,
    description='Emin, Emax (eV):',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
    style=style, layout=layout)

de_floattext = ipw.BoundedFloatText(
                        description='dE (eV)',
                        min=0.01,
                        max=1.00,
                        step=0.01,
                        value=0.04,
                        style=style, layout=layout_small)

fwhms_text = ipw.Text(
                  description='FWHMs (eV)',
                  value='0.08',
                  style=style, layout=layout)

extrap_plane_floattext = ipw.BoundedFloatText(
                        description='Extrap plane (ang)',
                        min=1.0,
                        max=10.0,
                        step=0.1,
                        value=4.0,
                        style=style, layout=layout_small)

height_text = ipw.Text(description='Const. H (ang)',
                              value='4.0 6.0',
                              style=style, layout=layout)

const_current_text = ipw.Text(description='Const. cur. (isoval)',
                              value='1e-7',
                              style=style, layout=layout)

ptip_floattext = ipw.BoundedFloatText(
                        description='p tip ratio',
                        min=0.0,
                        max=1.0,
                        step=0.01,
                        value=0.0,
                        style=style, layout=layout_small)



In [None]:
# ORBITALS 
n_homo_inttext = ipw.IntText(
                        description='num HOMO',
                        min=0,
                        max=100,
                        value=10,
                        style=style, layout=layout_small)
n_lumo_inttext = ipw.IntText(
                        description='num LUMO',
                        min=0,
                        max=100,
                        value=10,
                        style=style, layout=layout_small)

In [None]:
# AFM
scanstep_floattxt = ipw.BoundedFloatText(
                        description='Scan d\(\mathbf{x}\) (Å)',
                        min=0.05,
                        max=0.5,
                        step=0.05,
                        value=0.1,
                        style=style, layout=layout_small)

scanminz_floattxt = ipw.BoundedFloatText(
                        description='Scan \(z_{min}\) (Å)',
                        min=0.0,
                        max=5.0,
                        step=0.1,
                        value=3.5,
                        style=style, layout=layout_small)

scanmaxz_floattxt = ipw.BoundedFloatText(
                        description='Scan \(z_{max}\) (Å)',
                        min=5.0,
                        max=10.0,
                        step=0.1,
                        value=8.5,
                        style=style, layout=layout_small)

amp_floattxt = ipw.FloatText(
                        description='Amplitude (Å)',
                        step=0.1,
                        value=1.4,
                        style=style, layout=layout_small)

f0_cantilever_floattxt = ipw.FloatText(
                        description='Cantilever (\(f_0\)) ',
                        step=0.1,
                        value=22352.5,
                        style=style, layout=layout_small)

drop_2pp_resp = ipw.Dropdown(description="2PP RESP model",
                            style=style, layout=layout_small,
                            options = { # ChargeCuUp, ChargeCuDown, Ccharge, Ocharge
                                'pentacene': [-0.0669933, -0.0627402, 0.212718, -0.11767],
                                'ptcda':     [     -0.05,      -0.07,     0.23,    -0.13]
                            }) 
                                                   

In [None]:
# HRSTM
voltext_ipw = ipw.Label(value="Voltage Range:",
                        style=style, layout=layout_small)
volstep_ipw = ipw.BoundedFloatText(description="Step",
                                   value=0.1, min=0.01, max=0.5, step=0.01,
                                   style=style, layout=layout_small)
volmin_ipw = ipw.FloatText(description="Min",
                           value=-0.3,
                           style=style, layout=layout_small)
volmax_ipw = ipw.FloatText(description="Max",
                           value=0.3,
                           style=style, layout=layout_small)

fwhm_ipw = ipw.BoundedFloatText(description="FWHM for DOS of Sample (\(e\)V)",
                                value=0.05, min=0.01, max=0.5, step=0.01,
                                style=style, layout=layout)
workfun_ipw = ipw.BoundedFloatText(description="Workfunction of Sample (\(e\)V)",
                                  value=5.0, min=1.0, max=10.0, step=0.1,
                                  style=style, layout=layout)
wfnstep_ipw = ipw.BoundedFloatText(description="Meshwidth for Grid Orbitals d\(\mathbf{x}\) (Å)",
                                   value=0.2, min=0.05, max=1.0, step=0.05,
                                   style=style, layout=layout)
extrap_ipw = ipw.BoundedFloatText(description="Extrapolation Plane (Å)",
                                  value=4.0, min=1.0, max=10.0, step=0.1,
                                  style=style, layout=layout)
# Tip stuff
tiptype_ipw = ipw.ToggleButtons(description="Tip Type",
                                value='blunt', options=['parametrized', 'blunt'],
                                style=style, layout=layout)
rotate_ipw = ipw.Checkbox(description="Rotate Tip Coefficients",
                          value=True,
                          style=style, layout=layout)
orbstip_ipw = ipw.BoundedIntText(description="Maximal Tip Orbital",
                                 value=1, min=0, max=1, step=1,
                                 style=style, layout=layout)
fwhmtip_ipw = ipw.BoundedFloatText(description="FWHM for DOS of Tip (\(e\)V)",
                                value=0.00, min=0.00, max=1.0, step=0.01,
                                disabled=(tiptype_ipw.value=='parametrized'),
                                style=style, layout=layout)
## Parametrized tip info
stip_ipw = ipw.BoundedFloatText(description="\(s\)-Value",
                                value=0.15, min=0.0, max=1.0, step=0.01,
                                disabled=(not tiptype_ipw.value=='parametrized'),
                                style=style, layout=layout)
pytip_ipw = ipw.BoundedFloatText(description="\(p_y\)-Value",
                                value=0.5, min=0.0, max=1.0, step=0.01,
                                disabled=(not tiptype_ipw.value=='parametrized'),
                                style=style, layout=layout)
pztip_ipw = ipw.BoundedFloatText(description="\(p_z\)-Value",
                                value=0.0, min=0.0, max=1.0, step=0.01,
                                disabled=(not tiptype_ipw.value=='parametrized'),
                                style=style, layout=layout)
pxtip_ipw = ipw.BoundedFloatText(description="\(p_x\)-Value",
                                value=0.5, min=0.0, max=1.0, step=0.01,
                                disabled=(not tiptype_ipw.value=='parametrized'),
                                style=style, layout=layout)
para_list = [stip_ipw, pytip_ipw, pztip_ipw, pxtip_ipw, fwhmtip_ipw]
def para_values(value):
    if value=='parametrized':
        stip_ipw.disabled = False
        pytip_ipw.disabled = False
        pztip_ipw.disabled = False
        pxtip_ipw.disabled = False
        fwhmtip_ipw.disabled = True
        # Always using p-orbitals for parametrized tip
        orbstip_ipw.value = 1
        orbstip_ipw.disabled = True
    else:
        stip_ipw.disabled = True
        pytip_ipw.disabled = True
        pztip_ipw.disabled = True
        pxtip_ipw.disabled = True
        fwhmtip_ipw.disabled = False
        orbstip_ipw.disabled = False
para_ipw = ipw.interactive(para_values, value=tiptype_ipw)

# Submission

In [None]:
# Description for the calculation (try to read from the structure creator)
text_calc_description = ipw.Text(description='Description:', layout={'width': '45%'})
display(text_calc_description)

In [None]:
submit_out = ipw.Output()
def get_builder():
    with submit_out:
        clear_output()
        if structure_selector.structure is None:
            print("Please select a structure.")
            return
        if cp2k_code.value is  None:
            print("Please select CP2K code firstß.")
            return
        
        
        dft_params_dict = {
            'elpa_switch':     elpa_check.value,
            'sc_diag':           sc_diag_check.value,
            'protocol':          protocol.value,            
            'force_multiplicity':          force_multiplicity_check.value,
            'periodic':         'XYZ',
            'uks':             uks_switch.value,
            'charge':            charge_text.value,
        }
        if uks_switch.value:
            dft_params_dict['spin_up_guess'] = [int(v)-1 for v in spin_up_text.value.split()]
            dft_params_dict['spin_dw_guess'] = [int(v)-1 for v in spin_dw_text.value.split()]
            dft_params_dict['multiplicity']  = multiplicity_text.value
            dft_params_dict['force_multiplicity']  = force_multiplicity_check.value
            
        if smear_switch.value:
            dft_params_dict['smear_t'] = temperature_text.value            
        
        dft_params = orm.Dict(dict=dft_params_dict)
        
        struct = structure_selector.structure
                
        if drop_spm_type.value == 'STM':
            builder = Cp2kStmWorkChain.get_builder()
            builder.metadata.label = "CP2K_STM"
            builder.protocol = orm.Str(protocol.value)
            builder.spm_code = orm.load_node(stm_code.value)
            builder.spm_params=orm.Dict(dict=spm.create_stm_parameterdata(extrap_plane_floattext.value,
                                                               height_text.value.split(),
                                                               struct.symbols, 
                                                               "parent_calc_folder/",
                                                               elim_float_slider.value[0], 
                                                               elim_float_slider.value[1], 
                                                               de_floattext.value,
                                                               const_current_text.value.split(),
                                                               fwhms_text.value.split(),
                                                               ptip_floattext.value
                                                               ))
        elif drop_spm_type.value == 'AFM':
            builder = Cp2kAfmWorkChain.get_builder()
            builder.metadata.label = "CP2K_AFM"
            builder.protocol = orm.Str(protocol.value)
            builder.afm_pp_params=orm.Dict(dict=spm.create_pp_parameterdata(struct,
                                                                           scanstep_floattxt.value,
                                                                           scanminz_floattxt.value,
                                                                           scanmaxz_floattxt.value,
                                                                           amp_floattxt.value,
                                                                           f0_cantilever_floattxt.value))
            builder.afm_2pp_params=orm.Dict(dict=spm.create_2pp_parameterdata(struct,
                                                                            scanstep_floattxt.value,
                                                                            drop_2pp_resp.value,
                                                                            scanminz_floattxt.value,
                                                                            scanmaxz_floattxt.value,
                                                                            amp_floattxt.value,
                                                                            f0_cantilever_floattxt.value))
            builder.afm_pp_code = orm.load_node(ppafm_code.value)
            builder.afm_2pp_code =  orm.load_node(twopp_code.value)

        elif drop_spm_type.value == 'ORBITALS':
            builder = Cp2kOrbitalsWorkChain.get_builder()
            builder.metadata.label = "CP2K_ORBITALS"
            builder.protocol = orm.Str(protocol.value)
            builder.spm_code = orm.load_node(stm_code.value)
            builder.spm_params=orm.Dict(dict=spm.create_orbitals_parameterdata(extrap_plane_floattext.value,
                                                               height_text.value.split(), 
                                                               "parent_calc_folder/",
                                                               n_homo_inttext.value,
                                                               n_lumo_inttext.value, 
                                                               const_current_text.value.split(), 
                                                               fwhms_text.value.split(),
                                                               ptip_floattext.value
                                                               ))
        elif drop_spm_type.value == 'HRSTM':
            builder = Cp2kHrstmWorkChain.get_builder()
            builder.metadata.label = "CP2K_HRSTM"
            builder.protocol = orm.Str(protocol.value)
            builder.hrstm_code = orm.load_node(hrstm_code.value)
            builder.ppm_code = orm.load_node(twopp_code.value)
            ppm_params=spm.create_2pp_parameterdata(struct,
                                                                 scanstep_floattxt.value,
                                                                 drop_2pp_resp.value,
                                                                 scanminz_floattxt.value,
                                                                 scanmaxz_floattxt.value,
                                                                 amp_floattxt.value,
                                                                 f0_cantilever_floattxt.value
                                                                 )
            ppm_params['PBC'] = 'True'
            builder.ppm_params=orm.Dict(dict=ppm_params)
            builder.hrstm_params=orm.Dict(dict=spm.create_hrstm_parameterdata(builder.hrstm_code,
                                                                             "parent_calc_folder/",
                                                                             "ppm_calc_folder/",
                                                                             struct,
                                                                             ppm_params,
                                                                             tiptype_ipw.value,
                                                                             stip_ipw.value,
                                                                             pytip_ipw.value,
                                                                             pztip_ipw.value,
                                                                             pxtip_ipw.value,
                                                                             volmin_ipw.value,
                                                                             volmax_ipw.value,
                                                                             volstep_ipw.value,
                                                                             volstep_ipw.min,
                                                                             fwhm_ipw.value,
                                                                             wfnstep_ipw.value,
                                                                             extrap_ipw.value,
                                                                             workfun_ipw.value,
                                                                             orbstip_ipw.value,
                                                                             fwhmtip_ipw.value,
                                                                             rotate_ipw.value
                                                                             ))           
        builder.cp2k_code = orm.load_node(cp2k_code.value)
        builder.structure = structure_selector.structure_node
        builder.dft_params=dft_params
        # Check if a restart wfn is available.
        wave_function = None
        if structure_selector.structure_node.is_stored:
            wave_function = wfn.structure_available_wfn(
            node=structure_selector.structure_node,
            relative_replica_id=None,
            current_hostname=orm.load_node(cp2k_code.value).computer.hostname,
            return_path=False,
            dft_params=dft_params_dict,
            )
        if wave_function is not None:
            print(f"Restarting from wfn in folder: {wave_function.pk}")
            builder.parent_calc_folder = wave_function

        # Resources.
        builder.options = {
            "max_wallclock_seconds": resources.walltime_seconds,
            "resources": {
                'num_machines': resources.nodes,
                'num_mpiprocs_per_machine': resources.tasks_per_node,
                'num_cores_per_mpiproc': resources.threads_per_task,
            },
        }
        return builder

In [None]:
# Code selector

widgets = {'STM': [elim_float_slider, de_floattext, 
            fwhms_text, extrap_plane_floattext, height_text, const_current_text, ptip_floattext],
           'AFM': [scanstep_floattxt,scanminz_floattxt,scanmaxz_floattxt,amp_floattxt,f0_cantilever_floattxt,drop_2pp_resp],
            'ORBITALS': [n_homo_inttext, n_lumo_inttext, height_text,const_current_text, fwhms_text, 
                         extrap_plane_floattext, ptip_floattext],
            'HRSTM': [scanstep_floattxt,scanminz_floattxt,scanmaxz_floattxt,amp_floattxt,f0_cantilever_floattxt,drop_2pp_resp,ipw.HBox([voltext_ipw,volmin_ipw,volmax_ipw,volstep_ipw], style=style, layout=layout),
                   fwhm_ipw, workfun_ipw, wfnstep_ipw, extrap_ipw,
                    tiptype_ipw, rotate_ipw, orbstip_ipw, fwhmtip_ipw, stip_ipw, pytip_ipw, pztip_ipw, pxtip_ipw
                   ]
                   }

spm_code_box=ipw.HBox([])


def on_spm_change(c):
    spm_parameters_box.children = widgets[drop_spm_type.value]
    if drop_spm_type.value == 'STM':
        spm_code_box.children=[stm_code]
        submit_buttons.children=[btn_submit_stm]
        charge_text.value=0
        charge_text.disabled=True
        fwhms_text.value = '0.08'
        height_text.description = 'Const. H (ang)'
        height_text.value='4.0 6.0'
        const_current_text.description = 'Const. cur. (isoval)'
    elif drop_spm_type.value == 'ORBITALS':
        spm_code_box.children=[stm_code]
        submit_buttons.children=[btn_submit_orb]
        charge_text.value=0
        charge_text.disabled=False
        fwhms_text.value = '0.04'
        height_text.description = 'Heights (ang)'
        height_text.value='3.0 5.0'
        const_current_text.description = 'Isovalues'
    elif drop_spm_type.value == 'AFM':
        spm_code_box.children=[ppafm_code, twopp_code]
        submit_buttons.children=[btn_submit_afm]
        scanminz_floattxt.value=3.5 
        scanminz_floattxt.min=0.0 
        scanminz_floattxt.max=5.0
        scanminz_floattxt.step=0.1
        scanmaxz_floattxt.min=5.0
        scanmaxz_floattxt.max=10.0
        scanmaxz_floattxt.step=0.1
        scanmaxz_floattxt.value=8.5
        submit_buttons.children=[btn_submit_afm]
        charge_text.value=0
        charge_text.disabled=True
    elif drop_spm_type.value == 'HRSTM':
        spm_code_box.children=[hrstm_code, twopp_code]
        submit_buttons.children=[btn_submit_hrstm]
        scanminz_floattxt.value=4.5 
        scanminz_floattxt.min=3.0 
        scanminz_floattxt.max=10.0
        scanminz_floattxt.step=0.1
        scanmaxz_floattxt.min=3.0
        scanmaxz_floattxt.max=10.0
        scanmaxz_floattxt.step=0.1
        scanmaxz_floattxt.value=7.5
        submit_buttons.children=[btn_submit_hrstm]
        charge_text.value=0
        charge_text.disabled=True
                 
btn_submit_stm = awb.SubmitButtonWidget(Cp2kStmWorkChain, inputs_generator=get_builder)
btn_submit_afm = awb.SubmitButtonWidget(Cp2kAfmWorkChain, inputs_generator=get_builder)
btn_submit_orb = awb.SubmitButtonWidget(Cp2kOrbitalsWorkChain, inputs_generator=get_builder)
btn_submit_hrstm = awb.SubmitButtonWidget(Cp2kHrstmWorkChain, inputs_generator=get_builder)
submit_buttons=ipw.HBox([])

drop_spm_type.observe(on_spm_change)    

on_spm_change(0)


In [None]:
# Resources estimation.
resources_estimation = ResourcesEstimatorWidget()
resources_estimation.link_to_resources_widget(resources)
ipw.dlink((empa_viewer, 'details'), (resources_estimation, 'details'))
ipw.dlink((uks_switch, 'value'), (resources_estimation, 'uks'))
_ = ipw.dlink((cp2k_code, 'value'), (resources_estimation, 'selected_code'))

In [None]:
display(ipw.VBox([resources, resources_estimation,cp2k_code,protocol,elpa_check,spm_code_box]),
        submit_buttons, 
        submit_out)