In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

In [None]:
# General imports.
import ipywidgets as ipw
from IPython.display import display, clear_output

# AiiDA imports.
%load_ext aiida
%aiida
from aiida import orm

# AiiDAlab imports.
import aiidalab_widgets_base as awb

# Custom imports.
from surfaces_tools.widgets.computational_resources import ProcessResourcesWidget, ResourcesEstimatorWidget
from surfaces_tools.widgets.empa_viewer import EmpaStructureViewer
from surfaces_tools.utils import spm, wfn


Cp2kPdosWorkChain = WorkflowFactory('nanotech_empa.cp2k.pdos')

# Select structure

In [None]:
# Structure selector.

empa_viewer = EmpaStructureViewer()
structure_selector = awb.StructureManagerWidget(
    viewer=empa_viewer,
    importers=[awb.StructureUploadWidget(title="Import from computer"),
        awb.StructureBrowserWidget(title="AiiDA database"),
    ],
    editors = [
    ],
    storable=False, node_class='StructureData')
display(structure_selector)    

In [None]:
text_calc_description = ipw.Text(description='Description:', layout={'width': '45%'})
display(text_calc_description)

# Select molecule

The atoms of the molecule need to be ordered in front of the data structure for the automatic detection to work.

In [None]:
# PDOS selection addition and removal
def remove_from_tuple(tup, index):
    tmp_list = list(tup)
    del tmp_list[index]
    return tuple(tmp_list)

def rm_pdos_selection(b):
    rm_index = sel_rm_list.index(b)
    del sel_text_list[rm_index]
    del sel_ll_list[rm_index]
    del sel_rm_list[rm_index]
    
    sel_vbox.children = remove_from_tuple(sel_vbox.children, rm_index-1)

def add_pdos_selection(b):
    sel_text_list.append(ipw.Text(description='pdos selection',
                         style=style, layout={'width': '30%'}))
    rm_btn = ipw.Button(description='remove', layout={'width': '10%'})
    rm_btn.on_click(rm_pdos_selection)
    ll_txt = ipw.Text(description='label', layout={'width': '20%'})
    sel_rm_list.append(rm_btn)
    sel_ll_list.append(ll_txt)
    sel_vbox.children += (ipw.HBox([sel_text_list[-1], ll_txt, rm_btn]),)

In [None]:
def guess_molecule():
    try:
        mol.selection.value = awb.utils.list_to_string_range(empa_viewer.details['all_molecules'][0])
    except:
        print("Unable to automatically find slab and molecule")

def get_mol_ase():
    mol_inds = awb.utils.string_range_to_list(mol_selection.value)[0]
    ase_struct = structure_selector.structure
    return ase_struct[mol_inds]
    
def reset_state():
    mol_selection.value = "1..10"
    sel_text_list = [mol_selection]
    sel_rm_list = [None]
    sel_ll_list = []
    sel_vbox = ipw.VBox([])
    
def on_struct_change(c):
    guess_molecule()

structure_selector.observe(on_struct_change, names='value')


In [None]:
style = {'description_width': '120px'}
layout = {'width': '60%'}

mol_selection = ipw.Text(description='Molecule selection',
                         value="1..10",
                         style=style, layout={'width': '30%'})


add_sel_btn = ipw.Button(description='Add pdos selection',
                            layout={'width': '15%'})
add_sel_btn.on_click(add_pdos_selection)

sel_text_list = [mol_selection]
sel_rm_list = [None]
sel_ll_list = []
sel_vbox = ipw.VBox([])

display(ipw.HBox([mol_selection]), sel_vbox, add_sel_btn)

# Select computer and codes

In [None]:
# Code selector
cp2k_code = awb.ComputationalResourcesWidget(description="CP2K code:", default_calc_job_plugin="cp2k")
resources = ProcessResourcesWidget()
resources_mol = ProcessResourcesWidget()
resources_mol.nodes_widget.value = 12

style = {'description_width': '140px'}
layout = {'width': '70%'}

overlap_code = awb.ComputationalResourcesWidget(description="Overlap code:", default_calc_job_plugin="nanotech_empa.overlap")

elpa_check = ipw.Checkbox(
    value=True,
    description='use ELPA',
    disabled=False
)

display(cp2k_code, overlap_code, elpa_check)

# DFT parameters

In [None]:
style = {'description_width': '80px'}
layout = {'width': '70%'}

def enable_spin(b):
    for w in [spin_up_text, spin_dw_text, 
     multiplicity_text,
    sc_diag_check, force_multiplicity_check]:
        w.disabled = not uks_switch.value

uks_switch = ipw.ToggleButton(value=False,
                              description='Spin-polarized calculation',
                              style=style, layout={'width': '450px'})
uks_switch.observe(enable_spin, names='value')

spin_up_text = ipw.Text(placeholder='1..5 7',
                        description='Spin up',
                        disabled=True,
                        style=style, layout={'width': '370px'})
spin_dw_text = ipw.Text(placeholder='11..15 17',
                        description='Spin down',
                        disabled=True,
                        style=style, layout={'width': '370px'})

multiplicity_text = ipw.IntText(value=1,
                           description='Multiplicity',
                           disabled=True,
                           style=style, layout={'width': '20%'})
sc_diag_check = ipw.Checkbox(
    value=False,
    description='self-consistent diagonalization',
    disabled=True
)
force_multiplicity_check = ipw.Checkbox(
    value=True,
    description='Fix multiplicity',
    disabled=True
)

protocol = ipw.Dropdown( 
            value="standard",
            options=[("Standard", "standard"), ("Low accuracy", "low_accuracy"),('Debug','debug')],
            description="Protocol:",
            style={"description_width": "120px"},
        )
display(uks_switch, ipw.HBox([ipw.VBox([spin_up_text, spin_dw_text])]), 
        ipw.HBox([multiplicity_text,sc_diag_check]),protocol)

In [None]:
def enable_smearing(b):
    temperature_text.disabled = not smear_switch.value
    force_multiplicity_check.disabled = not smear_switch.value

smear_switch = ipw.ToggleButton(value=False,
                              description='Enable Fermi-Dirac smearing',
                              style=style, layout={'width': '450px'})
smear_switch.observe(enable_smearing, names='value')

temperature_text = ipw.FloatText(value=150.0,
                           description='Temperature [K]',
                           disabled=True,
                           style={'description_width': '100px'}, layout={'width': '20%'})


display(smear_switch, ipw.HBox([temperature_text,force_multiplicity_check]))

# Orbital overlap parameters

In [None]:
style = {'description_width': '140px'}
layout_small = {'width': '30%'}

num_homo_inttext = ipw.IntText(
                    value=4,
                    description="num gas HOMO",
                    style=style, layout=layout_small)
num_lumo_inttext = ipw.IntText(
                    value=4,
                    description="num gas LUMO",
                    style=style, layout=layout_small)

elim_inttext = ipw.IntText(
                    value=4,
                    description="num gas LUMO",
                    style=style, layout=layout_small)

emin_floattext = ipw.BoundedFloatText(
                        description='Slab system Emin (eV)',
                        min=-3.0,
                        max=-0.1,
                        step=0.1,
                        value=-2.0,
                        style=style, layout=layout_small)

emax_floattext = ipw.BoundedFloatText(
                        description='Slab system Emax (eV)',
                        min=0.1,
                        max=3.0,
                        step=0.1,
                        value=2.0,
                        readout_format='%.2f',
                        style=style, layout=layout_small)

display(num_homo_inttext, num_lumo_inttext, emin_floattext, emax_floattext)

# Submission

In [None]:
submit_out = ipw.Output()
def get_builder():
    with submit_out:
        clear_output()
        if structure_selector.structure is None:
            print("Please select a structure.")
            return
        if cp2k_code.value is None:
            print("Please select codes.")
            return


        builder = Cp2kPdosWorkChain.get_builder()
        builder.metadata.label = "CP2K_PDOS"
        builder.protocol = orm.Str(protocol.value)
        dft_params_dict = {
            'elpa_switch':       elpa_check.value,
            'uks':               uks_switch.value,
            'sc_diag':           sc_diag_check.value,
            'force_multiplicity': force_multiplicity_check.value,
            'periodic':          'XYZ',
        }
        if uks_switch.value:
            dft_params_dict['spin_up_guess'] = [int(v)-1 for v in spin_up_text.value.split()]
            dft_params_dict['spin_dw_guess'] = [int(v)-1 for v in spin_dw_text.value.split()]
            dft_params_dict['multiplicity']  = multiplicity_text.value
            dft_params_dict['force_multiplicity']  = force_multiplicity_check.value

        if smear_switch.value:
            dft_params_dict['smear_t'] = temperature_text.value            
        
        dft_params = Dict(dict=dft_params_dict)
        
        overlap_params = Dict(dict={
            '--cp2k_input_file1':  'parent_slab_folder/aiida.inp',
            '--basis_set_file1':   'parent_slab_folder/BASIS_MOLOPT',
            '--xyz_file1':         'parent_slab_folder/aiida.coords.xyz',
            '--wfn_file1':         'parent_slab_folder/aiida-RESTART.wfn',
            '--emin1':             str(emin_floattext.value),
            '--emax1':             str(emax_floattext.value),
            '--cp2k_input_file2':  'parent_mol_folder/aiida.inp',
            '--basis_set_file2':   'parent_mol_folder/BASIS_MOLOPT',
            '--xyz_file2':         'parent_mol_folder/aiida.coords.xyz',
            '--wfn_file2':         'parent_mol_folder/aiida-RESTART.wfn',
            '--nhomo2':            str(num_homo_inttext.value),
            '--nlumo2':            str(num_lumo_inttext.value),
            '--output_file':       './overlap.npz',
            '--eval_region':       ['G', 'G', 'G', 'G', 'n-3.0_C', 'p2.0'],
            '--dx':                '0.2',
            '--eval_cutoff':       '14.0'
        })
                
        struct = structure_selector.structure        
        
        mol_structure = orm.StructureData(ase=get_mol_ase())
        mol_structure.description = f"Molecule generated from pdos from slab uuid: {structure_selector.structure_node.uuid}" 
        
        # 
        aiida_pdos_list = List()
        pdos_labels = ['molecule']
        for widget in sel_ll_list:
            pdos_labels.append(widget.value)
        aiida_pdos_list.extend([(sel_text_list[i].value,pdos_labels[i]) for i in range(len(sel_text_list))])
        
        

        builder.cp2k_code = orm.load_node(cp2k_code.value)
        builder.slabsys_structure=structure_selector.structure_node
        builder.mol_structure=mol_structure
        builder.pdos_lists=aiida_pdos_list
        builder.dft_params=dft_params
        builder.overlap_code = orm.load_node(overlap_code.value)
        builder.overlap_params=overlap_params

        # Check if a restart wfn is available.
        wave_function = None
        if structure_selector.structure_node.is_stored:
            wave_function = wfn.structure_available_wfn(
            node=structure_selector.structure_node,
            relative_replica_id=None,
            current_hostname=builder.cp2k_code.computer.hostname,
            return_path=False,
            dft_params=dft_params_dict,
            )
        if wave_function is not None:
            print(f"Restarting from wfn in folder: {wave_function.pk}")
            builder.parent_calc_folder = wave_function        

        # Resources.
        builder.options = {
            'slab': {
                "max_wallclock_seconds": resources.walltime_seconds,
                "resources": {
                    "num_machines": resources.nodes,
                    "num_mpiprocs_per_machine": resources.tasks_per_node,
                    "num_cores_per_mpiproc": resources.threads_per_task,
                }
            },
            'molecule': {
                "max_wallclock_seconds": resources_mol.walltime_seconds,
                "resources": {
                    "num_machines": resources_mol.nodes,
                    "num_mpiprocs_per_machine": resources_mol.tasks_per_node,
                    "num_cores_per_mpiproc": resources_mol.threads_per_task,
                }
            }
        }
            
    return builder

btn_submit = awb.SubmitButtonWidget(Cp2kPdosWorkChain, inputs_generator=get_builder)

In [None]:
# Resources estimation.
resources_estimation = ResourcesEstimatorWidget()
resources_estimation.link_to_resources_widget(resources)
ipw.dlink((empa_viewer, 'details'), (resources_estimation, 'details'))
ipw.dlink((uks_switch, 'value'), (resources_estimation, 'uks'))
_ = ipw.dlink((cp2k_code, 'value'), (resources_estimation, 'selected_code'))

In [None]:
display(ipw.VBox([ipw.HBox([resources,resources_mol]), resources_estimation,
        btn_submit, 
        submit_out]))