#  Submit GW for molecules followed by IC correction.

In [None]:
# General imports.
import ipywidgets as ipw
from IPython.display import clear_output
from traitlets import dlink

# AiiDA/AiiDAlab imports.
%aiida
from aiida.plugins import WorkflowFactory
import aiidalab_widgets_base as awb

# Custom imports.
from widgets.empa_viewer import EmpaStructureViewer
from widgets.import_cdxml import CdxmlUpload2GnrWidget
from widgets.inputs import get_nodes

In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

In [None]:
Cp2kAdsorbedGwIcWorkChain = WorkflowFactory('nanotech_empa.cp2k.ads_gw_ic')

In [None]:
# Structure selector.
empa_viewer = EmpaStructureViewer()

structure_selector = awb.StructureManagerWidget(
    viewer=empa_viewer,
    importers=[
        awb.StructureBrowserWidget(title="AiiDA database"),
        awb.StructureUploadWidget(title="Import from computer"),
        awb.SmilesWidget(title="From SMILES"),
        CdxmlUpload2GnrWidget(title="CDXML"),
    ],
    storable=False, node_class='StructureData')
display(structure_selector)

# Code.
computer_code_dropdown = awb.CodeDropdown(input_plugin='cp2k')


In [None]:
# all widgets

In [None]:
#gw_type
gw_type = ipw.Dropdown(
    options=['gpw_std', 'gapw_std', 'gapw_hq'],
    value='gpw_std',
    description='Protocol:',
    disabled=False,
)


#Multiplicity
multiplicity = ipw.IntText(
    value=1,
    description='Multiplicity',
    disabled=False
)
#UKS
uks = ipw.Checkbox(
    value=False,
    description='UKS',
    disabled=False,
    indent=False
)

#spin guess
spins_up = ipw.Text(
    value='',
    placeholder='1 2 10..13',
    description='Spins U:',
    disabled=False
)

spins_down = ipw.Text(
    value='',
    placeholder='3 4 14..17',
    description='Spins D:',
    disabled=False
)

#geometry mode
geo_mode = ipw.Dropdown(
    options=['ads_geo', 'gas_opt'],
    value='ads_geo',
    description='Geo mode:',
    disabled=False,
)

#geometry mode
ads_height = ipw.FloatText(
    value=0.0,
    description='adsorption height in Å:',
    disabled=False,
    style={"description_width": "initial"}
)
#description
description = ipw.Text(
    value='',
    placeholder='Calculation description',
    description='Description:',
    disabled=False,
    style={"description_width": "initial"}
)
#scf resources
scf_nodes = ipw.IntText(
    value=1,
    description='#nodes',
    disabled=False,
    style={"description_width": "initial"}
)
scf_tasks_per_node = ipw.IntText(
    value=1,
    description='#tasks/node',
    disabled=False,
    style={"description_width": "initial"}
)
scf_threads_per_task = ipw.IntText(
    value=1,
    description='#threads/task',
    disabled=False,
    style={"description_width": "initial"}
)

#gw resources
gw_nodes = ipw.IntText(
    value=1,
    description='#nodes',
    disabled=False,
    style={"description_width": "initial"}
)
gw_tasks_per_node = ipw.IntText(
    value=1,
    description='#tasks/node',
    disabled=False,
    style={"description_width": "initial"}
)
gw_threads_per_task = ipw.IntText(
    value=1,
    description='#threads/task',
    disabled=False,
    style={"description_width": "initial"}
)
#ic resources
ic_nodes = ipw.IntText(
    value=1,
    description='#nodes',
    disabled=False,
    style={"description_width": "initial"}
)
ic_tasks_per_node = ipw.IntText(
    value=1,
    description='#tasks/node',
    disabled=False,
    style={"description_width": "initial"}
)
ic_threads_per_task = ipw.IntText(
    value=1,
    description='#threads/task',
    disabled=False,
    style={"description_width": "initial"}
)
#walltime
walltime = ipw.IntText(
    value=3600,
    description='Walltime seconds',
    disabled=False,
    style={"description_width": "initial"}
)
#description adsorption height
html_ads_height = ipw.HTML(
                    value="""
                <p style="font-weight:400;">If you specify a value >0
                    <font style="font-style:italic;font-weight:400;">(mandatory in case there is no slab)</font> it will override the height extracted from the geometry.
                </p>
                <p>The substrate surface is at <font style="font-style:italic;font-weight:600;">geometric center of the molecule - adsorption height</font></p>
               """
                )

def update_resources(_):
        try:
            max_tasks_per_node = (
                computer_code_dropdown.selected_code.computer.get_default_mpiprocs_per_machine()
            )
        except AttributeError:
            max_tasks_per_node = None
        if max_tasks_per_node is None:
            max_tasks_per_node = 1

        try:
            systype = empa_viewer.details["system_type"]
            element_list = empa_viewer.details["all_elements"]
        except KeyError:
            systype = "Other"
            element_list = []

        if "Slab" in systype:
            systype = "Slab"
        calctype = systype + "-DFT"

        (scf_nodes.value,scf_tasks_per_node.value,scf_threads_per_task.value) = get_nodes(
            element_list=element_list,
            calctype='Molecule-DFT',
            systype='Molecule',
            max_tasks_per_node=max_tasks_per_node,
            uks=uks.value,
        )
        (gw_nodes.value,gw_tasks_per_node.value,gw_threads_per_task.value) = get_nodes(
            element_list=element_list,
            calctype='gw',
            systype='Molecule',
            max_tasks_per_node=max_tasks_per_node,
            uks=uks.value,
        )
        (ic_nodes.value,ic_tasks_per_node.value,ic_threads_per_task.value) = get_nodes(
            element_list=element_list,
            calctype='gw_ic',
            systype='Molecule',
            max_tasks_per_node=max_tasks_per_node,
            uks=uks.value,
        )

estimate_nodes_button = ipw.Button(description="Estimate resources", button_style='warning')
estimate_nodes_button.on_click(update_resources)

In [None]:
def get_builder_gw():
    
    builder = Cp2kAdsorbedGwIcWorkChain.get_builder()

    builder.metadata.description = description.value
    builder.code = computer_code_dropdown.selected_code
    
    builder.geometry_mode = Str(geo_mode.value)

   
    #override automatic adsorption height
    if ads_height.value >0.0:
        builder.ads_height = Float(ads_height.value)

    ase_geom =  structure_selector.structure        

    # spin guess
    mag_list = [ 0 for t in ase_geom ]
    if uks.value:
        for i in awb.utils.string_range_to_list(spins_up.value)[0]:
            mag_list[i] = 1
        for i in string_range_to_list(spins_down.value)[0]:
            mag_list[i] = -1 
            
        builder.multiplicity = Int(multiplicity.value)
        
    #tags = [abs(i) if i<=0 else i+1 for i in mag_list]   
    #ase_geom.set_tags(tags)
        
    builder.structure = structure_selector.structure_node #StructureData(ase=ase_geom)
    builder.magnetization_per_site = List(list=mag_list)

    builder.protocol = Str(gw_type.value)


    builder.geometry_mode = Str(geo_mode.value)

    builder.resources_scf = Dict(dict={
        "num_machines": scf_nodes.value,
        "num_mpiprocs_per_machine": scf_tasks_per_node.value,
        "num_cores_per_mpiproc": scf_threads_per_task.value,
    })
    builder.resources_gw = Dict(dict={
        "num_machines": gw_nodes.value,
        "num_mpiprocs_per_machine": gw_tasks_per_node.value,
        "num_cores_per_mpiproc": gw_threads_per_task.value,
    })
    builder.resources_ic = Dict(dict={
        "num_machines": ic_nodes.value,
        "num_mpiprocs_per_machine": ic_tasks_per_node.value,
        "num_cores_per_mpiproc": ic_threads_per_task.value,
    })
    
    builder.walltime_seconds = Int(walltime.value)    
    
    return builder

In [None]:
def after_submission(_=None):   
    structure_selector.value = None



    
btn_submit_gw = awb.SubmitButtonWidget(Cp2kAdsorbedGwIcWorkChain, 
                                input_dictionary_function=get_builder_gw
                               )
btn_submit_gw.btn_submit.disabled=True

btn_submit_gw.on_submitted(after_submission)

In [None]:
output = ipw.Output()
def update_all(_=None):
    btn_submit_gw.btn_submit.disabled=False
    #check system
    only_one_molecule = empa_viewer.details['system_type'] == 'SlabXY' 
    only_one_molecule = only_one_molecule or empa_viewer.details['system_type'] == 'Molecule'  
    only_one_molecule = only_one_molecule and  len(empa_viewer.details['all_molecules'])==1    
    msg = 'GW for this system not implemented'
    if only_one_molecule:
        btn_submit_gw.btn_submit.disabled=False  
        msg = ''
    spins_up.value = awb.utils.list_to_string_range(empa_viewer.details['spins_up'])
    spins_down.value = awb.utils.list_to_string_range(empa_viewer.details['spins_down'])
    with output:
        clear_output()
        print(msg)
        scf_resources=ipw.HBox([scf_nodes, scf_tasks_per_node, scf_threads_per_task])
        gw_resources=ipw.HBox([gw_nodes, gw_tasks_per_node, gw_threads_per_task])
        ic_resources=ipw.HBox([ic_nodes, ic_tasks_per_node, ic_threads_per_task])
        if uks.value:
            to_display =  [computer_code_dropdown,gw_type,geo_mode,ipw.HBox([ads_height,html_ads_height]),
                           uks,spins_up,spins_down,multiplicity,description,ipw.HTML("DFT resources"),scf_resources,ipw.HTML("GW resources"),gw_resources,ipw.HTML("IC resources"),ic_resources,estimate_nodes_button,walltime,btn_submit_gw]
        else:
            to_display =  [computer_code_dropdown,gw_type,geo_mode,ipw.HBox([ads_height,html_ads_height]),
                           uks,description,ipw.HTML("DFT resources"),scf_resources,ipw.HTML("GW resources"),gw_resources,ipw.HTML("IC resources"),ic_resources,estimate_nodes_button,walltime,btn_submit_gw]
        display(ipw.VBox(to_display))


structure_selector.observe(update_all, names='structure')
uks.observe(update_all, names='value')

In [None]:
display(output)