In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

#  Submit adsorption energy

In [None]:
# General imports.
from copy import copy
import ipywidgets as ipw
from IPython.display import clear_output

# AiiDA imports.
%load_ext aiida
from aiida.plugins import  WorkflowFactory , DataFactory
from aiida.orm import  load_code

# AiiDAlab imports.
import aiidalab_widgets_base as awb

# Custom imports.
from widgets.empa_viewer import EmpaStructureViewer
from widgets.fragments import Fragment, FragmentList
from aiida_nanotech_empa.workflows.cp2k import cp2k_utils

StructureData = DataFactory("core.structure")
Float = DataFactory("core.float")
Dict = DataFactory("core.dict")
Str = DataFactory("core.str")
Int = DataFactory("core.int")
List = DataFactory("core.list")
Bool = DataFactory("core.bool")


In [None]:
Cp2kFragmentSeparationWorkChain = WorkflowFactory('nanotech_empa.cp2k.fragment_separation')

In [None]:
# Structure selector.

empa_viewer = EmpaStructureViewer()
structure_selector = awb.StructureManagerWidget(
    viewer=empa_viewer,
    importers=[
        awb.StructureBrowserWidget(title="AiiDA database"),
    ],
    editors = [
        awb.BasicStructureEditor(title="Edit structure"),
    ],
    storable=False, node_class='StructureData')

In [None]:
# Other simulation settings.

uks = ipw.Checkbox(
    value=False,
    description='UKS',
    disabled=False,
    indent=False
)

spins_up = ipw.Text(
    value='',
    placeholder='1 2 10..13',
    description='Spins up',
    disabled=False
)

spins_down = ipw.Text(
    value='',
    placeholder='3 4 14..17',
    description='Spins down',
    disabled=False
)

fixed_atoms = ipw.Text(
    value='',
    placeholder='3 4 14..17',
    description='Fixed atoms',
    disabled=False
)


#Protocol
protocol = ipw.Dropdown(
    value='standard',
    description='Protocol',
    disabled=False,
    options = [("Standard", "standard"), ("Low accuracy", "low_accuracy"), ("Debug","debug")]
    )

In [None]:
# Structure fragments.
fragment_list = FragmentList()
ipw.dlink((uks, 'value'), (fragment_list, 'uks'))
_ = ipw.dlink((structure_selector.viewer, 'selection'), (fragment_list, 'selection_string'), transform=lambda x: awb.utils.list_to_string_range(x))


In [None]:
# Coordinate structure selector with spin up/down and fragments.
def update_selection(change):
    if change['new']:
        # Add default "all" fragment.
        fragment_list.fragments = [Fragment(indices=f"1..{len(change['new'])}", name="all")]

    else:
        fragment_list.fragments = []

structure_selector.observe(update_selection, names='structure')

In [None]:
# Resources estimation.
MAX_NODES=48

def update_resources_for_fragments(_):


    for fragment in fragment_list.fragments:
        fragment.estimate_computational_resources(whole_structure=structure_selector.structure, selected_code=computational_resources.value)

estimate_nodes_button = ipw.Button(description="Estimate resources", button_style='warning')
estimate_nodes_button.on_click(update_resources_for_fragments)
node_estimate_message = awb.utils.StatusHTML()

In [None]:
# Code selector
computational_resources = awb.ComputationalResourcesWidget(description="CP2K code:", default_calc_job_plugin="cp2k")

In [None]:
# Workchain submission.

def get_builder():
    """Get the builder for the adsorption energy calculation."""
    builder = Cp2kFragmentSeparationWorkChain.get_builder()
    builder.code = load_code(computational_resources.value)
    builder.structure = structure_selector.structure_node

    # Fragments' indices.
    builder.fragments = {fragment.name.value: List(list=awb.utils.string_range_to_list(fragment.indices.value)[0]) for fragment in fragment_list.fragments}
    
    # Fragments' charges.
    builder.charges = {fragment.name.value: Int(fragment.charge.value) for fragment in fragment_list.fragments}

    # Resources.
    builder.options = {
        fragment.name.value: {
            "max_wallclock_seconds": fragment.resources.walltime_seconds,
            "resources": {
                "num_machines": fragment.resources.nodes,
                "num_mpiprocs_per_machine": fragment.resources.tasks_per_node,
                "num_cores_per_mpiproc": fragment.resources.threads_per_task,
            },
        }
        for fragment in fragment_list.fragments
    }

    # UKS.
    builder.uks = Bool(uks.value)
    if uks.value:
        builder.multiplicities = {fragment.name.value: Int(fragment.multiplicity.value) for fragment in fragment_list.fragments}

        # Spin guesses
        mag_list = [ 0 for t in structure_selector.structure ]
        for i in awb.utils.string_range_to_list(spins_up.value)[0]:
            mag_list[i] = 1
        for i in awb.utils.string_range_to_list(spins_down.value)[0]:
            mag_list[i] = -1 

        builder.magnetization_per_site = List(list=mag_list)


    builder.fixed_atoms = List(list=awb.utils.string_range_to_list(fixed_atoms.value)[0])
    builder.protocol = Str(protocol.value)    
    return builder

btn_submit_ads = awb.SubmitButtonWidget(Cp2kFragmentSeparationWorkChain, inputs_generator=get_builder)

In [None]:
# User's interface.

spins = ipw.VBox(children=[])
def update_view(_=None):
    to_display = []
    if uks.value:
        spins.children = [ipw.HBox([spins_up, spins_down])]
    else:
        spins.children = []
uks.observe(update_view, names='value')

display(structure_selector, fragment_list, uks, spins, protocol, fixed_atoms, ipw.HBox([computational_resources, estimate_nodes_button]), node_estimate_message, btn_submit_ads)