In [None]:
%aiida

In [None]:
from aiida_cp2k.calculations import Cp2kCalculation

from aiida.orm import StructureData
from aiida.orm import ArrayData
from aiida.engine import submit, run_get_node

from aiidalab_widgets_base import CodeDropdown, SubmitButtonWidget, StructureBrowserWidget
from aiidalab_widgets_base import ComputerDropdown

import ase
import ase.io
import numpy as np
import nglview
from copy import deepcopy
from pprint import pprint

# AiiDAlab imports.
import aiidalab_widgets_base as awb

# Custom imports.
from widgets.empa_viewer import EmpaStructureViewer
from widgets.ANALYZE_structure import StructureAnalyzer
from aiida_nanotech_empa.workflows.cp2k import cp2k_utils

import ipywidgets as ipw
from IPython.display import display, clear_output, HTML

from aiida.common.exceptions import MissingEntryPointError 
try:
    Cp2kPdosWorkChain = WorkflowFactory('nanotech_empa.cp2k.pdos')
except MissingEntryPointError as e:
    print("Entry point not found. Did you perhaps forget to set up the plugins under 'Setup codes'?")
    raise e

from apps.scanning_probe import common

# Select structure

In [None]:
# Structure selector.

empa_viewer = EmpaStructureViewer()
structure_selector = awb.StructureManagerWidget(
    viewer=empa_viewer,
    importers=[awb.StructureUploadWidget(title="Import from computer"),
        awb.StructureBrowserWidget(title="AiiDA database"),
    ],
    editors = [
    ],
    storable=False, node_class='StructureData')
display(structure_selector)    

In [None]:
# Description for the calculation (try to read from the structure creator)

text_calc_description = ipw.Text(description='Description:', layout={'width': '45%'})
display(text_calc_description)

# Select molecule

The atoms of the molecule need to be ordered in front of the data structure for the automatic detection to work.

In [None]:
# PDOS selection addition and removal
def remove_from_tuple(tup, index):
    tmp_list = list(tup)
    del tmp_list[index]
    return tuple(tmp_list)

def rm_pdos_selection(b):
    rm_index = sel_rm_list.index(b)
    del sel_text_list[rm_index]
    del sel_ll_list[rm_index]
    del sel_rm_list[rm_index]
    
    sel_vbox.children = remove_from_tuple(sel_vbox.children, rm_index-1)

def add_pdos_selection(b):
    sel_text_list.append(ipw.Text(description='pdos selection',
                         style=style, layout={'width': '30%'}))
    rm_btn = ipw.Button(description='remove', layout={'width': '10%'})
    rm_btn.on_click(rm_pdos_selection)
    ll_txt = ipw.Text(description='label', layout={'width': '20%'})
    sel_rm_list.append(rm_btn)
    sel_ll_list.append(ll_txt)
    sel_vbox.children += (ipw.HBox([sel_text_list[-1], ll_txt, rm_btn]),)

In [None]:
def guess_molecule():
    try:
        mol.selection.value = awb.utils.list_to_string_range(empa_viewer.details['all_molecules'][0])
    except:
        print("Unable to automatically find slab and molecule")

def get_mol_ase():
    mol_inds = awb.utils.string_range_to_list(mol_selection.value)[0]
    ase_struct = structure_selector.structure
    return ase_struct[mol_inds]
    
def reset_state():
    mol_selection.value = "1..10"
    sel_text_list = [mol_selection]
    sel_rm_list = [None]
    sel_ll_list = []
    sel_vbox = ipw.VBox([])
    
def on_struct_change(c):
    print("changed")
    guess_molecule()
        

    #if structure.creator is not None:
    #    text_calc_description.value = structure.creator.description    

structure_selector.observe(on_struct_change, names='value')


In [None]:
style = {'description_width': '120px'}
layout = {'width': '60%'}

mol_selection = ipw.Text(description='Molecule selection',
                         value="1..10",
                         style=style, layout={'width': '30%'})


add_sel_btn = ipw.Button(description='Add pdos selection',
                            layout={'width': '15%'})
add_sel_btn.on_click(add_pdos_selection)

sel_text_list = [mol_selection]
sel_rm_list = [None]
sel_ll_list = []
sel_vbox = ipw.VBox([])

display(ipw.HBox([mol_selection]), sel_vbox, add_sel_btn)

# Select computer and codes

In [None]:
# Code selector
computational_resources = awb.ComputationalResourcesWidget(description="CP2K code:", input_plugin="cp2k")

style = {'description_width': '140px'}
layout = {'width': '70%'}


def on_computer_change(c):
    global  overlap_codes
    if computational_resources.value is not None:
        overlap_codes = common.comp_plugin_codes(computational_resources.value.computer.label, 'spm.overlap')
        
        drop_over.options = [c.label for c in overlap_codes]
    
    
computational_resources.observe(on_computer_change)

drop_over = ipw.Dropdown(description="Overlap code")

on_computer_change(0)

elpa_check = ipw.Checkbox(
    value=True,
    description='use ELPA',
    disabled=False
)

display(computational_resources, drop_over, elpa_check)

# DFT parameters

In [None]:
style = {'description_width': '80px'}
layout = {'width': '70%'}

def enable_spin(b):
    for w in [spin_up_text, spin_dw_text, 
     multiplicity_text,
    sc_diag_check, force_multiplicity_check]:
        w.disabled = not uks_switch.value

uks_switch = ipw.ToggleButton(value=False,
                              description='Spin-polarized calculation',
                              style=style, layout={'width': '450px'})
uks_switch.observe(enable_spin, names='value')

spin_up_text = ipw.Text(placeholder='1..5 7',
                        description='Spin up',
                        disabled=True,
                        style=style, layout={'width': '370px'})
spin_dw_text = ipw.Text(placeholder='11..15 17',
                        description='Spin down',
                        disabled=True,
                        style=style, layout={'width': '370px'})

multiplicity_text = ipw.IntText(value=1,
                           description='Multiplicity',
                           disabled=True,
                           style=style, layout={'width': '20%'})
sc_diag_check = ipw.Checkbox(
    value=False,
    description='self-consistent diagonalization',
    disabled=True
)
force_multiplicity_check = ipw.Checkbox(
    value=True,
    description='Fix multiplicity',
    disabled=True
)

protocol = ipw.Dropdown( 
            value="standard",
            options=[("Standard", "standard"), ("Low accuracy", "low_accuracy"),('Debug','debug')],
            description="Protocol:",
            style={"description_width": "120px"},
        )
display(uks_switch, ipw.HBox([ipw.VBox([spin_up_text, spin_dw_text])]), 
        ipw.HBox([multiplicity_text,sc_diag_check]),protocol)

In [None]:
def enable_smearing(b):
    temperature_text.disabled = not smear_switch.value
    force_multiplicity_check.disabled = not smear_switch.value

smear_switch = ipw.ToggleButton(value=False,
                              description='Enable Fermi-Dirac smearing',
                              style=style, layout={'width': '450px'})
smear_switch.observe(enable_smearing, names='value')

temperature_text = ipw.FloatText(value=150.0,
                           description='Temperature [K]',
                           disabled=True,
                           style={'description_width': '100px'}, layout={'width': '20%'})


display(smear_switch, ipw.HBox([temperature_text,force_multiplicity_check]))

# Orbital overlap parameters

In [None]:
style = {'description_width': '140px'}
layout_small = {'width': '30%'}

num_homo_inttext = ipw.IntText(
                    value=4,
                    description="num gas HOMO",
                    style=style, layout=layout_small)
num_lumo_inttext = ipw.IntText(
                    value=4,
                    description="num gas LUMO",
                    style=style, layout=layout_small)

elim_inttext = ipw.IntText(
                    value=4,
                    description="num gas LUMO",
                    style=style, layout=layout_small)

emin_floattext = ipw.BoundedFloatText(
                        description='Slab system Emin (eV)',
                        min=-3.0,
                        max=-0.1,
                        step=0.1,
                        value=-2.0,
                        style=style, layout=layout_small)

emax_floattext = ipw.BoundedFloatText(
                        description='Slab system Emax (eV)',
                        min=0.1,
                        max=3.0,
                        step=0.1,
                        value=2.0,
                        readout_format='%.2f',
                        style=style, layout=layout_small)

display(num_homo_inttext, num_lumo_inttext, emin_floattext, emax_floattext)

# Resources

In [None]:
# Resources estimation.
MAX_NODES=96

structure_analyzer = StructureAnalyzer()

def update_resources(_):

    if not (structure_selector.structure and computational_resources.value):
        node_estimate_message.message = """<span style="color:red"> Error:</span> Can't estimate resources: both structure and code must be selected."""
        return

   
    
    structure_analyzer.structure = substructure
    calctype = "slab" 
    resources = cp2k_utils.get_nodes(
        atoms=structure_selector.structure,
        calctype=calctype,
        computer=computational_resources.value.computer,
        max_nodes=MAX_NODES,
        uks=uks_switch.value
    )
    slab.nodes_widget.value = resources[0]
    slab.cpus_per_node_widget.value = resources[1]

estimate_nodes_button = ipw.Button(description="Estimate resources", button_style='warning')
estimate_nodes_button.on_click(update_resources)
node_estimate_message = awb.utils.StatusHTML()

In [None]:
# Resources widgets.
STYLE = {'description_width': '100px'}
slab_nodes_widget = ipw.IntText(
    description="# Nodes",
    value=1,
    min=1,
    style=STYLE
    )
slab_cpus_per_node_widget = ipw.IntText(
    description="# CPUs per node",
    value=1,
    min=1,
    style=STYLE
    )
slab_num_cores_per_mpiproc_widget = ipw.IntText(
    description="# thredas",
    value=1,
    min=0,
    style=STYLE
)    
slab_run_time_widget = ipw.IntText(
    description="Runtime (mins)",
    value=1440,
    min=0,
    style=STYLE
)

molecule_nodes_widget = ipw.IntText(
    description="# Nodes",
    value=1,
    min=1,
    style=STYLE
    )
molecule_cpus_per_node_widget = ipw.IntText(
    description="# CPUs per node",
    value=1,
    min=1,
    style=STYLE
    )
molecule_num_cores_per_mpiproc_widget = ipw.IntText(
    description="# thredas",
    value=1,
    min=0,
    style=STYLE
    )
molecule_run_time_widget = ipw.IntText(
    description="Runtime (mins)",
    value=600,
    min=0,
    style=STYLE
)


In [None]:
slab_res=ipw.VBox([ipw.HTML("<b>Slab</b>"),
                slab_nodes_widget,
                slab_cpus_per_node_widget,
                slab_num_cores_per_mpiproc_widget,
                slab_run_time_widget])
mol_res=ipw.VBox([ipw.HTML("<b>Molecule</b>"),
                molecule_nodes_widget,
                molecule_cpus_per_node_widget,
                molecule_num_cores_per_mpiproc_widget,
                molecule_run_time_widget])
display(ipw.VBox([ipw.HBox([slab_res,mol_res]),estimate_nodes_button]))

# Submission

In [None]:
submit_out = ipw.Output()
def get_builder():
    with submit_out:
        clear_output()
        if structure_selector.structure is None:
            print("Please select a structure.")
            return
        if computational_resources.value is  None:
            print("Please select codes.")
            return


        builder = Cp2kPdosWorkChain.get_builder()
        dft_params_dict = {
            'elpa_switch':       elpa_check.value,
            'uks':               uks_switch.value,
            'sc_diag':           sc_diag_check.value,
            'force_multiplicity': force_multiplicity_check.value,
            'periodic':          'XYZ',
            'protocol':          protocol.value,
        }
        if uks_switch.value:
            dft_params_dict['spin_up_guess'] = [int(v)-1 for v in spin_up_text.value.split()]
            dft_params_dict['spin_dw_guess'] = [int(v)-1 for v in spin_dw_text.value.split()]
            dft_params_dict['multiplicity']  = multiplicity_text.value
            dft_params_dict['force_multiplicity']  = force_multiplicity_check.value

        if smear_switch.value:
            dft_params_dict['smear_t'] = temperature_text.value            
        
        dft_params = Dict(dict=dft_params_dict)
        
        overlap_params = Dict(dict={
            '--cp2k_input_file1':  'parent_slab_folder/aiida.inp',
            '--basis_set_file1':   'parent_slab_folder/BASIS_MOLOPT',
            '--xyz_file1':         'parent_slab_folder/aiida.coords.xyz',
            '--wfn_file1':         'parent_slab_folder/aiida-RESTART.wfn',
            '--emin1':             str(emin_floattext.value),
            '--emax1':             str(emax_floattext.value),
            '--cp2k_input_file2':  'parent_mol_folder/aiida.inp',
            '--basis_set_file2':   'parent_mol_folder/BASIS_MOLOPT',
            '--xyz_file2':         'parent_mol_folder/aiida.coords.xyz',
            '--wfn_file2':         'parent_mol_folder/aiida-RESTART.wfn',
            '--nhomo2':            str(num_homo_inttext.value),
            '--nlumo2':            str(num_lumo_inttext.value),
            '--output_file':       './overlap.npz',
            '--eval_region':       ['G', 'G', 'G', 'G', 'n-3.0_C', 'p2.0'],
            '--dx':                '0.2',
            '--eval_cutoff':       '14.0'
        })
        
        cp2k_code = computational_resources.value
        overlap_code = overlap_codes[drop_over.index]
        
        struct = structure_selector.structure
        
        ## Try to access the restart-wfn file ##
        #selected_comp = cp2k_code.computer
        wfn_file_path = ""
        #try:
        #    wfn_file_path = common.find_struct_wf(struct, selected_comp)
        #except:
        #    wfn_file_path = ""
        #if wfn_file_path == "":
        #    print("Info: didn't find any accessible .wfn file.")
        
        mol_structure = StructureData(ase=get_mol_ase())
        mol_structure.description = "mol generated from pdos from slab uuid: %s" %  structure_selector.structure_node.uuid
        
        # 
        aiida_pdos_list = List()
        pdos_labels = ['molecule']
        for widget in sel_ll_list:
            pdos_labels.append(widget.value)
        aiida_pdos_list.extend([(sel_text_list[i].value,pdos_labels[i]) for i in range(len(sel_text_list))])
        
        

        builder.cp2k_code=cp2k_code
        builder.slabsys_structure=structure_selector.structure_node
        builder.mol_structure=mol_structure
        builder.pdos_lists=aiida_pdos_list
        builder.wfn_file_path=Str(wfn_file_path)
        builder.dft_params=dft_params
        builder.overlap_code=overlap_code
        builder.overlap_params=overlap_params

        # Resources.
        builder.options = {
            'slab': {
                "max_wallclock_seconds": slab_run_time_widget.value * 60,
                "resources": {
                    "num_machines": slab_nodes_widget.value,
                    "num_mpiprocs_per_machine": slab_cpus_per_node_widget.value,
                    "num_cores_per_mpiproc": slab_num_cores_per_mpiproc_widget.value,
                }
            },
            'molecule': {
                "max_wallclock_seconds": molecule_run_time_widget.value * 60,
                "resources": {
                    "num_machines": molecule_nodes_widget.value,
                    "num_mpiprocs_per_machine": molecule_cpus_per_node_widget.value,
                    "num_cores_per_mpiproc": molecule_num_cores_per_mpiproc_widget.value,
                }
            }            
            }
            
    return builder

btn_submit = awb.SubmitButtonWidget(Cp2kPdosWorkChain, input_dictionary_function=get_builder)
display(btn_submit, submit_out)