In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

#  Submit adsorption energy

In [None]:
# General imports.
import ipywidgets as ipw

# AiiDA imports.
%load_ext aiida
%aiida
from aiida import orm, plugins

# AiiDAlab imports.
import aiidalab_widgets_base as awb

# Custom imports.
from surfaces_tools.widgets import empa_viewer, fragments, inputs
from surfaces_tools.utils import wfn

StructureData = plugins.DataFactory("core.structure")
Cp2kFragmentSeparationWorkChain = plugins.WorkflowFactory('nanotech_empa.cp2k.fragment_separation')

In [None]:
# Structure selector.

empa_viewer = empa_viewer.EmpaStructureViewer()
structure_selector = awb.StructureManagerWidget(
    viewer=empa_viewer,
    importers=[
        awb.StructureBrowserWidget(title="AiiDA database"),
    ],
    editors = [
        awb.BasicStructureEditor(title="Edit structure"),
    ],
    storable=False, node_class='StructureData')


In [None]:
# Other simulation settings.
fixed_atoms = ipw.Text(
    value='',
    placeholder='3 4 14..17',
    description='Fixed atoms',
    disabled=False
)

# Protocol
protocol = ipw.Dropdown(
    value='standard',
    description='Protocol:',
    disabled=False,
    options = [("Standard", "standard"), ("Low accuracy", "low_accuracy"), ("Debug","debug")]
    )

uks_widget = inputs.UksSectionWidget(charge_visibility='hidden', multiplicity_visibility='hidden')
_ = ipw.dlink((empa_viewer, 'details'), (uks_widget, 'details'))


In [None]:
# Structure fragments.
fragment_list = fragments.FragmentList()
ipw.dlink((uks_widget, 'uks'), (fragment_list, 'uks'))
_ = ipw.dlink((structure_selector.viewer, 'selection'), (fragment_list, 'selection_string'), transform=lambda x: awb.utils.list_to_string_range(x))

In [None]:
# Coordinate structure selector with spin up/down and fragments.
def update_selection(change):
    if change['new']:
        # Add default "all" fragment.
        fragment_list.fragments = [fragments.Fragment(indices=f"1..{len(change['new'])}", name="all")]

    else:
        fragment_list.fragments = []

structure_selector.observe(update_selection, names='structure')

In [None]:
# Resources estimation.
MAX_NODES=48

def update_resources_for_fragments(_):


    for fragment in fragment_list.fragments:
        fragment.estimate_computational_resources(whole_structure=structure_selector.structure, selected_code=computational_resources.value)

estimate_resources_button = ipw.Button(description="Estimate resources", button_style='warning')
estimate_resources_button.on_click(update_resources_for_fragments)

In [None]:
# Code selector
computational_resources = awb.ComputationalResourcesWidget(description="CP2K code:", default_calc_job_plugin="cp2k")

In [None]:
# Workchain submission.

def get_builder():
    """Get the builder for the adsorption energy calculation."""
    builder = Cp2kFragmentSeparationWorkChain.get_builder()
    builder.code = orm.load_code(computational_resources.value)
    builder.structure = structure_selector.structure_node
    builder.metadata.label="CP2K_AdsorptionE"

    # Fragments' indices.
    builder.fragments = {fragment.name.value: orm.List(list=awb.utils.string_range_to_list(fragment.indices.value)[0]) for fragment in fragment_list.fragments}
    
    # Fragments' charges.
    charges = {fragment.name.value: fragment.charge.value for fragment in fragment_list.fragments}

    # Resources.
    builder.options = {
        fragment.name.value: {
            "max_wallclock_seconds": fragment.resources.walltime_seconds,
            "resources": {
                "num_machines": fragment.resources.nodes,
                "num_mpiprocs_per_machine": fragment.resources.tasks_per_node,
                "num_cores_per_mpiproc": fragment.resources.threads_per_task,
            },
        }
        for fragment in fragment_list.fragments
    }

    dft_params = {}

    # UKS
    if uks_widget.uks:
        multiplicities = {fragment.name.value: fragment.multiplicity.value for fragment in fragment_list.fragments}        
        dft_params.update({
            "uks": True,
            "magnetization_per_site": uks_widget.return_dict()["dft_params"]["magnetization_per_site"],
            "multiplicities": multiplicities,
        })

    builder.fixed_atoms = orm.List(list=awb.utils.string_range_to_list(fixed_atoms.value)[0])
    builder.protocol = orm.Str(protocol.value)  

    # Check if a restart wfn is available.
    dft_params.update({
        "charges": charges,
        "vdw": True,
        "periodic": "XYZ"
    })

    builder.dft_params=orm.Dict(dft_params)
    wave_function = None

    if structure_selector.structure_node.is_stored:
        wave_function = wfn.structure_available_wfn(
        node=structure_selector.structure_node,
        relative_replica_id=None,
        current_hostname=builder.code.computer.hostname,
        return_path=False,
        dft_params=dft_params,
        )
    if wave_function is not None:
        print(f"Restarting from wfn in folder: {wave_function.pk}")
        builder.parent_calc_folder = wave_function    

    return builder

btn_submit_ads = awb.SubmitButtonWidget(Cp2kFragmentSeparationWorkChain, inputs_generator=get_builder)

In [None]:
display(structure_selector)

# Fragments

In [None]:
display(fragment_list)

# DFT settings

In [None]:
display(uks_widget, protocol, fixed_atoms)

# Codes and resources

In [None]:
display(ipw.HBox([computational_resources, estimate_nodes_button]))

# Submission

In [None]:
display(btn_submit_ads)