In [2]:
%aiida

In [3]:
from aiida_cp2k.calculations import Cp2kCalculation

from aiida.orm import StructureData
from aiida.orm import ArrayData
from aiida.engine import submit, run_get_node

from aiidalab_widgets_base import CodeDropdown, SubmitButtonWidget, StructureBrowserWidget
from aiidalab_widgets_base import ComputerDropdown

import ase
import ase.io
import numpy as np
import nglview
from copy import deepcopy
from pprint import pprint

import ipywidgets as ipw
from IPython.display import display, clear_output, HTML

from aiida.common.exceptions import MissingEntryPointError 
try:
    Cp2kPdosWorkChain = WorkflowFactory('nanotech_empa.cp2k.pdos')
except MissingEntryPointError as e:
    print("Entry point not found. Did you perhaps forget to set up the plugins under 'Setup codes'?")
    raise e

from apps.scanning_probe import common
from apps.scanning_probe.metadata_widget import MetadataWidget


from apps.scanning_probe import analyze_structure
from apps.scanning_probe.viewer_details import ViewerDetails

ModuleNotFoundError: No module named 'scanning_probe'

# Select structure

In [None]:
atoms = None
slab_analyzed = None

def on_struct_change(c):
    global atoms, slab_analyzed
    structure = struct_browser.results.value
    if structure:
        atoms = structure.get_ase()
        atoms.pbc = [1, 1, 1]
        
        slab_analyzed = analyze_structure.analyze(atoms)
        viewer_widget.setup(atoms, slab_analyzed)
        
        guess_molecule()
        
        #cell_text.value = " ".join([str(c) for c in np.diag(atoms.cell)])
        if structure.creator is not None:
            text_calc_description.value = structure.creator.description

    
struct_browser = StructureBrowserWidget()
struct_browser.results.observe(on_struct_change, names='value')

viewer_widget = ViewerDetails()

display(ipw.VBox([struct_browser, viewer_widget]))

In [None]:
# Description for the calculation (try to read from the structure creator)

text_calc_description = ipw.Text(description='Description:', layout={'width': '45%'})
display(text_calc_description)

# Select molecule

The atoms of the molecule need to be ordered in front of the data structure for the automatic detection to work.

In [None]:
# PDOS selection addition and removal
def remove_from_tuple(tup, index):
    tmp_list = list(tup)
    del tmp_list[index]
    return tuple(tmp_list)

def rm_pdos_selection(b):
    rm_index = sel_rm_list.index(b)
    del sel_text_list[rm_index]
    del sel_hl_list[rm_index]
    del sel_rm_list[rm_index]
    
    sel_vbox.children = remove_from_tuple(sel_vbox.children, rm_index-1)

def add_pdos_selection(b):
    sel_text_list.append(ipw.Text(description='pdos selection',
                         style=style, layout={'width': '50%'}))
    hl_btn = ipw.Button(description='highlight', layout={'width': '10%'})
    hl_btn.on_click(highlight)
    rm_btn = ipw.Button(description='remove', layout={'width': '10%'})
    rm_btn.on_click(rm_pdos_selection)
    sel_hl_list.append(hl_btn)
    sel_rm_list.append(rm_btn)
    sel_vbox.children += (ipw.HBox([sel_text_list[-1], hl_btn, rm_btn]), )

In [None]:
def parse_cp2k_selection_string(sel_str):
    split = sel_str.split()
    sel_indexes = []
    for prt in split:
        if '..' in prt:
            i_start = int(prt.split('..')[0])
            i_end = int(prt.split('..')[1]) + 1
            sel_indexes += list(np.arange(i_start, i_end))
        else:
            sel_indexes += [int(prt)]
    return np.array(list(set(sel_indexes))) - 1

In [None]:
def guess_molecule():
    try:
        ase_struct = struct_browser.results.value.get_ase()
        first_slab_atom = np.argwhere( (ase_struct.numbers == 29) |
                                       (ase_struct.numbers == 47) |
                                       (ase_struct.numbers == 79)
                                     )[0, 0] + 1
        if first_slab_atom <= 1:
            first_slab_atom = 2
        mol_selection.value="1..%d" % (first_slab_atom-1)
    except:
        print("Unable to automatically find slab and molecule")

def get_mol_ase():
    mol_inds = parse_cp2k_selection_string(mol_selection.value)
    ase_struct = struct_browser.results.value.get_ase()
    return ase_struct[mol_inds]

def highlight(b):
    sel_index = sel_hl_list.index(b)
    try:
        highlight_indexes = parse_cp2k_selection_string(sel_text_list[sel_index].value)
    except:
        print('Invalid selection string')
        highlight_indexes = None
    viewer_widget.reset()
    viewer_widget.highlight_atoms(highlight_indexes, color='green', size=0.3, opacity=0.4)
    
def reset_state():
    mol_selection.value = "1..10"
    sel_text_list = [mol_selection]
    sel_hl_list = [highlight_btn]
    sel_rm_list = [None]
    sel_vbox = ipw.VBox([])

style = {'description_width': '120px'}
layout = {'width': '60%'}

mol_selection = ipw.Text(description='Molecule selection',
                         value="1..10",
                         style=style, layout={'width': '50%'})

highlight_btn = ipw.Button(description='highlight',
                            layout={'width': '10%'})

highlight_btn.on_click(highlight)

add_sel_btn = ipw.Button(description='Add pdos selection',
                            layout={'width': '15%'})
add_sel_btn.on_click(add_pdos_selection)

sel_text_list = [mol_selection]
sel_hl_list = [highlight_btn]
sel_rm_list = [None]
sel_vbox = ipw.VBox([])

display(ipw.HBox([mol_selection, highlight_btn]), sel_vbox, add_sel_btn)

# Select computer and codes

In [None]:
style = {'description_width': '140px'}
layout = {'width': '70%'}

computer_drop = ComputerDropdown()

def on_computer_change(c):
    global cp2k_codes, eval_morb_codes, overlap_codes
    if computer_drop.selected_computer is not None:
        cp2k_codes = common.comp_plugin_codes(computer_drop.selected_computer.name, 'cp2k')
        overlap_codes = common.comp_plugin_codes(computer_drop.selected_computer.name, 'spm.overlap')

        drop_cp2k.options = [c.label for c in cp2k_codes]
        drop_over.options = [c.label for c in overlap_codes]
    
    
computer_drop._dropdown.observe(on_computer_change)

drop_cp2k = ipw.Dropdown(description="Cp2k code")

drop_over = ipw.Dropdown(description="Overlap code")

on_computer_change(0)

elpa_check = ipw.Checkbox(
    value=True,
    description='use ELPA',
    disabled=False
)

display(computer_drop, drop_cp2k, drop_over, elpa_check)

# DFT parameters

In [None]:
style = {'description_width': '80px'}
layout = {'width': '70%'}

def enable_spin(b):
    for w in [spin_up_text, spin_dw_text, 
    vis_spin_button, multiplicity_text,
    sc_diag_check, force_multiplicity_check]:
        w.disabled = not uks_switch.value

def visualize_spin_guess(b):
    try:
        spin_up = parse_cp2k_selection_string(spin_up_text.value)
        spin_dw = parse_cp2k_selection_string(spin_dw_text.value)
    except:
        print('Invalid selection string')
        highlight_indexes = None
        spin_up = []
        spin_dw = []
    viewer_widget.reset()
    viewer_widget.highlight_atoms(spin_up, color='red', size=0.3, opacity=0.4)
    viewer_widget.highlight_atoms(spin_dw, color='blue', size=0.3, opacity=0.4)

uks_switch = ipw.ToggleButton(value=False,
                              description='Spin-polarized calculation',
                              style=style, layout={'width': '450px'})
uks_switch.observe(enable_spin, names='value')

spin_up_text = ipw.Text(placeholder='1..5 7',
                        description='Spin up',
                        disabled=True,
                        style=style, layout={'width': '370px'})
spin_dw_text = ipw.Text(placeholder='11..15 17',
                        description='Spin down',
                        disabled=True,
                        style=style, layout={'width': '370px'})
vis_spin_button = ipw.Button(description="Visualize",
                             disabled=True,
                             style = {'description_width': '0px'}, layout={'width': '75px'})
vis_spin_button.on_click(visualize_spin_guess)

multiplicity_text = ipw.IntText(value=1,
                           description='Multiplicity',
                           disabled=True,
                           style=style, layout={'width': '20%'})
sc_diag_check = ipw.Checkbox(
    value=False,
    description='self-consistent diagonalization',
    disabled=True
)
force_multiplicity_check = ipw.Checkbox(
    value=True,
    description='Fix multiplicity',
    disabled=True
)

protocol = ipw.Dropdown( 
            value="standard",
            options=[("Standard", "standard"), ("Low accuracy", "low_accuracy")],
            description="Protocol:",
            style={"description_width": "120px"},
        )
display(uks_switch, ipw.HBox([ipw.VBox([spin_up_text, spin_dw_text]), vis_spin_button]), 
        ipw.HBox([multiplicity_text,sc_diag_check,force_multiplicity_check]),protocol)

In [None]:
def enable_smearing(b):
    temperature_text.disabled = not smear_switch.value

smear_switch = ipw.ToggleButton(value=False,
                              description='Enable Fermi-Dirac smearing',
                              style=style, layout={'width': '450px'})
smear_switch.observe(enable_smearing, names='value')

temperature_text = ipw.FloatText(value=300.0,
                           description='Temperature [K]',
                           disabled=True,
                           style={'description_width': '100px'}, layout={'width': '20%'})


display(smear_switch, temperature_text)

# Orbital overlap parameters

In [None]:
style = {'description_width': '140px'}
layout_small = {'width': '30%'}

num_homo_inttext = ipw.IntText(
                    value=4,
                    description="num gas HOMO",
                    style=style, layout=layout_small)
num_lumo_inttext = ipw.IntText(
                    value=4,
                    description="num gas LUMO",
                    style=style, layout=layout_small)

elim_inttext = ipw.IntText(
                    value=4,
                    description="num gas LUMO",
                    style=style, layout=layout_small)

emin_floattext = ipw.BoundedFloatText(
                        description='Slab system Emin (eV)',
                        min=-3.0,
                        max=-0.1,
                        step=0.1,
                        value=-2.0,
                        style=style, layout=layout_small)

emax_floattext = ipw.BoundedFloatText(
                        description='Slab system Emax (eV)',
                        min=0.1,
                        max=3.0,
                        step=0.1,
                        value=2.0,
                        readout_format='%.2f',
                        style=style, layout=layout_small)

display(num_homo_inttext, num_lumo_inttext, emin_floattext, emax_floattext)

# Submission

In [None]:
def on_submit(b):
    with submit_out:
        clear_output()
        if not struct_browser.results.value:
            print("Please select a structure.")
            return
        if not computer_drop.selected_computer:
            print("Please select a computer.")
            return
        
        if not drop_cp2k.value or not drop_over.value:
            print("Please select all the codes.")
            return
        
        dft_params_dict = {
            'elpa_switch':       elpa_check.value,
            'uks':               uks_switch.value,
            'sc_diag':           sc_diag_check.value,
            'force_multiplicity': force_multiplicity_check.value,
            'protocol':          protocol.value,
        }
        if uks_switch.value:
            dft_params_dict['spin_up_guess'] = parse_cp2k_selection_string(spin_up_text.value)
            dft_params_dict['spin_dw_guess'] = parse_cp2k_selection_string(spin_dw_text.value)
            dft_params_dict['multiplicity']  = multiplicity_text.value
            dft_params_dict['force_multiplicity']  = force_multiplicity_check.value

        if smear_switch.value:
            dft_params_dict['smear_t'] = temperature_text.value            
        
        dft_params = Dict(dict=dft_params_dict)
        
        overlap_params = Dict(dict={
            '--cp2k_input_file1':  'parent_slab_folder/aiida.inp',
            '--basis_set_file1':   'parent_slab_folder/BASIS_MOLOPT',
            '--xyz_file1':         'parent_slab_folder/aiida.coords.xyz',
            '--wfn_file1':         'parent_slab_folder/aiida-RESTART.wfn',
            '--emin1':             str(emin_floattext.value),
            '--emax1':             str(emax_floattext.value),
            '--cp2k_input_file2':  'parent_mol_folder/aiida.inp',
            '--basis_set_file2':   'parent_mol_folder/BASIS_MOLOPT',
            '--xyz_file2':         'parent_mol_folder/aiida.coords.xyz',
            '--wfn_file2':         'parent_mol_folder/aiida-RESTART.wfn',
            '--nhomo2':            str(num_homo_inttext.value),
            '--nlumo2':            str(num_lumo_inttext.value),
            '--output_file':       './overlap.npz',
            '--eval_region':       ['G', 'G', 'G', 'G', 'n-3.0_C', 'p2.0'],
            '--dx':                '0.2',
            '--eval_cutoff':       '14.0'
        })
        
        cp2k_code = cp2k_codes[drop_cp2k.index]
        overlap_code = overlap_codes[drop_over.index]
        
        struct = struct_browser.results.value
        
        ## Try to access the restart-wfn file ##
        selected_comp = cp2k_code.computer
        wfn_file_path = ""
        #try:
        #    wfn_file_path = common.find_struct_wf(struct, selected_comp)
        #except:
        #    wfn_file_path = ""
        #if wfn_file_path == "":
        #    print("Info: didn't find any accessible .wfn file.")
        
        mol_structure = StructureData(ase=get_mol_ase())
        mol_structure.description = "mol from slab PK%d" %  struct.pk
        
        # 
        aiida_pdos_list = List()
        aiida_pdos_list.extend([sel_text.value for sel_text in sel_text_list])
        
        node = submit(
            Cp2kPdosWorkChain,
            cp2k_code=cp2k_code,
            slabsys_structure=struct,
            mol_structure=mol_structure,
            pdos_lists=aiida_pdos_list,
            wfn_file_path=Str(wfn_file_path),
            dft_params=dft_params,
            overlap_code=overlap_code,
            overlap_params=overlap_params,
            metadata={'description': text_calc_description.value,
                      'label': 'Cp2kPdosWorkchain',}
        )
                
        # set calculation version; also used to determine post-processing
        node.set_extra("version", 0)
        
        print()
        print("Submitted:")
        print(node)

btn_submit = ipw.Button(description="Submit")
btn_submit.on_click(on_submit)
submit_out = ipw.Output()
display(btn_submit, submit_out)