# Submit PDOS calculation

In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

In [None]:
# General imports.
import ipywidgets as ipw
from IPython.display import display, clear_output

# AiiDA imports.
%load_ext aiida
%aiida
from aiida import orm, plugins

# AiiDAlab imports.
import aiidalab_widgets_base as awb

# Custom imports.
from surfaces_tools.widgets import empa_viewer, inputs, fragments, stack

Cp2kPdosWorkChain = plugins.WorkflowFactory('nanotech_empa.cp2k.pdos')

In [None]:
# Structure selector.

empa_viewer = empa_viewer.EmpaStructureViewer()
structure_selector = awb.StructureManagerWidget(
    viewer=empa_viewer,
    importers=[awb.StructureUploadWidget(title="Import from computer"),
        awb.StructureBrowserWidget(title="AiiDA database"),
    ],
    storable=False, node_class='StructureData'
)

In [None]:
# Structure fragments.

fragment_list = fragments.FragmentList(add_fragment_visibility='hidden')
_ = ipw.dlink((structure_selector.viewer, 'selection'), (fragment_list, 'selection_string'), transform=lambda x: awb.utils.list_to_string_range(x))

In [None]:
# PDOS selection.

class PdosSelectionWidget(stack.HorizontalItemWidget):

    def __init__(self, label="", indices="1..2"):
        self.selection = ipw.Text(description='Atoms selection:', value=indices, style={'description_width': 'initial'}, layout={'width': '250px'})
        self.label = ipw.Text(description='Label:', value=label, style={'description_width': 'initial'}, layout={'width': '250px'})
        super().__init__(children=[self.label, self.selection])

class PdosListWidget(stack.VerticalStackWidget):
    def add_item(self, _):
        selection_string = awb.utils.list_to_string_range(structure_selector.viewer.selection)
        self.items += (self.item_class(indices=selection_string),)

projections = PdosListWidget(item_class=PdosSelectionWidget, add_button_text="Add projection")

def update_fragments_and_projection(change):
    if change['new']:
        molecule_indices = ""
        try:
            molecule_indices = awb.utils.list_to_string_range(empa_viewer.details['all_molecules'][0])
        except Exception:
            print("Unable to automatically identify the molecule")
        fragment_list.fragments = [fragments.Fragment(indices=f"1..{len(change['new'])}", name="all"), fragments.Fragment(indices=molecule_indices, name="molecule")]
        projections.items = [PdosSelectionWidget(label="molecule", indices=molecule_indices)]
    else:
        fragment_list.fragments = []
        projections.items = []

        
structure_selector.observe(update_fragments_and_projection, names='structure')

In [None]:
# DFT parameters.

style = {'description_width': 'initial'}

protocol = ipw.Dropdown(
    value="standard",
    options=[("Standard", "standard"), ("Low accuracy", "low_accuracy"),('Debug','debug')],
    description="Protocol:",
    )

uks_widget = inputs.UksSectionWidget(charge_visibility='hidden', multiplicity_visibility='hidden')
ipw.dlink((empa_viewer, 'details'), (uks_widget, 'details'))
ipw.dlink((uks_widget, 'uks'), (fragment_list, 'uks'))

sc_diag_check = ipw.Checkbox(
    value=False,
    description='self-consistent diagonalization',
    disabled=True
)
force_multiplicity_check = ipw.Checkbox(
    value=True,
    description='Fix multiplicity',
    disabled=True
)

def enable_smearing(b):
    temperature_text.disabled = not smear_switch.value
    force_multiplicity_check.disabled = not smear_switch.value

smear_switch = ipw.Checkbox(value=False, description='Enable Fermi-Dirac smearing', style=style)
smear_switch.observe(enable_smearing, names='value')

temperature_text = ipw.FloatText(value=150.0,
                           description='Temperature [K]',
                           disabled=True,
                           style=style)

In [None]:
# Orbital overlap parameters.

style = {'description_width': 'initial'}
layout_small = {'width': '300px'}

molecule_n_homo = ipw.IntText(
                    value=4,
                    description="Molecule # HOMO:",
                    style=style,
                    layout=layout_small)

molecule_n_lumo = ipw.IntText(
                    value=4,
                    description="Molecule # LUMO:",
                    style=style, layout=layout_small
                    )

slab_emin = ipw.BoundedFloatText(
                        description='Slab Emin (eV):',
                        min=-3.0,
                        max=-0.1,
                        step=0.1,
                        value=-2.0,
                        style=style, layout=layout_small)

slab_emax = ipw.BoundedFloatText(
                        description='Slab Emax (eV):',
                        min=0.1,
                        max=3.0,
                        step=0.1,
                        value=2.0,
                        readout_format='%.2f',
                        style=style, layout=layout_small)

In [None]:
# Codes and computational resources.
cp2k_code = awb.ComputationalResourcesWidget(description="CP2K code:", default_calc_job_plugin="cp2k")
overlap_code = awb.ComputationalResourcesWidget(description="Overlap code:", default_calc_job_plugin="nanotech_empa.overlap")
elpa_check = ipw.Checkbox(
    value=True,
    description='use ELPA',
    disabled=False
)

def update_resources_for_fragments(_):
    for fragment in fragment_list.fragments:
        fragment.estimate_computational_resources(whole_structure=structure_selector.structure, selected_code=cp2k_code.value)

estimate_resources_button = ipw.Button(description="Estimate resources", button_style='warning')
estimate_resources_button.on_click(update_resources_for_fragments)

In [None]:
workflow_description = ipw.Text(
    description='Workflow description:',
    placeholder="Provide the description here.",
    style={"description_width": "initial"},
    layout={"width": "70%"},
    )

In [None]:
submit_out = ipw.Output()
def get_builder():
    with submit_out:
        clear_output()
        if structure_selector.structure is None:
            print("Please select a structure.")
            return

        if cp2k_code.value is None or overlap_code.value is None:
            print("Please select codes.")
            return


        builder = Cp2kPdosWorkChain.get_builder()
        builder.metadata.label = "CP2K_PDOS"
        builder.metadata.description = workflow_description.value
        builder.protocol = orm.Str(protocol.value)
        dft_params_dict = {
            'elpa_switch': elpa_check.value,
            'uks': uks_widget.uks,
            'sc_diag': sc_diag_check.value,
            'periodic': 'XYZ',
            "charges": {fragment.name.value: fragment.charge.value for fragment in fragment_list.fragments}
        }

        if uks_widget.uks:
            dft_params_dict.update({
                'magnetization_per_site': uks_widget.return_dict()["dft_params"]["magnetization_per_site"],
                'multiplicities': {fragment.name.value: fragment.multiplicity.value for fragment in fragment_list.fragments},
                'force_multiplicity': force_multiplicity_check.value
            })

        if smear_switch.value:
            dft_params_dict['smear_t'] = temperature_text.value            
        
        dft_params = orm.Dict(dict=dft_params_dict)
        
        overlap_params = orm.Dict(dict={
            '--cp2k_input_file1':  'parent_slab_folder/aiida.inp',
            '--basis_set_file1':   'parent_slab_folder/BASIS_MOLOPT',
            '--xyz_file1':         'parent_slab_folder/aiida.coords.xyz',
            '--wfn_file1':         'parent_slab_folder/aiida-RESTART.wfn',
            '--emin1':             str(slab_emin.value),
            '--emax1':             str(slab_emax.value),
            '--cp2k_input_file2':  'parent_mol_folder/aiida.inp',
            '--basis_set_file2':   'parent_mol_folder/BASIS_MOLOPT',
            '--xyz_file2':         'parent_mol_folder/aiida.coords.xyz',
            '--wfn_file2':         'parent_mol_folder/aiida-RESTART.wfn',
            '--nhomo2':            str(molecule_n_homo.value),
            '--nlumo2':            str(molecule_n_lumo.value),
            '--output_file':       './overlap.npz',
            '--eval_region':       ['G', 'G', 'G', 'G', 'n-3.0_C', 'p2.0'],
            '--dx':                '0.2',
            '--eval_cutoff':       '14.0'
        })

        builder.cp2k_code = orm.load_node(cp2k_code.value)
        builder.structure = structure_selector.structure_node
        builder.molecule_indices = orm.List(awb.utils.string_range_to_list(fragment_list.fragments[1].indices.value)[0])
        builder.pdos_lists = orm.List([(p.selection.value, p.label.value) for p in projections.items])
        builder.dft_params = dft_params
        builder.overlap_code = orm.load_node(overlap_code.value)
        builder.overlap_params = overlap_params

        # Check if a restart wfn is available. To be implemented
        wave_function = None
        #if structure_selector.structure_node.is_stored:
        #    wave_function = wfn.structure_available_wfn(
        #    node=structure_selector.structure_node,
        #    relative_replica_id=None,
        #    current_hostname=builder.cp2k_code.computer.hostname,
        #    return_path=False,
        #    dft_params=dft_params_dict,
        #    )
        if wave_function is not None:
            print(f"Restarting from wfn in folder: {wave_function.pk}")
            builder.parent_calc_folder = wave_function        

        # Resources.
        slab_resources = fragment_list.fragments[0].resources
        molecule_resources = fragment_list.fragments[1].resources
        builder.options = {
            'slab': {
                "max_wallclock_seconds": slab_resources.walltime_seconds,
                "resources": {
                    "num_machines": slab_resources.nodes,
                    "num_mpiprocs_per_machine": slab_resources.tasks_per_node,
                    "num_cores_per_mpiproc": slab_resources.threads_per_task,
                }
            },
            'molecule': {
                "max_wallclock_seconds": molecule_resources.walltime_seconds,
                "resources": {
                    "num_machines": molecule_resources.nodes,
                    "num_mpiprocs_per_machine": molecule_resources.tasks_per_node,
                    "num_cores_per_mpiproc": molecule_resources.threads_per_task,
                }
            }
        }
            
    return builder

btn_submit = awb.SubmitButtonWidget(Cp2kPdosWorkChain, inputs_generator=get_builder)

# Select structure

In [None]:
display(structure_selector)    

# Select parts of the system for the projection

In [None]:
display(projections, fragment_list)


# DFT parameters

In [None]:
display(protocol, uks_widget, ipw.HBox([smear_switch, temperature_text, force_multiplicity_check]))

# Orbital overlap parameters

In [None]:
display(ipw.HBox([
    ipw.VBox([molecule_n_homo, molecule_n_lumo], layout={'width': '45%'}),
    ipw.VBox([slab_emin, slab_emax], layout={"width": "45%"})
    ],
    layout={'width': '100%'})
    )

# Code and resources

In [None]:
display(cp2k_code, overlap_code, elpa_check, estimate_resources_button)

# Submission

In [None]:
display(workflow_description, btn_submit, submit_out)