In [None]:
%aiida

In [None]:
from aiida_cp2k.calculations import Cp2kCalculation

from aiida.orm import StructureData
from aiida.orm import ArrayData
from aiida.engine import submit, run_get_node

from aiidalab_widgets_base import CodeDropdown, SubmitButtonWidget, StructureBrowserWidget
from aiidalab_widgets_base import ComputerDropdown

from apps.surfaces.widgets import analyze_structure

import ase
import ase.io
import numpy as np
import nglview
from copy import deepcopy
from pprint import pprint

import ipywidgets as ipw
from IPython.display import display, clear_output, HTML

from apps.scanning_probe.pdos.pdos_workchain import PdosWorkChain
from apps.scanning_probe import common
from apps.scanning_probe.metadata_widget import MetadataWidget

# Select structure

In [None]:
def on_struct_change(c):
    highlight_indexes = None
    reset_state()
    guess_molecule()
    update_view()
    
    
struct_browser = StructureBrowserWidget()
struct_browser.results.observe(on_struct_change, names='value')

viewer = nglview.NGLWidget()
clear_output()

display(ipw.VBox([struct_browser, viewer]))

highlight_indexes = None

def update_view():
    # remove old components
    if hasattr(viewer, "component_1"):
        viewer.component_1.clear_representations()
        viewer.component_1.remove_unitcell()
        cid = viewer.component_1.id
        viewer.remove_component(cid)
    if hasattr(viewer, "component_0"):
        viewer.component_0.clear_representations()
        viewer.component_0.remove_unitcell()
        cid = viewer.component_0.id
        viewer.remove_component(cid)

        
    structure = struct_browser.results.value
    if not structure:
        return
    
    atoms = structure.get_ase()
    mol_atom_indexes = parse_cp2k_selection_string(mol_selection.value)
    mol_atoms = atoms[mol_atom_indexes]
    slab_atom_indexes = [i for i in range(len(atoms)) if i not in mol_atom_indexes]
    slab_atoms = atoms[slab_atom_indexes]
    
    atoms.pbc = [1, 1, 1]
    #print("CELL:", atoms.cell)
    # add new component
    viewer.add_component(nglview.ASEStructure(mol_atoms)) # adds ball+stick
    viewer.add_component(nglview.ASEStructure(slab_atoms)) # adds ball+stick
    viewer.add_unitcell()
    viewer.center()
    
    viewer.component_0.remove_ball_and_stick()
    viewer.component_0.add_ball_and_stick(aspectRatio=1.5, scale=2.0, opacity=1.0)
    
    viewer.component_1.remove_ball_and_stick()
    viewer.component_1.add_ball_and_stick(aspectRatio=10.0, opacity=1.0)
    
    if highlight_indexes is not None:
        mol_at_i_list = list(mol_atom_indexes)
        hi_mol_indexes = [mol_at_i_list.index(e) for e in highlight_indexes if e in mol_at_i_list]
        slab_at_i_list = list(slab_atom_indexes)
        hi_slab_indexes = [slab_at_i_list.index(e) for e in highlight_indexes if e in slab_at_i_list]
        
        viewer.component_0.add_ball_and_stick(selection=hi_mol_indexes, color='green', aspectRatio=2.0, scale=2.0, opacity=0.4)
        viewer.component_1.add_ball_and_stick(selection=hi_slab_indexes, color='green', aspectRatio=11.0, opacity=0.4)
    
    # Orient camera to look from positive z
    cell_z = atoms.cell[2, 2]
    com = atoms.get_center_of_mass()
    def_orientation = viewer._camera_orientation
    top_z_orientation = [1.0, 0.0, 0.0, 0,
                         0.0, 1.0, 0.0, 0,
                         0.0, 0.0, -np.max([cell_z, 30.0]) , 0,
                         -com[0], -com[1], -com[2], 1]
    viewer._set_camera_orientation(top_z_orientation)
    #viewer.camera = 'orthographic'

# Select molecule

The atoms of the molecule need to be ordered in front of the data structure.

In [None]:
# PDOS selection addition and removal
def remove_from_tuple(tup, index):
    tmp_list = list(tup)
    del tmp_list[index]
    return tuple(tmp_list)

def rm_pdos_selection(b):
    rm_index = sel_rm_list.index(b)
    del sel_text_list[rm_index]
    del sel_hl_list[rm_index]
    del sel_rm_list[rm_index]
    
    sel_vbox.children = remove_from_tuple(sel_vbox.children, rm_index-1)

def add_pdos_selection(b):
    sel_text_list.append(ipw.Text(description='pdos selection',
                         style=style, layout={'width': '50%'}))
    hl_btn = ipw.Button(description='highlight', layout={'width': '10%'})
    hl_btn.on_click(highlight)
    rm_btn = ipw.Button(description='remove', layout={'width': '10%'})
    rm_btn.on_click(rm_pdos_selection)
    sel_hl_list.append(hl_btn)
    sel_rm_list.append(rm_btn)
    sel_vbox.children += (ipw.HBox([sel_text_list[-1], hl_btn, rm_btn]), )

In [None]:
def parse_cp2k_selection_string(sel_str):
    split = sel_str.split()
    sel_indexes = []
    for prt in split:
        if '..' in prt:
            i_start = int(prt.split('..')[0])
            i_end = int(prt.split('..')[1]) + 1
            sel_indexes += list(np.arange(i_start, i_end))
        else:
            sel_indexes += [int(prt)]
    return np.array(list(set(sel_indexes))) - 1

In [None]:
def guess_molecule():
    try:
        ase_struct = struct_browser.results.value.get_ase()
        first_slab_atom = np.argwhere( (ase_struct.numbers == 29) |
                                       (ase_struct.numbers == 47) |
                                       (ase_struct.numbers == 79)
                                     )[0, 0] + 1
        mol_selection.value="1..%d" % (first_slab_atom-1)
    except:
        print("Unable to automatically find slab and molecule")

def get_mol_ase():
    mol_inds = parse_cp2k_selection_string(mol_selection.value)
    ase_struct = struct_browser.results.value.get_ase()
    return ase_struct[mol_inds]

def highlight(b):
    global highlight_indexes
    sel_index = sel_hl_list.index(b)
    try:
        highlight_indexes = parse_cp2k_selection_string(sel_text_list[sel_index].value)
    except:
        print('Invalid selection string')
        highlight_indexes = None
    update_view()
    
def reset_state():
    mol_selection.value = "1..10"
    sel_text_list = [mol_selection]
    sel_hl_list = [highlight_btn]
    sel_rm_list = [None]
    sel_vbox = ipw.VBox([])

style = {'description_width': '120px'}
layout = {'width': '60%'}

mol_selection = ipw.Text(description='Molecule selection',
                         value="1..10",
                         style=style, layout={'width': '50%'})

highlight_btn = ipw.Button(description='highlight',
                            layout={'width': '10%'})

highlight_btn.on_click(highlight)

add_sel_btn = ipw.Button(description='Add pdos selection',
                            layout={'width': '15%'})
add_sel_btn.on_click(add_pdos_selection)

sel_text_list = [mol_selection]
sel_hl_list = [highlight_btn]
sel_rm_list = [None]
sel_vbox = ipw.VBox([])

display(ipw.HBox([mol_selection, highlight_btn]), sel_vbox, add_sel_btn)

# Select computer and codes

In [None]:
style = {'description_width': '140px'}
layout = {'width': '70%'}

computer_drop = ComputerDropdown()

def on_computer_change(c):
    global cp2k_codes, eval_morb_codes, overlap_codes
    if computer_drop.selected_computer is not None:
        cp2k_codes = common.comp_plugin_codes(computer_drop.selected_computer.name, 'cp2k')
        overlap_codes = common.comp_plugin_codes(computer_drop.selected_computer.name, 'spm.overlap')

        drop_cp2k.options = [c.label for c in cp2k_codes]
        drop_over.options = [c.label for c in overlap_codes]
    
    
computer_drop._dropdown.observe(on_computer_change)

drop_cp2k = ipw.Dropdown(description="Cp2k code")

drop_over = ipw.Dropdown(description="Overlap code")

on_computer_change(0)

elpa_check = ipw.Checkbox(
    value=True,
    description='use ELPA',
    disabled=False
)

display(computer_drop, drop_cp2k, drop_over, elpa_check)

# Orbital overlap parameters

In [None]:
layout_small = {'width': '30%'}

num_homo_inttext = ipw.IntText(
                    value=4,
                    description="num gas HOMO",
                    style=style, layout=layout_small)
num_lumo_inttext = ipw.IntText(
                    value=4,
                    description="num gas LUMO",
                    style=style, layout=layout_small)

elim_inttext = ipw.IntText(
                    value=4,
                    description="num gas LUMO",
                    style=style, layout=layout_small)

emin_floattext = ipw.BoundedFloatText(
                        description='Slab system Emin (eV)',
                        min=-3.0,
                        max=-0.1,
                        step=0.1,
                        value=-2.0,
                        style=style, layout=layout_small)

emax_floattext = ipw.BoundedFloatText(
                        description='Slab system Emax (eV)',
                        min=0.1,
                        max=3.0,
                        step=0.1,
                        value=2.0,
                        readout_format='%.2f',
                        style=style, layout=layout_small)

display(num_homo_inttext, num_lumo_inttext, emin_floattext, emax_floattext)

In [None]:
def on_submit(b):
    with submit_out:
        clear_output()
        if not struct_browser.results.value:
            print("Please select a structure.")
            return
        if not computer_drop.selected_computer:
            print("Please select a computer.")
            return
        
        if not drop_cp2k.value or not drop_over.value:
            print("Please select all the codes.")
            return
        
        overlap_params = Dict(dict={
            '--cp2k_input_file1':  'parent_slab_folder/aiida.inp',
            '--basis_set_file1':   'parent_slab_folder/BASIS_MOLOPT',
            '--xyz_file1':         'parent_slab_folder/geom.xyz',
            '--wfn_file1':         'parent_slab_folder/aiida-RESTART.wfn',
            '--emin1':             str(emin_floattext.value),
            '--emax1':             str(emax_floattext.value),
            '--cp2k_input_file2':  'parent_mol_folder/aiida.inp',
            '--basis_set_file2':   'parent_mol_folder/BASIS_MOLOPT',
            '--xyz_file2':         'parent_mol_folder/geom.xyz',
            '--wfn_file2':         'parent_mol_folder/aiida-RESTART.wfn',
            '--nhomo2':            str(num_homo_inttext.value),
            '--nlumo2':            str(num_lumo_inttext.value),
            '--output_file':       './overlap.npz',
            '--eval_region':       ['G', 'G', 'G', 'G', 'n-3.0_C', 'p2.0'],
            '--dx':                '0.2',
            '--eval_cutoff':       '14.0'
        })
        
        cp2k_code = cp2k_codes[drop_cp2k.index]
        overlap_code = overlap_codes[drop_over.index]
        
        struct = struct_browser.results.value
        
        ## Try to access the restart-wfn file ##
        selected_comp = cp2k_code.computer
        try:
            wfn_file_path = common.find_struct_wf(struct, selected_comp)
        except:
            wfn_file_path = ""
        if wfn_file_path == "":
            print("Didn't find any accessible .wfn file.")
        
        mol_structure = StructureData(ase=get_mol_ase())
        mol_structure.description = "mol from slab PK%d" %  struct.pk
        
        # 
        aiida_pdos_list = List()
        aiida_pdos_list.extend([sel_text.value for sel_text in sel_text_list])
        
        node = submit(
            PdosWorkChain,
            cp2k_code=cp2k_code,
            slabsys_structure=struct,
            mol_structure=mol_structure,
            pdos_lists=aiida_pdos_list,
            wfn_file_path=Str(wfn_file_path),
            elpa_switch=Bool(elpa_check.value),
            overlap_code=overlap_code,
            overlap_params=overlap_params
        )
                
        # set calculation version; also used to determine post-processing
        node.set_extra("version", 0)
        
        print()
        print("Submitted:")
        print(node)

btn_submit = ipw.Button(description="Submit")
btn_submit.on_click(on_submit)
submit_out = ipw.Output()
display(btn_submit, submit_out)