# Submit GW calculation for molecules followed by the IC correction

In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

In [None]:
# General imports.
import ipywidgets as ipw
from IPython.display import clear_output

# AiiDA imports.
%load_ext aiida
%aiida
# AiiDAlab imports.
import aiidalab_widgets_base as awb
from aiida import orm, plugins

# Custom imports.
from surfaces_tools.widgets import (
    analyze_structure,
    cdxml,
    computational_resources,
    inputs,
)

In [None]:
Cp2kAdsorbedGwIcWorkChain = plugins.WorkflowFactory("nanotech_empa.cp2k.ads_gw_ic")

In [None]:
# Structure selector.
input_details = inputs.InputDetails()

structure_selector = awb.StructureManagerWidget(
    importers=[
        awb.StructureBrowserWidget(title="AiiDA database"),
        awb.StructureUploadWidget(title="Import from computer"),
        awb.SmilesWidget(title="From SMILES"),
        cdxml.CdxmlUpload2GnrWidget(title="CDXML"),
    ],
    storable=False,
    node_class="StructureData",
)
display(structure_selector)
ipw.dlink((structure_selector, "structure"), (input_details, "structure"))

# Code.
resources = awb.ComputationalResourcesWidget(
    description="CP2K code:", default_calc_job_plugin="cp2k"
)

In [None]:
# gw_type
gw_type = ipw.Dropdown(
    options=["gpw_std", "gapw_std", "gapw_hq"],
    value="gpw_std",
    description="Protocol:",
    disabled=False,
)


# Multiplicity
multiplicity = ipw.IntText(value=1, description="Multiplicity", disabled=False)
# UKS
uks = ipw.Checkbox(value=False, description="UKS", disabled=False, indent=False)

# spin guess
spins_up = ipw.Text(
    value="", placeholder="1 2 10..13", description="Spins U:", disabled=False
)

spins_down = ipw.Text(
    value="", placeholder="3 4 14..17", description="Spins D:", disabled=False
)

# geometry mode
geo_mode = ipw.Dropdown(
    options=["ads_geo", "gas_opt"],
    value="ads_geo",
    description="Geo mode:",
    disabled=False,
)

# geometry mode
ads_height = ipw.FloatText(
    value=0.0,
    description="adsorption height in Å:",
    disabled=False,
    style={"description_width": "initial"},
)
# description
description = ipw.Text(
    value="",
    placeholder="Calculation description",
    description="Description:",
    disabled=False,
    style={"description_width": "initial"},
)


# description adsorption height
html_ads_height = ipw.HTML(
    value="""
                <p style="font-weight:400;">If you specify a value >0
                    <font style="font-style:italic;font-weight:400;">(mandatory in case there is no slab)</font> it will override the height extracted from the geometry.
                </p>
                <p>The substrate surface is at <font style="font-style:italic;font-weight:600;">geometric center of the molecule - adsorption height</font></p>
               """
)

In [None]:
def get_builder_gw():
    builder = Cp2kAdsorbedGwIcWorkChain.get_builder()
    builder.metadata.description = description.value
    builder.code = orm.load_code(resources.value)
    builder.geometry_mode = orm.Str(geo_mode.value)

    # Override automatic adsorption height.
    if ads_height.value > 0.0:
        builder.ads_height = orm.Float(ads_height.value)

    ase_geom = structure_selector.structure

    # Spin guess.
    mag_list = [0 for t in ase_geom]
    if uks.value:
        for i in awb.utils.string_range_to_list(spins_up.value)[0]:
            mag_list[i] = 1
        for i in string_range_to_list(spins_down.value)[0]:
            mag_list[i] = -1

        builder.multiplicity = orm.Int(multiplicity.value)

    builder.structure = structure_selector.structure_node
    builder.magnetization_per_site = orm.List(list=mag_list)

    builder.protocol = orm.Str(gw_type.value)

    builder.geometry_mode = orm.Str(geo_mode.value)

    builder.options.scf = {
        "max_wallclock_seconds": resources_scf.walltime_seconds,
        "resources": {
            "num_machines": resources_scf.nodes,
            "num_mpiprocs_per_machine": resources_scf.tasks_per_node,
            "num_cores_per_mpiproc": resources_scf.threads_per_task,
        },
    }
    builder.options.gw = {
        "max_wallclock_seconds": resources_gw.walltime_seconds,
        "resources": {
            "num_machines": resources_gw.nodes,
            "num_mpiprocs_per_machine": resources_gw.tasks_per_node,
            "num_cores_per_mpiproc": resources_gw.threads_per_task,
        },
    }
    builder.options.ic = {
        "max_wallclock_seconds": resources_ic.walltime_seconds,
        "resources": {
            "num_machines": resources_ic.nodes,
            "num_mpiprocs_per_machine": resources_ic.tasks_per_node,
            "num_cores_per_mpiproc": resources_ic.threads_per_task,
        },
    }

    return builder

In [None]:
def after_submission(_=None):
    structure_selector.value = None


btn_submit_gw = awb.SubmitButtonWidget(
    Cp2kAdsorbedGwIcWorkChain, inputs_generator=get_builder_gw
)
btn_submit_gw.btn_submit.disabled = True

btn_submit_gw.on_submitted(after_submission)

In [None]:
output = ipw.Output()


def update_all(_=None):
    btn_submit_gw.btn_submit.disabled = False
    # check system
    only_one_molecule = input_details.details["system_type"] == "SlabXY"
    only_one_molecule = (
        only_one_molecule or input_details.details["system_type"] == "Molecule"
    )
    only_one_molecule = (
        only_one_molecule and len(input_details.details["all_molecules"]) == 1
    )
    msg = "GW for this system not implemented"
    if only_one_molecule:
        btn_submit_gw.btn_submit.disabled = False
        msg = ""
    spins_up.value = awb.utils.list_to_string_range(input_details.details["spins_up"])
    spins_down.value = awb.utils.list_to_string_range(
        input_details.details["spins_down"]
    )
    with output:
        clear_output()
        print(msg)

        if uks.value:
            to_display = [
                gw_type,
                geo_mode,
                ipw.HBox([ads_height, html_ads_height]),
                uks,
                spins_up,
                spins_down,
                multiplicity,
                description,
            ]
        else:
            to_display = [
                gw_type,
                geo_mode,
                ipw.HBox([ads_height, html_ads_height]),
                uks,
            ]
        display(ipw.VBox(to_display))


structure_selector.observe(update_all, names="structure")
uks.observe(update_all, names="value")

In [None]:
# Resources.
resources_scf = computational_resources.ProcessResourcesWidget()
resources_gw = computational_resources.ProcessResourcesWidget()
resources_ic = computational_resources.ProcessResourcesWidget()

# Resources estimation.
resources_estimation_scf = computational_resources.ResourcesEstimatorWidget(
    calculation_type="dft"
)
resources_estimation_gw = computational_resources.ResourcesEstimatorWidget(
    calculation_type="gw"
)
resources_estimation_ic = computational_resources.ResourcesEstimatorWidget(
    calculation_type="gw_ic"
)

# Link resources widgets.
resources_estimation_scf.link_to_resources_widget(resources_scf)
resources_estimation_gw.link_to_resources_widget(resources_gw)
resources_estimation_ic.link_to_resources_widget(resources_ic)


def extract_details_on_adsorbed_molecule(all_atoms):
    structure_analyzer = analyze_structure.StructureAnalyzer()
    structure_analyzer.structure = all_atoms
    if "all_molecules" in structure_analyzer.details:
        molecules_indices = [
            item
            for sublist in structure_analyzer.details["all_molecules"]
            for item in sublist
        ]
        structure_analyzer.structure = all_atoms[molecules_indices]
        return structure_analyzer.details
    else:
        return {}


# Link viewer to resources estimation widgets.
ipw.dlink(
    (structure_selector, "structure"),
    (resources_estimation_scf, "details"),
    transform=extract_details_on_adsorbed_molecule,
)
ipw.dlink(
    (structure_selector, "structure"),
    (resources_estimation_gw, "details"),
    transform=extract_details_on_adsorbed_molecule,
)
ipw.dlink(
    (structure_selector, "structure"),
    (resources_estimation_ic, "details"),
    transform=extract_details_on_adsorbed_molecule,
)

# Link code selector to resources estimation widgets.
_ = ipw.dlink((resources, "value"), (resources_estimation_scf, "selected_code"))
_ = ipw.dlink((resources, "value"), (resources_estimation_gw, "selected_code"))
_ = ipw.dlink((resources, "value"), (resources_estimation_ic, "selected_code"))

# Link UKS
_ = ipw.dlink((uks, "value"), (resources_estimation_scf, "uks"))
_ = ipw.dlink((uks, "value"), (resources_estimation_gw, "uks"))
_ = ipw.dlink((uks, "value"), (resources_estimation_ic, "uks"))

# Estimate all resources
estimate_nodes_button = ipw.Button(
    description="Estimate resources", button_style="warning"
)
estimate_nodes_button.on_click(resources_estimation_scf.estimate_resources)
estimate_nodes_button.on_click(resources_estimation_gw.estimate_resources)
estimate_nodes_button.on_click(resources_estimation_ic.estimate_resources)

In [None]:
display(
    output,
    description,
    ipw.HBox(
        [
            ipw.VBox(
                [
                    ipw.HTML("DFT resources"),
                    resources_scf,
                ]
            ),
            ipw.VBox(
                [
                    ipw.HTML("GW resources"),
                    resources_gw,
                ]
            ),
            ipw.VBox(
                [
                    ipw.HTML("IC resources"),
                    resources_ic,
                ]
            ),
        ]
    ),
    estimate_nodes_button,
    resources,
    btn_submit_gw,
)